routes_test.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/binary"
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "io/fs"
  10. "math"
  11. "math/rand/v2"
  12. "net"
  13. "net/http"
  14. "net/http/httptest"
  15. "os"
  16. "path/filepath"
  17. "sort"
  18. "strings"
  19. "testing"
  20. "unicode"
  21. "github.com/ollama/ollama/api"
  22. "github.com/ollama/ollama/fs/ggml"
  23. "github.com/ollama/ollama/openai"
  24. "github.com/ollama/ollama/types/model"
  25. "github.com/ollama/ollama/version"
  26. )
  27. func createTestFile(t *testing.T, name string) (string, string) {
  28. t.Helper()
  29. modelDir := os.Getenv("OLLAMA_MODELS")
  30. if modelDir == "" {
  31. t.Fatalf("OLLAMA_MODELS not specified")
  32. }
  33. f, err := os.CreateTemp(t.TempDir(), name)
  34. if err != nil {
  35. t.Fatalf("failed to create temp file: %v", err)
  36. }
  37. defer f.Close()
  38. err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
  39. if err != nil {
  40. t.Fatalf("failed to write to file: %v", err)
  41. }
  42. err = binary.Write(f, binary.LittleEndian, uint32(3))
  43. if err != nil {
  44. t.Fatalf("failed to write to file: %v", err)
  45. }
  46. err = binary.Write(f, binary.LittleEndian, uint64(0))
  47. if err != nil {
  48. t.Fatalf("failed to write to file: %v", err)
  49. }
  50. err = binary.Write(f, binary.LittleEndian, uint64(0))
  51. if err != nil {
  52. t.Fatalf("failed to write to file: %v", err)
  53. }
  54. // Calculate sha256 sum of file
  55. if _, err := f.Seek(0, 0); err != nil {
  56. t.Fatal(err)
  57. }
  58. digest, _ := GetSHA256Digest(f)
  59. if err := f.Close(); err != nil {
  60. t.Fatal(err)
  61. }
  62. if err := createLink(f.Name(), filepath.Join(modelDir, "blobs", fmt.Sprintf("sha256-%s", strings.TrimPrefix(digest, "sha256:")))); err != nil {
  63. t.Fatal(err)
  64. }
  65. return f.Name(), digest
  66. }
  67. // equalStringSlices checks if two slices of strings are equal.
  68. func equalStringSlices(a, b []string) bool {
  69. if len(a) != len(b) {
  70. return false
  71. }
  72. for i := range a {
  73. if a[i] != b[i] {
  74. return false
  75. }
  76. }
  77. return true
  78. }
  79. func Test_Routes(t *testing.T) {
  80. type testCase struct {
  81. Name string
  82. Method string
  83. Path string
  84. Setup func(t *testing.T, req *http.Request)
  85. Expected func(t *testing.T, resp *http.Response)
  86. }
  87. createTestModel := func(t *testing.T, name string) {
  88. t.Helper()
  89. _, digest := createTestFile(t, "ollama-model")
  90. fn := func(resp api.ProgressResponse) {
  91. t.Logf("Status: %s", resp.Status)
  92. }
  93. r := api.CreateRequest{
  94. Name: name,
  95. Files: map[string]string{"test.gguf": digest},
  96. Parameters: map[string]any{
  97. "seed": 42,
  98. "top_p": 0.9,
  99. "stop": []string{"foo", "bar"},
  100. },
  101. }
  102. modelName := model.ParseName(name)
  103. baseLayers, err := ggufLayers(digest, fn)
  104. if err != nil {
  105. t.Fatalf("failed to create model: %v", err)
  106. }
  107. if err := createModel(r, modelName, baseLayers, fn); err != nil {
  108. t.Fatal(err)
  109. }
  110. }
  111. testCases := []testCase{
  112. {
  113. Name: "Version Handler",
  114. Method: http.MethodGet,
  115. Path: "/api/version",
  116. Setup: func(t *testing.T, req *http.Request) {
  117. },
  118. Expected: func(t *testing.T, resp *http.Response) {
  119. contentType := resp.Header.Get("Content-Type")
  120. if contentType != "application/json; charset=utf-8" {
  121. t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
  122. }
  123. body, err := io.ReadAll(resp.Body)
  124. if err != nil {
  125. t.Fatalf("failed to read response body: %v", err)
  126. }
  127. expectedBody := fmt.Sprintf(`{"version":"%s"}`, version.Version)
  128. if string(body) != expectedBody {
  129. t.Errorf("expected body %s, got %s", expectedBody, string(body))
  130. }
  131. },
  132. },
  133. {
  134. Name: "Tags Handler (no tags)",
  135. Method: http.MethodGet,
  136. Path: "/api/tags",
  137. Expected: func(t *testing.T, resp *http.Response) {
  138. contentType := resp.Header.Get("Content-Type")
  139. if contentType != "application/json; charset=utf-8" {
  140. t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
  141. }
  142. body, err := io.ReadAll(resp.Body)
  143. if err != nil {
  144. t.Fatalf("failed to read response body: %v", err)
  145. }
  146. var modelList api.ListResponse
  147. err = json.Unmarshal(body, &modelList)
  148. if err != nil {
  149. t.Fatalf("failed to unmarshal response body: %v", err)
  150. }
  151. if modelList.Models == nil || len(modelList.Models) != 0 {
  152. t.Errorf("expected empty model list, got %v", modelList.Models)
  153. }
  154. },
  155. },
  156. {
  157. Name: "openai empty list",
  158. Method: http.MethodGet,
  159. Path: "/v1/models",
  160. Expected: func(t *testing.T, resp *http.Response) {
  161. contentType := resp.Header.Get("Content-Type")
  162. if contentType != "application/json" {
  163. t.Errorf("expected content type application/json, got %s", contentType)
  164. }
  165. body, err := io.ReadAll(resp.Body)
  166. if err != nil {
  167. t.Fatalf("failed to read response body: %v", err)
  168. }
  169. var modelList openai.ListCompletion
  170. err = json.Unmarshal(body, &modelList)
  171. if err != nil {
  172. t.Fatalf("failed to unmarshal response body: %v", err)
  173. }
  174. if modelList.Object != "list" || len(modelList.Data) != 0 {
  175. t.Errorf("expected empty model list, got %v", modelList.Data)
  176. }
  177. },
  178. },
  179. {
  180. Name: "Tags Handler (yes tags)",
  181. Method: http.MethodGet,
  182. Path: "/api/tags",
  183. Setup: func(t *testing.T, req *http.Request) {
  184. createTestModel(t, "test-model")
  185. },
  186. Expected: func(t *testing.T, resp *http.Response) {
  187. contentType := resp.Header.Get("Content-Type")
  188. if contentType != "application/json; charset=utf-8" {
  189. t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
  190. }
  191. body, err := io.ReadAll(resp.Body)
  192. if err != nil {
  193. t.Fatalf("failed to read response body: %v", err)
  194. }
  195. if strings.Contains(string(body), "expires_at") {
  196. t.Errorf("response body should not contain 'expires_at'")
  197. }
  198. var modelList api.ListResponse
  199. err = json.Unmarshal(body, &modelList)
  200. if err != nil {
  201. t.Fatalf("failed to unmarshal response body: %v", err)
  202. }
  203. if len(modelList.Models) != 1 || modelList.Models[0].Name != "test-model:latest" {
  204. t.Errorf("expected model 'test-model:latest', got %v", modelList.Models)
  205. }
  206. },
  207. },
  208. {
  209. Name: "Delete Model Handler",
  210. Method: http.MethodDelete,
  211. Path: "/api/delete",
  212. Setup: func(t *testing.T, req *http.Request) {
  213. createTestModel(t, "model-to-delete")
  214. deleteReq := api.DeleteRequest{
  215. Name: "model-to-delete",
  216. }
  217. jsonData, err := json.Marshal(deleteReq)
  218. if err != nil {
  219. t.Fatalf("failed to marshal delete request: %v", err)
  220. }
  221. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  222. },
  223. Expected: func(t *testing.T, resp *http.Response) {
  224. if resp.StatusCode != http.StatusOK {
  225. t.Errorf("expected status code 200, got %d", resp.StatusCode)
  226. }
  227. // Verify the model was deleted
  228. _, err := GetModel("model-to-delete")
  229. if err == nil || !os.IsNotExist(err) {
  230. t.Errorf("expected model to be deleted, got error %v", err)
  231. }
  232. },
  233. },
  234. {
  235. Name: "Delete Non-existent Model",
  236. Method: http.MethodDelete,
  237. Path: "/api/delete",
  238. Setup: func(t *testing.T, req *http.Request) {
  239. deleteReq := api.DeleteRequest{
  240. Name: "non-existent-model",
  241. }
  242. jsonData, err := json.Marshal(deleteReq)
  243. if err != nil {
  244. t.Fatalf("failed to marshal delete request: %v", err)
  245. }
  246. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  247. },
  248. Expected: func(t *testing.T, resp *http.Response) {
  249. if resp.StatusCode != http.StatusNotFound {
  250. t.Errorf("expected status code 404, got %d", resp.StatusCode)
  251. }
  252. body, err := io.ReadAll(resp.Body)
  253. if err != nil {
  254. t.Fatalf("failed to read response body: %v", err)
  255. }
  256. var errorResp map[string]string
  257. err = json.Unmarshal(body, &errorResp)
  258. if err != nil {
  259. t.Fatalf("failed to unmarshal response body: %v", err)
  260. }
  261. if !strings.Contains(errorResp["error"], "not found") {
  262. t.Errorf("expected error message to contain 'not found', got %s", errorResp["error"])
  263. }
  264. },
  265. },
  266. {
  267. Name: "openai list models with tags",
  268. Method: http.MethodGet,
  269. Path: "/v1/models",
  270. Expected: func(t *testing.T, resp *http.Response) {
  271. contentType := resp.Header.Get("Content-Type")
  272. if contentType != "application/json" {
  273. t.Errorf("expected content type application/json, got %s", contentType)
  274. }
  275. body, err := io.ReadAll(resp.Body)
  276. if err != nil {
  277. t.Fatalf("failed to read response body: %v", err)
  278. }
  279. var modelList openai.ListCompletion
  280. err = json.Unmarshal(body, &modelList)
  281. if err != nil {
  282. t.Fatalf("failed to unmarshal response body: %v", err)
  283. }
  284. if len(modelList.Data) != 1 || modelList.Data[0].Id != "test-model:latest" || modelList.Data[0].OwnedBy != "library" {
  285. t.Errorf("expected model 'test-model:latest' owned by 'library', got %v", modelList.Data)
  286. }
  287. },
  288. },
  289. {
  290. Name: "Create Model Handler",
  291. Method: http.MethodPost,
  292. Path: "/api/create",
  293. Setup: func(t *testing.T, req *http.Request) {
  294. _, digest := createTestFile(t, "ollama-model")
  295. stream := false
  296. createReq := api.CreateRequest{
  297. Name: "t-bone",
  298. Files: map[string]string{"test.gguf": digest},
  299. Stream: &stream,
  300. }
  301. jsonData, err := json.Marshal(createReq)
  302. if err != nil {
  303. t.Fatalf("failed to marshal create request: %v", err)
  304. }
  305. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  306. },
  307. Expected: func(t *testing.T, resp *http.Response) {
  308. contentType := resp.Header.Get("Content-Type")
  309. if contentType != "application/json" {
  310. t.Errorf("expected content type application/json, got %s", contentType)
  311. }
  312. _, err := io.ReadAll(resp.Body)
  313. if err != nil {
  314. t.Fatalf("failed to read response body: %v", err)
  315. }
  316. if resp.StatusCode != http.StatusOK { // Updated line
  317. t.Errorf("expected status code 200, got %d", resp.StatusCode)
  318. }
  319. model, err := GetModel("t-bone")
  320. if err != nil {
  321. t.Fatalf("failed to get model: %v", err)
  322. }
  323. if model.ShortName != "t-bone:latest" {
  324. t.Errorf("expected model name 't-bone:latest', got %s", model.ShortName)
  325. }
  326. },
  327. },
  328. {
  329. Name: "Copy Model Handler",
  330. Method: http.MethodPost,
  331. Path: "/api/copy",
  332. Setup: func(t *testing.T, req *http.Request) {
  333. createTestModel(t, "hamshank")
  334. copyReq := api.CopyRequest{
  335. Source: "hamshank",
  336. Destination: "beefsteak",
  337. }
  338. jsonData, err := json.Marshal(copyReq)
  339. if err != nil {
  340. t.Fatalf("failed to marshal copy request: %v", err)
  341. }
  342. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  343. },
  344. Expected: func(t *testing.T, resp *http.Response) {
  345. model, err := GetModel("beefsteak")
  346. if err != nil {
  347. t.Fatalf("failed to get model: %v", err)
  348. }
  349. if model.ShortName != "beefsteak:latest" {
  350. t.Errorf("expected model name 'beefsteak:latest', got %s", model.ShortName)
  351. }
  352. },
  353. },
  354. {
  355. Name: "Show Model Handler",
  356. Method: http.MethodPost,
  357. Path: "/api/show",
  358. Setup: func(t *testing.T, req *http.Request) {
  359. createTestModel(t, "show-model")
  360. showReq := api.ShowRequest{Model: "show-model"}
  361. jsonData, err := json.Marshal(showReq)
  362. if err != nil {
  363. t.Fatalf("failed to marshal show request: %v", err)
  364. }
  365. req.Body = io.NopCloser(bytes.NewReader(jsonData))
  366. },
  367. Expected: func(t *testing.T, resp *http.Response) {
  368. contentType := resp.Header.Get("Content-Type")
  369. if contentType != "application/json; charset=utf-8" {
  370. t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
  371. }
  372. body, err := io.ReadAll(resp.Body)
  373. if err != nil {
  374. t.Fatalf("failed to read response body: %v", err)
  375. }
  376. var showResp api.ShowResponse
  377. err = json.Unmarshal(body, &showResp)
  378. if err != nil {
  379. t.Fatalf("failed to unmarshal response body: %v", err)
  380. }
  381. var params []string
  382. paramsSplit := strings.Split(showResp.Parameters, "\n")
  383. for _, p := range paramsSplit {
  384. params = append(params, strings.Join(strings.Fields(p), " "))
  385. }
  386. sort.Strings(params)
  387. expectedParams := []string{
  388. "seed 42",
  389. "stop \"bar\"",
  390. "stop \"foo\"",
  391. "top_p 0.9",
  392. }
  393. if !equalStringSlices(params, expectedParams) {
  394. t.Errorf("expected parameters %v, got %v", expectedParams, params)
  395. }
  396. paramCount, ok := showResp.ModelInfo["general.parameter_count"].(float64)
  397. if !ok {
  398. t.Fatalf("expected parameter count to be a float64, got %T", showResp.ModelInfo["general.parameter_count"])
  399. }
  400. if math.Abs(paramCount) > 1e-9 {
  401. t.Errorf("expected parameter count to be 0, got %f", paramCount)
  402. }
  403. },
  404. },
  405. {
  406. Name: "openai retrieve model handler",
  407. Setup: func(t *testing.T, req *http.Request) {
  408. createTestModel(t, "show-model")
  409. },
  410. Method: http.MethodGet,
  411. Path: "/v1/models/show-model",
  412. Expected: func(t *testing.T, resp *http.Response) {
  413. contentType := resp.Header.Get("Content-Type")
  414. if contentType != "application/json" {
  415. t.Errorf("expected content type application/json, got %s", contentType)
  416. }
  417. body, err := io.ReadAll(resp.Body)
  418. if err != nil {
  419. t.Fatalf("failed to read response body: %v", err)
  420. }
  421. var retrieveResp api.RetrieveModelResponse
  422. err = json.Unmarshal(body, &retrieveResp)
  423. if err != nil {
  424. t.Fatalf("failed to unmarshal response body: %v", err)
  425. }
  426. if retrieveResp.Id != "show-model" || retrieveResp.OwnedBy != "library" {
  427. t.Errorf("expected model 'show-model' owned by 'library', got %v", retrieveResp)
  428. }
  429. },
  430. },
  431. }
  432. t.Setenv("OLLAMA_MODELS", t.TempDir())
  433. s := &Server{}
  434. router := s.GenerateRoutes()
  435. httpSrv := httptest.NewServer(router)
  436. t.Cleanup(httpSrv.Close)
  437. for _, tc := range testCases {
  438. t.Run(tc.Name, func(t *testing.T) {
  439. u := httpSrv.URL + tc.Path
  440. req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
  441. if err != nil {
  442. t.Fatalf("failed to create request: %v", err)
  443. }
  444. if tc.Setup != nil {
  445. tc.Setup(t, req)
  446. }
  447. resp, err := httpSrv.Client().Do(req)
  448. if err != nil {
  449. t.Fatalf("failed to do request: %v", err)
  450. }
  451. defer resp.Body.Close()
  452. if tc.Expected != nil {
  453. tc.Expected(t, resp)
  454. }
  455. })
  456. }
  457. }
  458. func casingShuffle(s string) string {
  459. rr := []rune(s)
  460. for i := range rr {
  461. if rand.N(2) == 0 {
  462. rr[i] = unicode.ToUpper(rr[i])
  463. } else {
  464. rr[i] = unicode.ToLower(rr[i])
  465. }
  466. }
  467. return string(rr)
  468. }
  469. func TestManifestCaseSensitivity(t *testing.T) {
  470. t.Setenv("OLLAMA_MODELS", t.TempDir())
  471. r := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  472. w.WriteHeader(http.StatusOK)
  473. io.WriteString(w, `{}`) //nolint:errcheck
  474. }))
  475. defer r.Close()
  476. nameUsed := make(map[string]bool)
  477. name := func() string {
  478. const fqmn = "example/namespace/model:tag"
  479. for {
  480. v := casingShuffle(fqmn)
  481. if nameUsed[v] {
  482. continue
  483. }
  484. nameUsed[v] = true
  485. return v
  486. }
  487. }
  488. wantStableName := name()
  489. t.Logf("stable name: %s", wantStableName)
  490. // checkManifestList tests that there is strictly one manifest in the
  491. // models directory, and that the manifest is for the model under test.
  492. checkManifestList := func() {
  493. t.Helper()
  494. mandir := filepath.Join(os.Getenv("OLLAMA_MODELS"), "manifests/")
  495. var entries []string
  496. t.Logf("dir entries:")
  497. fsys := os.DirFS(mandir)
  498. err := fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
  499. if err != nil {
  500. return err
  501. }
  502. t.Logf(" %s", fs.FormatDirEntry(info))
  503. if info.IsDir() {
  504. return nil
  505. }
  506. path = strings.TrimPrefix(path, mandir)
  507. entries = append(entries, path)
  508. return nil
  509. })
  510. if err != nil {
  511. t.Fatalf("failed to walk directory: %v", err)
  512. }
  513. if len(entries) != 1 {
  514. t.Errorf("len(got) = %d, want 1", len(entries))
  515. return // do not use Fatal so following steps run
  516. }
  517. g := entries[0] // raw path
  518. g = filepath.ToSlash(g)
  519. w := model.ParseName(wantStableName).Filepath()
  520. w = filepath.ToSlash(w)
  521. if g != w {
  522. t.Errorf("\ngot: %s\nwant: %s", g, w)
  523. }
  524. }
  525. checkOK := func(w *httptest.ResponseRecorder) {
  526. t.Helper()
  527. if w.Code != http.StatusOK {
  528. t.Errorf("code = %d, want 200", w.Code)
  529. t.Logf("body: %s", w.Body.String())
  530. }
  531. }
  532. var s Server
  533. testMakeRequestDialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
  534. var d net.Dialer
  535. return d.DialContext(ctx, "tcp", r.Listener.Addr().String())
  536. }
  537. t.Cleanup(func() { testMakeRequestDialContext = nil })
  538. t.Logf("creating")
  539. _, digest := createBinFile(t, nil, nil)
  540. checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
  541. // Start with the stable name, and later use a case-shuffled
  542. // version.
  543. Name: wantStableName,
  544. Files: map[string]string{"test.gguf": digest},
  545. Stream: &stream,
  546. }))
  547. checkManifestList()
  548. t.Logf("creating (again)")
  549. checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
  550. Name: name(),
  551. Files: map[string]string{"test.gguf": digest},
  552. Stream: &stream,
  553. }))
  554. checkManifestList()
  555. t.Logf("pulling")
  556. checkOK(createRequest(t, s.PullHandler, api.PullRequest{
  557. Name: name(),
  558. Stream: &stream,
  559. Insecure: true,
  560. }))
  561. checkManifestList()
  562. t.Logf("copying")
  563. checkOK(createRequest(t, s.CopyHandler, api.CopyRequest{
  564. Source: name(),
  565. Destination: name(),
  566. }))
  567. checkManifestList()
  568. t.Logf("pushing")
  569. rr := createRequest(t, s.PushHandler, api.PushRequest{
  570. Model: name(),
  571. Insecure: true,
  572. Username: "alice",
  573. Password: "x",
  574. })
  575. checkOK(rr)
  576. if !strings.Contains(rr.Body.String(), `"status":"success"`) {
  577. t.Errorf("got = %q, want success", rr.Body.String())
  578. }
  579. }
  580. func TestShow(t *testing.T) {
  581. t.Setenv("OLLAMA_MODELS", t.TempDir())
  582. var s Server
  583. _, digest1 := createBinFile(t, ggml.KV{"general.architecture": "test"}, nil)
  584. _, digest2 := createBinFile(t, ggml.KV{"general.type": "projector", "general.architecture": "clip"}, nil)
  585. createRequest(t, s.CreateHandler, api.CreateRequest{
  586. Name: "show-model",
  587. Files: map[string]string{"model.gguf": digest1, "projector.gguf": digest2},
  588. })
  589. w := createRequest(t, s.ShowHandler, api.ShowRequest{
  590. Name: "show-model",
  591. })
  592. if w.Code != http.StatusOK {
  593. t.Fatalf("expected status code 200, actual %d", w.Code)
  594. }
  595. var resp api.ShowResponse
  596. if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
  597. t.Fatal(err)
  598. }
  599. if resp.ModelInfo["general.architecture"] != "test" {
  600. t.Fatal("Expected model architecture to be 'test', but got", resp.ModelInfo["general.architecture"])
  601. }
  602. if resp.ProjectorInfo["general.architecture"] != "clip" {
  603. t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
  604. }
  605. }
  606. func TestNormalize(t *testing.T) {
  607. type testCase struct {
  608. input []float32
  609. }
  610. testCases := []testCase{
  611. {input: []float32{1}},
  612. {input: []float32{0, 1, 2, 3}},
  613. {input: []float32{0.1, 0.2, 0.3}},
  614. {input: []float32{-0.1, 0.2, 0.3, -0.4}},
  615. {input: []float32{0, 0, 0}},
  616. }
  617. isNormalized := func(vec []float32) (res bool) {
  618. sum := 0.0
  619. for _, v := range vec {
  620. sum += float64(v * v)
  621. }
  622. if math.Abs(sum-1) > 1e-6 {
  623. return sum == 0
  624. } else {
  625. return true
  626. }
  627. }
  628. for _, tc := range testCases {
  629. t.Run("", func(t *testing.T) {
  630. normalized := normalize(tc.input)
  631. if !isNormalized(normalized) {
  632. t.Errorf("Vector %v is not normalized", tc.input)
  633. }
  634. })
  635. }
  636. }