routes_test.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/binary"
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "io/fs"
  10. "math"
  11. "math/rand/v2"
  12. "net"
  13. "net/http"
  14. "net/http/httptest"
  15. "os"
  16. "path/filepath"
  17. "sort"
  18. "strings"
  19. "testing"
  20. "unicode"
  21. "github.com/ollama/ollama/api"
  22. "github.com/ollama/ollama/fs/ggml"
  23. "github.com/ollama/ollama/openai"
  24. "github.com/ollama/ollama/server/internal/client/ollama"
  25. "github.com/ollama/ollama/types/model"
  26. "github.com/ollama/ollama/version"
  27. )
  28. func createTestFile(t *testing.T, name string) (string, string) {
  29. t.Helper()
  30. modelDir := os.Getenv("OLLAMA_MODELS")
  31. if modelDir == "" {
  32. t.Fatalf("OLLAMA_MODELS not specified")
  33. }
  34. f, err := os.CreateTemp(t.TempDir(), name)
  35. if err != nil {
  36. t.Fatalf("failed to create temp file: %v", err)
  37. }
  38. defer f.Close()
  39. err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
  40. if err != nil {
  41. t.Fatalf("failed to write to file: %v", err)
  42. }
  43. err = binary.Write(f, binary.LittleEndian, uint32(3))
  44. if err != nil {
  45. t.Fatalf("failed to write to file: %v", err)
  46. }
  47. err = binary.Write(f, binary.LittleEndian, uint64(0))
  48. if err != nil {
  49. t.Fatalf("failed to write to file: %v", err)
  50. }
  51. err = binary.Write(f, binary.LittleEndian, uint64(0))
  52. if err != nil {
  53. t.Fatalf("failed to write to file: %v", err)
  54. }
  55. // Calculate sha256 sum of file
  56. if _, err := f.Seek(0, 0); err != nil {
  57. t.Fatal(err)
  58. }
  59. digest, _ := GetSHA256Digest(f)
  60. if err := f.Close(); err != nil {
  61. t.Fatal(err)
  62. }
  63. if err := createLink(f.Name(), filepath.Join(modelDir, "blobs", fmt.Sprintf("sha256-%s", strings.TrimPrefix(digest, "sha256:")))); err != nil {
  64. t.Fatal(err)
  65. }
  66. return f.Name(), digest
  67. }
  68. // equalStringSlices checks if two slices of strings are equal.
  69. func equalStringSlices(a, b []string) bool {
  70. if len(a) != len(b) {
  71. return false
  72. }
  73. for i := range a {
  74. if a[i] != b[i] {
  75. return false
  76. }
  77. }
  78. return true
  79. }
  80. type panicTransport struct{}
  81. func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
  82. panic("unexpected RoundTrip call")
  83. }
  84. var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
  85. func TestRoutes(t *testing.T) {
  86. type testCase struct {
  87. Name string
  88. Method string
  89. Path string
  90. Setup func(t *testing.T, req *http.Request)
  91. Expected func(t *testing.T, resp *http.Response)
  92. }
  93. createTestModel := func(t *testing.T, name string) {
  94. t.Helper()
  95. _, digest := createTestFile(t, "ollama-model")
  96. fn := func(resp api.ProgressResponse) {
  97. t.Logf("Status: %s", resp.Status)
  98. }
  99. r := api.CreateRequest{
  100. Name: name,
  101. Files: map[string]string{"test.gguf": digest},
  102. Parameters: map[string]any{
  103. "seed": 42,
  104. "top_p": 0.9,
  105. "stop": []string{"foo", "bar"},
  106. },
  107. }
  108. modelName := model.ParseName(name)
  109. baseLayers, err := ggufLayers(digest, fn)
  110. if err != nil {
  111. t.Fatalf("failed to create model: %v", err)
  112. }
  113. if err := createModel(r, modelName, baseLayers, fn); err != nil {
  114. t.Fatal(err)
  115. }
  116. }
  117. testCases := []testCase{
  118. {
  119. Name: "Version Handler",
  120. Method: http.MethodGet,
  121. Path: "/api/version",
  122. Setup: func(t *testing.T, req *http.Request) {
  123. },
  124. Expected: func(t *testing.T, resp *http.Response) {
  125. contentType := resp.Header.Get("Content-Type")
  126. if contentType != "application/json; charset=utf-8" {
  127. t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
  128. }
  129. body, err := io.ReadAll(resp.Body)
  130. if err != nil {
  131. t.Fatalf("failed to read response body: %v", err)
  132. }
  133. expectedBody := fmt.Sprintf(`{"version":"%s"}`, version.Version)
  134. if string(body) != expectedBody {
  135. t.Errorf("expected body %s, got %s", expectedBody, string(body))
  136. }
  137. },
  138. },
  139. {
  140. Name: "Tags Handler (no tags)",
  141. Method: http.MethodGet,
  142. Path: "/api/tags",
  143. Expected: func(t *testing.T, resp *http.Response) {
  144. contentType := resp.Header.Get("Content-Type")
  145. if contentType != "application/json; charset=utf-8" {
  146. t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
  147. }
  148. body, err := io.ReadAll(resp.Body)
  149. if err != nil {
  150. t.Fatalf("failed to read response body: %v", err)
  151. }
  152. var modelList api.ListResponse
  153. err = json.Unmarshal(body, &modelList)
  154. if err != nil {
  155. t.Fatalf("failed to unmarshal response body: %v", err)
  156. }
  157. if modelList.Models == nil || len(modelList.Models) != 0 {
  158. t.Errorf("expected empty model list, got %v", modelList.Models)
  159. }
  160. },
  161. },
  162. {
  163. Name: "openai empty list",
  164. Method: http.MethodGet,
  165. Path: "/v1/models",
  166. Expected: func(t *testing.T, resp *http.Response) {
  167. contentType := resp.Header.Get("Content-Type")
  168. if contentType != "application/json" {
  169. t.Errorf("expected content type application/json, got %s", contentType)
  170. }
  171. body, err := io.ReadAll(resp.Body)
  172. if err != nil {
  173. t.Fatalf("failed to read response body: %v", err)
  174. }
  175. var modelList openai.ListCompletion
  176. err = json.Unmarshal(body, &modelList)
  177. if err != nil {
  178. t.Fatalf("failed to unmarshal response body: %v", err)
  179. }
  180. if modelList.Object != "list" || len(modelList.Data) != 0 {
  181. t.Errorf("expected empty model list, got %v", modelList.Data)
  182. }
  183. },
  184. },
  185. {
  186. Name: "Tags Handler (yes tags)",
  187. Method: http.MethodGet,
  188. Path: "/api/tags",
  189. Setup: func(t *testing.T, req *http.Request) {
  190. createTestModel(t, "test-model")
  191. },
  192. Expected: func(t *testing.T, resp *http.Response) {
  193. contentType := resp.Header.Get("Content-Type")
  194. if contentType != "application/json; charset=utf-8" {
  195. t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
  196. }
  197. body, err := io.ReadAll(resp.Body)
  198. if err != nil {
  199. t.Fatalf("failed to read response body: %v", err)
  200. }
  201. if strings.Contains(string(body), "expires_at") {
  202. t.Errorf("response body should not contain 'expires_at'")
  203. }
  204. var modelList api.ListResponse
  205. err = json.Unmarshal(body, &modelList)
  206. if err != nil {
  207. t.Fatalf("failed to unmarshal response body: %v", err)
  208. }
  209. if len(modelList.Models) != 1 || modelList.Models[0].Name != "test-model:latest" {
  210. t.Errorf("expected model 'test-model:latest', got %v", modelList.Models)
  211. }
  212. },
  213. },
  214. {
  215. Name: "Delete Model Handler",
  216. Method: http.MethodDelete,
  217. Path: "/api/delete",
  218. Setup: func(t *testing.T, req *http.Request) {
  219. createTestModel(t, "model_to_delete")
  220. deleteReq := api.DeleteRequest{
  221. Name: "model_to_delete",
  222. }
  223. jsonData, err := json.Marshal(deleteReq)
  224. if err != nil {
  225. t.Fatalf("failed to marshal delete request: %v", err)
  226. }
  227. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  228. },
  229. Expected: func(t *testing.T, resp *http.Response) {
  230. if resp.StatusCode != http.StatusOK {
  231. t.Errorf("expected status code 200, got %d", resp.StatusCode)
  232. }
  233. // Verify the model was deleted
  234. _, err := GetModel("model-to-delete")
  235. if err == nil || !os.IsNotExist(err) {
  236. t.Errorf("expected model to be deleted, got error %v", err)
  237. }
  238. },
  239. },
  240. {
  241. Name: "Delete Non-existent Model",
  242. Method: http.MethodDelete,
  243. Path: "/api/delete",
  244. Setup: func(t *testing.T, req *http.Request) {
  245. deleteReq := api.DeleteRequest{
  246. Name: "non_existent_model",
  247. }
  248. jsonData, err := json.Marshal(deleteReq)
  249. if err != nil {
  250. t.Fatalf("failed to marshal delete request: %v", err)
  251. }
  252. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  253. },
  254. Expected: func(t *testing.T, resp *http.Response) {
  255. if resp.StatusCode != http.StatusNotFound {
  256. t.Errorf("expected status code 404, got %d", resp.StatusCode)
  257. }
  258. body, err := io.ReadAll(resp.Body)
  259. if err != nil {
  260. t.Fatalf("failed to read response body: %v", err)
  261. }
  262. var errorResp map[string]string
  263. err = json.Unmarshal(body, &errorResp)
  264. if err != nil {
  265. t.Fatalf("failed to unmarshal response body: %v", err)
  266. }
  267. if !strings.Contains(errorResp["error"], "not found") {
  268. t.Errorf("expected error message to contain 'not found', got %s", errorResp["error"])
  269. }
  270. },
  271. },
  272. {
  273. Name: "openai list models with tags",
  274. Method: http.MethodGet,
  275. Path: "/v1/models",
  276. Expected: func(t *testing.T, resp *http.Response) {
  277. contentType := resp.Header.Get("Content-Type")
  278. if contentType != "application/json" {
  279. t.Errorf("expected content type application/json, got %s", contentType)
  280. }
  281. body, err := io.ReadAll(resp.Body)
  282. if err != nil {
  283. t.Fatalf("failed to read response body: %v", err)
  284. }
  285. var modelList openai.ListCompletion
  286. err = json.Unmarshal(body, &modelList)
  287. if err != nil {
  288. t.Fatalf("failed to unmarshal response body: %v", err)
  289. }
  290. if len(modelList.Data) != 1 || modelList.Data[0].Id != "test-model:latest" || modelList.Data[0].OwnedBy != "library" {
  291. t.Errorf("expected model 'test-model:latest' owned by 'library', got %v", modelList.Data)
  292. }
  293. },
  294. },
  295. {
  296. Name: "Create Model Handler",
  297. Method: http.MethodPost,
  298. Path: "/api/create",
  299. Setup: func(t *testing.T, req *http.Request) {
  300. _, digest := createTestFile(t, "ollama-model")
  301. stream := false
  302. createReq := api.CreateRequest{
  303. Name: "t-bone",
  304. Files: map[string]string{"test.gguf": digest},
  305. Stream: &stream,
  306. }
  307. jsonData, err := json.Marshal(createReq)
  308. if err != nil {
  309. t.Fatalf("failed to marshal create request: %v", err)
  310. }
  311. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  312. },
  313. Expected: func(t *testing.T, resp *http.Response) {
  314. contentType := resp.Header.Get("Content-Type")
  315. if contentType != "application/json" {
  316. t.Errorf("expected content type application/json, got %s", contentType)
  317. }
  318. _, err := io.ReadAll(resp.Body)
  319. if err != nil {
  320. t.Fatalf("failed to read response body: %v", err)
  321. }
  322. if resp.StatusCode != http.StatusOK { // Updated line
  323. t.Errorf("expected status code 200, got %d", resp.StatusCode)
  324. }
  325. model, err := GetModel("t-bone")
  326. if err != nil {
  327. t.Fatalf("failed to get model: %v", err)
  328. }
  329. if model.ShortName != "t-bone:latest" {
  330. t.Errorf("expected model name 't-bone:latest', got %s", model.ShortName)
  331. }
  332. },
  333. },
  334. {
  335. Name: "Copy Model Handler",
  336. Method: http.MethodPost,
  337. Path: "/api/copy",
  338. Setup: func(t *testing.T, req *http.Request) {
  339. createTestModel(t, "hamshank")
  340. copyReq := api.CopyRequest{
  341. Source: "hamshank",
  342. Destination: "beefsteak",
  343. }
  344. jsonData, err := json.Marshal(copyReq)
  345. if err != nil {
  346. t.Fatalf("failed to marshal copy request: %v", err)
  347. }
  348. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  349. },
  350. Expected: func(t *testing.T, resp *http.Response) {
  351. model, err := GetModel("beefsteak")
  352. if err != nil {
  353. t.Fatalf("failed to get model: %v", err)
  354. }
  355. if model.ShortName != "beefsteak:latest" {
  356. t.Errorf("expected model name 'beefsteak:latest', got %s", model.ShortName)
  357. }
  358. },
  359. },
  360. {
  361. Name: "Show Model Handler",
  362. Method: http.MethodPost,
  363. Path: "/api/show",
  364. Setup: func(t *testing.T, req *http.Request) {
  365. createTestModel(t, "show-model")
  366. showReq := api.ShowRequest{Model: "show-model"}
  367. jsonData, err := json.Marshal(showReq)
  368. if err != nil {
  369. t.Fatalf("failed to marshal show request: %v", err)
  370. }
  371. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  372. },
  373. Expected: func(t *testing.T, resp *http.Response) {
  374. contentType := resp.Header.Get("Content-Type")
  375. if contentType != "application/json; charset=utf-8" {
  376. t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
  377. }
  378. body, err := io.ReadAll(resp.Body)
  379. if err != nil {
  380. t.Fatalf("failed to read response body: %v", err)
  381. }
  382. var showResp api.ShowResponse
  383. err = json.Unmarshal(body, &showResp)
  384. if err != nil {
  385. t.Fatalf("failed to unmarshal response body: %v", err)
  386. }
  387. var params []string
  388. paramsSplit := strings.Split(showResp.Parameters, "\n")
  389. for _, p := range paramsSplit {
  390. params = append(params, strings.Join(strings.Fields(p), " "))
  391. }
  392. sort.Strings(params)
  393. expectedParams := []string{
  394. "seed 42",
  395. "stop \"bar\"",
  396. "stop \"foo\"",
  397. "top_p 0.9",
  398. }
  399. if !equalStringSlices(params, expectedParams) {
  400. t.Errorf("expected parameters %v, got %v", expectedParams, params)
  401. }
  402. paramCount, ok := showResp.ModelInfo["general.parameter_count"].(float64)
  403. if !ok {
  404. t.Fatalf("expected parameter count to be a float64, got %T", showResp.ModelInfo["general.parameter_count"])
  405. }
  406. if math.Abs(paramCount) > 1e-9 {
  407. t.Errorf("expected parameter count to be 0, got %f", paramCount)
  408. }
  409. },
  410. },
  411. {
  412. Name: "openai retrieve model handler",
  413. Setup: func(t *testing.T, req *http.Request) {
  414. createTestModel(t, "show-model")
  415. },
  416. Method: http.MethodGet,
  417. Path: "/v1/models/show-model",
  418. Expected: func(t *testing.T, resp *http.Response) {
  419. contentType := resp.Header.Get("Content-Type")
  420. if contentType != "application/json" {
  421. t.Errorf("expected content type application/json, got %s", contentType)
  422. }
  423. body, err := io.ReadAll(resp.Body)
  424. if err != nil {
  425. t.Fatalf("failed to read response body: %v", err)
  426. }
  427. var retrieveResp api.RetrieveModelResponse
  428. err = json.Unmarshal(body, &retrieveResp)
  429. if err != nil {
  430. t.Fatalf("failed to unmarshal response body: %v", err)
  431. }
  432. if retrieveResp.Id != "show-model" || retrieveResp.OwnedBy != "library" {
  433. t.Errorf("expected model 'show-model' owned by 'library', got %v", retrieveResp)
  434. }
  435. },
  436. },
  437. }
  438. modelsDir := t.TempDir()
  439. t.Setenv("OLLAMA_MODELS", modelsDir)
  440. rc := &ollama.Registry{
  441. // This is a temporary measure to allow us to move forward,
  442. // surfacing any code contacting ollama.com we do not intended
  443. // to.
  444. //
  445. // Currently, this only handles DELETE /api/delete, which
  446. // should not make any contact with the ollama.com registry, so
  447. // be clear about that.
  448. //
  449. // Tests that do need to contact the registry here, will be
  450. // consumed into our new server/api code packages and removed
  451. // from here.
  452. HTTPClient: panicOnRoundTrip,
  453. }
  454. s := &Server{}
  455. router, err := s.GenerateRoutes(rc)
  456. if err != nil {
  457. t.Fatalf("failed to generate routes: %v", err)
  458. }
  459. httpSrv := httptest.NewServer(router)
  460. t.Cleanup(httpSrv.Close)
  461. for _, tc := range testCases {
  462. t.Run(tc.Name, func(t *testing.T) {
  463. u := httpSrv.URL + tc.Path
  464. req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
  465. if err != nil {
  466. t.Fatalf("failed to create request: %v", err)
  467. }
  468. if tc.Setup != nil {
  469. tc.Setup(t, req)
  470. }
  471. resp, err := httpSrv.Client().Do(req)
  472. if err != nil {
  473. t.Fatalf("failed to do request: %v", err)
  474. }
  475. defer resp.Body.Close()
  476. if tc.Expected != nil {
  477. tc.Expected(t, resp)
  478. }
  479. })
  480. }
  481. }
  482. func casingShuffle(s string) string {
  483. rr := []rune(s)
  484. for i := range rr {
  485. if rand.N(2) == 0 {
  486. rr[i] = unicode.ToUpper(rr[i])
  487. } else {
  488. rr[i] = unicode.ToLower(rr[i])
  489. }
  490. }
  491. return string(rr)
  492. }
  493. func TestManifestCaseSensitivity(t *testing.T) {
  494. t.Setenv("OLLAMA_MODELS", t.TempDir())
  495. r := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  496. w.WriteHeader(http.StatusOK)
  497. io.WriteString(w, `{}`) //nolint:errcheck
  498. }))
  499. defer r.Close()
  500. nameUsed := make(map[string]bool)
  501. name := func() string {
  502. const fqmn = "example/namespace/model:tag"
  503. for {
  504. v := casingShuffle(fqmn)
  505. if nameUsed[v] {
  506. continue
  507. }
  508. nameUsed[v] = true
  509. return v
  510. }
  511. }
  512. wantStableName := name()
  513. t.Logf("stable name: %s", wantStableName)
  514. // checkManifestList tests that there is strictly one manifest in the
  515. // models directory, and that the manifest is for the model under test.
  516. checkManifestList := func() {
  517. t.Helper()
  518. mandir := filepath.Join(os.Getenv("OLLAMA_MODELS"), "manifests/")
  519. var entries []string
  520. t.Logf("dir entries:")
  521. fsys := os.DirFS(mandir)
  522. err := fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
  523. if err != nil {
  524. return err
  525. }
  526. t.Logf(" %s", fs.FormatDirEntry(info))
  527. if info.IsDir() {
  528. return nil
  529. }
  530. path = strings.TrimPrefix(path, mandir)
  531. entries = append(entries, path)
  532. return nil
  533. })
  534. if err != nil {
  535. t.Fatalf("failed to walk directory: %v", err)
  536. }
  537. if len(entries) != 1 {
  538. t.Errorf("len(got) = %d, want 1", len(entries))
  539. return // do not use Fatal so following steps run
  540. }
  541. g := entries[0] // raw path
  542. g = filepath.ToSlash(g)
  543. w := model.ParseName(wantStableName).Filepath()
  544. w = filepath.ToSlash(w)
  545. if g != w {
  546. t.Errorf("\ngot: %s\nwant: %s", g, w)
  547. }
  548. }
  549. checkOK := func(w *httptest.ResponseRecorder) {
  550. t.Helper()
  551. if w.Code != http.StatusOK {
  552. t.Errorf("code = %d, want 200", w.Code)
  553. t.Logf("body: %s", w.Body.String())
  554. }
  555. }
  556. var s Server
  557. testMakeRequestDialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
  558. var d net.Dialer
  559. return d.DialContext(ctx, "tcp", r.Listener.Addr().String())
  560. }
  561. t.Cleanup(func() { testMakeRequestDialContext = nil })
  562. t.Logf("creating")
  563. _, digest := createBinFile(t, nil, nil)
  564. checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
  565. // Start with the stable name, and later use a case-shuffled
  566. // version.
  567. Name: wantStableName,
  568. Files: map[string]string{"test.gguf": digest},
  569. Stream: &stream,
  570. }))
  571. checkManifestList()
  572. t.Logf("creating (again)")
  573. checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
  574. Name: name(),
  575. Files: map[string]string{"test.gguf": digest},
  576. Stream: &stream,
  577. }))
  578. checkManifestList()
  579. t.Logf("pulling")
  580. checkOK(createRequest(t, s.PullHandler, api.PullRequest{
  581. Name: name(),
  582. Stream: &stream,
  583. Insecure: true,
  584. }))
  585. checkManifestList()
  586. t.Logf("copying")
  587. checkOK(createRequest(t, s.CopyHandler, api.CopyRequest{
  588. Source: name(),
  589. Destination: name(),
  590. }))
  591. checkManifestList()
  592. t.Logf("pushing")
  593. rr := createRequest(t, s.PushHandler, api.PushRequest{
  594. Model: name(),
  595. Insecure: true,
  596. Username: "alice",
  597. Password: "x",
  598. })
  599. checkOK(rr)
  600. if !strings.Contains(rr.Body.String(), `"status":"success"`) {
  601. t.Errorf("got = %q, want success", rr.Body.String())
  602. }
  603. }
  604. func TestShow(t *testing.T) {
  605. t.Setenv("OLLAMA_MODELS", t.TempDir())
  606. var s Server
  607. _, digest1 := createBinFile(t, ggml.KV{"general.architecture": "test"}, nil)
  608. _, digest2 := createBinFile(t, ggml.KV{"general.type": "projector", "general.architecture": "clip"}, nil)
  609. createRequest(t, s.CreateHandler, api.CreateRequest{
  610. Name: "show-model",
  611. Files: map[string]string{"model.gguf": digest1, "projector.gguf": digest2},
  612. })
  613. w := createRequest(t, s.ShowHandler, api.ShowRequest{
  614. Name: "show-model",
  615. })
  616. if w.Code != http.StatusOK {
  617. t.Fatalf("expected status code 200, actual %d", w.Code)
  618. }
  619. var resp api.ShowResponse
  620. if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
  621. t.Fatal(err)
  622. }
  623. if resp.ModelInfo["general.architecture"] != "test" {
  624. t.Fatal("Expected model architecture to be 'test', but got", resp.ModelInfo["general.architecture"])
  625. }
  626. if resp.ProjectorInfo["general.architecture"] != "clip" {
  627. t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
  628. }
  629. }
  630. func TestNormalize(t *testing.T) {
  631. type testCase struct {
  632. input []float32
  633. }
  634. testCases := []testCase{
  635. {input: []float32{1}},
  636. {input: []float32{0, 1, 2, 3}},
  637. {input: []float32{0.1, 0.2, 0.3}},
  638. {input: []float32{-0.1, 0.2, 0.3, -0.4}},
  639. {input: []float32{0, 0, 0}},
  640. }
  641. isNormalized := func(vec []float32) (res bool) {
  642. sum := 0.0
  643. for _, v := range vec {
  644. sum += float64(v * v)
  645. }
  646. if math.Abs(sum-1) > 1e-6 {
  647. return sum == 0
  648. } else {
  649. return true
  650. }
  651. }
  652. for _, tc := range testCases {
  653. t.Run("", func(t *testing.T) {
  654. normalized := normalize(tc.input)
  655. if !isNormalized(normalized) {
  656. t.Errorf("Vector %v is not normalized", tc.input)
  657. }
  658. })
  659. }
  660. }