cmd_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. package cmd
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "os"
  10. "path/filepath"
  11. "strings"
  12. "testing"
  13. "github.com/google/go-cmp/cmp"
  14. "github.com/spf13/cobra"
  15. "github.com/ollama/ollama/api"
  16. )
  17. func TestShowInfo(t *testing.T) {
  18. t.Run("bare details", func(t *testing.T) {
  19. var b bytes.Buffer
  20. if err := showInfo(&api.ShowResponse{
  21. Details: api.ModelDetails{
  22. Family: "test",
  23. ParameterSize: "7B",
  24. QuantizationLevel: "FP16",
  25. },
  26. }, &b); err != nil {
  27. t.Fatal(err)
  28. }
  29. expect := ` Model
  30. architecture test
  31. parameters 7B
  32. quantization FP16
  33. `
  34. if diff := cmp.Diff(expect, b.String()); diff != "" {
  35. t.Errorf("unexpected output (-want +got):\n%s", diff)
  36. }
  37. })
  38. t.Run("bare model info", func(t *testing.T) {
  39. var b bytes.Buffer
  40. if err := showInfo(&api.ShowResponse{
  41. ModelInfo: map[string]any{
  42. "general.architecture": "test",
  43. "general.parameter_count": float64(7_000_000_000),
  44. "test.context_length": float64(0),
  45. "test.embedding_length": float64(0),
  46. },
  47. Details: api.ModelDetails{
  48. Family: "test",
  49. ParameterSize: "7B",
  50. QuantizationLevel: "FP16",
  51. },
  52. }, &b); err != nil {
  53. t.Fatal(err)
  54. }
  55. expect := ` Model
  56. architecture test
  57. parameters 7B
  58. context length 0
  59. embedding length 0
  60. quantization FP16
  61. `
  62. if diff := cmp.Diff(expect, b.String()); diff != "" {
  63. t.Errorf("unexpected output (-want +got):\n%s", diff)
  64. }
  65. })
  66. t.Run("parameters", func(t *testing.T) {
  67. var b bytes.Buffer
  68. if err := showInfo(&api.ShowResponse{
  69. Details: api.ModelDetails{
  70. Family: "test",
  71. ParameterSize: "7B",
  72. QuantizationLevel: "FP16",
  73. },
  74. Parameters: `
  75. stop never
  76. stop gonna
  77. stop give
  78. stop you
  79. stop up
  80. temperature 99`,
  81. }, &b); err != nil {
  82. t.Fatal(err)
  83. }
  84. expect := ` Model
  85. architecture test
  86. parameters 7B
  87. quantization FP16
  88. Parameters
  89. stop never
  90. stop gonna
  91. stop give
  92. stop you
  93. stop up
  94. temperature 99
  95. `
  96. if diff := cmp.Diff(expect, b.String()); diff != "" {
  97. t.Errorf("unexpected output (-want +got):\n%s", diff)
  98. }
  99. })
  100. t.Run("project info", func(t *testing.T) {
  101. var b bytes.Buffer
  102. if err := showInfo(&api.ShowResponse{
  103. Details: api.ModelDetails{
  104. Family: "test",
  105. ParameterSize: "7B",
  106. QuantizationLevel: "FP16",
  107. },
  108. ProjectorInfo: map[string]any{
  109. "general.architecture": "clip",
  110. "general.parameter_count": float64(133_700_000),
  111. "clip.vision.embedding_length": float64(0),
  112. "clip.vision.projection_dim": float64(0),
  113. },
  114. }, &b); err != nil {
  115. t.Fatal(err)
  116. }
  117. expect := ` Model
  118. architecture test
  119. parameters 7B
  120. quantization FP16
  121. Projector
  122. architecture clip
  123. parameters 133.70M
  124. embedding length 0
  125. dimensions 0
  126. `
  127. if diff := cmp.Diff(expect, b.String()); diff != "" {
  128. t.Errorf("unexpected output (-want +got):\n%s", diff)
  129. }
  130. })
  131. t.Run("system", func(t *testing.T) {
  132. var b bytes.Buffer
  133. if err := showInfo(&api.ShowResponse{
  134. Details: api.ModelDetails{
  135. Family: "test",
  136. ParameterSize: "7B",
  137. QuantizationLevel: "FP16",
  138. },
  139. System: `You are a pirate!
  140. Ahoy, matey!
  141. Weigh anchor!
  142. `,
  143. }, &b); err != nil {
  144. t.Fatal(err)
  145. }
  146. expect := ` Model
  147. architecture test
  148. parameters 7B
  149. quantization FP16
  150. System
  151. You are a pirate!
  152. Ahoy, matey!
  153. `
  154. if diff := cmp.Diff(expect, b.String()); diff != "" {
  155. t.Errorf("unexpected output (-want +got):\n%s", diff)
  156. }
  157. })
  158. t.Run("license", func(t *testing.T) {
  159. var b bytes.Buffer
  160. license, err := os.ReadFile(filepath.Join("..", "LICENSE"))
  161. if err != nil {
  162. t.Fatal(err)
  163. }
  164. if err := showInfo(&api.ShowResponse{
  165. Details: api.ModelDetails{
  166. Family: "test",
  167. ParameterSize: "7B",
  168. QuantizationLevel: "FP16",
  169. },
  170. License: string(license),
  171. }, &b); err != nil {
  172. t.Fatal(err)
  173. }
  174. expect := ` Model
  175. architecture test
  176. parameters 7B
  177. quantization FP16
  178. License
  179. MIT License
  180. Copyright (c) Ollama
  181. `
  182. if diff := cmp.Diff(expect, b.String()); diff != "" {
  183. t.Errorf("unexpected output (-want +got):\n%s", diff)
  184. }
  185. })
  186. }
  187. func TestDeleteHandler(t *testing.T) {
  188. stopped := false
  189. mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  190. if r.URL.Path == "/api/delete" && r.Method == http.MethodDelete {
  191. var req api.DeleteRequest
  192. if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
  193. http.Error(w, err.Error(), http.StatusBadRequest)
  194. return
  195. }
  196. if req.Name == "test-model" {
  197. w.WriteHeader(http.StatusOK)
  198. } else {
  199. w.WriteHeader(http.StatusNotFound)
  200. }
  201. return
  202. }
  203. if r.URL.Path == "/api/generate" && r.Method == http.MethodPost {
  204. var req api.GenerateRequest
  205. if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
  206. http.Error(w, err.Error(), http.StatusBadRequest)
  207. return
  208. }
  209. if req.Model == "test-model" {
  210. w.WriteHeader(http.StatusOK)
  211. if err := json.NewEncoder(w).Encode(api.GenerateResponse{
  212. Done: true,
  213. }); err != nil {
  214. http.Error(w, err.Error(), http.StatusInternalServerError)
  215. }
  216. stopped = true
  217. return
  218. } else {
  219. w.WriteHeader(http.StatusNotFound)
  220. if err := json.NewEncoder(w).Encode(api.GenerateResponse{
  221. Done: false,
  222. }); err != nil {
  223. http.Error(w, err.Error(), http.StatusInternalServerError)
  224. }
  225. }
  226. }
  227. }))
  228. t.Setenv("OLLAMA_HOST", mockServer.URL)
  229. t.Cleanup(mockServer.Close)
  230. cmd := &cobra.Command{}
  231. cmd.SetContext(context.TODO())
  232. if err := DeleteHandler(cmd, []string{"test-model"}); err != nil {
  233. t.Fatalf("DeleteHandler failed: %v", err)
  234. }
  235. if !stopped {
  236. t.Fatal("Model was not stopped before deletion")
  237. }
  238. err := DeleteHandler(cmd, []string{"test-model-not-found"})
  239. if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") {
  240. t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err)
  241. }
  242. }
  243. func TestGetModelfileName(t *testing.T) {
  244. tests := []struct {
  245. name string
  246. modelfileName string
  247. fileExists bool
  248. expectedName string
  249. expectedErr error
  250. }{
  251. {
  252. name: "no modelfile specified, no modelfile exists",
  253. modelfileName: "",
  254. fileExists: false,
  255. expectedName: "",
  256. expectedErr: os.ErrNotExist,
  257. },
  258. {
  259. name: "no modelfile specified, modelfile exists",
  260. modelfileName: "",
  261. fileExists: true,
  262. expectedName: "Modelfile",
  263. expectedErr: nil,
  264. },
  265. {
  266. name: "modelfile specified, no modelfile exists",
  267. modelfileName: "crazyfile",
  268. fileExists: false,
  269. expectedName: "crazyfile",
  270. expectedErr: os.ErrNotExist,
  271. },
  272. {
  273. name: "modelfile specified, modelfile exists",
  274. modelfileName: "anotherfile",
  275. fileExists: true,
  276. expectedName: "anotherfile",
  277. expectedErr: nil,
  278. },
  279. }
  280. for _, tt := range tests {
  281. t.Run(tt.name, func(t *testing.T) {
  282. cmd := &cobra.Command{
  283. Use: "fakecmd",
  284. }
  285. cmd.Flags().String("file", "", "path to modelfile")
  286. var expectedFilename string
  287. if tt.fileExists {
  288. tempDir, err := os.MkdirTemp("", "modelfiledir")
  289. defer os.RemoveAll(tempDir)
  290. if err != nil {
  291. t.Fatalf("temp modelfile dir creation failed: %v", err)
  292. }
  293. var fn string
  294. if tt.modelfileName != "" {
  295. fn = tt.modelfileName
  296. } else {
  297. fn = "Modelfile"
  298. }
  299. tempFile, err := os.CreateTemp(tempDir, fn)
  300. if err != nil {
  301. t.Fatalf("temp modelfile creation failed: %v", err)
  302. }
  303. expectedFilename = tempFile.Name()
  304. err = cmd.Flags().Set("file", expectedFilename)
  305. if err != nil {
  306. t.Fatalf("couldn't set file flag: %v", err)
  307. }
  308. } else {
  309. if tt.modelfileName != "" {
  310. expectedFilename = tt.modelfileName
  311. err := cmd.Flags().Set("file", tt.modelfileName)
  312. if err != nil {
  313. t.Fatalf("couldn't set file flag: %v", err)
  314. }
  315. }
  316. }
  317. actualFilename, actualErr := getModelfileName(cmd)
  318. if actualFilename != expectedFilename {
  319. t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename)
  320. }
  321. if tt.expectedErr != os.ErrNotExist {
  322. if actualErr != tt.expectedErr {
  323. t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
  324. }
  325. } else {
  326. if !os.IsNotExist(actualErr) {
  327. t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
  328. }
  329. }
  330. })
  331. }
  332. }
  333. func TestPushHandler(t *testing.T) {
  334. tests := []struct {
  335. name string
  336. modelName string
  337. serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
  338. expectedError string
  339. expectedOutput string
  340. }{
  341. {
  342. name: "successful push",
  343. modelName: "test-model",
  344. serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
  345. "/api/push": func(w http.ResponseWriter, r *http.Request) {
  346. if r.Method != http.MethodPost {
  347. t.Errorf("expected POST request, got %s", r.Method)
  348. }
  349. var req api.PushRequest
  350. if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
  351. http.Error(w, err.Error(), http.StatusBadRequest)
  352. return
  353. }
  354. if req.Name != "test-model" {
  355. t.Errorf("expected model name 'test-model', got %s", req.Name)
  356. }
  357. // Simulate progress updates
  358. responses := []api.ProgressResponse{
  359. {Status: "preparing manifest"},
  360. {Digest: "sha256:abc123456789", Total: 100, Completed: 50},
  361. {Digest: "sha256:abc123456789", Total: 100, Completed: 100},
  362. }
  363. for _, resp := range responses {
  364. if err := json.NewEncoder(w).Encode(resp); err != nil {
  365. http.Error(w, err.Error(), http.StatusInternalServerError)
  366. return
  367. }
  368. w.(http.Flusher).Flush()
  369. }
  370. },
  371. },
  372. expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
  373. },
  374. {
  375. name: "unauthorized push",
  376. modelName: "unauthorized-model",
  377. serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
  378. "/api/push": func(w http.ResponseWriter, r *http.Request) {
  379. w.Header().Set("Content-Type", "application/json")
  380. w.WriteHeader(http.StatusUnauthorized)
  381. err := json.NewEncoder(w).Encode(map[string]string{
  382. "error": "access denied",
  383. })
  384. if err != nil {
  385. t.Fatal(err)
  386. }
  387. },
  388. },
  389. expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
  390. },
  391. }
  392. for _, tt := range tests {
  393. t.Run(tt.name, func(t *testing.T) {
  394. mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  395. if handler, ok := tt.serverResponse[r.URL.Path]; ok {
  396. handler(w, r)
  397. return
  398. }
  399. http.Error(w, "not found", http.StatusNotFound)
  400. }))
  401. defer mockServer.Close()
  402. t.Setenv("OLLAMA_HOST", mockServer.URL)
  403. cmd := &cobra.Command{}
  404. cmd.Flags().Bool("insecure", false, "")
  405. cmd.SetContext(context.TODO())
  406. // Redirect stderr to capture progress output
  407. oldStderr := os.Stderr
  408. r, w, _ := os.Pipe()
  409. os.Stderr = w
  410. // Capture stdout for the "Model pushed" message
  411. oldStdout := os.Stdout
  412. outR, outW, _ := os.Pipe()
  413. os.Stdout = outW
  414. err := PushHandler(cmd, []string{tt.modelName})
  415. // Restore stderr
  416. w.Close()
  417. os.Stderr = oldStderr
  418. // drain the pipe
  419. if _, err := io.ReadAll(r); err != nil {
  420. t.Fatal(err)
  421. }
  422. // Restore stdout and get output
  423. outW.Close()
  424. os.Stdout = oldStdout
  425. stdout, _ := io.ReadAll(outR)
  426. if tt.expectedError == "" {
  427. if err != nil {
  428. t.Errorf("expected no error, got %v", err)
  429. }
  430. if tt.expectedOutput != "" {
  431. if got := string(stdout); got != tt.expectedOutput {
  432. t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
  433. }
  434. }
  435. } else {
  436. if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
  437. t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
  438. }
  439. }
  440. })
  441. }
  442. }