cmd_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. package cmd
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "io"
  7. "net/http"
  8. "net/http/httptest"
  9. "os"
  10. "strings"
  11. "testing"
  12. "github.com/google/go-cmp/cmp"
  13. "github.com/spf13/cobra"
  14. "github.com/ollama/ollama/api"
  15. )
  16. func TestShowInfo(t *testing.T) {
  17. t.Run("bare details", func(t *testing.T) {
  18. var b bytes.Buffer
  19. if err := showInfo(&api.ShowResponse{
  20. Details: api.ModelDetails{
  21. Family: "test",
  22. ParameterSize: "7B",
  23. QuantizationLevel: "FP16",
  24. },
  25. }, &b); err != nil {
  26. t.Fatal(err)
  27. }
  28. expect := ` Model
  29. architecture test
  30. parameters 7B
  31. quantization FP16
  32. `
  33. if diff := cmp.Diff(expect, b.String()); diff != "" {
  34. t.Errorf("unexpected output (-want +got):\n%s", diff)
  35. }
  36. })
  37. t.Run("bare model info", func(t *testing.T) {
  38. var b bytes.Buffer
  39. if err := showInfo(&api.ShowResponse{
  40. ModelInfo: map[string]any{
  41. "general.architecture": "test",
  42. "general.parameter_count": float64(7_000_000_000),
  43. "test.context_length": float64(0),
  44. "test.embedding_length": float64(0),
  45. },
  46. Details: api.ModelDetails{
  47. Family: "test",
  48. ParameterSize: "7B",
  49. QuantizationLevel: "FP16",
  50. },
  51. }, &b); err != nil {
  52. t.Fatal(err)
  53. }
  54. expect := ` Model
  55. architecture test
  56. parameters 7B
  57. context length 0
  58. embedding length 0
  59. quantization FP16
  60. `
  61. if diff := cmp.Diff(expect, b.String()); diff != "" {
  62. t.Errorf("unexpected output (-want +got):\n%s", diff)
  63. }
  64. })
  65. t.Run("parameters", func(t *testing.T) {
  66. var b bytes.Buffer
  67. if err := showInfo(&api.ShowResponse{
  68. Details: api.ModelDetails{
  69. Family: "test",
  70. ParameterSize: "7B",
  71. QuantizationLevel: "FP16",
  72. },
  73. Parameters: `
  74. stop never
  75. stop gonna
  76. stop give
  77. stop you
  78. stop up
  79. temperature 99`,
  80. }, &b); err != nil {
  81. t.Fatal(err)
  82. }
  83. expect := ` Model
  84. architecture test
  85. parameters 7B
  86. quantization FP16
  87. Parameters
  88. stop never
  89. stop gonna
  90. stop give
  91. stop you
  92. stop up
  93. temperature 99
  94. `
  95. if diff := cmp.Diff(expect, b.String()); diff != "" {
  96. t.Errorf("unexpected output (-want +got):\n%s", diff)
  97. }
  98. })
  99. t.Run("project info", func(t *testing.T) {
  100. var b bytes.Buffer
  101. if err := showInfo(&api.ShowResponse{
  102. Details: api.ModelDetails{
  103. Family: "test",
  104. ParameterSize: "7B",
  105. QuantizationLevel: "FP16",
  106. },
  107. ProjectorInfo: map[string]any{
  108. "general.architecture": "clip",
  109. "general.parameter_count": float64(133_700_000),
  110. "clip.vision.embedding_length": float64(0),
  111. "clip.vision.projection_dim": float64(0),
  112. },
  113. }, &b); err != nil {
  114. t.Fatal(err)
  115. }
  116. expect := ` Model
  117. architecture test
  118. parameters 7B
  119. quantization FP16
  120. Projector
  121. architecture clip
  122. parameters 133.70M
  123. embedding length 0
  124. dimensions 0
  125. `
  126. if diff := cmp.Diff(expect, b.String()); diff != "" {
  127. t.Errorf("unexpected output (-want +got):\n%s", diff)
  128. }
  129. })
  130. t.Run("system", func(t *testing.T) {
  131. var b bytes.Buffer
  132. if err := showInfo(&api.ShowResponse{
  133. Details: api.ModelDetails{
  134. Family: "test",
  135. ParameterSize: "7B",
  136. QuantizationLevel: "FP16",
  137. },
  138. System: `You are a pirate!
  139. Ahoy, matey!
  140. Weigh anchor!
  141. `,
  142. }, &b); err != nil {
  143. t.Fatal(err)
  144. }
  145. expect := ` Model
  146. architecture test
  147. parameters 7B
  148. quantization FP16
  149. System
  150. You are a pirate!
  151. Ahoy, matey!
  152. `
  153. if diff := cmp.Diff(expect, b.String()); diff != "" {
  154. t.Errorf("unexpected output (-want +got):\n%s", diff)
  155. }
  156. })
  157. t.Run("license", func(t *testing.T) {
  158. var b bytes.Buffer
  159. license := "MIT License\nCopyright (c) Ollama\n"
  160. if err := showInfo(&api.ShowResponse{
  161. Details: api.ModelDetails{
  162. Family: "test",
  163. ParameterSize: "7B",
  164. QuantizationLevel: "FP16",
  165. },
  166. License: license,
  167. }, &b); err != nil {
  168. t.Fatal(err)
  169. }
  170. expect := ` Model
  171. architecture test
  172. parameters 7B
  173. quantization FP16
  174. License
  175. MIT License
  176. Copyright (c) Ollama
  177. `
  178. if diff := cmp.Diff(expect, b.String()); diff != "" {
  179. t.Errorf("unexpected output (-want +got):\n%s", diff)
  180. }
  181. })
  182. }
  183. func TestDeleteHandler(t *testing.T) {
  184. stopped := false
  185. mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  186. if r.URL.Path == "/api/delete" && r.Method == http.MethodDelete {
  187. var req api.DeleteRequest
  188. if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
  189. http.Error(w, err.Error(), http.StatusBadRequest)
  190. return
  191. }
  192. if req.Name == "test-model" {
  193. w.WriteHeader(http.StatusOK)
  194. } else {
  195. w.WriteHeader(http.StatusNotFound)
  196. }
  197. return
  198. }
  199. if r.URL.Path == "/api/generate" && r.Method == http.MethodPost {
  200. var req api.GenerateRequest
  201. if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
  202. http.Error(w, err.Error(), http.StatusBadRequest)
  203. return
  204. }
  205. if req.Model == "test-model" {
  206. w.WriteHeader(http.StatusOK)
  207. if err := json.NewEncoder(w).Encode(api.GenerateResponse{
  208. Done: true,
  209. }); err != nil {
  210. http.Error(w, err.Error(), http.StatusInternalServerError)
  211. }
  212. stopped = true
  213. return
  214. } else {
  215. w.WriteHeader(http.StatusNotFound)
  216. if err := json.NewEncoder(w).Encode(api.GenerateResponse{
  217. Done: false,
  218. }); err != nil {
  219. http.Error(w, err.Error(), http.StatusInternalServerError)
  220. }
  221. }
  222. }
  223. }))
  224. t.Setenv("OLLAMA_HOST", mockServer.URL)
  225. t.Cleanup(mockServer.Close)
  226. cmd := &cobra.Command{}
  227. cmd.SetContext(context.TODO())
  228. if err := DeleteHandler(cmd, []string{"test-model"}); err != nil {
  229. t.Fatalf("DeleteHandler failed: %v", err)
  230. }
  231. if !stopped {
  232. t.Fatal("Model was not stopped before deletion")
  233. }
  234. err := DeleteHandler(cmd, []string{"test-model-not-found"})
  235. if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") {
  236. t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err)
  237. }
  238. }
  239. func TestGetModelfileName(t *testing.T) {
  240. tests := []struct {
  241. name string
  242. modelfileName string
  243. fileExists bool
  244. expectedName string
  245. expectedErr error
  246. }{
  247. {
  248. name: "no modelfile specified, no modelfile exists",
  249. modelfileName: "",
  250. fileExists: false,
  251. expectedName: "",
  252. expectedErr: os.ErrNotExist,
  253. },
  254. {
  255. name: "no modelfile specified, modelfile exists",
  256. modelfileName: "",
  257. fileExists: true,
  258. expectedName: "Modelfile",
  259. expectedErr: nil,
  260. },
  261. {
  262. name: "modelfile specified, no modelfile exists",
  263. modelfileName: "crazyfile",
  264. fileExists: false,
  265. expectedName: "crazyfile",
  266. expectedErr: os.ErrNotExist,
  267. },
  268. {
  269. name: "modelfile specified, modelfile exists",
  270. modelfileName: "anotherfile",
  271. fileExists: true,
  272. expectedName: "anotherfile",
  273. expectedErr: nil,
  274. },
  275. }
  276. for _, tt := range tests {
  277. t.Run(tt.name, func(t *testing.T) {
  278. cmd := &cobra.Command{
  279. Use: "fakecmd",
  280. }
  281. cmd.Flags().String("file", "", "path to modelfile")
  282. var expectedFilename string
  283. if tt.fileExists {
  284. tempDir, err := os.MkdirTemp("", "modelfiledir")
  285. defer os.RemoveAll(tempDir)
  286. if err != nil {
  287. t.Fatalf("temp modelfile dir creation failed: %v", err)
  288. }
  289. var fn string
  290. if tt.modelfileName != "" {
  291. fn = tt.modelfileName
  292. } else {
  293. fn = "Modelfile"
  294. }
  295. tempFile, err := os.CreateTemp(tempDir, fn)
  296. if err != nil {
  297. t.Fatalf("temp modelfile creation failed: %v", err)
  298. }
  299. expectedFilename = tempFile.Name()
  300. err = cmd.Flags().Set("file", expectedFilename)
  301. if err != nil {
  302. t.Fatalf("couldn't set file flag: %v", err)
  303. }
  304. } else {
  305. if tt.modelfileName != "" {
  306. expectedFilename = tt.modelfileName
  307. err := cmd.Flags().Set("file", tt.modelfileName)
  308. if err != nil {
  309. t.Fatalf("couldn't set file flag: %v", err)
  310. }
  311. }
  312. }
  313. actualFilename, actualErr := getModelfileName(cmd)
  314. if actualFilename != expectedFilename {
  315. t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename)
  316. }
  317. if tt.expectedErr != os.ErrNotExist {
  318. if actualErr != tt.expectedErr {
  319. t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
  320. }
  321. } else {
  322. if !os.IsNotExist(actualErr) {
  323. t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
  324. }
  325. }
  326. })
  327. }
  328. }
  329. func TestPushHandler(t *testing.T) {
  330. tests := []struct {
  331. name string
  332. modelName string
  333. serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
  334. expectedError string
  335. expectedOutput string
  336. }{
  337. {
  338. name: "successful push",
  339. modelName: "test-model",
  340. serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
  341. "/api/push": func(w http.ResponseWriter, r *http.Request) {
  342. if r.Method != http.MethodPost {
  343. t.Errorf("expected POST request, got %s", r.Method)
  344. }
  345. var req api.PushRequest
  346. if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
  347. http.Error(w, err.Error(), http.StatusBadRequest)
  348. return
  349. }
  350. if req.Name != "test-model" {
  351. t.Errorf("expected model name 'test-model', got %s", req.Name)
  352. }
  353. // Simulate progress updates
  354. responses := []api.ProgressResponse{
  355. {Status: "preparing manifest"},
  356. {Digest: "sha256:abc123456789", Total: 100, Completed: 50},
  357. {Digest: "sha256:abc123456789", Total: 100, Completed: 100},
  358. }
  359. for _, resp := range responses {
  360. if err := json.NewEncoder(w).Encode(resp); err != nil {
  361. http.Error(w, err.Error(), http.StatusInternalServerError)
  362. return
  363. }
  364. w.(http.Flusher).Flush()
  365. }
  366. },
  367. },
  368. expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
  369. },
  370. {
  371. name: "unauthorized push",
  372. modelName: "unauthorized-model",
  373. serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
  374. "/api/push": func(w http.ResponseWriter, r *http.Request) {
  375. w.Header().Set("Content-Type", "application/json")
  376. w.WriteHeader(http.StatusUnauthorized)
  377. err := json.NewEncoder(w).Encode(map[string]string{
  378. "error": "access denied",
  379. })
  380. if err != nil {
  381. t.Fatal(err)
  382. }
  383. },
  384. },
  385. expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
  386. },
  387. }
  388. for _, tt := range tests {
  389. t.Run(tt.name, func(t *testing.T) {
  390. mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  391. if handler, ok := tt.serverResponse[r.URL.Path]; ok {
  392. handler(w, r)
  393. return
  394. }
  395. http.Error(w, "not found", http.StatusNotFound)
  396. }))
  397. defer mockServer.Close()
  398. t.Setenv("OLLAMA_HOST", mockServer.URL)
  399. cmd := &cobra.Command{}
  400. cmd.Flags().Bool("insecure", false, "")
  401. cmd.SetContext(context.TODO())
  402. // Redirect stderr to capture progress output
  403. oldStderr := os.Stderr
  404. r, w, _ := os.Pipe()
  405. os.Stderr = w
  406. // Capture stdout for the "Model pushed" message
  407. oldStdout := os.Stdout
  408. outR, outW, _ := os.Pipe()
  409. os.Stdout = outW
  410. err := PushHandler(cmd, []string{tt.modelName})
  411. // Restore stderr
  412. w.Close()
  413. os.Stderr = oldStderr
  414. // drain the pipe
  415. if _, err := io.ReadAll(r); err != nil {
  416. t.Fatal(err)
  417. }
  418. // Restore stdout and get output
  419. outW.Close()
  420. os.Stdout = oldStdout
  421. stdout, _ := io.ReadAll(outR)
  422. if tt.expectedError == "" {
  423. if err != nil {
  424. t.Errorf("expected no error, got %v", err)
  425. }
  426. if tt.expectedOutput != "" {
  427. if got := string(stdout); got != tt.expectedOutput {
  428. t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
  429. }
  430. }
  431. } else {
  432. if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
  433. t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
  434. }
  435. }
  436. })
  437. }
  438. }