routes_generate_test.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/google/go-cmp/cmp"
  14. "github.com/ollama/ollama/api"
  15. "github.com/ollama/ollama/discover"
  16. "github.com/ollama/ollama/fileutils"
  17. "github.com/ollama/ollama/runners"
  18. )
  19. type mockRunner struct {
  20. runners.LLMServer
  21. // CompletionRequest is only valid until the next call to Completion
  22. runners.CompletionRequest
  23. runners.CompletionResponse
  24. }
  25. func (m *mockRunner) Completion(_ context.Context, r runners.CompletionRequest, fn func(r runners.CompletionResponse)) error {
  26. m.CompletionRequest = r
  27. fn(m.CompletionResponse)
  28. return nil
  29. }
  30. func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
  31. for range strings.Fields(s) {
  32. tokens = append(tokens, len(tokens))
  33. }
  34. return
  35. }
  36. func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *fileutils.GGML, []string, []string, api.Options, int) (runners.LLMServer, error) {
  37. return func(gpus discover.GpuInfoList, model string, ggml *fileutils.GGML, projectors, system []string, opts api.Options, numParallel int) (runners.LLMServer, error) {
  38. return mock, nil
  39. }
  40. }
  41. func TestGenerateChat(t *testing.T) {
  42. gin.SetMode(gin.TestMode)
  43. mock := mockRunner{
  44. CompletionResponse: runners.CompletionResponse{
  45. Done: true,
  46. DoneReason: "stop",
  47. PromptEvalCount: 1,
  48. PromptEvalDuration: 1,
  49. EvalCount: 1,
  50. EvalDuration: 1,
  51. },
  52. }
  53. s := Server{
  54. sched: &Scheduler{
  55. pendingReqCh: make(chan *LlmRequest, 1),
  56. finishedReqCh: make(chan *LlmRequest, 1),
  57. expiredCh: make(chan *runnerRef, 1),
  58. unloadedCh: make(chan any, 1),
  59. loaded: make(map[string]*runnerRef),
  60. newServerFn: newMockServer(&mock),
  61. getGpuFn: discover.GetGPUInfo,
  62. getCpuFn: discover.GetCPUInfo,
  63. reschedDelay: 250 * time.Millisecond,
  64. loadFn: func(req *LlmRequest, ggml *fileutils.GGML, gpus discover.GpuInfoList, numParallel int) {
  65. // add small delay to simulate loading
  66. time.Sleep(time.Millisecond)
  67. req.successCh <- &runnerRef{
  68. llama: &mock,
  69. }
  70. },
  71. },
  72. }
  73. go s.sched.Run(context.TODO())
  74. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  75. Model: "test",
  76. Modelfile: fmt.Sprintf(`FROM %s
  77. TEMPLATE """
  78. {{- if .System }}System: {{ .System }} {{ end }}
  79. {{- if .Prompt }}User: {{ .Prompt }} {{ end }}
  80. {{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
  81. `, createBinFile(t, fileutils.KV{
  82. "general.architecture": "llama",
  83. "llama.block_count": uint32(1),
  84. "llama.context_length": uint32(8192),
  85. "llama.embedding_length": uint32(4096),
  86. "llama.attention.head_count": uint32(32),
  87. "llama.attention.head_count_kv": uint32(8),
  88. "tokenizer.ggml.tokens": []string{""},
  89. "tokenizer.ggml.scores": []float32{0},
  90. "tokenizer.ggml.token_type": []int32{0},
  91. }, []fileutils.Tensor{
  92. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  93. {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  94. {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  95. {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  96. {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  97. {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  98. {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  99. {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  100. {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  101. {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  102. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  103. })),
  104. Stream: &stream,
  105. })
  106. if w.Code != http.StatusOK {
  107. t.Fatalf("expected status 200, got %d", w.Code)
  108. }
  109. t.Run("missing body", func(t *testing.T) {
  110. w := createRequest(t, s.ChatHandler, nil)
  111. if w.Code != http.StatusBadRequest {
  112. t.Errorf("expected status 400, got %d", w.Code)
  113. }
  114. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  115. t.Errorf("mismatch (-got +want):\n%s", diff)
  116. }
  117. })
  118. t.Run("missing model", func(t *testing.T) {
  119. w := createRequest(t, s.ChatHandler, api.ChatRequest{})
  120. if w.Code != http.StatusBadRequest {
  121. t.Errorf("expected status 400, got %d", w.Code)
  122. }
  123. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  124. t.Errorf("mismatch (-got +want):\n%s", diff)
  125. }
  126. })
  127. t.Run("missing capabilities chat", func(t *testing.T) {
  128. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  129. Model: "bert",
  130. Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, fileutils.KV{
  131. "general.architecture": "bert",
  132. "bert.pooling_type": uint32(0),
  133. }, []fileutils.Tensor{})),
  134. Stream: &stream,
  135. })
  136. if w.Code != http.StatusOK {
  137. t.Fatalf("expected status 200, got %d", w.Code)
  138. }
  139. w = createRequest(t, s.ChatHandler, api.ChatRequest{
  140. Model: "bert",
  141. })
  142. if w.Code != http.StatusBadRequest {
  143. t.Errorf("expected status 400, got %d", w.Code)
  144. }
  145. if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
  146. t.Errorf("mismatch (-got +want):\n%s", diff)
  147. }
  148. })
  149. t.Run("load model", func(t *testing.T) {
  150. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  151. Model: "test",
  152. })
  153. if w.Code != http.StatusOK {
  154. t.Errorf("expected status 200, got %d", w.Code)
  155. }
  156. var actual api.ChatResponse
  157. if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
  158. t.Fatal(err)
  159. }
  160. if actual.Model != "test" {
  161. t.Errorf("expected model test, got %s", actual.Model)
  162. }
  163. if !actual.Done {
  164. t.Errorf("expected done true, got false")
  165. }
  166. if actual.DoneReason != "load" {
  167. t.Errorf("expected done reason load, got %s", actual.DoneReason)
  168. }
  169. })
  170. checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
  171. t.Helper()
  172. var actual api.ChatResponse
  173. if err := json.NewDecoder(body).Decode(&actual); err != nil {
  174. t.Fatal(err)
  175. }
  176. if actual.Model != model {
  177. t.Errorf("expected model test, got %s", actual.Model)
  178. }
  179. if !actual.Done {
  180. t.Errorf("expected done false, got true")
  181. }
  182. if actual.DoneReason != "stop" {
  183. t.Errorf("expected done reason stop, got %s", actual.DoneReason)
  184. }
  185. if diff := cmp.Diff(actual.Message, api.Message{
  186. Role: "assistant",
  187. Content: content,
  188. }); diff != "" {
  189. t.Errorf("mismatch (-got +want):\n%s", diff)
  190. }
  191. if actual.PromptEvalCount == 0 {
  192. t.Errorf("expected prompt eval count > 0, got 0")
  193. }
  194. if actual.PromptEvalDuration == 0 {
  195. t.Errorf("expected prompt eval duration > 0, got 0")
  196. }
  197. if actual.EvalCount == 0 {
  198. t.Errorf("expected eval count > 0, got 0")
  199. }
  200. if actual.EvalDuration == 0 {
  201. t.Errorf("expected eval duration > 0, got 0")
  202. }
  203. if actual.LoadDuration == 0 {
  204. t.Errorf("expected load duration > 0, got 0")
  205. }
  206. if actual.TotalDuration == 0 {
  207. t.Errorf("expected total duration > 0, got 0")
  208. }
  209. }
  210. mock.CompletionResponse.Content = "Hi!"
  211. t.Run("messages", func(t *testing.T) {
  212. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  213. Model: "test",
  214. Messages: []api.Message{
  215. {Role: "user", Content: "Hello!"},
  216. },
  217. Stream: &stream,
  218. })
  219. if w.Code != http.StatusOK {
  220. t.Errorf("expected status 200, got %d", w.Code)
  221. }
  222. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
  223. t.Errorf("mismatch (-got +want):\n%s", diff)
  224. }
  225. checkChatResponse(t, w.Body, "test", "Hi!")
  226. })
  227. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  228. Model: "test-system",
  229. Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
  230. })
  231. if w.Code != http.StatusOK {
  232. t.Fatalf("expected status 200, got %d", w.Code)
  233. }
  234. t.Run("messages with model system", func(t *testing.T) {
  235. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  236. Model: "test-system",
  237. Messages: []api.Message{
  238. {Role: "user", Content: "Hello!"},
  239. },
  240. Stream: &stream,
  241. })
  242. if w.Code != http.StatusOK {
  243. t.Errorf("expected status 200, got %d", w.Code)
  244. }
  245. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
  246. t.Errorf("mismatch (-got +want):\n%s", diff)
  247. }
  248. checkChatResponse(t, w.Body, "test-system", "Hi!")
  249. })
  250. mock.CompletionResponse.Content = "Abra kadabra!"
  251. t.Run("messages with system", func(t *testing.T) {
  252. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  253. Model: "test-system",
  254. Messages: []api.Message{
  255. {Role: "system", Content: "You can perform magic tricks."},
  256. {Role: "user", Content: "Hello!"},
  257. },
  258. Stream: &stream,
  259. })
  260. if w.Code != http.StatusOK {
  261. t.Errorf("expected status 200, got %d", w.Code)
  262. }
  263. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
  264. t.Errorf("mismatch (-got +want):\n%s", diff)
  265. }
  266. checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
  267. })
  268. t.Run("messages with interleaved system", func(t *testing.T) {
  269. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  270. Model: "test-system",
  271. Messages: []api.Message{
  272. {Role: "user", Content: "Hello!"},
  273. {Role: "assistant", Content: "I can help you with that."},
  274. {Role: "system", Content: "You can perform magic tricks."},
  275. {Role: "user", Content: "Help me write tests."},
  276. },
  277. Stream: &stream,
  278. })
  279. if w.Code != http.StatusOK {
  280. t.Errorf("expected status 200, got %d", w.Code)
  281. }
  282. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
  283. t.Errorf("mismatch (-got +want):\n%s", diff)
  284. }
  285. checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
  286. })
  287. }
  288. func TestGenerate(t *testing.T) {
  289. gin.SetMode(gin.TestMode)
  290. mock := mockRunner{
  291. CompletionResponse: runners.CompletionResponse{
  292. Done: true,
  293. DoneReason: "stop",
  294. PromptEvalCount: 1,
  295. PromptEvalDuration: 1,
  296. EvalCount: 1,
  297. EvalDuration: 1,
  298. },
  299. }
  300. s := Server{
  301. sched: &Scheduler{
  302. pendingReqCh: make(chan *LlmRequest, 1),
  303. finishedReqCh: make(chan *LlmRequest, 1),
  304. expiredCh: make(chan *runnerRef, 1),
  305. unloadedCh: make(chan any, 1),
  306. loaded: make(map[string]*runnerRef),
  307. newServerFn: newMockServer(&mock),
  308. getGpuFn: discover.GetGPUInfo,
  309. getCpuFn: discover.GetCPUInfo,
  310. reschedDelay: 250 * time.Millisecond,
  311. loadFn: func(req *LlmRequest, ggml *fileutils.GGML, gpus discover.GpuInfoList, numParallel int) {
  312. // add small delay to simulate loading
  313. time.Sleep(time.Millisecond)
  314. req.successCh <- &runnerRef{
  315. llama: &mock,
  316. }
  317. },
  318. },
  319. }
  320. go s.sched.Run(context.TODO())
  321. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  322. Model: "test",
  323. Modelfile: fmt.Sprintf(`FROM %s
  324. TEMPLATE """
  325. {{- if .System }}System: {{ .System }} {{ end }}
  326. {{- if .Prompt }}User: {{ .Prompt }} {{ end }}
  327. {{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
  328. `, createBinFile(t, fileutils.KV{
  329. "general.architecture": "llama",
  330. "llama.block_count": uint32(1),
  331. "llama.context_length": uint32(8192),
  332. "llama.embedding_length": uint32(4096),
  333. "llama.attention.head_count": uint32(32),
  334. "llama.attention.head_count_kv": uint32(8),
  335. "tokenizer.ggml.tokens": []string{""},
  336. "tokenizer.ggml.scores": []float32{0},
  337. "tokenizer.ggml.token_type": []int32{0},
  338. }, []fileutils.Tensor{
  339. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  340. {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  341. {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  342. {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  343. {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  344. {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  345. {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  346. {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  347. {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  348. {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  349. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  350. })),
  351. Stream: &stream,
  352. })
  353. if w.Code != http.StatusOK {
  354. t.Fatalf("expected status 200, got %d", w.Code)
  355. }
  356. t.Run("missing body", func(t *testing.T) {
  357. w := createRequest(t, s.GenerateHandler, nil)
  358. if w.Code != http.StatusNotFound {
  359. t.Errorf("expected status 404, got %d", w.Code)
  360. }
  361. if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
  362. t.Errorf("mismatch (-got +want):\n%s", diff)
  363. }
  364. })
  365. t.Run("missing model", func(t *testing.T) {
  366. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
  367. if w.Code != http.StatusNotFound {
  368. t.Errorf("expected status 404, got %d", w.Code)
  369. }
  370. if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
  371. t.Errorf("mismatch (-got +want):\n%s", diff)
  372. }
  373. })
  374. t.Run("missing capabilities generate", func(t *testing.T) {
  375. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  376. Model: "bert",
  377. Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, fileutils.KV{
  378. "general.architecture": "bert",
  379. "bert.pooling_type": uint32(0),
  380. }, []fileutils.Tensor{})),
  381. Stream: &stream,
  382. })
  383. if w.Code != http.StatusOK {
  384. t.Fatalf("expected status 200, got %d", w.Code)
  385. }
  386. w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
  387. Model: "bert",
  388. })
  389. if w.Code != http.StatusBadRequest {
  390. t.Errorf("expected status 400, got %d", w.Code)
  391. }
  392. if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
  393. t.Errorf("mismatch (-got +want):\n%s", diff)
  394. }
  395. })
  396. t.Run("missing capabilities suffix", func(t *testing.T) {
  397. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  398. Model: "test",
  399. Prompt: "def add(",
  400. Suffix: " return c",
  401. })
  402. if w.Code != http.StatusBadRequest {
  403. t.Errorf("expected status 400, got %d", w.Code)
  404. }
  405. if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
  406. t.Errorf("mismatch (-got +want):\n%s", diff)
  407. }
  408. })
  409. t.Run("load model", func(t *testing.T) {
  410. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  411. Model: "test",
  412. })
  413. if w.Code != http.StatusOK {
  414. t.Errorf("expected status 200, got %d", w.Code)
  415. }
  416. var actual api.GenerateResponse
  417. if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
  418. t.Fatal(err)
  419. }
  420. if actual.Model != "test" {
  421. t.Errorf("expected model test, got %s", actual.Model)
  422. }
  423. if !actual.Done {
  424. t.Errorf("expected done true, got false")
  425. }
  426. if actual.DoneReason != "load" {
  427. t.Errorf("expected done reason load, got %s", actual.DoneReason)
  428. }
  429. })
  430. checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
  431. t.Helper()
  432. var actual api.GenerateResponse
  433. if err := json.NewDecoder(body).Decode(&actual); err != nil {
  434. t.Fatal(err)
  435. }
  436. if actual.Model != model {
  437. t.Errorf("expected model test, got %s", actual.Model)
  438. }
  439. if !actual.Done {
  440. t.Errorf("expected done false, got true")
  441. }
  442. if actual.DoneReason != "stop" {
  443. t.Errorf("expected done reason stop, got %s", actual.DoneReason)
  444. }
  445. if actual.Response != content {
  446. t.Errorf("expected response %s, got %s", content, actual.Response)
  447. }
  448. if actual.Context == nil {
  449. t.Errorf("expected context not nil")
  450. }
  451. if actual.PromptEvalCount == 0 {
  452. t.Errorf("expected prompt eval count > 0, got 0")
  453. }
  454. if actual.PromptEvalDuration == 0 {
  455. t.Errorf("expected prompt eval duration > 0, got 0")
  456. }
  457. if actual.EvalCount == 0 {
  458. t.Errorf("expected eval count > 0, got 0")
  459. }
  460. if actual.EvalDuration == 0 {
  461. t.Errorf("expected eval duration > 0, got 0")
  462. }
  463. if actual.LoadDuration == 0 {
  464. t.Errorf("expected load duration > 0, got 0")
  465. }
  466. if actual.TotalDuration == 0 {
  467. t.Errorf("expected total duration > 0, got 0")
  468. }
  469. }
  470. mock.CompletionResponse.Content = "Hi!"
  471. t.Run("prompt", func(t *testing.T) {
  472. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  473. Model: "test",
  474. Prompt: "Hello!",
  475. Stream: &stream,
  476. })
  477. if w.Code != http.StatusOK {
  478. t.Errorf("expected status 200, got %d", w.Code)
  479. }
  480. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
  481. t.Errorf("mismatch (-got +want):\n%s", diff)
  482. }
  483. checkGenerateResponse(t, w.Body, "test", "Hi!")
  484. })
  485. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  486. Model: "test-system",
  487. Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
  488. })
  489. if w.Code != http.StatusOK {
  490. t.Fatalf("expected status 200, got %d", w.Code)
  491. }
  492. t.Run("prompt with model system", func(t *testing.T) {
  493. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  494. Model: "test-system",
  495. Prompt: "Hello!",
  496. Stream: &stream,
  497. })
  498. if w.Code != http.StatusOK {
  499. t.Errorf("expected status 200, got %d", w.Code)
  500. }
  501. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
  502. t.Errorf("mismatch (-got +want):\n%s", diff)
  503. }
  504. checkGenerateResponse(t, w.Body, "test-system", "Hi!")
  505. })
  506. mock.CompletionResponse.Content = "Abra kadabra!"
  507. t.Run("prompt with system", func(t *testing.T) {
  508. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  509. Model: "test-system",
  510. Prompt: "Hello!",
  511. System: "You can perform magic tricks.",
  512. Stream: &stream,
  513. })
  514. if w.Code != http.StatusOK {
  515. t.Errorf("expected status 200, got %d", w.Code)
  516. }
  517. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
  518. t.Errorf("mismatch (-got +want):\n%s", diff)
  519. }
  520. checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
  521. })
  522. t.Run("prompt with template", func(t *testing.T) {
  523. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  524. Model: "test-system",
  525. Prompt: "Help me write tests.",
  526. System: "You can perform magic tricks.",
  527. Template: `{{- if .System }}{{ .System }} {{ end }}
  528. {{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
  529. {{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
  530. Stream: &stream,
  531. })
  532. if w.Code != http.StatusOK {
  533. t.Errorf("expected status 200, got %d", w.Code)
  534. }
  535. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
  536. t.Errorf("mismatch (-got +want):\n%s", diff)
  537. }
  538. checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
  539. })
  540. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  541. Model: "test-suffix",
  542. Modelfile: `FROM test
  543. TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
  544. {{- else }}{{ .Prompt }}
  545. {{- end }}"""`,
  546. })
  547. if w.Code != http.StatusOK {
  548. t.Fatalf("expected status 200, got %d", w.Code)
  549. }
  550. t.Run("prompt with suffix", func(t *testing.T) {
  551. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  552. Model: "test-suffix",
  553. Prompt: "def add(",
  554. Suffix: " return c",
  555. })
  556. if w.Code != http.StatusOK {
  557. t.Errorf("expected status 200, got %d", w.Code)
  558. }
  559. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
  560. t.Errorf("mismatch (-got +want):\n%s", diff)
  561. }
  562. })
  563. t.Run("prompt without suffix", func(t *testing.T) {
  564. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  565. Model: "test-suffix",
  566. Prompt: "def add(",
  567. })
  568. if w.Code != http.StatusOK {
  569. t.Errorf("expected status 200, got %d", w.Code)
  570. }
  571. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
  572. t.Errorf("mismatch (-got +want):\n%s", diff)
  573. }
  574. })
  575. t.Run("raw", func(t *testing.T) {
  576. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  577. Model: "test-system",
  578. Prompt: "Help me write tests.",
  579. Raw: true,
  580. Stream: &stream,
  581. })
  582. if w.Code != http.StatusOK {
  583. t.Errorf("expected status 200, got %d", w.Code)
  584. }
  585. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
  586. t.Errorf("mismatch (-got +want):\n%s", diff)
  587. }
  588. })
  589. }