routes_test.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/binary"
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "io/fs"
  10. "math"
  11. "math/rand/v2"
  12. "net"
  13. "net/http"
  14. "net/http/httptest"
  15. "os"
  16. "path/filepath"
  17. "sort"
  18. "strings"
  19. "testing"
  20. "unicode"
  21. "github.com/ollama/ollama/api"
  22. "github.com/ollama/ollama/fs/ggml"
  23. "github.com/ollama/ollama/openai"
  24. "github.com/ollama/ollama/server/internal/cache/blob"
  25. "github.com/ollama/ollama/server/internal/client/ollama"
  26. "github.com/ollama/ollama/types/model"
  27. "github.com/ollama/ollama/version"
  28. )
  29. func createTestFile(t *testing.T, name string) (string, string) {
  30. t.Helper()
  31. modelDir := os.Getenv("OLLAMA_MODELS")
  32. if modelDir == "" {
  33. t.Fatalf("OLLAMA_MODELS not specified")
  34. }
  35. f, err := os.CreateTemp(t.TempDir(), name)
  36. if err != nil {
  37. t.Fatalf("failed to create temp file: %v", err)
  38. }
  39. defer f.Close()
  40. err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
  41. if err != nil {
  42. t.Fatalf("failed to write to file: %v", err)
  43. }
  44. err = binary.Write(f, binary.LittleEndian, uint32(3))
  45. if err != nil {
  46. t.Fatalf("failed to write to file: %v", err)
  47. }
  48. err = binary.Write(f, binary.LittleEndian, uint64(0))
  49. if err != nil {
  50. t.Fatalf("failed to write to file: %v", err)
  51. }
  52. err = binary.Write(f, binary.LittleEndian, uint64(0))
  53. if err != nil {
  54. t.Fatalf("failed to write to file: %v", err)
  55. }
  56. // Calculate sha256 sum of file
  57. if _, err := f.Seek(0, 0); err != nil {
  58. t.Fatal(err)
  59. }
  60. digest, _ := GetSHA256Digest(f)
  61. if err := f.Close(); err != nil {
  62. t.Fatal(err)
  63. }
  64. if err := createLink(f.Name(), filepath.Join(modelDir, "blobs", fmt.Sprintf("sha256-%s", strings.TrimPrefix(digest, "sha256:")))); err != nil {
  65. t.Fatal(err)
  66. }
  67. return f.Name(), digest
  68. }
  69. // equalStringSlices checks if two slices of strings are equal.
  70. func equalStringSlices(a, b []string) bool {
  71. if len(a) != len(b) {
  72. return false
  73. }
  74. for i := range a {
  75. if a[i] != b[i] {
  76. return false
  77. }
  78. }
  79. return true
  80. }
  81. type panicTransport struct{}
  82. func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
  83. panic("unexpected RoundTrip call")
  84. }
  85. var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
  86. func TestRoutes(t *testing.T) {
  87. type testCase struct {
  88. Name string
  89. Method string
  90. Path string
  91. Setup func(t *testing.T, req *http.Request)
  92. Expected func(t *testing.T, resp *http.Response)
  93. }
  94. createTestModel := func(t *testing.T, name string) {
  95. t.Helper()
  96. _, digest := createTestFile(t, "ollama-model")
  97. fn := func(resp api.ProgressResponse) {
  98. t.Logf("Status: %s", resp.Status)
  99. }
  100. r := api.CreateRequest{
  101. Name: name,
  102. Files: map[string]string{"test.gguf": digest},
  103. Parameters: map[string]any{
  104. "seed": 42,
  105. "top_p": 0.9,
  106. "stop": []string{"foo", "bar"},
  107. },
  108. }
  109. modelName := model.ParseName(name)
  110. baseLayers, err := ggufLayers(digest, fn)
  111. if err != nil {
  112. t.Fatalf("failed to create model: %v", err)
  113. }
  114. if err := createModel(r, modelName, baseLayers, fn); err != nil {
  115. t.Fatal(err)
  116. }
  117. }
  118. testCases := []testCase{
  119. {
  120. Name: "Version Handler",
  121. Method: http.MethodGet,
  122. Path: "/api/version",
  123. Setup: func(t *testing.T, req *http.Request) {
  124. },
  125. Expected: func(t *testing.T, resp *http.Response) {
  126. contentType := resp.Header.Get("Content-Type")
  127. if contentType != "application/json; charset=utf-8" {
  128. t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
  129. }
  130. body, err := io.ReadAll(resp.Body)
  131. if err != nil {
  132. t.Fatalf("failed to read response body: %v", err)
  133. }
  134. expectedBody := fmt.Sprintf(`{"version":"%s"}`, version.Version)
  135. if string(body) != expectedBody {
  136. t.Errorf("expected body %s, got %s", expectedBody, string(body))
  137. }
  138. },
  139. },
  140. {
  141. Name: "Tags Handler (no tags)",
  142. Method: http.MethodGet,
  143. Path: "/api/tags",
  144. Expected: func(t *testing.T, resp *http.Response) {
  145. contentType := resp.Header.Get("Content-Type")
  146. if contentType != "application/json; charset=utf-8" {
  147. t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
  148. }
  149. body, err := io.ReadAll(resp.Body)
  150. if err != nil {
  151. t.Fatalf("failed to read response body: %v", err)
  152. }
  153. var modelList api.ListResponse
  154. err = json.Unmarshal(body, &modelList)
  155. if err != nil {
  156. t.Fatalf("failed to unmarshal response body: %v", err)
  157. }
  158. if modelList.Models == nil || len(modelList.Models) != 0 {
  159. t.Errorf("expected empty model list, got %v", modelList.Models)
  160. }
  161. },
  162. },
  163. {
  164. Name: "openai empty list",
  165. Method: http.MethodGet,
  166. Path: "/v1/models",
  167. Expected: func(t *testing.T, resp *http.Response) {
  168. contentType := resp.Header.Get("Content-Type")
  169. if contentType != "application/json" {
  170. t.Errorf("expected content type application/json, got %s", contentType)
  171. }
  172. body, err := io.ReadAll(resp.Body)
  173. if err != nil {
  174. t.Fatalf("failed to read response body: %v", err)
  175. }
  176. var modelList openai.ListCompletion
  177. err = json.Unmarshal(body, &modelList)
  178. if err != nil {
  179. t.Fatalf("failed to unmarshal response body: %v", err)
  180. }
  181. if modelList.Object != "list" || len(modelList.Data) != 0 {
  182. t.Errorf("expected empty model list, got %v", modelList.Data)
  183. }
  184. },
  185. },
  186. {
  187. Name: "Tags Handler (yes tags)",
  188. Method: http.MethodGet,
  189. Path: "/api/tags",
  190. Setup: func(t *testing.T, req *http.Request) {
  191. createTestModel(t, "test-model")
  192. },
  193. Expected: func(t *testing.T, resp *http.Response) {
  194. contentType := resp.Header.Get("Content-Type")
  195. if contentType != "application/json; charset=utf-8" {
  196. t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
  197. }
  198. body, err := io.ReadAll(resp.Body)
  199. if err != nil {
  200. t.Fatalf("failed to read response body: %v", err)
  201. }
  202. if strings.Contains(string(body), "expires_at") {
  203. t.Errorf("response body should not contain 'expires_at'")
  204. }
  205. var modelList api.ListResponse
  206. err = json.Unmarshal(body, &modelList)
  207. if err != nil {
  208. t.Fatalf("failed to unmarshal response body: %v", err)
  209. }
  210. if len(modelList.Models) != 1 || modelList.Models[0].Name != "test-model:latest" {
  211. t.Errorf("expected model 'test-model:latest', got %v", modelList.Models)
  212. }
  213. },
  214. },
  215. {
  216. Name: "Delete Model Handler",
  217. Method: http.MethodDelete,
  218. Path: "/api/delete",
  219. Setup: func(t *testing.T, req *http.Request) {
  220. createTestModel(t, "model_to_delete")
  221. deleteReq := api.DeleteRequest{
  222. Name: "model_to_delete",
  223. }
  224. jsonData, err := json.Marshal(deleteReq)
  225. if err != nil {
  226. t.Fatalf("failed to marshal delete request: %v", err)
  227. }
  228. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  229. },
  230. Expected: func(t *testing.T, resp *http.Response) {
  231. if resp.StatusCode != http.StatusOK {
  232. t.Errorf("expected status code 200, got %d", resp.StatusCode)
  233. }
  234. // Verify the model was deleted
  235. _, err := GetModel("model-to-delete")
  236. if err == nil || !os.IsNotExist(err) {
  237. t.Errorf("expected model to be deleted, got error %v", err)
  238. }
  239. },
  240. },
  241. {
  242. Name: "Delete Non-existent Model",
  243. Method: http.MethodDelete,
  244. Path: "/api/delete",
  245. Setup: func(t *testing.T, req *http.Request) {
  246. deleteReq := api.DeleteRequest{
  247. Name: "non_existent_model",
  248. }
  249. jsonData, err := json.Marshal(deleteReq)
  250. if err != nil {
  251. t.Fatalf("failed to marshal delete request: %v", err)
  252. }
  253. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  254. },
  255. Expected: func(t *testing.T, resp *http.Response) {
  256. if resp.StatusCode != http.StatusNotFound {
  257. t.Errorf("expected status code 404, got %d", resp.StatusCode)
  258. }
  259. body, err := io.ReadAll(resp.Body)
  260. if err != nil {
  261. t.Fatalf("failed to read response body: %v", err)
  262. }
  263. var errorResp map[string]string
  264. err = json.Unmarshal(body, &errorResp)
  265. if err != nil {
  266. t.Fatalf("failed to unmarshal response body: %v", err)
  267. }
  268. if !strings.Contains(errorResp["error"], "not found") {
  269. t.Errorf("expected error message to contain 'not found', got %s", errorResp["error"])
  270. }
  271. },
  272. },
  273. {
  274. Name: "openai list models with tags",
  275. Method: http.MethodGet,
  276. Path: "/v1/models",
  277. Expected: func(t *testing.T, resp *http.Response) {
  278. contentType := resp.Header.Get("Content-Type")
  279. if contentType != "application/json" {
  280. t.Errorf("expected content type application/json, got %s", contentType)
  281. }
  282. body, err := io.ReadAll(resp.Body)
  283. if err != nil {
  284. t.Fatalf("failed to read response body: %v", err)
  285. }
  286. var modelList openai.ListCompletion
  287. err = json.Unmarshal(body, &modelList)
  288. if err != nil {
  289. t.Fatalf("failed to unmarshal response body: %v", err)
  290. }
  291. if len(modelList.Data) != 1 || modelList.Data[0].Id != "test-model:latest" || modelList.Data[0].OwnedBy != "library" {
  292. t.Errorf("expected model 'test-model:latest' owned by 'library', got %v", modelList.Data)
  293. }
  294. },
  295. },
  296. {
  297. Name: "Create Model Handler",
  298. Method: http.MethodPost,
  299. Path: "/api/create",
  300. Setup: func(t *testing.T, req *http.Request) {
  301. _, digest := createTestFile(t, "ollama-model")
  302. stream := false
  303. createReq := api.CreateRequest{
  304. Name: "t-bone",
  305. Files: map[string]string{"test.gguf": digest},
  306. Stream: &stream,
  307. }
  308. jsonData, err := json.Marshal(createReq)
  309. if err != nil {
  310. t.Fatalf("failed to marshal create request: %v", err)
  311. }
  312. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  313. },
  314. Expected: func(t *testing.T, resp *http.Response) {
  315. contentType := resp.Header.Get("Content-Type")
  316. if contentType != "application/json" {
  317. t.Errorf("expected content type application/json, got %s", contentType)
  318. }
  319. _, err := io.ReadAll(resp.Body)
  320. if err != nil {
  321. t.Fatalf("failed to read response body: %v", err)
  322. }
  323. if resp.StatusCode != http.StatusOK { // Updated line
  324. t.Errorf("expected status code 200, got %d", resp.StatusCode)
  325. }
  326. model, err := GetModel("t-bone")
  327. if err != nil {
  328. t.Fatalf("failed to get model: %v", err)
  329. }
  330. if model.ShortName != "t-bone:latest" {
  331. t.Errorf("expected model name 't-bone:latest', got %s", model.ShortName)
  332. }
  333. },
  334. },
  335. {
  336. Name: "Copy Model Handler",
  337. Method: http.MethodPost,
  338. Path: "/api/copy",
  339. Setup: func(t *testing.T, req *http.Request) {
  340. createTestModel(t, "hamshank")
  341. copyReq := api.CopyRequest{
  342. Source: "hamshank",
  343. Destination: "beefsteak",
  344. }
  345. jsonData, err := json.Marshal(copyReq)
  346. if err != nil {
  347. t.Fatalf("failed to marshal copy request: %v", err)
  348. }
  349. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  350. },
  351. Expected: func(t *testing.T, resp *http.Response) {
  352. model, err := GetModel("beefsteak")
  353. if err != nil {
  354. t.Fatalf("failed to get model: %v", err)
  355. }
  356. if model.ShortName != "beefsteak:latest" {
  357. t.Errorf("expected model name 'beefsteak:latest', got %s", model.ShortName)
  358. }
  359. },
  360. },
  361. {
  362. Name: "Show Model Handler",
  363. Method: http.MethodPost,
  364. Path: "/api/show",
  365. Setup: func(t *testing.T, req *http.Request) {
  366. createTestModel(t, "show-model")
  367. showReq := api.ShowRequest{Model: "show-model"}
  368. jsonData, err := json.Marshal(showReq)
  369. if err != nil {
  370. t.Fatalf("failed to marshal show request: %v", err)
  371. }
  372. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  373. },
  374. Expected: func(t *testing.T, resp *http.Response) {
  375. contentType := resp.Header.Get("Content-Type")
  376. if contentType != "application/json; charset=utf-8" {
  377. t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
  378. }
  379. body, err := io.ReadAll(resp.Body)
  380. if err != nil {
  381. t.Fatalf("failed to read response body: %v", err)
  382. }
  383. var showResp api.ShowResponse
  384. err = json.Unmarshal(body, &showResp)
  385. if err != nil {
  386. t.Fatalf("failed to unmarshal response body: %v", err)
  387. }
  388. var params []string
  389. paramsSplit := strings.Split(showResp.Parameters, "\n")
  390. for _, p := range paramsSplit {
  391. params = append(params, strings.Join(strings.Fields(p), " "))
  392. }
  393. sort.Strings(params)
  394. expectedParams := []string{
  395. "seed 42",
  396. "stop \"bar\"",
  397. "stop \"foo\"",
  398. "top_p 0.9",
  399. }
  400. if !equalStringSlices(params, expectedParams) {
  401. t.Errorf("expected parameters %v, got %v", expectedParams, params)
  402. }
  403. paramCount, ok := showResp.ModelInfo["general.parameter_count"].(float64)
  404. if !ok {
  405. t.Fatalf("expected parameter count to be a float64, got %T", showResp.ModelInfo["general.parameter_count"])
  406. }
  407. if math.Abs(paramCount) > 1e-9 {
  408. t.Errorf("expected parameter count to be 0, got %f", paramCount)
  409. }
  410. },
  411. },
  412. {
  413. Name: "openai retrieve model handler",
  414. Setup: func(t *testing.T, req *http.Request) {
  415. createTestModel(t, "show-model")
  416. },
  417. Method: http.MethodGet,
  418. Path: "/v1/models/show-model",
  419. Expected: func(t *testing.T, resp *http.Response) {
  420. contentType := resp.Header.Get("Content-Type")
  421. if contentType != "application/json" {
  422. t.Errorf("expected content type application/json, got %s", contentType)
  423. }
  424. body, err := io.ReadAll(resp.Body)
  425. if err != nil {
  426. t.Fatalf("failed to read response body: %v", err)
  427. }
  428. var retrieveResp api.RetrieveModelResponse
  429. err = json.Unmarshal(body, &retrieveResp)
  430. if err != nil {
  431. t.Fatalf("failed to unmarshal response body: %v", err)
  432. }
  433. if retrieveResp.Id != "show-model" || retrieveResp.OwnedBy != "library" {
  434. t.Errorf("expected model 'show-model' owned by 'library', got %v", retrieveResp)
  435. }
  436. },
  437. },
  438. }
  439. modelsDir := t.TempDir()
  440. t.Setenv("OLLAMA_MODELS", modelsDir)
  441. c, err := blob.Open(modelsDir)
  442. if err != nil {
  443. t.Fatalf("failed to open models dir: %v", err)
  444. }
  445. rc := &ollama.Registry{
  446. // This is a temporary measure to allow us to move forward,
  447. // surfacing any code contacting ollama.com we do not intended
  448. // to.
  449. //
  450. // Currently, this only handles DELETE /api/delete, which
  451. // should not make any contact with the ollama.com registry, so
  452. // be clear about that.
  453. //
  454. // Tests that do need to contact the registry here, will be
  455. // consumed into our new server/api code packages and removed
  456. // from here.
  457. HTTPClient: panicOnRoundTrip,
  458. }
  459. s := &Server{}
  460. router, err := s.GenerateRoutes(c, rc)
  461. if err != nil {
  462. t.Fatalf("failed to generate routes: %v", err)
  463. }
  464. httpSrv := httptest.NewServer(router)
  465. t.Cleanup(httpSrv.Close)
  466. for _, tc := range testCases {
  467. t.Run(tc.Name, func(t *testing.T) {
  468. u := httpSrv.URL + tc.Path
  469. req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
  470. if err != nil {
  471. t.Fatalf("failed to create request: %v", err)
  472. }
  473. if tc.Setup != nil {
  474. tc.Setup(t, req)
  475. }
  476. resp, err := httpSrv.Client().Do(req)
  477. if err != nil {
  478. t.Fatalf("failed to do request: %v", err)
  479. }
  480. defer resp.Body.Close()
  481. if tc.Expected != nil {
  482. tc.Expected(t, resp)
  483. }
  484. })
  485. }
  486. }
  487. func casingShuffle(s string) string {
  488. rr := []rune(s)
  489. for i := range rr {
  490. if rand.N(2) == 0 {
  491. rr[i] = unicode.ToUpper(rr[i])
  492. } else {
  493. rr[i] = unicode.ToLower(rr[i])
  494. }
  495. }
  496. return string(rr)
  497. }
  498. func TestManifestCaseSensitivity(t *testing.T) {
  499. t.Setenv("OLLAMA_MODELS", t.TempDir())
  500. r := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  501. w.WriteHeader(http.StatusOK)
  502. io.WriteString(w, `{}`) //nolint:errcheck
  503. }))
  504. defer r.Close()
  505. nameUsed := make(map[string]bool)
  506. name := func() string {
  507. const fqmn = "example/namespace/model:tag"
  508. for {
  509. v := casingShuffle(fqmn)
  510. if nameUsed[v] {
  511. continue
  512. }
  513. nameUsed[v] = true
  514. return v
  515. }
  516. }
  517. wantStableName := name()
  518. t.Logf("stable name: %s", wantStableName)
  519. // checkManifestList tests that there is strictly one manifest in the
  520. // models directory, and that the manifest is for the model under test.
  521. checkManifestList := func() {
  522. t.Helper()
  523. mandir := filepath.Join(os.Getenv("OLLAMA_MODELS"), "manifests/")
  524. var entries []string
  525. t.Logf("dir entries:")
  526. fsys := os.DirFS(mandir)
  527. err := fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
  528. if err != nil {
  529. return err
  530. }
  531. t.Logf(" %s", fs.FormatDirEntry(info))
  532. if info.IsDir() {
  533. return nil
  534. }
  535. path = strings.TrimPrefix(path, mandir)
  536. entries = append(entries, path)
  537. return nil
  538. })
  539. if err != nil {
  540. t.Fatalf("failed to walk directory: %v", err)
  541. }
  542. if len(entries) != 1 {
  543. t.Errorf("len(got) = %d, want 1", len(entries))
  544. return // do not use Fatal so following steps run
  545. }
  546. g := entries[0] // raw path
  547. g = filepath.ToSlash(g)
  548. w := model.ParseName(wantStableName).Filepath()
  549. w = filepath.ToSlash(w)
  550. if g != w {
  551. t.Errorf("\ngot: %s\nwant: %s", g, w)
  552. }
  553. }
  554. checkOK := func(w *httptest.ResponseRecorder) {
  555. t.Helper()
  556. if w.Code != http.StatusOK {
  557. t.Errorf("code = %d, want 200", w.Code)
  558. t.Logf("body: %s", w.Body.String())
  559. }
  560. }
  561. var s Server
  562. testMakeRequestDialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
  563. var d net.Dialer
  564. return d.DialContext(ctx, "tcp", r.Listener.Addr().String())
  565. }
  566. t.Cleanup(func() { testMakeRequestDialContext = nil })
  567. t.Logf("creating")
  568. _, digest := createBinFile(t, nil, nil)
  569. checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
  570. // Start with the stable name, and later use a case-shuffled
  571. // version.
  572. Name: wantStableName,
  573. Files: map[string]string{"test.gguf": digest},
  574. Stream: &stream,
  575. }))
  576. checkManifestList()
  577. t.Logf("creating (again)")
  578. checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
  579. Name: name(),
  580. Files: map[string]string{"test.gguf": digest},
  581. Stream: &stream,
  582. }))
  583. checkManifestList()
  584. t.Logf("pulling")
  585. checkOK(createRequest(t, s.PullHandler, api.PullRequest{
  586. Name: name(),
  587. Stream: &stream,
  588. Insecure: true,
  589. }))
  590. checkManifestList()
  591. t.Logf("copying")
  592. checkOK(createRequest(t, s.CopyHandler, api.CopyRequest{
  593. Source: name(),
  594. Destination: name(),
  595. }))
  596. checkManifestList()
  597. t.Logf("pushing")
  598. rr := createRequest(t, s.PushHandler, api.PushRequest{
  599. Model: name(),
  600. Insecure: true,
  601. Username: "alice",
  602. Password: "x",
  603. })
  604. checkOK(rr)
  605. if !strings.Contains(rr.Body.String(), `"status":"success"`) {
  606. t.Errorf("got = %q, want success", rr.Body.String())
  607. }
  608. }
  609. func TestShow(t *testing.T) {
  610. t.Setenv("OLLAMA_MODELS", t.TempDir())
  611. var s Server
  612. _, digest1 := createBinFile(t, ggml.KV{"general.architecture": "test"}, nil)
  613. _, digest2 := createBinFile(t, ggml.KV{"general.type": "projector", "general.architecture": "clip"}, nil)
  614. createRequest(t, s.CreateHandler, api.CreateRequest{
  615. Name: "show-model",
  616. Files: map[string]string{"model.gguf": digest1, "projector.gguf": digest2},
  617. })
  618. w := createRequest(t, s.ShowHandler, api.ShowRequest{
  619. Name: "show-model",
  620. })
  621. if w.Code != http.StatusOK {
  622. t.Fatalf("expected status code 200, actual %d", w.Code)
  623. }
  624. var resp api.ShowResponse
  625. if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
  626. t.Fatal(err)
  627. }
  628. if resp.ModelInfo["general.architecture"] != "test" {
  629. t.Fatal("Expected model architecture to be 'test', but got", resp.ModelInfo["general.architecture"])
  630. }
  631. if resp.ProjectorInfo["general.architecture"] != "clip" {
  632. t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
  633. }
  634. }
  635. func TestNormalize(t *testing.T) {
  636. type testCase struct {
  637. input []float32
  638. }
  639. testCases := []testCase{
  640. {input: []float32{1}},
  641. {input: []float32{0, 1, 2, 3}},
  642. {input: []float32{0.1, 0.2, 0.3}},
  643. {input: []float32{-0.1, 0.2, 0.3, -0.4}},
  644. {input: []float32{0, 0, 0}},
  645. }
  646. isNormalized := func(vec []float32) (res bool) {
  647. sum := 0.0
  648. for _, v := range vec {
  649. sum += float64(v * v)
  650. }
  651. if math.Abs(sum-1) > 1e-6 {
  652. return sum == 0
  653. } else {
  654. return true
  655. }
  656. }
  657. for _, tc := range testCases {
  658. t.Run("", func(t *testing.T) {
  659. normalized := normalize(tc.input)
  660. if !isNormalized(normalized) {
  661. t.Errorf("Vector %v is not normalized", tc.input)
  662. }
  663. })
  664. }
  665. }