routes_generate_test.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/google/go-cmp/cmp"
  14. "github.com/ollama/ollama/api"
  15. "github.com/ollama/ollama/gpu"
  16. "github.com/ollama/ollama/llm"
  17. )
  18. type mockRunner struct {
  19. llm.LlamaServer
  20. // CompletionRequest is only valid until the next call to Completion
  21. llm.CompletionRequest
  22. llm.CompletionResponse
  23. }
  24. func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
  25. m.CompletionRequest = r
  26. fn(m.CompletionResponse)
  27. return nil
  28. }
  29. func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
  30. for range strings.Fields(s) {
  31. tokens = append(tokens, len(tokens))
  32. }
  33. return
  34. }
  35. func newMockServer(mock *mockRunner) func(gpu.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
  36. return func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
  37. return mock, nil
  38. }
  39. }
  40. func TestGenerateChat(t *testing.T) {
  41. gin.SetMode(gin.TestMode)
  42. mock := mockRunner{
  43. CompletionResponse: llm.CompletionResponse{
  44. Done: true,
  45. DoneReason: "stop",
  46. PromptEvalCount: 1,
  47. PromptEvalDuration: 1,
  48. EvalCount: 1,
  49. EvalDuration: 1,
  50. },
  51. }
  52. s := Server{
  53. sched: &Scheduler{
  54. pendingReqCh: make(chan *LlmRequest, 1),
  55. finishedReqCh: make(chan *LlmRequest, 1),
  56. expiredCh: make(chan *runnerRef, 1),
  57. unloadedCh: make(chan any, 1),
  58. loaded: make(map[string]*runnerRef),
  59. newServerFn: newMockServer(&mock),
  60. getGpuFn: gpu.GetGPUInfo,
  61. getCpuFn: gpu.GetCPUInfo,
  62. reschedDelay: 250 * time.Millisecond,
  63. loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
  64. // add small delay to simulate loading
  65. time.Sleep(time.Millisecond)
  66. req.successCh <- &runnerRef{
  67. llama: &mock,
  68. }
  69. },
  70. },
  71. }
  72. go s.sched.Run(context.TODO())
  73. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  74. Model: "test",
  75. Modelfile: fmt.Sprintf(`FROM %s
  76. TEMPLATE """
  77. {{- if .System }}System: {{ .System }} {{ end }}
  78. {{- if .Prompt }}User: {{ .Prompt }} {{ end }}
  79. {{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
  80. `, createBinFile(t, llm.KV{
  81. "general.architecture": "llama",
  82. "llama.block_count": uint32(1),
  83. "llama.context_length": uint32(8192),
  84. "llama.embedding_length": uint32(4096),
  85. "llama.attention.head_count": uint32(32),
  86. "llama.attention.head_count_kv": uint32(8),
  87. "tokenizer.ggml.tokens": []string{""},
  88. "tokenizer.ggml.scores": []float32{0},
  89. "tokenizer.ggml.token_type": []int32{0},
  90. }, []llm.Tensor{
  91. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  92. {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  93. {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  94. {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  95. {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  96. {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  97. {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  98. {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  99. {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  100. {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  101. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  102. })),
  103. Stream: &stream,
  104. })
  105. if w.Code != http.StatusOK {
  106. t.Fatalf("expected status 200, got %d", w.Code)
  107. }
  108. t.Run("missing body", func(t *testing.T) {
  109. w := createRequest(t, s.ChatHandler, nil)
  110. if w.Code != http.StatusBadRequest {
  111. t.Errorf("expected status 400, got %d", w.Code)
  112. }
  113. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  114. t.Errorf("mismatch (-got +want):\n%s", diff)
  115. }
  116. })
  117. t.Run("missing model", func(t *testing.T) {
  118. w := createRequest(t, s.ChatHandler, api.ChatRequest{})
  119. if w.Code != http.StatusBadRequest {
  120. t.Errorf("expected status 400, got %d", w.Code)
  121. }
  122. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  123. t.Errorf("mismatch (-got +want):\n%s", diff)
  124. }
  125. })
  126. t.Run("missing capabilities chat", func(t *testing.T) {
  127. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  128. Model: "bert",
  129. Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
  130. "general.architecture": "bert",
  131. "bert.pooling_type": uint32(0),
  132. }, []llm.Tensor{})),
  133. Stream: &stream,
  134. })
  135. if w.Code != http.StatusOK {
  136. t.Fatalf("expected status 200, got %d", w.Code)
  137. }
  138. w = createRequest(t, s.ChatHandler, api.ChatRequest{
  139. Model: "bert",
  140. })
  141. if w.Code != http.StatusBadRequest {
  142. t.Errorf("expected status 400, got %d", w.Code)
  143. }
  144. if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
  145. t.Errorf("mismatch (-got +want):\n%s", diff)
  146. }
  147. })
  148. t.Run("load model", func(t *testing.T) {
  149. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  150. Model: "test",
  151. })
  152. if w.Code != http.StatusOK {
  153. t.Errorf("expected status 200, got %d", w.Code)
  154. }
  155. var actual api.ChatResponse
  156. if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
  157. t.Fatal(err)
  158. }
  159. if actual.Model != "test" {
  160. t.Errorf("expected model test, got %s", actual.Model)
  161. }
  162. if !actual.Done {
  163. t.Errorf("expected done true, got false")
  164. }
  165. if actual.DoneReason != "load" {
  166. t.Errorf("expected done reason load, got %s", actual.DoneReason)
  167. }
  168. })
  169. checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
  170. t.Helper()
  171. var actual api.ChatResponse
  172. if err := json.NewDecoder(body).Decode(&actual); err != nil {
  173. t.Fatal(err)
  174. }
  175. if actual.Model != model {
  176. t.Errorf("expected model test, got %s", actual.Model)
  177. }
  178. if !actual.Done {
  179. t.Errorf("expected done false, got true")
  180. }
  181. if actual.DoneReason != "stop" {
  182. t.Errorf("expected done reason stop, got %s", actual.DoneReason)
  183. }
  184. if diff := cmp.Diff(actual.Message, api.Message{
  185. Role: "assistant",
  186. Content: content,
  187. }); diff != "" {
  188. t.Errorf("mismatch (-got +want):\n%s", diff)
  189. }
  190. if actual.PromptEvalCount == 0 {
  191. t.Errorf("expected prompt eval count > 0, got 0")
  192. }
  193. if actual.PromptEvalDuration == 0 {
  194. t.Errorf("expected prompt eval duration > 0, got 0")
  195. }
  196. if actual.EvalCount == 0 {
  197. t.Errorf("expected eval count > 0, got 0")
  198. }
  199. if actual.EvalDuration == 0 {
  200. t.Errorf("expected eval duration > 0, got 0")
  201. }
  202. if actual.LoadDuration == 0 {
  203. t.Errorf("expected load duration > 0, got 0")
  204. }
  205. if actual.TotalDuration == 0 {
  206. t.Errorf("expected total duration > 0, got 0")
  207. }
  208. }
  209. mock.CompletionResponse.Content = "Hi!"
  210. t.Run("messages", func(t *testing.T) {
  211. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  212. Model: "test",
  213. Messages: []api.Message{
  214. {Role: "user", Content: "Hello!"},
  215. },
  216. Stream: &stream,
  217. })
  218. if w.Code != http.StatusOK {
  219. t.Errorf("expected status 200, got %d", w.Code)
  220. }
  221. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
  222. t.Errorf("mismatch (-got +want):\n%s", diff)
  223. }
  224. checkChatResponse(t, w.Body, "test", "Hi!")
  225. })
  226. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  227. Model: "test-system",
  228. Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
  229. })
  230. if w.Code != http.StatusOK {
  231. t.Fatalf("expected status 200, got %d", w.Code)
  232. }
  233. t.Run("messages with model system", func(t *testing.T) {
  234. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  235. Model: "test-system",
  236. Messages: []api.Message{
  237. {Role: "user", Content: "Hello!"},
  238. },
  239. Stream: &stream,
  240. })
  241. if w.Code != http.StatusOK {
  242. t.Errorf("expected status 200, got %d", w.Code)
  243. }
  244. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
  245. t.Errorf("mismatch (-got +want):\n%s", diff)
  246. }
  247. checkChatResponse(t, w.Body, "test-system", "Hi!")
  248. })
  249. mock.CompletionResponse.Content = "Abra kadabra!"
  250. t.Run("messages with system", func(t *testing.T) {
  251. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  252. Model: "test-system",
  253. Messages: []api.Message{
  254. {Role: "system", Content: "You can perform magic tricks."},
  255. {Role: "user", Content: "Hello!"},
  256. },
  257. Stream: &stream,
  258. })
  259. if w.Code != http.StatusOK {
  260. t.Errorf("expected status 200, got %d", w.Code)
  261. }
  262. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
  263. t.Errorf("mismatch (-got +want):\n%s", diff)
  264. }
  265. checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
  266. })
  267. t.Run("messages with interleaved system", func(t *testing.T) {
  268. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  269. Model: "test-system",
  270. Messages: []api.Message{
  271. {Role: "user", Content: "Hello!"},
  272. {Role: "assistant", Content: "I can help you with that."},
  273. {Role: "system", Content: "You can perform magic tricks."},
  274. {Role: "user", Content: "Help me write tests."},
  275. },
  276. Stream: &stream,
  277. })
  278. if w.Code != http.StatusOK {
  279. t.Errorf("expected status 200, got %d", w.Code)
  280. }
  281. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
  282. t.Errorf("mismatch (-got +want):\n%s", diff)
  283. }
  284. checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
  285. })
  286. }
  287. func TestGenerate(t *testing.T) {
  288. gin.SetMode(gin.TestMode)
  289. mock := mockRunner{
  290. CompletionResponse: llm.CompletionResponse{
  291. Done: true,
  292. DoneReason: "stop",
  293. PromptEvalCount: 1,
  294. PromptEvalDuration: 1,
  295. EvalCount: 1,
  296. EvalDuration: 1,
  297. },
  298. }
  299. s := Server{
  300. sched: &Scheduler{
  301. pendingReqCh: make(chan *LlmRequest, 1),
  302. finishedReqCh: make(chan *LlmRequest, 1),
  303. expiredCh: make(chan *runnerRef, 1),
  304. unloadedCh: make(chan any, 1),
  305. loaded: make(map[string]*runnerRef),
  306. newServerFn: newMockServer(&mock),
  307. getGpuFn: gpu.GetGPUInfo,
  308. getCpuFn: gpu.GetCPUInfo,
  309. reschedDelay: 250 * time.Millisecond,
  310. loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
  311. // add small delay to simulate loading
  312. time.Sleep(time.Millisecond)
  313. req.successCh <- &runnerRef{
  314. llama: &mock,
  315. }
  316. },
  317. },
  318. }
  319. go s.sched.Run(context.TODO())
  320. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  321. Model: "test",
  322. Modelfile: fmt.Sprintf(`FROM %s
  323. TEMPLATE """
  324. {{- if .System }}System: {{ .System }} {{ end }}
  325. {{- if .Prompt }}User: {{ .Prompt }} {{ end }}
  326. {{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
  327. `, createBinFile(t, llm.KV{
  328. "general.architecture": "llama",
  329. "llama.block_count": uint32(1),
  330. "llama.context_length": uint32(8192),
  331. "llama.embedding_length": uint32(4096),
  332. "llama.attention.head_count": uint32(32),
  333. "llama.attention.head_count_kv": uint32(8),
  334. "tokenizer.ggml.tokens": []string{""},
  335. "tokenizer.ggml.scores": []float32{0},
  336. "tokenizer.ggml.token_type": []int32{0},
  337. }, []llm.Tensor{
  338. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  339. {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  340. {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  341. {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  342. {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  343. {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  344. {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  345. {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  346. {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  347. {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  348. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  349. })),
  350. Stream: &stream,
  351. })
  352. if w.Code != http.StatusOK {
  353. t.Fatalf("expected status 200, got %d", w.Code)
  354. }
  355. t.Run("missing body", func(t *testing.T) {
  356. w := createRequest(t, s.GenerateHandler, nil)
  357. if w.Code != http.StatusBadRequest {
  358. t.Errorf("expected status 400, got %d", w.Code)
  359. }
  360. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  361. t.Errorf("mismatch (-got +want):\n%s", diff)
  362. }
  363. })
  364. t.Run("missing model", func(t *testing.T) {
  365. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
  366. if w.Code != http.StatusBadRequest {
  367. t.Errorf("expected status 400, got %d", w.Code)
  368. }
  369. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  370. t.Errorf("mismatch (-got +want):\n%s", diff)
  371. }
  372. })
  373. t.Run("missing capabilities generate", func(t *testing.T) {
  374. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  375. Model: "bert",
  376. Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
  377. "general.architecture": "bert",
  378. "bert.pooling_type": uint32(0),
  379. }, []llm.Tensor{})),
  380. Stream: &stream,
  381. })
  382. if w.Code != http.StatusOK {
  383. t.Fatalf("expected status 200, got %d", w.Code)
  384. }
  385. w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
  386. Model: "bert",
  387. })
  388. if w.Code != http.StatusBadRequest {
  389. t.Errorf("expected status 400, got %d", w.Code)
  390. }
  391. if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
  392. t.Errorf("mismatch (-got +want):\n%s", diff)
  393. }
  394. })
  395. t.Run("missing capabilities suffix", func(t *testing.T) {
  396. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  397. Model: "test",
  398. Prompt: "def add(",
  399. Suffix: " return c",
  400. })
  401. if w.Code != http.StatusBadRequest {
  402. t.Errorf("expected status 400, got %d", w.Code)
  403. }
  404. if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
  405. t.Errorf("mismatch (-got +want):\n%s", diff)
  406. }
  407. })
  408. t.Run("load model", func(t *testing.T) {
  409. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  410. Model: "test",
  411. })
  412. if w.Code != http.StatusOK {
  413. t.Errorf("expected status 200, got %d", w.Code)
  414. }
  415. var actual api.GenerateResponse
  416. if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
  417. t.Fatal(err)
  418. }
  419. if actual.Model != "test" {
  420. t.Errorf("expected model test, got %s", actual.Model)
  421. }
  422. if !actual.Done {
  423. t.Errorf("expected done true, got false")
  424. }
  425. if actual.DoneReason != "load" {
  426. t.Errorf("expected done reason load, got %s", actual.DoneReason)
  427. }
  428. })
  429. checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
  430. t.Helper()
  431. var actual api.GenerateResponse
  432. if err := json.NewDecoder(body).Decode(&actual); err != nil {
  433. t.Fatal(err)
  434. }
  435. if actual.Model != model {
  436. t.Errorf("expected model test, got %s", actual.Model)
  437. }
  438. if !actual.Done {
  439. t.Errorf("expected done false, got true")
  440. }
  441. if actual.DoneReason != "stop" {
  442. t.Errorf("expected done reason stop, got %s", actual.DoneReason)
  443. }
  444. if actual.Response != content {
  445. t.Errorf("expected response %s, got %s", content, actual.Response)
  446. }
  447. if actual.Context == nil {
  448. t.Errorf("expected context not nil")
  449. }
  450. if actual.PromptEvalCount == 0 {
  451. t.Errorf("expected prompt eval count > 0, got 0")
  452. }
  453. if actual.PromptEvalDuration == 0 {
  454. t.Errorf("expected prompt eval duration > 0, got 0")
  455. }
  456. if actual.EvalCount == 0 {
  457. t.Errorf("expected eval count > 0, got 0")
  458. }
  459. if actual.EvalDuration == 0 {
  460. t.Errorf("expected eval duration > 0, got 0")
  461. }
  462. if actual.LoadDuration == 0 {
  463. t.Errorf("expected load duration > 0, got 0")
  464. }
  465. if actual.TotalDuration == 0 {
  466. t.Errorf("expected total duration > 0, got 0")
  467. }
  468. }
  469. mock.CompletionResponse.Content = "Hi!"
  470. t.Run("prompt", func(t *testing.T) {
  471. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  472. Model: "test",
  473. Prompt: "Hello!",
  474. Stream: &stream,
  475. })
  476. if w.Code != http.StatusOK {
  477. t.Errorf("expected status 200, got %d", w.Code)
  478. }
  479. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
  480. t.Errorf("mismatch (-got +want):\n%s", diff)
  481. }
  482. checkGenerateResponse(t, w.Body, "test", "Hi!")
  483. })
  484. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  485. Model: "test-system",
  486. Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
  487. })
  488. if w.Code != http.StatusOK {
  489. t.Fatalf("expected status 200, got %d", w.Code)
  490. }
  491. t.Run("prompt with model system", func(t *testing.T) {
  492. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  493. Model: "test-system",
  494. Prompt: "Hello!",
  495. Stream: &stream,
  496. })
  497. if w.Code != http.StatusOK {
  498. t.Errorf("expected status 200, got %d", w.Code)
  499. }
  500. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
  501. t.Errorf("mismatch (-got +want):\n%s", diff)
  502. }
  503. checkGenerateResponse(t, w.Body, "test-system", "Hi!")
  504. })
  505. mock.CompletionResponse.Content = "Abra kadabra!"
  506. t.Run("prompt with system", func(t *testing.T) {
  507. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  508. Model: "test-system",
  509. Prompt: "Hello!",
  510. System: "You can perform magic tricks.",
  511. Stream: &stream,
  512. })
  513. if w.Code != http.StatusOK {
  514. t.Errorf("expected status 200, got %d", w.Code)
  515. }
  516. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
  517. t.Errorf("mismatch (-got +want):\n%s", diff)
  518. }
  519. checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
  520. })
  521. t.Run("prompt with template", func(t *testing.T) {
  522. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  523. Model: "test-system",
  524. Prompt: "Help me write tests.",
  525. System: "You can perform magic tricks.",
  526. Template: `{{- if .System }}{{ .System }} {{ end }}
  527. {{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
  528. {{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
  529. Stream: &stream,
  530. })
  531. if w.Code != http.StatusOK {
  532. t.Errorf("expected status 200, got %d", w.Code)
  533. }
  534. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
  535. t.Errorf("mismatch (-got +want):\n%s", diff)
  536. }
  537. checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
  538. })
  539. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  540. Model: "test-suffix",
  541. Modelfile: `FROM test
  542. TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
  543. {{- else }}{{ .Prompt }}
  544. {{- end }}"""`,
  545. })
  546. if w.Code != http.StatusOK {
  547. t.Fatalf("expected status 200, got %d", w.Code)
  548. }
  549. t.Run("prompt with suffix", func(t *testing.T) {
  550. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  551. Model: "test-suffix",
  552. Prompt: "def add(",
  553. Suffix: " return c",
  554. })
  555. if w.Code != http.StatusOK {
  556. t.Errorf("expected status 200, got %d", w.Code)
  557. }
  558. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
  559. t.Errorf("mismatch (-got +want):\n%s", diff)
  560. }
  561. })
  562. t.Run("prompt without suffix", func(t *testing.T) {
  563. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  564. Model: "test-suffix",
  565. Prompt: "def add(",
  566. })
  567. if w.Code != http.StatusOK {
  568. t.Errorf("expected status 200, got %d", w.Code)
  569. }
  570. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
  571. t.Errorf("mismatch (-got +want):\n%s", diff)
  572. }
  573. })
  574. t.Run("raw", func(t *testing.T) {
  575. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  576. Model: "test-system",
  577. Prompt: "Help me write tests.",
  578. Raw: true,
  579. Stream: &stream,
  580. })
  581. if w.Code != http.StatusOK {
  582. t.Errorf("expected status 200, got %d", w.Code)
  583. }
  584. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
  585. t.Errorf("mismatch (-got +want):\n%s", diff)
  586. }
  587. })
  588. }