name_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. package model
  2. import (
  3. "bytes"
  4. "cmp"
  5. "fmt"
  6. "log/slog"
  7. "slices"
  8. "strings"
  9. "testing"
  10. )
  11. type fields struct {
  12. host, namespace, model, tag, build string
  13. digest string
  14. }
  15. func fieldsFromName(p Name) fields {
  16. return fields{
  17. host: p.parts[PartHost],
  18. namespace: p.parts[PartNamespace],
  19. model: p.parts[PartModel],
  20. tag: p.parts[PartTag],
  21. build: p.parts[PartBuild],
  22. digest: p.parts[PartDigest],
  23. }
  24. }
  25. var testNames = map[string]fields{
  26. "mistral:latest": {model: "mistral", tag: "latest"},
  27. "mistral": {model: "mistral"},
  28. "mistral:30B": {model: "mistral", tag: "30B"},
  29. "mistral:7b": {model: "mistral", tag: "7b"},
  30. "mistral:7b+Q4_0": {model: "mistral", tag: "7b", build: "Q4_0"},
  31. "mistral+KQED": {model: "mistral", build: "KQED"},
  32. "mistral.x-3:7b+Q4_0": {model: "mistral.x-3", tag: "7b", build: "Q4_0"},
  33. "mistral:7b+q4_0": {model: "mistral", tag: "7b", build: "q4_0"},
  34. "llama2": {model: "llama2"},
  35. "user/model": {namespace: "user", model: "model"},
  36. "example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"},
  37. "example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
  38. // invalid digest
  39. "mistral:latest@invalid256-": {},
  40. "mistral:latest@-123": {},
  41. "mistral:latest@!-123": {},
  42. "mistral:latest@1-!": {},
  43. "mistral:latest@": {},
  44. // resolved
  45. "x@sha123-1": {model: "x", digest: "sha123-1"},
  46. "@sha456-2": {digest: "sha456-2"},
  47. "@@sha123-1": {},
  48. // preserves case for build
  49. "x+b": {model: "x", build: "b"},
  50. // invalid (includes fuzzing trophies)
  51. " / / : + ": {},
  52. " / : + ": {},
  53. " : + ": {},
  54. " + ": {},
  55. " : ": {},
  56. " / ": {},
  57. " /": {},
  58. "/ ": {},
  59. "/": {},
  60. ":": {},
  61. "+": {},
  62. // (".") in namepsace is not allowed
  63. "invalid.com/7b+x": {},
  64. "invalid:7b+Q4_0:latest": {},
  65. "in valid": {},
  66. "invalid/y/z/foo": {},
  67. "/0": {},
  68. "0 /0": {},
  69. "0 /": {},
  70. "0/": {},
  71. ":/0": {},
  72. "+0/00000": {},
  73. "0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91": {},
  74. "0//0": {},
  75. "m+^^^": {},
  76. "file:///etc/passwd": {},
  77. "file:///etc/passwd:latest": {},
  78. "file:///etc/passwd:latest+u": {},
  79. ":x": {},
  80. "+x": {},
  81. "x+": {},
  82. // Disallow ("\.+") in any part to prevent path traversal anywhere
  83. // we convert the name to a path.
  84. "../etc/passwd": {},
  85. ".../etc/passwd": {},
  86. "./../passwd": {},
  87. "./0+..": {},
  88. strings.Repeat("a", MaxNamePartLen): {model: strings.Repeat("a", MaxNamePartLen)},
  89. strings.Repeat("a", MaxNamePartLen+1): {},
  90. }
  91. // TestConsecutiveDots tests that consecutive dots are not allowed in any
  92. // part, to avoid path traversal. There also are some tests in testNames, but
  93. // this test is more exhaustive and exists to emphasize the importance of
  94. // preventing path traversal.
  95. func TestNameConsecutiveDots(t *testing.T) {
  96. for i := 1; i < 10; i++ {
  97. s := strings.Repeat(".", i)
  98. if i > 1 {
  99. if g := ParseNameFill(s, "").String(); g != "" {
  100. t.Errorf("ParseName(%q) = %q; want empty string", s, g)
  101. }
  102. } else {
  103. if g := ParseNameFill(s, "").String(); g != s {
  104. t.Errorf("ParseName(%q) = %q; want %q", s, g, s)
  105. }
  106. }
  107. }
  108. }
  109. func TestNameParts(t *testing.T) {
  110. var p Name
  111. if w, g := int(PartDigest+1), len(p.Parts()); w != g {
  112. t.Errorf("Parts() = %d; want %d", g, w)
  113. }
  114. }
  115. func TestNamePartString(t *testing.T) {
  116. if g := PartKind(-2).String(); g != "Unknown" {
  117. t.Errorf("Unknown part = %q; want %q", g, "Unknown")
  118. }
  119. for kind, name := range kindNames {
  120. if g := kind.String(); g != name {
  121. t.Errorf("%s = %q; want %q", kind, g, name)
  122. }
  123. }
  124. }
  125. func TestParseName(t *testing.T) {
  126. for baseName, want := range testNames {
  127. for _, prefix := range []string{"", "https://", "http://"} {
  128. // We should get the same results with or without the
  129. // http(s) prefixes
  130. s := prefix + baseName
  131. t.Run(s, func(t *testing.T) {
  132. name := ParseNameFill(s, "")
  133. got := fieldsFromName(name)
  134. if got != want {
  135. t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
  136. }
  137. // test round-trip
  138. if !ParseNameFill(name.String(), "").EqualFold(name) {
  139. t.Errorf("ParseName(%q).String() = %s; want %s", s, name.String(), baseName)
  140. }
  141. })
  142. }
  143. }
  144. }
  145. func TestCompleteWithAndWithoutBuild(t *testing.T) {
  146. cases := []struct {
  147. in string
  148. complete bool
  149. completeNoBuild bool
  150. }{
  151. {"", false, false},
  152. {"incomplete/mistral:7b+x", false, false},
  153. {"incomplete/mistral:7b+Q4_0", false, false},
  154. {"incomplete:7b+x", false, false},
  155. {"complete.com/x/mistral:latest+Q4_0", true, true},
  156. {"complete.com/x/mistral:latest", false, true},
  157. }
  158. for _, tt := range cases {
  159. t.Run(tt.in, func(t *testing.T) {
  160. p := ParseNameFill(tt.in, "")
  161. t.Logf("ParseName(%q) = %#v", tt.in, p)
  162. if g := p.IsComplete(); g != tt.complete {
  163. t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete)
  164. }
  165. if g := p.IsCompleteNoBuild(); g != tt.completeNoBuild {
  166. t.Errorf("CompleteNoBuild(%q) = %v; want %v", tt.in, g, tt.completeNoBuild)
  167. }
  168. })
  169. }
  170. // Complete uses Parts which returns a slice, but it should be
  171. // inlined when used in Complete, preventing any allocations or
  172. // escaping to the heap.
  173. allocs := testing.AllocsPerRun(1000, func() {
  174. keep(ParseNameFill("complete.com/x/mistral:latest+Q4_0", "").IsComplete())
  175. })
  176. if allocs > 0 {
  177. t.Errorf("Complete allocs = %v; want 0", allocs)
  178. }
  179. }
  180. func TestNameLogValue(t *testing.T) {
  181. cases := []string{
  182. "example.com/library/mistral:latest+Q4_0",
  183. "mistral:latest",
  184. "mistral:7b+Q4_0",
  185. }
  186. for _, s := range cases {
  187. t.Run(s, func(t *testing.T) {
  188. var b bytes.Buffer
  189. log := slog.New(slog.NewTextHandler(&b, nil))
  190. name := ParseNameFill(s, "")
  191. log.Info("", "name", name)
  192. want := fmt.Sprintf("name=%s", name.GoString())
  193. got := b.String()
  194. if !strings.Contains(got, want) {
  195. t.Errorf("expected log output to contain %q; got %q", want, got)
  196. }
  197. })
  198. }
  199. }
  200. func TestNameGoString(t *testing.T) {
  201. cases := []struct {
  202. name string
  203. in string
  204. wantString string
  205. wantGoString string // default is tt.in
  206. }{
  207. {
  208. name: "Complete Name",
  209. in: "example.com/library/mistral:latest+Q4_0",
  210. wantGoString: "example.com/library/mistral:latest+Q4_0@?",
  211. },
  212. {
  213. name: "Short Name",
  214. in: "mistral:latest",
  215. wantGoString: "?/?/mistral:latest+?@?",
  216. },
  217. {
  218. name: "Long Name",
  219. in: "library/mistral:latest",
  220. wantGoString: "?/library/mistral:latest+?@?",
  221. },
  222. {
  223. name: "Case Preserved",
  224. in: "Library/Mistral:Latest",
  225. wantGoString: "?/Library/Mistral:Latest+?@?",
  226. },
  227. {
  228. name: "With digest",
  229. in: "Library/Mistral:Latest@sha256-123456",
  230. wantGoString: "?/Library/Mistral:Latest+?@sha256-123456",
  231. },
  232. }
  233. for _, tt := range cases {
  234. t.Run(tt.name, func(t *testing.T) {
  235. p := ParseNameFill(tt.in, "")
  236. tt.wantGoString = cmp.Or(tt.wantGoString, tt.in)
  237. if g := fmt.Sprintf("%#v", p); g != tt.wantGoString {
  238. t.Errorf("GoString() = %q; want %q", g, tt.wantGoString)
  239. }
  240. })
  241. }
  242. }
  243. func TestDisplayShortest(t *testing.T) {
  244. cases := []struct {
  245. in string
  246. mask string
  247. want string
  248. wantPanic bool
  249. }{
  250. {"example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
  251. {"example.com/library/mistral:latest+Q4_0", "example.com/_/_:latest", "library/mistral", false},
  252. {"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false},
  253. {"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false},
  254. // case-insensitive
  255. {"Example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
  256. {"example.com/Library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
  257. {"example.com/library/Mistral:latest+Q4_0", "example.com/library/_:latest", "Mistral", false},
  258. {"example.com/library/mistral:Latest+Q4_0", "example.com/library/_:latest", "mistral", false},
  259. {"example.com/library/mistral:Latest+q4_0", "example.com/library/_:latest", "mistral", false},
  260. // invalid mask
  261. {"example.com/library/mistral:latest+Q4_0", "example.com/mistral", "", true},
  262. // DefaultMask
  263. {"registry.ollama.ai/library/mistral:latest+Q4_0", DefaultMask, "mistral", false},
  264. // Auto-Fill
  265. {"x", "example.com/library/_:latest", "x", false},
  266. {"x", "example.com/library/_:latest+Q4_0", "x", false},
  267. {"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false},
  268. {"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false},
  269. }
  270. for _, tt := range cases {
  271. t.Run("", func(t *testing.T) {
  272. defer func() {
  273. if tt.wantPanic {
  274. if recover() == nil {
  275. t.Errorf("expected panic")
  276. }
  277. }
  278. }()
  279. p := ParseNameFill(tt.in, "")
  280. t.Logf("ParseName(%q) = %#v", tt.in, p)
  281. if g := p.DisplayShortest(tt.mask); g != tt.want {
  282. t.Errorf("got = %q; want %q", g, tt.want)
  283. }
  284. })
  285. }
  286. }
  287. func TestParseNameAllocs(t *testing.T) {
  288. allocs := testing.AllocsPerRun(1000, func() {
  289. keep(ParseNameFill("example.com/mistral:7b+Q4_0", ""))
  290. })
  291. if allocs > 0 {
  292. t.Errorf("ParseName allocs = %v; want 0", allocs)
  293. }
  294. }
  295. func BenchmarkParseName(b *testing.B) {
  296. b.ReportAllocs()
  297. for range b.N {
  298. keep(ParseNameFill("example.com/mistral:7b+Q4_0", ""))
  299. }
  300. }
  301. func FuzzParseName(f *testing.F) {
  302. f.Add("example.com/mistral:7b+Q4_0")
  303. f.Add("example.com/mistral:7b+q4_0")
  304. f.Add("example.com/mistral:7b+x")
  305. f.Add("x/y/z:8n+I")
  306. f.Add(":x")
  307. f.Add("@sha256-123456")
  308. f.Add("example.com/mistral:latest+Q4_0@sha256-123456")
  309. f.Add(":@!@")
  310. f.Add("...")
  311. f.Fuzz(func(t *testing.T, s string) {
  312. r0 := ParseNameFill(s, "")
  313. if strings.Contains(s, "..") && !r0.IsZero() {
  314. t.Fatalf("non-zero value for path with '..': %q", s)
  315. }
  316. if !r0.IsValid() && !r0.IsResolved() {
  317. if !r0.EqualFold(Name{}) {
  318. t.Errorf("expected invalid path to be zero value; got %#v", r0)
  319. }
  320. t.Skipf("invalid path: %q", s)
  321. }
  322. for _, p := range r0.Parts() {
  323. if len(p) > MaxNamePartLen {
  324. t.Errorf("part too long: %q", p)
  325. }
  326. }
  327. if !strings.EqualFold(r0.String(), s) {
  328. t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.String(), s)
  329. }
  330. r1 := ParseNameFill(r0.String(), "")
  331. if !r0.EqualFold(r1) {
  332. t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
  333. }
  334. })
  335. }
  336. func TestFill(t *testing.T) {
  337. cases := []struct {
  338. dst string
  339. src string
  340. want string
  341. }{
  342. {"mistral", "o.com/library/PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
  343. {"o.com/library/mistral", "PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
  344. {"", "o.com/library/mistral:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
  345. }
  346. for _, tt := range cases {
  347. t.Run(tt.dst, func(t *testing.T) {
  348. r := Fill(ParseNameFill(tt.dst, ""), ParseNameFill(tt.src, ""))
  349. if r.String() != tt.want {
  350. t.Errorf("Fill(%q, %q) = %q; want %q", tt.dst, tt.src, r, tt.want)
  351. }
  352. })
  353. }
  354. }
  355. func TestNameStringAllocs(t *testing.T) {
  356. name := ParseNameFill("example.com/ns/mistral:latest+Q4_0", "")
  357. allocs := testing.AllocsPerRun(1000, func() {
  358. keep(name.String())
  359. })
  360. if allocs > 1 {
  361. t.Errorf("String allocs = %v; want 0", allocs)
  362. }
  363. }
  364. func ExampleFill() {
  365. defaults := ParseNameFill("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0", "")
  366. r := Fill(ParseNameFill("mistral", ""), defaults)
  367. fmt.Println(r)
  368. // Output:
  369. // registry.ollama.com/library/mistral:latest+Q4_0
  370. }
  371. func ExampleName_MapHash() {
  372. m := map[uint64]bool{}
  373. // key 1
  374. m[ParseNameFill("mistral:latest+q4", "").MapHash()] = true
  375. m[ParseNameFill("miSTRal:latest+Q4", "").MapHash()] = true
  376. m[ParseNameFill("mistral:LATest+Q4", "").MapHash()] = true
  377. // key 2
  378. m[ParseNameFill("mistral:LATest", "").MapHash()] = true
  379. fmt.Println(len(m))
  380. // Output:
  381. // 2
  382. }
  383. func ExampleName_CompareFold_sort() {
  384. names := []Name{
  385. ParseNameFill("mistral:latest", ""),
  386. ParseNameFill("mistRal:7b+q4", ""),
  387. ParseNameFill("MIstral:7b", ""),
  388. }
  389. slices.SortFunc(names, Name.CompareFold)
  390. for _, n := range names {
  391. fmt.Println(n)
  392. }
  393. // Output:
  394. // MIstral:7b
  395. // mistRal:7b+q4
  396. // mistral:latest
  397. }
  398. func ExampleName_completeAndResolved() {
  399. for _, s := range []string{
  400. "x/y/z:latest+q4_0@sha123-1",
  401. "x/y/z:latest+q4_0",
  402. "@sha123-1",
  403. } {
  404. name := ParseNameFill(s, "")
  405. fmt.Printf("complete:%v resolved:%v digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest())
  406. }
  407. // Output:
  408. // complete:true resolved:true digest:sha123-1
  409. // complete:true resolved:false digest:
  410. // complete:false resolved:true digest:sha123-1
  411. }
  412. func ExampleName_DisplayShortest() {
  413. name := ParseNameFill("example.com/jmorganca/mistral:latest+Q4_0", "")
  414. fmt.Println(name.DisplayShortest("example.com/jmorganca/_:latest"))
  415. fmt.Println(name.DisplayShortest("example.com/_/_:latest"))
  416. fmt.Println(name.DisplayShortest("example.com/_/_:_"))
  417. fmt.Println(name.DisplayShortest("_/_/_:_"))
  418. // Default
  419. name = ParseNameFill("registry.ollama.ai/library/mistral:latest+Q4_0", "")
  420. fmt.Println(name.DisplayShortest(""))
  421. // Output:
  422. // mistral
  423. // jmorganca/mistral
  424. // jmorganca/mistral:latest
  425. // example.com/jmorganca/mistral:latest
  426. // mistral
  427. }
  428. func keep[T any](v T) T { return v }