routes_generate_test.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/google/go-cmp/cmp"
  14. "github.com/ollama/ollama/api"
  15. "github.com/ollama/ollama/gpu"
  16. "github.com/ollama/ollama/llm"
  17. )
  18. type mockRunner struct {
  19. llm.LlamaServer
  20. // CompletionRequest is only valid until the next call to Completion
  21. llm.CompletionRequest
  22. llm.CompletionResponse
  23. }
  24. func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
  25. m.CompletionRequest = r
  26. fn(m.CompletionResponse)
  27. return nil
  28. }
  29. func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
  30. for range strings.Fields(s) {
  31. tokens = append(tokens, len(tokens))
  32. }
  33. return
  34. }
  35. func newMockServer(mock *mockRunner) func(gpu.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
  36. return func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
  37. return mock, nil
  38. }
  39. }
  40. func TestGenerateChat(t *testing.T) {
  41. gin.SetMode(gin.TestMode)
  42. mock := mockRunner{
  43. CompletionResponse: llm.CompletionResponse{
  44. Done: true,
  45. DoneReason: "stop",
  46. PromptEvalCount: 1,
  47. PromptEvalDuration: 1,
  48. EvalCount: 1,
  49. EvalDuration: 1,
  50. },
  51. }
  52. s := Server{
  53. sched: &Scheduler{
  54. pendingReqCh: make(chan *LlmRequest, 1),
  55. finishedReqCh: make(chan *LlmRequest, 1),
  56. expiredCh: make(chan *runnerRef, 1),
  57. unloadedCh: make(chan any, 1),
  58. loaded: make(map[string]*runnerRef),
  59. newServerFn: newMockServer(&mock),
  60. getGpuFn: gpu.GetGPUInfo,
  61. getCpuFn: gpu.GetCPUInfo,
  62. reschedDelay: 250 * time.Millisecond,
  63. loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
  64. // add 10ms delay to simulate loading
  65. time.Sleep(10 * time.Millisecond)
  66. req.successCh <- &runnerRef{
  67. llama: &mock,
  68. }
  69. },
  70. },
  71. }
  72. go s.sched.Run(context.TODO())
  73. w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
  74. Model: "test",
  75. Modelfile: fmt.Sprintf(`FROM %s
  76. TEMPLATE """
  77. {{- if .System }}System: {{ .System }} {{ end }}
  78. {{- if .Prompt }}User: {{ .Prompt }} {{ end }}
  79. {{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
  80. `, createBinFile(t, llm.KV{
  81. "general.architecture": "llama",
  82. "llama.block_count": uint32(1),
  83. "llama.context_length": uint32(8192),
  84. "llama.embedding_length": uint32(4096),
  85. "llama.attention.head_count": uint32(32),
  86. "llama.attention.head_count_kv": uint32(8),
  87. "tokenizer.ggml.tokens": []string{""},
  88. "tokenizer.ggml.scores": []float32{0},
  89. "tokenizer.ggml.token_type": []int32{0},
  90. }, []llm.Tensor{
  91. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  92. {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  93. {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  94. {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  95. {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  96. {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  97. {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  98. {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  99. {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  100. {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  101. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  102. })),
  103. Stream: &stream,
  104. })
  105. if w.Code != http.StatusOK {
  106. t.Fatalf("expected status 200, got %d", w.Code)
  107. }
  108. t.Run("missing body", func(t *testing.T) {
  109. w := createRequest(t, s.ChatHandler, nil)
  110. if w.Code != http.StatusBadRequest {
  111. t.Errorf("expected status 400, got %d", w.Code)
  112. }
  113. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  114. t.Errorf("mismatch (-got +want):\n%s", diff)
  115. }
  116. })
  117. t.Run("missing model", func(t *testing.T) {
  118. w := createRequest(t, s.ChatHandler, api.ChatRequest{})
  119. if w.Code != http.StatusBadRequest {
  120. t.Errorf("expected status 400, got %d", w.Code)
  121. }
  122. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  123. t.Errorf("mismatch (-got +want):\n%s", diff)
  124. }
  125. })
  126. t.Run("missing capabilities chat", func(t *testing.T) {
  127. w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
  128. Model: "bert",
  129. Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
  130. "general.architecture": "bert",
  131. "bert.pooling_type": uint32(0),
  132. }, []llm.Tensor{})),
  133. Stream: &stream,
  134. })
  135. if w.Code != http.StatusOK {
  136. t.Fatalf("expected status 200, got %d", w.Code)
  137. }
  138. w = createRequest(t, s.ChatHandler, api.ChatRequest{
  139. Model: "bert",
  140. })
  141. if w.Code != http.StatusBadRequest {
  142. t.Errorf("expected status 400, got %d", w.Code)
  143. }
  144. if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
  145. t.Errorf("mismatch (-got +want):\n%s", diff)
  146. }
  147. })
  148. t.Run("load model", func(t *testing.T) {
  149. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  150. Model: "test",
  151. })
  152. if w.Code != http.StatusOK {
  153. t.Errorf("expected status 200, got %d", w.Code)
  154. }
  155. var actual api.ChatResponse
  156. if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
  157. t.Fatal(err)
  158. }
  159. if actual.Model != "test" {
  160. t.Errorf("expected model test, got %s", actual.Model)
  161. }
  162. if !actual.Done {
  163. t.Errorf("expected done true, got false")
  164. }
  165. if actual.DoneReason != "load" {
  166. t.Errorf("expected done reason load, got %s", actual.DoneReason)
  167. }
  168. })
  169. checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
  170. t.Helper()
  171. var actual api.ChatResponse
  172. if err := json.NewDecoder(body).Decode(&actual); err != nil {
  173. t.Fatal(err)
  174. }
  175. if actual.Model != model {
  176. t.Errorf("expected model test, got %s", actual.Model)
  177. }
  178. if !actual.Done {
  179. t.Errorf("expected done false, got true")
  180. }
  181. if actual.DoneReason != "stop" {
  182. t.Errorf("expected done reason stop, got %s", actual.DoneReason)
  183. }
  184. if diff := cmp.Diff(actual.Message, api.Message{
  185. Role: "assistant",
  186. Content: content,
  187. }); diff != "" {
  188. t.Errorf("mismatch (-got +want):\n%s", diff)
  189. }
  190. if actual.PromptEvalCount == 0 {
  191. t.Errorf("expected prompt eval count > 0, got 0")
  192. }
  193. if actual.PromptEvalDuration == 0 {
  194. t.Errorf("expected prompt eval duration > 0, got 0")
  195. }
  196. if actual.EvalCount == 0 {
  197. t.Errorf("expected eval count > 0, got 0")
  198. }
  199. if actual.EvalDuration == 0 {
  200. t.Errorf("expected eval duration > 0, got 0")
  201. }
  202. if actual.LoadDuration == 0 {
  203. t.Errorf("expected load duration > 0, got 0")
  204. }
  205. if actual.TotalDuration == 0 {
  206. t.Errorf("expected total duration > 0, got 0")
  207. }
  208. }
  209. mock.CompletionResponse.Content = "Hi!"
  210. t.Run("messages", func(t *testing.T) {
  211. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  212. Model: "test",
  213. Messages: []api.Message{
  214. {Role: "user", Content: "Hello!"},
  215. },
  216. Stream: &stream,
  217. })
  218. if w.Code != http.StatusOK {
  219. t.Errorf("expected status 200, got %d", w.Code)
  220. }
  221. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
  222. t.Errorf("mismatch (-got +want):\n%s", diff)
  223. }
  224. checkChatResponse(t, w.Body, "test", "Hi!")
  225. })
  226. w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
  227. Model: "test-system",
  228. Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
  229. })
  230. if w.Code != http.StatusOK {
  231. t.Fatalf("expected status 200, got %d", w.Code)
  232. }
  233. t.Run("messages with model system", func(t *testing.T) {
  234. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  235. Model: "test-system",
  236. Messages: []api.Message{
  237. {Role: "user", Content: "Hello!"},
  238. },
  239. Stream: &stream,
  240. })
  241. if w.Code != http.StatusOK {
  242. t.Errorf("expected status 200, got %d", w.Code)
  243. }
  244. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
  245. t.Errorf("mismatch (-got +want):\n%s", diff)
  246. }
  247. checkChatResponse(t, w.Body, "test-system", "Hi!")
  248. })
  249. mock.CompletionResponse.Content = "Abra kadabra!"
  250. t.Run("messages with system", func(t *testing.T) {
  251. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  252. Model: "test-system",
  253. Messages: []api.Message{
  254. {Role: "system", Content: "You can perform magic tricks."},
  255. {Role: "user", Content: "Hello!"},
  256. },
  257. Stream: &stream,
  258. })
  259. if w.Code != http.StatusOK {
  260. t.Errorf("expected status 200, got %d", w.Code)
  261. }
  262. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
  263. t.Errorf("mismatch (-got +want):\n%s", diff)
  264. }
  265. checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
  266. })
  267. t.Run("messages with interleaved system", func(t *testing.T) {
  268. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  269. Model: "test-system",
  270. Messages: []api.Message{
  271. {Role: "user", Content: "Hello!"},
  272. {Role: "assistant", Content: "I can help you with that."},
  273. {Role: "system", Content: "You can perform magic tricks."},
  274. {Role: "user", Content: "Help me write tests."},
  275. },
  276. Stream: &stream,
  277. })
  278. if w.Code != http.StatusOK {
  279. t.Errorf("expected status 200, got %d", w.Code)
  280. }
  281. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
  282. t.Errorf("mismatch (-got +want):\n%s", diff)
  283. }
  284. checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
  285. })
  286. }
  287. func TestGenerate(t *testing.T) {
  288. gin.SetMode(gin.TestMode)
  289. mock := mockRunner{
  290. CompletionResponse: llm.CompletionResponse{
  291. Done: true,
  292. DoneReason: "stop",
  293. PromptEvalCount: 1,
  294. PromptEvalDuration: 1,
  295. EvalCount: 1,
  296. EvalDuration: 1,
  297. },
  298. }
  299. s := Server{
  300. sched: &Scheduler{
  301. pendingReqCh: make(chan *LlmRequest, 1),
  302. finishedReqCh: make(chan *LlmRequest, 1),
  303. expiredCh: make(chan *runnerRef, 1),
  304. unloadedCh: make(chan any, 1),
  305. loaded: make(map[string]*runnerRef),
  306. newServerFn: newMockServer(&mock),
  307. getGpuFn: gpu.GetGPUInfo,
  308. getCpuFn: gpu.GetCPUInfo,
  309. reschedDelay: 250 * time.Millisecond,
  310. loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
  311. req.successCh <- &runnerRef{
  312. llama: &mock,
  313. }
  314. },
  315. },
  316. }
  317. go s.sched.Run(context.TODO())
  318. w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
  319. Model: "test",
  320. Modelfile: fmt.Sprintf(`FROM %s
  321. TEMPLATE """
  322. {{- if .System }}System: {{ .System }} {{ end }}
  323. {{- if .Prompt }}User: {{ .Prompt }} {{ end }}
  324. {{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
  325. `, createBinFile(t, llm.KV{
  326. "general.architecture": "llama",
  327. "llama.block_count": uint32(1),
  328. "llama.context_length": uint32(8192),
  329. "llama.embedding_length": uint32(4096),
  330. "llama.attention.head_count": uint32(32),
  331. "llama.attention.head_count_kv": uint32(8),
  332. "tokenizer.ggml.tokens": []string{""},
  333. "tokenizer.ggml.scores": []float32{0},
  334. "tokenizer.ggml.token_type": []int32{0},
  335. }, []llm.Tensor{
  336. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  337. {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  338. {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  339. {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  340. {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  341. {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  342. {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  343. {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  344. {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  345. {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  346. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  347. })),
  348. Stream: &stream,
  349. })
  350. if w.Code != http.StatusOK {
  351. t.Fatalf("expected status 200, got %d", w.Code)
  352. }
  353. t.Run("missing body", func(t *testing.T) {
  354. w := createRequest(t, s.GenerateHandler, nil)
  355. if w.Code != http.StatusBadRequest {
  356. t.Errorf("expected status 400, got %d", w.Code)
  357. }
  358. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  359. t.Errorf("mismatch (-got +want):\n%s", diff)
  360. }
  361. })
  362. t.Run("missing model", func(t *testing.T) {
  363. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
  364. if w.Code != http.StatusBadRequest {
  365. t.Errorf("expected status 400, got %d", w.Code)
  366. }
  367. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  368. t.Errorf("mismatch (-got +want):\n%s", diff)
  369. }
  370. })
  371. t.Run("missing capabilities generate", func(t *testing.T) {
  372. w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
  373. Model: "bert",
  374. Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
  375. "general.architecture": "bert",
  376. "bert.pooling_type": uint32(0),
  377. }, []llm.Tensor{})),
  378. Stream: &stream,
  379. })
  380. if w.Code != http.StatusOK {
  381. t.Fatalf("expected status 200, got %d", w.Code)
  382. }
  383. w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
  384. Model: "bert",
  385. })
  386. if w.Code != http.StatusBadRequest {
  387. t.Errorf("expected status 400, got %d", w.Code)
  388. }
  389. if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
  390. t.Errorf("mismatch (-got +want):\n%s", diff)
  391. }
  392. })
  393. t.Run("missing capabilities suffix", func(t *testing.T) {
  394. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  395. Model: "test",
  396. Prompt: "def add(",
  397. Suffix: " return c",
  398. })
  399. if w.Code != http.StatusBadRequest {
  400. t.Errorf("expected status 400, got %d", w.Code)
  401. }
  402. if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
  403. t.Errorf("mismatch (-got +want):\n%s", diff)
  404. }
  405. })
  406. t.Run("load model", func(t *testing.T) {
  407. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  408. Model: "test",
  409. })
  410. if w.Code != http.StatusOK {
  411. t.Errorf("expected status 200, got %d", w.Code)
  412. }
  413. var actual api.GenerateResponse
  414. if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
  415. t.Fatal(err)
  416. }
  417. if actual.Model != "test" {
  418. t.Errorf("expected model test, got %s", actual.Model)
  419. }
  420. if !actual.Done {
  421. t.Errorf("expected done true, got false")
  422. }
  423. if actual.DoneReason != "load" {
  424. t.Errorf("expected done reason load, got %s", actual.DoneReason)
  425. }
  426. })
  427. checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
  428. t.Helper()
  429. var actual api.GenerateResponse
  430. if err := json.NewDecoder(body).Decode(&actual); err != nil {
  431. t.Fatal(err)
  432. }
  433. if actual.Model != model {
  434. t.Errorf("expected model test, got %s", actual.Model)
  435. }
  436. if !actual.Done {
  437. t.Errorf("expected done false, got true")
  438. }
  439. if actual.DoneReason != "stop" {
  440. t.Errorf("expected done reason stop, got %s", actual.DoneReason)
  441. }
  442. if actual.Response != content {
  443. t.Errorf("expected response %s, got %s", content, actual.Response)
  444. }
  445. if actual.Context == nil {
  446. t.Errorf("expected context not nil")
  447. }
  448. if actual.PromptEvalCount == 0 {
  449. t.Errorf("expected prompt eval count > 0, got 0")
  450. }
  451. if actual.PromptEvalDuration == 0 {
  452. t.Errorf("expected prompt eval duration > 0, got 0")
  453. }
  454. if actual.EvalCount == 0 {
  455. t.Errorf("expected eval count > 0, got 0")
  456. }
  457. if actual.EvalDuration == 0 {
  458. t.Errorf("expected eval duration > 0, got 0")
  459. }
  460. if actual.LoadDuration == 0 {
  461. t.Errorf("expected load duration > 0, got 0")
  462. }
  463. if actual.TotalDuration == 0 {
  464. t.Errorf("expected total duration > 0, got 0")
  465. }
  466. }
  467. mock.CompletionResponse.Content = "Hi!"
  468. t.Run("prompt", func(t *testing.T) {
  469. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  470. Model: "test",
  471. Prompt: "Hello!",
  472. Stream: &stream,
  473. })
  474. if w.Code != http.StatusOK {
  475. t.Errorf("expected status 200, got %d", w.Code)
  476. }
  477. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
  478. t.Errorf("mismatch (-got +want):\n%s", diff)
  479. }
  480. checkGenerateResponse(t, w.Body, "test", "Hi!")
  481. })
  482. w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
  483. Model: "test-system",
  484. Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
  485. })
  486. if w.Code != http.StatusOK {
  487. t.Fatalf("expected status 200, got %d", w.Code)
  488. }
  489. t.Run("prompt with model system", func(t *testing.T) {
  490. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  491. Model: "test-system",
  492. Prompt: "Hello!",
  493. Stream: &stream,
  494. })
  495. if w.Code != http.StatusOK {
  496. t.Errorf("expected status 200, got %d", w.Code)
  497. }
  498. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
  499. t.Errorf("mismatch (-got +want):\n%s", diff)
  500. }
  501. checkGenerateResponse(t, w.Body, "test-system", "Hi!")
  502. })
  503. mock.CompletionResponse.Content = "Abra kadabra!"
  504. t.Run("prompt with system", func(t *testing.T) {
  505. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  506. Model: "test-system",
  507. Prompt: "Hello!",
  508. System: "You can perform magic tricks.",
  509. Stream: &stream,
  510. })
  511. if w.Code != http.StatusOK {
  512. t.Errorf("expected status 200, got %d", w.Code)
  513. }
  514. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
  515. t.Errorf("mismatch (-got +want):\n%s", diff)
  516. }
  517. checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
  518. })
  519. t.Run("prompt with template", func(t *testing.T) {
  520. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  521. Model: "test-system",
  522. Prompt: "Help me write tests.",
  523. System: "You can perform magic tricks.",
  524. Template: `{{- if .System }}{{ .System }} {{ end }}
  525. {{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
  526. {{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
  527. Stream: &stream,
  528. })
  529. if w.Code != http.StatusOK {
  530. t.Errorf("expected status 200, got %d", w.Code)
  531. }
  532. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
  533. t.Errorf("mismatch (-got +want):\n%s", diff)
  534. }
  535. checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
  536. })
  537. w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
  538. Model: "test-suffix",
  539. Modelfile: `FROM test
  540. TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
  541. {{- else }}{{ .Prompt }}
  542. {{- end }}"""`,
  543. })
  544. if w.Code != http.StatusOK {
  545. t.Fatalf("expected status 200, got %d", w.Code)
  546. }
  547. t.Run("prompt with suffix", func(t *testing.T) {
  548. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  549. Model: "test-suffix",
  550. Prompt: "def add(",
  551. Suffix: " return c",
  552. })
  553. if w.Code != http.StatusOK {
  554. t.Errorf("expected status 200, got %d", w.Code)
  555. }
  556. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
  557. t.Errorf("mismatch (-got +want):\n%s", diff)
  558. }
  559. })
  560. t.Run("prompt without suffix", func(t *testing.T) {
  561. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  562. Model: "test-suffix",
  563. Prompt: "def add(",
  564. })
  565. if w.Code != http.StatusOK {
  566. t.Errorf("expected status 200, got %d", w.Code)
  567. }
  568. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
  569. t.Errorf("mismatch (-got +want):\n%s", diff)
  570. }
  571. })
  572. t.Run("raw", func(t *testing.T) {
  573. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  574. Model: "test-system",
  575. Prompt: "Help me write tests.",
  576. Raw: true,
  577. Stream: &stream,
  578. })
  579. if w.Code != http.StatusOK {
  580. t.Errorf("expected status 200, got %d", w.Code)
  581. }
  582. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
  583. t.Errorf("mismatch (-got +want):\n%s", diff)
  584. }
  585. })
  586. }