routes_generate_test.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/google/go-cmp/cmp"
  14. "github.com/ollama/ollama/api"
  15. "github.com/ollama/ollama/gpu"
  16. "github.com/ollama/ollama/llm"
  17. )
  18. type mockRunner struct {
  19. llm.LlamaServer
  20. // CompletionRequest is only valid until the next call to Completion
  21. llm.CompletionRequest
  22. llm.CompletionResponse
  23. }
  24. func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
  25. m.CompletionRequest = r
  26. fn(m.CompletionResponse)
  27. return nil
  28. }
  29. func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
  30. for range strings.Fields(s) {
  31. tokens = append(tokens, len(tokens))
  32. }
  33. return
  34. }
  35. func newMockServer(mock *mockRunner) func(gpu.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
  36. return func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
  37. return mock, nil
  38. }
  39. }
  40. func TestGenerateChat(t *testing.T) {
  41. gin.SetMode(gin.TestMode)
  42. mock := mockRunner{
  43. CompletionResponse: llm.CompletionResponse{
  44. Done: true,
  45. DoneReason: "stop",
  46. PromptEvalCount: 1,
  47. PromptEvalDuration: 1,
  48. EvalCount: 1,
  49. EvalDuration: 1,
  50. },
  51. }
  52. s := Server{
  53. sched: &Scheduler{
  54. pendingReqCh: make(chan *LlmRequest, 1),
  55. finishedReqCh: make(chan *LlmRequest, 1),
  56. expiredCh: make(chan *runnerRef, 1),
  57. unloadedCh: make(chan any, 1),
  58. loaded: make(map[string]*runnerRef),
  59. newServerFn: newMockServer(&mock),
  60. getGpuFn: gpu.GetGPUInfo,
  61. getCpuFn: gpu.GetCPUInfo,
  62. reschedDelay: 250 * time.Millisecond,
  63. loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
  64. req.successCh <- &runnerRef{
  65. llama: &mock,
  66. }
  67. },
  68. },
  69. }
  70. go s.sched.Run(context.TODO())
  71. w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
  72. Name: "test",
  73. Modelfile: fmt.Sprintf(`FROM %s
  74. TEMPLATE """
  75. {{- if .System }}System: {{ .System }} {{ end }}
  76. {{- if .Prompt }}User: {{ .Prompt }} {{ end }}
  77. {{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
  78. `, createBinFile(t, llm.KV{
  79. "general.architecture": "llama",
  80. "llama.block_count": uint32(1),
  81. "llama.context_length": uint32(8192),
  82. "llama.embedding_length": uint32(4096),
  83. "llama.attention.head_count": uint32(32),
  84. "llama.attention.head_count_kv": uint32(8),
  85. "tokenizer.ggml.tokens": []string{""},
  86. "tokenizer.ggml.scores": []float32{0},
  87. "tokenizer.ggml.token_type": []int32{0},
  88. }, []llm.Tensor{
  89. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  90. {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  91. {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  92. {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  93. {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  94. {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  95. {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  96. {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  97. {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  98. {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  99. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  100. })),
  101. Stream: &stream,
  102. })
  103. if w.Code != http.StatusOK {
  104. t.Fatalf("expected status 200, got %d", w.Code)
  105. }
  106. t.Run("missing body", func(t *testing.T) {
  107. w := createRequest(t, s.ChatHandler, nil)
  108. if w.Code != http.StatusBadRequest {
  109. t.Errorf("expected status 400, got %d", w.Code)
  110. }
  111. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  112. t.Errorf("mismatch (-got +want):\n%s", diff)
  113. }
  114. })
  115. t.Run("missing model", func(t *testing.T) {
  116. w := createRequest(t, s.ChatHandler, api.ChatRequest{})
  117. if w.Code != http.StatusBadRequest {
  118. t.Errorf("expected status 400, got %d", w.Code)
  119. }
  120. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  121. t.Errorf("mismatch (-got +want):\n%s", diff)
  122. }
  123. })
  124. t.Run("missing capabilities", func(t *testing.T) {
  125. w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
  126. Name: "bert",
  127. Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
  128. "general.architecture": "bert",
  129. "bert.pooling_type": uint32(0),
  130. }, []llm.Tensor{})),
  131. Stream: &stream,
  132. })
  133. if w.Code != http.StatusOK {
  134. t.Fatalf("expected status 200, got %d", w.Code)
  135. }
  136. w = createRequest(t, s.ChatHandler, api.ChatRequest{
  137. Model: "bert",
  138. })
  139. if w.Code != http.StatusBadRequest {
  140. t.Errorf("expected status 400, got %d", w.Code)
  141. }
  142. if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
  143. t.Errorf("mismatch (-got +want):\n%s", diff)
  144. }
  145. })
  146. t.Run("load model", func(t *testing.T) {
  147. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  148. Model: "test",
  149. })
  150. if w.Code != http.StatusOK {
  151. t.Errorf("expected status 200, got %d", w.Code)
  152. }
  153. var actual api.ChatResponse
  154. if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
  155. t.Fatal(err)
  156. }
  157. if actual.Model != "test" {
  158. t.Errorf("expected model test, got %s", actual.Model)
  159. }
  160. if !actual.Done {
  161. t.Errorf("expected done true, got false")
  162. }
  163. if actual.DoneReason != "load" {
  164. t.Errorf("expected done reason load, got %s", actual.DoneReason)
  165. }
  166. })
  167. checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
  168. t.Helper()
  169. var actual api.ChatResponse
  170. if err := json.NewDecoder(body).Decode(&actual); err != nil {
  171. t.Fatal(err)
  172. }
  173. if actual.Model != model {
  174. t.Errorf("expected model test, got %s", actual.Model)
  175. }
  176. if !actual.Done {
  177. t.Errorf("expected done false, got true")
  178. }
  179. if actual.DoneReason != "stop" {
  180. t.Errorf("expected done reason stop, got %s", actual.DoneReason)
  181. }
  182. if diff := cmp.Diff(actual.Message, api.Message{
  183. Role: "assistant",
  184. Content: content,
  185. }); diff != "" {
  186. t.Errorf("mismatch (-got +want):\n%s", diff)
  187. }
  188. if actual.PromptEvalCount == 0 {
  189. t.Errorf("expected prompt eval count > 0, got 0")
  190. }
  191. if actual.PromptEvalDuration == 0 {
  192. t.Errorf("expected prompt eval duration > 0, got 0")
  193. }
  194. if actual.EvalCount == 0 {
  195. t.Errorf("expected eval count > 0, got 0")
  196. }
  197. if actual.EvalDuration == 0 {
  198. t.Errorf("expected eval duration > 0, got 0")
  199. }
  200. if actual.LoadDuration == 0 {
  201. t.Errorf("expected load duration > 0, got 0")
  202. }
  203. if actual.TotalDuration == 0 {
  204. t.Errorf("expected load duration > 0, got 0")
  205. }
  206. }
  207. mock.CompletionResponse.Content = "Hi!"
  208. t.Run("messages", func(t *testing.T) {
  209. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  210. Model: "test",
  211. Messages: []api.Message{
  212. {Role: "user", Content: "Hello!"},
  213. },
  214. Stream: &stream,
  215. })
  216. if w.Code != http.StatusOK {
  217. t.Errorf("expected status 200, got %d", w.Code)
  218. }
  219. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
  220. t.Errorf("mismatch (-got +want):\n%s", diff)
  221. }
  222. checkChatResponse(t, w.Body, "test", "Hi!")
  223. })
  224. w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
  225. Model: "test-system",
  226. Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
  227. })
  228. if w.Code != http.StatusOK {
  229. t.Fatalf("expected status 200, got %d", w.Code)
  230. }
  231. t.Run("messages with model system", func(t *testing.T) {
  232. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  233. Model: "test-system",
  234. Messages: []api.Message{
  235. {Role: "user", Content: "Hello!"},
  236. },
  237. Stream: &stream,
  238. })
  239. if w.Code != http.StatusOK {
  240. t.Errorf("expected status 200, got %d", w.Code)
  241. }
  242. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
  243. t.Errorf("mismatch (-got +want):\n%s", diff)
  244. }
  245. checkChatResponse(t, w.Body, "test-system", "Hi!")
  246. })
  247. mock.CompletionResponse.Content = "Abra kadabra!"
  248. t.Run("messages with system", func(t *testing.T) {
  249. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  250. Model: "test-system",
  251. Messages: []api.Message{
  252. {Role: "system", Content: "You can perform magic tricks."},
  253. {Role: "user", Content: "Hello!"},
  254. },
  255. Stream: &stream,
  256. })
  257. if w.Code != http.StatusOK {
  258. t.Errorf("expected status 200, got %d", w.Code)
  259. }
  260. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
  261. t.Errorf("mismatch (-got +want):\n%s", diff)
  262. }
  263. checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
  264. })
  265. t.Run("messages with interleaved system", func(t *testing.T) {
  266. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  267. Model: "test-system",
  268. Messages: []api.Message{
  269. {Role: "user", Content: "Hello!"},
  270. {Role: "assistant", Content: "I can help you with that."},
  271. {Role: "system", Content: "You can perform magic tricks."},
  272. {Role: "user", Content: "Help me write tests."},
  273. },
  274. Stream: &stream,
  275. })
  276. if w.Code != http.StatusOK {
  277. t.Errorf("expected status 200, got %d", w.Code)
  278. }
  279. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
  280. t.Errorf("mismatch (-got +want):\n%s", diff)
  281. }
  282. checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
  283. })
  284. }
  285. func TestGenerate(t *testing.T) {
  286. gin.SetMode(gin.TestMode)
  287. mock := mockRunner{
  288. CompletionResponse: llm.CompletionResponse{
  289. Done: true,
  290. DoneReason: "stop",
  291. PromptEvalCount: 1,
  292. PromptEvalDuration: 1,
  293. EvalCount: 1,
  294. EvalDuration: 1,
  295. },
  296. }
  297. s := Server{
  298. sched: &Scheduler{
  299. pendingReqCh: make(chan *LlmRequest, 1),
  300. finishedReqCh: make(chan *LlmRequest, 1),
  301. expiredCh: make(chan *runnerRef, 1),
  302. unloadedCh: make(chan any, 1),
  303. loaded: make(map[string]*runnerRef),
  304. newServerFn: newMockServer(&mock),
  305. getGpuFn: gpu.GetGPUInfo,
  306. getCpuFn: gpu.GetCPUInfo,
  307. reschedDelay: 250 * time.Millisecond,
  308. loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
  309. req.successCh <- &runnerRef{
  310. llama: &mock,
  311. }
  312. },
  313. },
  314. }
  315. go s.sched.Run(context.TODO())
  316. w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
  317. Name: "test",
  318. Modelfile: fmt.Sprintf(`FROM %s
  319. TEMPLATE """
  320. {{- if .System }}System: {{ .System }} {{ end }}
  321. {{- if .Prompt }}User: {{ .Prompt }} {{ end }}
  322. {{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
  323. `, createBinFile(t, llm.KV{
  324. "general.architecture": "llama",
  325. "llama.block_count": uint32(1),
  326. "llama.context_length": uint32(8192),
  327. "llama.embedding_length": uint32(4096),
  328. "llama.attention.head_count": uint32(32),
  329. "llama.attention.head_count_kv": uint32(8),
  330. "tokenizer.ggml.tokens": []string{""},
  331. "tokenizer.ggml.scores": []float32{0},
  332. "tokenizer.ggml.token_type": []int32{0},
  333. }, []llm.Tensor{
  334. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  335. {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  336. {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  337. {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  338. {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  339. {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  340. {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  341. {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  342. {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  343. {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  344. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  345. })),
  346. Stream: &stream,
  347. })
  348. if w.Code != http.StatusOK {
  349. t.Fatalf("expected status 200, got %d", w.Code)
  350. }
  351. t.Run("missing body", func(t *testing.T) {
  352. w := createRequest(t, s.GenerateHandler, nil)
  353. if w.Code != http.StatusBadRequest {
  354. t.Errorf("expected status 400, got %d", w.Code)
  355. }
  356. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  357. t.Errorf("mismatch (-got +want):\n%s", diff)
  358. }
  359. })
  360. t.Run("missing model", func(t *testing.T) {
  361. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
  362. if w.Code != http.StatusBadRequest {
  363. t.Errorf("expected status 400, got %d", w.Code)
  364. }
  365. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  366. t.Errorf("mismatch (-got +want):\n%s", diff)
  367. }
  368. })
  369. t.Run("missing capabilities", func(t *testing.T) {
  370. w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
  371. Name: "bert",
  372. Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
  373. "general.architecture": "bert",
  374. "bert.pooling_type": uint32(0),
  375. }, []llm.Tensor{})),
  376. Stream: &stream,
  377. })
  378. if w.Code != http.StatusOK {
  379. t.Fatalf("expected status 200, got %d", w.Code)
  380. }
  381. w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
  382. Model: "bert",
  383. })
  384. if w.Code != http.StatusBadRequest {
  385. t.Errorf("expected status 400, got %d", w.Code)
  386. }
  387. if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
  388. t.Errorf("mismatch (-got +want):\n%s", diff)
  389. }
  390. })
  391. t.Run("load model", func(t *testing.T) {
  392. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  393. Model: "test",
  394. })
  395. if w.Code != http.StatusOK {
  396. t.Errorf("expected status 200, got %d", w.Code)
  397. }
  398. var actual api.GenerateResponse
  399. if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
  400. t.Fatal(err)
  401. }
  402. if actual.Model != "test" {
  403. t.Errorf("expected model test, got %s", actual.Model)
  404. }
  405. if !actual.Done {
  406. t.Errorf("expected done true, got false")
  407. }
  408. if actual.DoneReason != "load" {
  409. t.Errorf("expected done reason load, got %s", actual.DoneReason)
  410. }
  411. })
  412. checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
  413. t.Helper()
  414. var actual api.GenerateResponse
  415. if err := json.NewDecoder(body).Decode(&actual); err != nil {
  416. t.Fatal(err)
  417. }
  418. if actual.Model != model {
  419. t.Errorf("expected model test, got %s", actual.Model)
  420. }
  421. if !actual.Done {
  422. t.Errorf("expected done false, got true")
  423. }
  424. if actual.DoneReason != "stop" {
  425. t.Errorf("expected done reason stop, got %s", actual.DoneReason)
  426. }
  427. if actual.Response != content {
  428. t.Errorf("expected response %s, got %s", content, actual.Response)
  429. }
  430. if actual.Context == nil {
  431. t.Errorf("expected context not nil")
  432. }
  433. if actual.PromptEvalCount == 0 {
  434. t.Errorf("expected prompt eval count > 0, got 0")
  435. }
  436. if actual.PromptEvalDuration == 0 {
  437. t.Errorf("expected prompt eval duration > 0, got 0")
  438. }
  439. if actual.EvalCount == 0 {
  440. t.Errorf("expected eval count > 0, got 0")
  441. }
  442. if actual.EvalDuration == 0 {
  443. t.Errorf("expected eval duration > 0, got 0")
  444. }
  445. if actual.LoadDuration == 0 {
  446. t.Errorf("expected load duration > 0, got 0")
  447. }
  448. if actual.TotalDuration == 0 {
  449. t.Errorf("expected load duration > 0, got 0")
  450. }
  451. }
  452. mock.CompletionResponse.Content = "Hi!"
  453. t.Run("prompt", func(t *testing.T) {
  454. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  455. Model: "test",
  456. Prompt: "Hello!",
  457. Stream: &stream,
  458. })
  459. if w.Code != http.StatusOK {
  460. t.Errorf("expected status 200, got %d", w.Code)
  461. }
  462. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
  463. t.Errorf("mismatch (-got +want):\n%s", diff)
  464. }
  465. checkGenerateResponse(t, w.Body, "test", "Hi!")
  466. })
  467. w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
  468. Model: "test-system",
  469. Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
  470. })
  471. if w.Code != http.StatusOK {
  472. t.Fatalf("expected status 200, got %d", w.Code)
  473. }
  474. t.Run("prompt with model system", func(t *testing.T) {
  475. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  476. Model: "test-system",
  477. Prompt: "Hello!",
  478. Stream: &stream,
  479. })
  480. if w.Code != http.StatusOK {
  481. t.Errorf("expected status 200, got %d", w.Code)
  482. }
  483. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
  484. t.Errorf("mismatch (-got +want):\n%s", diff)
  485. }
  486. checkGenerateResponse(t, w.Body, "test-system", "Hi!")
  487. })
  488. mock.CompletionResponse.Content = "Abra kadabra!"
  489. t.Run("prompt with system", func(t *testing.T) {
  490. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  491. Model: "test-system",
  492. Prompt: "Hello!",
  493. System: "You can perform magic tricks.",
  494. Stream: &stream,
  495. })
  496. if w.Code != http.StatusOK {
  497. t.Errorf("expected status 200, got %d", w.Code)
  498. }
  499. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
  500. t.Errorf("mismatch (-got +want):\n%s", diff)
  501. }
  502. checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
  503. })
  504. t.Run("prompt with template", func(t *testing.T) {
  505. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  506. Model: "test-system",
  507. Prompt: "Help me write tests.",
  508. System: "You can perform magic tricks.",
  509. Template: `{{- if .System }}{{ .System }} {{ end }}
  510. {{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
  511. {{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
  512. Stream: &stream,
  513. })
  514. if w.Code != http.StatusOK {
  515. t.Errorf("expected status 200, got %d", w.Code)
  516. }
  517. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
  518. t.Errorf("mismatch (-got +want):\n%s", diff)
  519. }
  520. checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
  521. })
  522. t.Run("raw", func(t *testing.T) {
  523. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  524. Model: "test-system",
  525. Prompt: "Help me write tests.",
  526. Raw: true,
  527. Stream: &stream,
  528. })
  529. if w.Code != http.StatusOK {
  530. t.Errorf("expected status 200, got %d", w.Code)
  531. }
  532. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
  533. t.Errorf("mismatch (-got +want):\n%s", diff)
  534. }
  535. })
  536. }