routes_generate_test.go 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "sync"
  11. "testing"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. "github.com/google/go-cmp/cmp"
  15. "github.com/ollama/ollama/api"
  16. "github.com/ollama/ollama/discover"
  17. "github.com/ollama/ollama/llm"
  18. )
  19. type mockRunner struct {
  20. llm.LlamaServer
  21. // CompletionRequest is only valid until the next call to Completion
  22. llm.CompletionRequest
  23. llm.CompletionResponse
  24. CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
  25. }
  26. func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
  27. m.CompletionRequest = r
  28. if m.CompletionFn != nil {
  29. return m.CompletionFn(ctx, r, fn)
  30. }
  31. fn(m.CompletionResponse)
  32. return nil
  33. }
  34. func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
  35. for range strings.Fields(s) {
  36. tokens = append(tokens, len(tokens))
  37. }
  38. return
  39. }
  40. func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
  41. return func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
  42. return mock, nil
  43. }
  44. }
  45. func TestGenerateChat(t *testing.T) {
  46. gin.SetMode(gin.TestMode)
  47. mock := mockRunner{
  48. CompletionResponse: llm.CompletionResponse{
  49. Done: true,
  50. DoneReason: "stop",
  51. PromptEvalCount: 1,
  52. PromptEvalDuration: 1,
  53. EvalCount: 1,
  54. EvalDuration: 1,
  55. },
  56. }
  57. s := Server{
  58. sched: &Scheduler{
  59. pendingReqCh: make(chan *LlmRequest, 1),
  60. finishedReqCh: make(chan *LlmRequest, 1),
  61. expiredCh: make(chan *runnerRef, 1),
  62. unloadedCh: make(chan any, 1),
  63. loaded: make(map[string]*runnerRef),
  64. newServerFn: newMockServer(&mock),
  65. getGpuFn: discover.GetGPUInfo,
  66. getCpuFn: discover.GetCPUInfo,
  67. reschedDelay: 250 * time.Millisecond,
  68. loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
  69. // add small delay to simulate loading
  70. time.Sleep(time.Millisecond)
  71. req.successCh <- &runnerRef{
  72. llama: &mock,
  73. }
  74. },
  75. },
  76. }
  77. go s.sched.Run(context.TODO())
  78. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  79. Model: "test",
  80. Modelfile: fmt.Sprintf(`FROM %s
  81. TEMPLATE """
  82. {{- if .Tools }}
  83. {{ .Tools }}
  84. {{ end }}
  85. {{- range .Messages }}
  86. {{- .Role }}: {{ .Content }}
  87. {{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
  88. {{- end }}
  89. {{ end }}"""
  90. `, createBinFile(t, llm.KV{
  91. "general.architecture": "llama",
  92. "llama.block_count": uint32(1),
  93. "llama.context_length": uint32(8192),
  94. "llama.embedding_length": uint32(4096),
  95. "llama.attention.head_count": uint32(32),
  96. "llama.attention.head_count_kv": uint32(8),
  97. "tokenizer.ggml.tokens": []string{""},
  98. "tokenizer.ggml.scores": []float32{0},
  99. "tokenizer.ggml.token_type": []int32{0},
  100. }, []llm.Tensor{
  101. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  102. {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  103. {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  104. {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  105. {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  106. {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  107. {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  108. {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  109. {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  110. {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  111. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  112. })),
  113. Stream: &stream,
  114. })
  115. if w.Code != http.StatusOK {
  116. t.Fatalf("expected status 200, got %d", w.Code)
  117. }
  118. t.Run("missing body", func(t *testing.T) {
  119. w := createRequest(t, s.ChatHandler, nil)
  120. if w.Code != http.StatusBadRequest {
  121. t.Errorf("expected status 400, got %d", w.Code)
  122. }
  123. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  124. t.Errorf("mismatch (-got +want):\n%s", diff)
  125. }
  126. })
  127. t.Run("missing model", func(t *testing.T) {
  128. w := createRequest(t, s.ChatHandler, api.ChatRequest{})
  129. if w.Code != http.StatusBadRequest {
  130. t.Errorf("expected status 400, got %d", w.Code)
  131. }
  132. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  133. t.Errorf("mismatch (-got +want):\n%s", diff)
  134. }
  135. })
  136. t.Run("missing capabilities chat", func(t *testing.T) {
  137. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  138. Model: "bert",
  139. Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
  140. "general.architecture": "bert",
  141. "bert.pooling_type": uint32(0),
  142. }, []llm.Tensor{})),
  143. Stream: &stream,
  144. })
  145. if w.Code != http.StatusOK {
  146. t.Fatalf("expected status 200, got %d", w.Code)
  147. }
  148. w = createRequest(t, s.ChatHandler, api.ChatRequest{
  149. Model: "bert",
  150. })
  151. if w.Code != http.StatusBadRequest {
  152. t.Errorf("expected status 400, got %d", w.Code)
  153. }
  154. if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
  155. t.Errorf("mismatch (-got +want):\n%s", diff)
  156. }
  157. })
  158. t.Run("load model", func(t *testing.T) {
  159. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  160. Model: "test",
  161. })
  162. if w.Code != http.StatusOK {
  163. t.Errorf("expected status 200, got %d", w.Code)
  164. }
  165. var actual api.ChatResponse
  166. if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
  167. t.Fatal(err)
  168. }
  169. if actual.Model != "test" {
  170. t.Errorf("expected model test, got %s", actual.Model)
  171. }
  172. if !actual.Done {
  173. t.Errorf("expected done true, got false")
  174. }
  175. if actual.DoneReason != "load" {
  176. t.Errorf("expected done reason load, got %s", actual.DoneReason)
  177. }
  178. })
  179. checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
  180. t.Helper()
  181. var actual api.ChatResponse
  182. if err := json.NewDecoder(body).Decode(&actual); err != nil {
  183. t.Fatal(err)
  184. }
  185. if actual.Model != model {
  186. t.Errorf("expected model test, got %s", actual.Model)
  187. }
  188. if !actual.Done {
  189. t.Errorf("expected done false, got true")
  190. }
  191. if actual.DoneReason != "stop" {
  192. t.Errorf("expected done reason stop, got %s", actual.DoneReason)
  193. }
  194. if diff := cmp.Diff(actual.Message, api.Message{
  195. Role: "assistant",
  196. Content: content,
  197. }); diff != "" {
  198. t.Errorf("mismatch (-got +want):\n%s", diff)
  199. }
  200. if actual.PromptEvalCount == 0 {
  201. t.Errorf("expected prompt eval count > 0, got 0")
  202. }
  203. if actual.PromptEvalDuration == 0 {
  204. t.Errorf("expected prompt eval duration > 0, got 0")
  205. }
  206. if actual.EvalCount == 0 {
  207. t.Errorf("expected eval count > 0, got 0")
  208. }
  209. if actual.EvalDuration == 0 {
  210. t.Errorf("expected eval duration > 0, got 0")
  211. }
  212. if actual.LoadDuration == 0 {
  213. t.Errorf("expected load duration > 0, got 0")
  214. }
  215. if actual.TotalDuration == 0 {
  216. t.Errorf("expected total duration > 0, got 0")
  217. }
  218. }
  219. mock.CompletionResponse.Content = "Hi!"
  220. t.Run("messages", func(t *testing.T) {
  221. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  222. Model: "test",
  223. Messages: []api.Message{
  224. {Role: "user", Content: "Hello!"},
  225. },
  226. Stream: &stream,
  227. })
  228. if w.Code != http.StatusOK {
  229. t.Errorf("expected status 200, got %d", w.Code)
  230. }
  231. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
  232. t.Errorf("mismatch (-got +want):\n%s", diff)
  233. }
  234. checkChatResponse(t, w.Body, "test", "Hi!")
  235. })
  236. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  237. Model: "test-system",
  238. Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
  239. })
  240. if w.Code != http.StatusOK {
  241. t.Fatalf("expected status 200, got %d", w.Code)
  242. }
  243. t.Run("messages with model system", func(t *testing.T) {
  244. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  245. Model: "test-system",
  246. Messages: []api.Message{
  247. {Role: "user", Content: "Hello!"},
  248. },
  249. Stream: &stream,
  250. })
  251. if w.Code != http.StatusOK {
  252. t.Errorf("expected status 200, got %d", w.Code)
  253. }
  254. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
  255. t.Errorf("mismatch (-got +want):\n%s", diff)
  256. }
  257. checkChatResponse(t, w.Body, "test-system", "Hi!")
  258. })
  259. mock.CompletionResponse.Content = "Abra kadabra!"
  260. t.Run("messages with system", func(t *testing.T) {
  261. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  262. Model: "test-system",
  263. Messages: []api.Message{
  264. {Role: "system", Content: "You can perform magic tricks."},
  265. {Role: "user", Content: "Hello!"},
  266. },
  267. Stream: &stream,
  268. })
  269. if w.Code != http.StatusOK {
  270. t.Errorf("expected status 200, got %d", w.Code)
  271. }
  272. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
  273. t.Errorf("mismatch (-got +want):\n%s", diff)
  274. }
  275. checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
  276. })
  277. t.Run("messages with interleaved system", func(t *testing.T) {
  278. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  279. Model: "test-system",
  280. Messages: []api.Message{
  281. {Role: "user", Content: "Hello!"},
  282. {Role: "assistant", Content: "I can help you with that."},
  283. {Role: "system", Content: "You can perform magic tricks."},
  284. {Role: "user", Content: "Help me write tests."},
  285. },
  286. Stream: &stream,
  287. })
  288. if w.Code != http.StatusOK {
  289. t.Errorf("expected status 200, got %d", w.Code)
  290. }
  291. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
  292. t.Errorf("mismatch (-got +want):\n%s", diff)
  293. }
  294. checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
  295. })
  296. t.Run("messages with tools (non-streaming)", func(t *testing.T) {
  297. if w.Code != http.StatusOK {
  298. t.Fatalf("failed to create test-system model: %d", w.Code)
  299. }
  300. tools := []api.Tool{
  301. {
  302. Type: "function",
  303. Function: api.ToolFunction{
  304. Name: "get_weather",
  305. Description: "Get the current weather",
  306. Parameters: struct {
  307. Type string `json:"type"`
  308. Required []string `json:"required"`
  309. Properties map[string]struct {
  310. Type string `json:"type"`
  311. Description string `json:"description"`
  312. Enum []string `json:"enum,omitempty"`
  313. } `json:"properties"`
  314. }{
  315. Type: "object",
  316. Required: []string{"location"},
  317. Properties: map[string]struct {
  318. Type string `json:"type"`
  319. Description string `json:"description"`
  320. Enum []string `json:"enum,omitempty"`
  321. }{
  322. "location": {
  323. Type: "string",
  324. Description: "The city and state",
  325. },
  326. "unit": {
  327. Type: "string",
  328. Enum: []string{"celsius", "fahrenheit"},
  329. },
  330. },
  331. },
  332. },
  333. },
  334. }
  335. mock.CompletionResponse = llm.CompletionResponse{
  336. Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
  337. Done: true,
  338. DoneReason: "done",
  339. PromptEvalCount: 1,
  340. PromptEvalDuration: 1,
  341. EvalCount: 1,
  342. EvalDuration: 1,
  343. }
  344. streamRequest := true
  345. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  346. Model: "test-system",
  347. Messages: []api.Message{
  348. {Role: "user", Content: "What's the weather in Seattle?"},
  349. },
  350. Tools: tools,
  351. Stream: &streamRequest,
  352. })
  353. if w.Code != http.StatusOK {
  354. var errResp struct {
  355. Error string `json:"error"`
  356. }
  357. if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
  358. t.Logf("Failed to decode error response: %v", err)
  359. } else {
  360. t.Logf("Error response: %s", errResp.Error)
  361. }
  362. }
  363. if w.Code != http.StatusOK {
  364. t.Errorf("expected status 200, got %d", w.Code)
  365. }
  366. var resp api.ChatResponse
  367. if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
  368. t.Fatal(err)
  369. }
  370. if resp.Message.ToolCalls == nil {
  371. t.Error("expected tool calls, got nil")
  372. }
  373. expectedToolCall := api.ToolCall{
  374. Function: api.ToolCallFunction{
  375. Name: "get_weather",
  376. Arguments: api.ToolCallFunctionArguments{
  377. "location": "Seattle, WA",
  378. "unit": "celsius",
  379. },
  380. },
  381. }
  382. if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
  383. t.Errorf("tool call mismatch (-got +want):\n%s", diff)
  384. }
  385. })
  386. t.Run("messages with tools (streaming)", func(t *testing.T) {
  387. tools := []api.Tool{
  388. {
  389. Type: "function",
  390. Function: api.ToolFunction{
  391. Name: "get_weather",
  392. Description: "Get the current weather",
  393. Parameters: struct {
  394. Type string `json:"type"`
  395. Required []string `json:"required"`
  396. Properties map[string]struct {
  397. Type string `json:"type"`
  398. Description string `json:"description"`
  399. Enum []string `json:"enum,omitempty"`
  400. } `json:"properties"`
  401. }{
  402. Type: "object",
  403. Required: []string{"location"},
  404. Properties: map[string]struct {
  405. Type string `json:"type"`
  406. Description string `json:"description"`
  407. Enum []string `json:"enum,omitempty"`
  408. }{
  409. "location": {
  410. Type: "string",
  411. Description: "The city and state",
  412. },
  413. "unit": {
  414. Type: "string",
  415. Enum: []string{"celsius", "fahrenheit"},
  416. },
  417. },
  418. },
  419. },
  420. },
  421. }
  422. // Simulate streaming response with multiple chunks
  423. var wg sync.WaitGroup
  424. wg.Add(1)
  425. mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
  426. defer wg.Done()
  427. // Send chunks with small delays to simulate streaming
  428. responses := []llm.CompletionResponse{
  429. {
  430. Content: `{"name":"get_`,
  431. Done: false,
  432. PromptEvalCount: 1,
  433. PromptEvalDuration: 1,
  434. },
  435. {
  436. Content: `weather","arguments":{"location":"Seattle`,
  437. Done: false,
  438. PromptEvalCount: 2,
  439. PromptEvalDuration: 1,
  440. },
  441. {
  442. Content: `, WA","unit":"celsius"}}`,
  443. Done: true,
  444. DoneReason: "tool_call",
  445. PromptEvalCount: 3,
  446. PromptEvalDuration: 1,
  447. },
  448. }
  449. for _, resp := range responses {
  450. select {
  451. case <-ctx.Done():
  452. return ctx.Err()
  453. default:
  454. fn(resp)
  455. time.Sleep(10 * time.Millisecond) // Small delay between chunks
  456. }
  457. }
  458. return nil
  459. }
  460. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  461. Model: "test-system",
  462. Messages: []api.Message{
  463. {Role: "user", Content: "What's the weather in Seattle?"},
  464. },
  465. Tools: tools,
  466. Stream: &stream,
  467. })
  468. wg.Wait()
  469. if w.Code != http.StatusOK {
  470. t.Errorf("expected status 200, got %d", w.Code)
  471. }
  472. // Read and validate the streamed responses
  473. decoder := json.NewDecoder(w.Body)
  474. var finalToolCall api.ToolCall
  475. for {
  476. var resp api.ChatResponse
  477. if err := decoder.Decode(&resp); err == io.EOF {
  478. break
  479. } else if err != nil {
  480. t.Fatal(err)
  481. }
  482. if resp.Done {
  483. if len(resp.Message.ToolCalls) != 1 {
  484. t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
  485. }
  486. finalToolCall = resp.Message.ToolCalls[0]
  487. }
  488. }
  489. expectedToolCall := api.ToolCall{
  490. Function: api.ToolCallFunction{
  491. Name: "get_weather",
  492. Arguments: api.ToolCallFunctionArguments{
  493. "location": "Seattle, WA",
  494. "unit": "celsius",
  495. },
  496. },
  497. }
  498. if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
  499. t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
  500. }
  501. })
  502. }
  503. func TestGenerate(t *testing.T) {
  504. gin.SetMode(gin.TestMode)
  505. mock := mockRunner{
  506. CompletionResponse: llm.CompletionResponse{
  507. Done: true,
  508. DoneReason: "stop",
  509. PromptEvalCount: 1,
  510. PromptEvalDuration: 1,
  511. EvalCount: 1,
  512. EvalDuration: 1,
  513. },
  514. }
  515. s := Server{
  516. sched: &Scheduler{
  517. pendingReqCh: make(chan *LlmRequest, 1),
  518. finishedReqCh: make(chan *LlmRequest, 1),
  519. expiredCh: make(chan *runnerRef, 1),
  520. unloadedCh: make(chan any, 1),
  521. loaded: make(map[string]*runnerRef),
  522. newServerFn: newMockServer(&mock),
  523. getGpuFn: discover.GetGPUInfo,
  524. getCpuFn: discover.GetCPUInfo,
  525. reschedDelay: 250 * time.Millisecond,
  526. loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
  527. // add small delay to simulate loading
  528. time.Sleep(time.Millisecond)
  529. req.successCh <- &runnerRef{
  530. llama: &mock,
  531. }
  532. },
  533. },
  534. }
  535. go s.sched.Run(context.TODO())
  536. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  537. Model: "test",
  538. Modelfile: fmt.Sprintf(`FROM %s
  539. TEMPLATE """
  540. {{- if .System }}System: {{ .System }} {{ end }}
  541. {{- if .Prompt }}User: {{ .Prompt }} {{ end }}
  542. {{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
  543. `, createBinFile(t, llm.KV{
  544. "general.architecture": "llama",
  545. "llama.block_count": uint32(1),
  546. "llama.context_length": uint32(8192),
  547. "llama.embedding_length": uint32(4096),
  548. "llama.attention.head_count": uint32(32),
  549. "llama.attention.head_count_kv": uint32(8),
  550. "tokenizer.ggml.tokens": []string{""},
  551. "tokenizer.ggml.scores": []float32{0},
  552. "tokenizer.ggml.token_type": []int32{0},
  553. }, []llm.Tensor{
  554. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  555. {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  556. {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  557. {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  558. {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  559. {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  560. {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  561. {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  562. {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  563. {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  564. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  565. })),
  566. Stream: &stream,
  567. })
  568. if w.Code != http.StatusOK {
  569. t.Fatalf("expected status 200, got %d", w.Code)
  570. }
  571. t.Run("missing body", func(t *testing.T) {
  572. w := createRequest(t, s.GenerateHandler, nil)
  573. if w.Code != http.StatusNotFound {
  574. t.Errorf("expected status 404, got %d", w.Code)
  575. }
  576. if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
  577. t.Errorf("mismatch (-got +want):\n%s", diff)
  578. }
  579. })
  580. t.Run("missing model", func(t *testing.T) {
  581. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
  582. if w.Code != http.StatusNotFound {
  583. t.Errorf("expected status 404, got %d", w.Code)
  584. }
  585. if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
  586. t.Errorf("mismatch (-got +want):\n%s", diff)
  587. }
  588. })
  589. t.Run("missing capabilities generate", func(t *testing.T) {
  590. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  591. Model: "bert",
  592. Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
  593. "general.architecture": "bert",
  594. "bert.pooling_type": uint32(0),
  595. }, []llm.Tensor{})),
  596. Stream: &stream,
  597. })
  598. if w.Code != http.StatusOK {
  599. t.Fatalf("expected status 200, got %d", w.Code)
  600. }
  601. w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
  602. Model: "bert",
  603. })
  604. if w.Code != http.StatusBadRequest {
  605. t.Errorf("expected status 400, got %d", w.Code)
  606. }
  607. if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
  608. t.Errorf("mismatch (-got +want):\n%s", diff)
  609. }
  610. })
  611. t.Run("missing capabilities suffix", func(t *testing.T) {
  612. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  613. Model: "test",
  614. Prompt: "def add(",
  615. Suffix: " return c",
  616. })
  617. if w.Code != http.StatusBadRequest {
  618. t.Errorf("expected status 400, got %d", w.Code)
  619. }
  620. if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support insert"}`); diff != "" {
  621. t.Errorf("mismatch (-got +want):\n%s", diff)
  622. }
  623. })
  624. t.Run("load model", func(t *testing.T) {
  625. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  626. Model: "test",
  627. })
  628. if w.Code != http.StatusOK {
  629. t.Errorf("expected status 200, got %d", w.Code)
  630. }
  631. var actual api.GenerateResponse
  632. if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
  633. t.Fatal(err)
  634. }
  635. if actual.Model != "test" {
  636. t.Errorf("expected model test, got %s", actual.Model)
  637. }
  638. if !actual.Done {
  639. t.Errorf("expected done true, got false")
  640. }
  641. if actual.DoneReason != "load" {
  642. t.Errorf("expected done reason load, got %s", actual.DoneReason)
  643. }
  644. })
  645. checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
  646. t.Helper()
  647. var actual api.GenerateResponse
  648. if err := json.NewDecoder(body).Decode(&actual); err != nil {
  649. t.Fatal(err)
  650. }
  651. if actual.Model != model {
  652. t.Errorf("expected model test, got %s", actual.Model)
  653. }
  654. if !actual.Done {
  655. t.Errorf("expected done false, got true")
  656. }
  657. if actual.DoneReason != "stop" {
  658. t.Errorf("expected done reason stop, got %s", actual.DoneReason)
  659. }
  660. if actual.Response != content {
  661. t.Errorf("expected response %s, got %s", content, actual.Response)
  662. }
  663. if actual.Context == nil {
  664. t.Errorf("expected context not nil")
  665. }
  666. if actual.PromptEvalCount == 0 {
  667. t.Errorf("expected prompt eval count > 0, got 0")
  668. }
  669. if actual.PromptEvalDuration == 0 {
  670. t.Errorf("expected prompt eval duration > 0, got 0")
  671. }
  672. if actual.EvalCount == 0 {
  673. t.Errorf("expected eval count > 0, got 0")
  674. }
  675. if actual.EvalDuration == 0 {
  676. t.Errorf("expected eval duration > 0, got 0")
  677. }
  678. if actual.LoadDuration == 0 {
  679. t.Errorf("expected load duration > 0, got 0")
  680. }
  681. if actual.TotalDuration == 0 {
  682. t.Errorf("expected total duration > 0, got 0")
  683. }
  684. }
  685. mock.CompletionResponse.Content = "Hi!"
  686. t.Run("prompt", func(t *testing.T) {
  687. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  688. Model: "test",
  689. Prompt: "Hello!",
  690. Stream: &stream,
  691. })
  692. if w.Code != http.StatusOK {
  693. t.Errorf("expected status 200, got %d", w.Code)
  694. }
  695. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
  696. t.Errorf("mismatch (-got +want):\n%s", diff)
  697. }
  698. checkGenerateResponse(t, w.Body, "test", "Hi!")
  699. })
  700. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  701. Model: "test-system",
  702. Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
  703. })
  704. if w.Code != http.StatusOK {
  705. t.Fatalf("expected status 200, got %d", w.Code)
  706. }
  707. t.Run("prompt with model system", func(t *testing.T) {
  708. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  709. Model: "test-system",
  710. Prompt: "Hello!",
  711. Stream: &stream,
  712. })
  713. if w.Code != http.StatusOK {
  714. t.Errorf("expected status 200, got %d", w.Code)
  715. }
  716. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
  717. t.Errorf("mismatch (-got +want):\n%s", diff)
  718. }
  719. checkGenerateResponse(t, w.Body, "test-system", "Hi!")
  720. })
  721. mock.CompletionResponse.Content = "Abra kadabra!"
  722. t.Run("prompt with system", func(t *testing.T) {
  723. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  724. Model: "test-system",
  725. Prompt: "Hello!",
  726. System: "You can perform magic tricks.",
  727. Stream: &stream,
  728. })
  729. if w.Code != http.StatusOK {
  730. t.Errorf("expected status 200, got %d", w.Code)
  731. }
  732. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
  733. t.Errorf("mismatch (-got +want):\n%s", diff)
  734. }
  735. checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
  736. })
  737. t.Run("prompt with template", func(t *testing.T) {
  738. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  739. Model: "test-system",
  740. Prompt: "Help me write tests.",
  741. System: "You can perform magic tricks.",
  742. Template: `{{- if .System }}{{ .System }} {{ end }}
  743. {{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
  744. {{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
  745. Stream: &stream,
  746. })
  747. if w.Code != http.StatusOK {
  748. t.Errorf("expected status 200, got %d", w.Code)
  749. }
  750. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
  751. t.Errorf("mismatch (-got +want):\n%s", diff)
  752. }
  753. checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
  754. })
  755. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  756. Model: "test-suffix",
  757. Modelfile: `FROM test
  758. TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
  759. {{- else }}{{ .Prompt }}
  760. {{- end }}"""`,
  761. })
  762. if w.Code != http.StatusOK {
  763. t.Fatalf("expected status 200, got %d", w.Code)
  764. }
  765. t.Run("prompt with suffix", func(t *testing.T) {
  766. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  767. Model: "test-suffix",
  768. Prompt: "def add(",
  769. Suffix: " return c",
  770. })
  771. if w.Code != http.StatusOK {
  772. t.Errorf("expected status 200, got %d", w.Code)
  773. }
  774. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
  775. t.Errorf("mismatch (-got +want):\n%s", diff)
  776. }
  777. })
  778. t.Run("prompt without suffix", func(t *testing.T) {
  779. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  780. Model: "test-suffix",
  781. Prompt: "def add(",
  782. })
  783. if w.Code != http.StatusOK {
  784. t.Errorf("expected status 200, got %d", w.Code)
  785. }
  786. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
  787. t.Errorf("mismatch (-got +want):\n%s", diff)
  788. }
  789. })
  790. t.Run("raw", func(t *testing.T) {
  791. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  792. Model: "test-system",
  793. Prompt: "Help me write tests.",
  794. Raw: true,
  795. Stream: &stream,
  796. })
  797. if w.Code != http.StatusOK {
  798. t.Errorf("expected status 200, got %d", w.Code)
  799. }
  800. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
  801. t.Errorf("mismatch (-got +want):\n%s", diff)
  802. }
  803. })
  804. }