routes_generate_test.go 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "sync"
  11. "testing"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. "github.com/google/go-cmp/cmp"
  15. "github.com/ollama/ollama/api"
  16. "github.com/ollama/ollama/discover"
  17. "github.com/ollama/ollama/llm"
  18. )
  19. type mockRunner struct {
  20. llm.LlamaServer
  21. // CompletionRequest is only valid until the next call to Completion
  22. llm.CompletionRequest
  23. llm.CompletionResponse
  24. CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
  25. }
  26. func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
  27. m.CompletionRequest = r
  28. if m.CompletionFn != nil {
  29. return m.CompletionFn(ctx, r, fn)
  30. }
  31. fn(m.CompletionResponse)
  32. return nil
  33. }
  34. func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
  35. for range strings.Fields(s) {
  36. tokens = append(tokens, len(tokens))
  37. }
  38. return
  39. }
  40. func (mockRunner) Detokenize(_ context.Context, tokens []int) (string, error) {
  41. var strs []string
  42. for _, t := range tokens {
  43. strs = append(strs, fmt.Sprint(t))
  44. }
  45. return strings.Join(strs, " "), nil
  46. }
  47. func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
  48. return func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
  49. return mock, nil
  50. }
  51. }
  52. func TestGenerateChat(t *testing.T) {
  53. gin.SetMode(gin.TestMode)
  54. mock := mockRunner{
  55. CompletionResponse: llm.CompletionResponse{
  56. Done: true,
  57. DoneReason: "stop",
  58. PromptEvalCount: 1,
  59. PromptEvalDuration: 1,
  60. EvalCount: 1,
  61. EvalDuration: 1,
  62. },
  63. }
  64. s := Server{
  65. sched: &Scheduler{
  66. pendingReqCh: make(chan *LlmRequest, 1),
  67. finishedReqCh: make(chan *LlmRequest, 1),
  68. expiredCh: make(chan *runnerRef, 1),
  69. unloadedCh: make(chan any, 1),
  70. loaded: make(map[string]*runnerRef),
  71. newServerFn: newMockServer(&mock),
  72. getGpuFn: discover.GetGPUInfo,
  73. getCpuFn: discover.GetCPUInfo,
  74. reschedDelay: 250 * time.Millisecond,
  75. loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
  76. // add small delay to simulate loading
  77. time.Sleep(time.Millisecond)
  78. req.successCh <- &runnerRef{
  79. llama: &mock,
  80. }
  81. },
  82. },
  83. }
  84. go s.sched.Run(context.TODO())
  85. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  86. Model: "test",
  87. Modelfile: fmt.Sprintf(`FROM %s
  88. TEMPLATE """
  89. {{- if .Tools }}
  90. {{ .Tools }}
  91. {{ end }}
  92. {{- range .Messages }}
  93. {{- .Role }}: {{ .Content }}
  94. {{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
  95. {{- end }}
  96. {{ end }}"""
  97. `, createBinFile(t, llm.KV{
  98. "general.architecture": "llama",
  99. "llama.block_count": uint32(1),
  100. "llama.context_length": uint32(8192),
  101. "llama.embedding_length": uint32(4096),
  102. "llama.attention.head_count": uint32(32),
  103. "llama.attention.head_count_kv": uint32(8),
  104. "tokenizer.ggml.tokens": []string{""},
  105. "tokenizer.ggml.scores": []float32{0},
  106. "tokenizer.ggml.token_type": []int32{0},
  107. }, []llm.Tensor{
  108. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  109. {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  110. {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  111. {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  112. {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  113. {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  114. {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  115. {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  116. {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  117. {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  118. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  119. })),
  120. Stream: &stream,
  121. })
  122. if w.Code != http.StatusOK {
  123. t.Fatalf("expected status 200, got %d", w.Code)
  124. }
  125. t.Run("missing body", func(t *testing.T) {
  126. w := createRequest(t, s.ChatHandler, nil)
  127. if w.Code != http.StatusBadRequest {
  128. t.Errorf("expected status 400, got %d", w.Code)
  129. }
  130. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  131. t.Errorf("mismatch (-got +want):\n%s", diff)
  132. }
  133. })
  134. t.Run("missing model", func(t *testing.T) {
  135. w := createRequest(t, s.ChatHandler, api.ChatRequest{})
  136. if w.Code != http.StatusBadRequest {
  137. t.Errorf("expected status 400, got %d", w.Code)
  138. }
  139. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  140. t.Errorf("mismatch (-got +want):\n%s", diff)
  141. }
  142. })
  143. t.Run("missing capabilities chat", func(t *testing.T) {
  144. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  145. Model: "bert",
  146. Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
  147. "general.architecture": "bert",
  148. "bert.pooling_type": uint32(0),
  149. }, []llm.Tensor{})),
  150. Stream: &stream,
  151. })
  152. if w.Code != http.StatusOK {
  153. t.Fatalf("expected status 200, got %d", w.Code)
  154. }
  155. w = createRequest(t, s.ChatHandler, api.ChatRequest{
  156. Model: "bert",
  157. })
  158. if w.Code != http.StatusBadRequest {
  159. t.Errorf("expected status 400, got %d", w.Code)
  160. }
  161. if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
  162. t.Errorf("mismatch (-got +want):\n%s", diff)
  163. }
  164. })
  165. t.Run("load model", func(t *testing.T) {
  166. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  167. Model: "test",
  168. })
  169. if w.Code != http.StatusOK {
  170. t.Errorf("expected status 200, got %d", w.Code)
  171. }
  172. var actual api.ChatResponse
  173. if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
  174. t.Fatal(err)
  175. }
  176. if actual.Model != "test" {
  177. t.Errorf("expected model test, got %s", actual.Model)
  178. }
  179. if !actual.Done {
  180. t.Errorf("expected done true, got false")
  181. }
  182. if actual.DoneReason != "load" {
  183. t.Errorf("expected done reason load, got %s", actual.DoneReason)
  184. }
  185. })
  186. checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
  187. t.Helper()
  188. var actual api.ChatResponse
  189. if err := json.NewDecoder(body).Decode(&actual); err != nil {
  190. t.Fatal(err)
  191. }
  192. if actual.Model != model {
  193. t.Errorf("expected model test, got %s", actual.Model)
  194. }
  195. if !actual.Done {
  196. t.Errorf("expected done false, got true")
  197. }
  198. if actual.DoneReason != "stop" {
  199. t.Errorf("expected done reason stop, got %s", actual.DoneReason)
  200. }
  201. if diff := cmp.Diff(actual.Message, api.Message{
  202. Role: "assistant",
  203. Content: content,
  204. }); diff != "" {
  205. t.Errorf("mismatch (-got +want):\n%s", diff)
  206. }
  207. if actual.PromptEvalCount == 0 {
  208. t.Errorf("expected prompt eval count > 0, got 0")
  209. }
  210. if actual.PromptEvalDuration == 0 {
  211. t.Errorf("expected prompt eval duration > 0, got 0")
  212. }
  213. if actual.EvalCount == 0 {
  214. t.Errorf("expected eval count > 0, got 0")
  215. }
  216. if actual.EvalDuration == 0 {
  217. t.Errorf("expected eval duration > 0, got 0")
  218. }
  219. if actual.LoadDuration == 0 {
  220. t.Errorf("expected load duration > 0, got 0")
  221. }
  222. if actual.TotalDuration == 0 {
  223. t.Errorf("expected total duration > 0, got 0")
  224. }
  225. }
  226. mock.CompletionResponse.Content = "Hi!"
  227. t.Run("messages", func(t *testing.T) {
  228. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  229. Model: "test",
  230. Messages: []api.Message{
  231. {Role: "user", Content: "Hello!"},
  232. },
  233. Stream: &stream,
  234. })
  235. if w.Code != http.StatusOK {
  236. t.Errorf("expected status 200, got %d", w.Code)
  237. }
  238. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
  239. t.Errorf("mismatch (-got +want):\n%s", diff)
  240. }
  241. checkChatResponse(t, w.Body, "test", "Hi!")
  242. })
  243. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  244. Model: "test-system",
  245. Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
  246. })
  247. if w.Code != http.StatusOK {
  248. t.Fatalf("expected status 200, got %d", w.Code)
  249. }
  250. t.Run("messages with model system", func(t *testing.T) {
  251. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  252. Model: "test-system",
  253. Messages: []api.Message{
  254. {Role: "user", Content: "Hello!"},
  255. },
  256. Stream: &stream,
  257. })
  258. if w.Code != http.StatusOK {
  259. t.Errorf("expected status 200, got %d", w.Code)
  260. }
  261. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
  262. t.Errorf("mismatch (-got +want):\n%s", diff)
  263. }
  264. checkChatResponse(t, w.Body, "test-system", "Hi!")
  265. })
  266. mock.CompletionResponse.Content = "Abra kadabra!"
  267. t.Run("messages with system", func(t *testing.T) {
  268. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  269. Model: "test-system",
  270. Messages: []api.Message{
  271. {Role: "system", Content: "You can perform magic tricks."},
  272. {Role: "user", Content: "Hello!"},
  273. },
  274. Stream: &stream,
  275. })
  276. if w.Code != http.StatusOK {
  277. t.Errorf("expected status 200, got %d", w.Code)
  278. }
  279. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
  280. t.Errorf("mismatch (-got +want):\n%s", diff)
  281. }
  282. checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
  283. })
  284. t.Run("messages with interleaved system", func(t *testing.T) {
  285. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  286. Model: "test-system",
  287. Messages: []api.Message{
  288. {Role: "user", Content: "Hello!"},
  289. {Role: "assistant", Content: "I can help you with that."},
  290. {Role: "system", Content: "You can perform magic tricks."},
  291. {Role: "user", Content: "Help me write tests."},
  292. },
  293. Stream: &stream,
  294. })
  295. if w.Code != http.StatusOK {
  296. t.Errorf("expected status 200, got %d", w.Code)
  297. }
  298. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
  299. t.Errorf("mismatch (-got +want):\n%s", diff)
  300. }
  301. checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
  302. })
  303. t.Run("messages with tools (non-streaming)", func(t *testing.T) {
  304. if w.Code != http.StatusOK {
  305. t.Fatalf("failed to create test-system model: %d", w.Code)
  306. }
  307. tools := []api.Tool{
  308. {
  309. Type: "function",
  310. Function: api.ToolFunction{
  311. Name: "get_weather",
  312. Description: "Get the current weather",
  313. Parameters: struct {
  314. Type string `json:"type"`
  315. Required []string `json:"required"`
  316. Properties map[string]struct {
  317. Type string `json:"type"`
  318. Description string `json:"description"`
  319. Enum []string `json:"enum,omitempty"`
  320. } `json:"properties"`
  321. }{
  322. Type: "object",
  323. Required: []string{"location"},
  324. Properties: map[string]struct {
  325. Type string `json:"type"`
  326. Description string `json:"description"`
  327. Enum []string `json:"enum,omitempty"`
  328. }{
  329. "location": {
  330. Type: "string",
  331. Description: "The city and state",
  332. },
  333. "unit": {
  334. Type: "string",
  335. Enum: []string{"celsius", "fahrenheit"},
  336. },
  337. },
  338. },
  339. },
  340. },
  341. }
  342. mock.CompletionResponse = llm.CompletionResponse{
  343. Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
  344. Done: true,
  345. DoneReason: "done",
  346. PromptEvalCount: 1,
  347. PromptEvalDuration: 1,
  348. EvalCount: 1,
  349. EvalDuration: 1,
  350. }
  351. streamRequest := true
  352. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  353. Model: "test-system",
  354. Messages: []api.Message{
  355. {Role: "user", Content: "What's the weather in Seattle?"},
  356. },
  357. Tools: tools,
  358. Stream: &streamRequest,
  359. })
  360. if w.Code != http.StatusOK {
  361. var errResp struct {
  362. Error string `json:"error"`
  363. }
  364. if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
  365. t.Logf("Failed to decode error response: %v", err)
  366. } else {
  367. t.Logf("Error response: %s", errResp.Error)
  368. }
  369. }
  370. if w.Code != http.StatusOK {
  371. t.Errorf("expected status 200, got %d", w.Code)
  372. }
  373. var resp api.ChatResponse
  374. if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
  375. t.Fatal(err)
  376. }
  377. if resp.Message.ToolCalls == nil {
  378. t.Error("expected tool calls, got nil")
  379. }
  380. expectedToolCall := api.ToolCall{
  381. Function: api.ToolCallFunction{
  382. Name: "get_weather",
  383. Arguments: api.ToolCallFunctionArguments{
  384. "location": "Seattle, WA",
  385. "unit": "celsius",
  386. },
  387. },
  388. }
  389. if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
  390. t.Errorf("tool call mismatch (-got +want):\n%s", diff)
  391. }
  392. })
  393. t.Run("messages with tools (streaming)", func(t *testing.T) {
  394. tools := []api.Tool{
  395. {
  396. Type: "function",
  397. Function: api.ToolFunction{
  398. Name: "get_weather",
  399. Description: "Get the current weather",
  400. Parameters: struct {
  401. Type string `json:"type"`
  402. Required []string `json:"required"`
  403. Properties map[string]struct {
  404. Type string `json:"type"`
  405. Description string `json:"description"`
  406. Enum []string `json:"enum,omitempty"`
  407. } `json:"properties"`
  408. }{
  409. Type: "object",
  410. Required: []string{"location"},
  411. Properties: map[string]struct {
  412. Type string `json:"type"`
  413. Description string `json:"description"`
  414. Enum []string `json:"enum,omitempty"`
  415. }{
  416. "location": {
  417. Type: "string",
  418. Description: "The city and state",
  419. },
  420. "unit": {
  421. Type: "string",
  422. Enum: []string{"celsius", "fahrenheit"},
  423. },
  424. },
  425. },
  426. },
  427. },
  428. }
  429. // Simulate streaming response with multiple chunks
  430. var wg sync.WaitGroup
  431. wg.Add(1)
  432. mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
  433. defer wg.Done()
  434. // Send chunks with small delays to simulate streaming
  435. responses := []llm.CompletionResponse{
  436. {
  437. Content: `{"name":"get_`,
  438. Done: false,
  439. PromptEvalCount: 1,
  440. PromptEvalDuration: 1,
  441. },
  442. {
  443. Content: `weather","arguments":{"location":"Seattle`,
  444. Done: false,
  445. PromptEvalCount: 2,
  446. PromptEvalDuration: 1,
  447. },
  448. {
  449. Content: `, WA","unit":"celsius"}}`,
  450. Done: true,
  451. DoneReason: "tool_call",
  452. PromptEvalCount: 3,
  453. PromptEvalDuration: 1,
  454. },
  455. }
  456. for _, resp := range responses {
  457. select {
  458. case <-ctx.Done():
  459. return ctx.Err()
  460. default:
  461. fn(resp)
  462. time.Sleep(10 * time.Millisecond) // Small delay between chunks
  463. }
  464. }
  465. return nil
  466. }
  467. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  468. Model: "test-system",
  469. Messages: []api.Message{
  470. {Role: "user", Content: "What's the weather in Seattle?"},
  471. },
  472. Tools: tools,
  473. Stream: &stream,
  474. })
  475. wg.Wait()
  476. if w.Code != http.StatusOK {
  477. t.Errorf("expected status 200, got %d", w.Code)
  478. }
  479. // Read and validate the streamed responses
  480. decoder := json.NewDecoder(w.Body)
  481. var finalToolCall api.ToolCall
  482. for {
  483. var resp api.ChatResponse
  484. if err := decoder.Decode(&resp); err == io.EOF {
  485. break
  486. } else if err != nil {
  487. t.Fatal(err)
  488. }
  489. if resp.Done {
  490. if len(resp.Message.ToolCalls) != 1 {
  491. t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
  492. }
  493. finalToolCall = resp.Message.ToolCalls[0]
  494. }
  495. }
  496. expectedToolCall := api.ToolCall{
  497. Function: api.ToolCallFunction{
  498. Name: "get_weather",
  499. Arguments: api.ToolCallFunctionArguments{
  500. "location": "Seattle, WA",
  501. "unit": "celsius",
  502. },
  503. },
  504. }
  505. if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
  506. t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
  507. }
  508. })
  509. }
  510. func TestGenerate(t *testing.T) {
  511. gin.SetMode(gin.TestMode)
  512. mock := mockRunner{
  513. CompletionResponse: llm.CompletionResponse{
  514. Done: true,
  515. DoneReason: "stop",
  516. PromptEvalCount: 1,
  517. PromptEvalDuration: 1,
  518. EvalCount: 1,
  519. EvalDuration: 1,
  520. },
  521. }
  522. s := Server{
  523. sched: &Scheduler{
  524. pendingReqCh: make(chan *LlmRequest, 1),
  525. finishedReqCh: make(chan *LlmRequest, 1),
  526. expiredCh: make(chan *runnerRef, 1),
  527. unloadedCh: make(chan any, 1),
  528. loaded: make(map[string]*runnerRef),
  529. newServerFn: newMockServer(&mock),
  530. getGpuFn: discover.GetGPUInfo,
  531. getCpuFn: discover.GetCPUInfo,
  532. reschedDelay: 250 * time.Millisecond,
  533. loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
  534. // add small delay to simulate loading
  535. time.Sleep(time.Millisecond)
  536. req.successCh <- &runnerRef{
  537. llama: &mock,
  538. }
  539. },
  540. },
  541. }
  542. go s.sched.Run(context.TODO())
  543. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  544. Model: "test",
  545. Modelfile: fmt.Sprintf(`FROM %s
  546. TEMPLATE """
  547. {{- if .System }}System: {{ .System }} {{ end }}
  548. {{- if .Prompt }}User: {{ .Prompt }} {{ end }}
  549. {{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
  550. `, createBinFile(t, llm.KV{
  551. "general.architecture": "llama",
  552. "llama.block_count": uint32(1),
  553. "llama.context_length": uint32(8192),
  554. "llama.embedding_length": uint32(4096),
  555. "llama.attention.head_count": uint32(32),
  556. "llama.attention.head_count_kv": uint32(8),
  557. "tokenizer.ggml.tokens": []string{""},
  558. "tokenizer.ggml.scores": []float32{0},
  559. "tokenizer.ggml.token_type": []int32{0},
  560. }, []llm.Tensor{
  561. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  562. {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  563. {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  564. {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  565. {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  566. {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  567. {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  568. {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  569. {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  570. {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  571. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  572. })),
  573. Stream: &stream,
  574. })
  575. if w.Code != http.StatusOK {
  576. t.Fatalf("expected status 200, got %d", w.Code)
  577. }
  578. t.Run("missing body", func(t *testing.T) {
  579. w := createRequest(t, s.GenerateHandler, nil)
  580. if w.Code != http.StatusNotFound {
  581. t.Errorf("expected status 404, got %d", w.Code)
  582. }
  583. if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
  584. t.Errorf("mismatch (-got +want):\n%s", diff)
  585. }
  586. })
  587. t.Run("missing model", func(t *testing.T) {
  588. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
  589. if w.Code != http.StatusNotFound {
  590. t.Errorf("expected status 404, got %d", w.Code)
  591. }
  592. if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
  593. t.Errorf("mismatch (-got +want):\n%s", diff)
  594. }
  595. })
  596. t.Run("missing capabilities generate", func(t *testing.T) {
  597. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  598. Model: "bert",
  599. Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
  600. "general.architecture": "bert",
  601. "bert.pooling_type": uint32(0),
  602. }, []llm.Tensor{})),
  603. Stream: &stream,
  604. })
  605. if w.Code != http.StatusOK {
  606. t.Fatalf("expected status 200, got %d", w.Code)
  607. }
  608. w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
  609. Model: "bert",
  610. })
  611. if w.Code != http.StatusBadRequest {
  612. t.Errorf("expected status 400, got %d", w.Code)
  613. }
  614. if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
  615. t.Errorf("mismatch (-got +want):\n%s", diff)
  616. }
  617. })
  618. t.Run("missing capabilities suffix", func(t *testing.T) {
  619. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  620. Model: "test",
  621. Prompt: "def add(",
  622. Suffix: " return c",
  623. })
  624. if w.Code != http.StatusBadRequest {
  625. t.Errorf("expected status 400, got %d", w.Code)
  626. }
  627. if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support insert"}`); diff != "" {
  628. t.Errorf("mismatch (-got +want):\n%s", diff)
  629. }
  630. })
  631. t.Run("load model", func(t *testing.T) {
  632. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  633. Model: "test",
  634. })
  635. if w.Code != http.StatusOK {
  636. t.Errorf("expected status 200, got %d", w.Code)
  637. }
  638. var actual api.GenerateResponse
  639. if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
  640. t.Fatal(err)
  641. }
  642. if actual.Model != "test" {
  643. t.Errorf("expected model test, got %s", actual.Model)
  644. }
  645. if !actual.Done {
  646. t.Errorf("expected done true, got false")
  647. }
  648. if actual.DoneReason != "load" {
  649. t.Errorf("expected done reason load, got %s", actual.DoneReason)
  650. }
  651. })
  652. checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
  653. t.Helper()
  654. var actual api.GenerateResponse
  655. if err := json.NewDecoder(body).Decode(&actual); err != nil {
  656. t.Fatal(err)
  657. }
  658. if actual.Model != model {
  659. t.Errorf("expected model test, got %s", actual.Model)
  660. }
  661. if !actual.Done {
  662. t.Errorf("expected done false, got true")
  663. }
  664. if actual.DoneReason != "stop" {
  665. t.Errorf("expected done reason stop, got %s", actual.DoneReason)
  666. }
  667. if actual.Response != content {
  668. t.Errorf("expected response %s, got %s", content, actual.Response)
  669. }
  670. if actual.Context == nil {
  671. t.Errorf("expected context not nil")
  672. }
  673. if actual.PromptEvalCount == 0 {
  674. t.Errorf("expected prompt eval count > 0, got 0")
  675. }
  676. if actual.PromptEvalDuration == 0 {
  677. t.Errorf("expected prompt eval duration > 0, got 0")
  678. }
  679. if actual.EvalCount == 0 {
  680. t.Errorf("expected eval count > 0, got 0")
  681. }
  682. if actual.EvalDuration == 0 {
  683. t.Errorf("expected eval duration > 0, got 0")
  684. }
  685. if actual.LoadDuration == 0 {
  686. t.Errorf("expected load duration > 0, got 0")
  687. }
  688. if actual.TotalDuration == 0 {
  689. t.Errorf("expected total duration > 0, got 0")
  690. }
  691. }
  692. mock.CompletionResponse.Content = "Hi!"
  693. t.Run("prompt", func(t *testing.T) {
  694. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  695. Model: "test",
  696. Prompt: "Hello!",
  697. Stream: &stream,
  698. })
  699. if w.Code != http.StatusOK {
  700. t.Errorf("expected status 200, got %d", w.Code)
  701. }
  702. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
  703. t.Errorf("mismatch (-got +want):\n%s", diff)
  704. }
  705. checkGenerateResponse(t, w.Body, "test", "Hi!")
  706. })
  707. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  708. Model: "test-system",
  709. Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
  710. })
  711. if w.Code != http.StatusOK {
  712. t.Fatalf("expected status 200, got %d", w.Code)
  713. }
  714. t.Run("prompt with model system", func(t *testing.T) {
  715. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  716. Model: "test-system",
  717. Prompt: "Hello!",
  718. Stream: &stream,
  719. })
  720. if w.Code != http.StatusOK {
  721. t.Errorf("expected status 200, got %d", w.Code)
  722. }
  723. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
  724. t.Errorf("mismatch (-got +want):\n%s", diff)
  725. }
  726. checkGenerateResponse(t, w.Body, "test-system", "Hi!")
  727. })
  728. mock.CompletionResponse.Content = "Abra kadabra!"
  729. t.Run("prompt with system", func(t *testing.T) {
  730. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  731. Model: "test-system",
  732. Prompt: "Hello!",
  733. System: "You can perform magic tricks.",
  734. Stream: &stream,
  735. })
  736. if w.Code != http.StatusOK {
  737. t.Errorf("expected status 200, got %d", w.Code)
  738. }
  739. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
  740. t.Errorf("mismatch (-got +want):\n%s", diff)
  741. }
  742. checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
  743. })
  744. t.Run("prompt with template", func(t *testing.T) {
  745. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  746. Model: "test-system",
  747. Prompt: "Help me write tests.",
  748. System: "You can perform magic tricks.",
  749. Template: `{{- if .System }}{{ .System }} {{ end }}
  750. {{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
  751. {{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
  752. Stream: &stream,
  753. })
  754. if w.Code != http.StatusOK {
  755. t.Errorf("expected status 200, got %d", w.Code)
  756. }
  757. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
  758. t.Errorf("mismatch (-got +want):\n%s", diff)
  759. }
  760. checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
  761. })
  762. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  763. Model: "test-suffix",
  764. Modelfile: `FROM test
  765. TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
  766. {{- else }}{{ .Prompt }}
  767. {{- end }}"""`,
  768. })
  769. if w.Code != http.StatusOK {
  770. t.Fatalf("expected status 200, got %d", w.Code)
  771. }
  772. t.Run("prompt with suffix", func(t *testing.T) {
  773. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  774. Model: "test-suffix",
  775. Prompt: "def add(",
  776. Suffix: " return c",
  777. })
  778. if w.Code != http.StatusOK {
  779. t.Errorf("expected status 200, got %d", w.Code)
  780. }
  781. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
  782. t.Errorf("mismatch (-got +want):\n%s", diff)
  783. }
  784. })
  785. t.Run("prompt without suffix", func(t *testing.T) {
  786. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  787. Model: "test-suffix",
  788. Prompt: "def add(",
  789. })
  790. if w.Code != http.StatusOK {
  791. t.Errorf("expected status 200, got %d", w.Code)
  792. }
  793. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
  794. t.Errorf("mismatch (-got +want):\n%s", diff)
  795. }
  796. })
  797. t.Run("raw", func(t *testing.T) {
  798. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  799. Model: "test-system",
  800. Prompt: "Help me write tests.",
  801. Raw: true,
  802. Stream: &stream,
  803. })
  804. if w.Code != http.StatusOK {
  805. t.Errorf("expected status 200, got %d", w.Code)
  806. }
  807. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
  808. t.Errorf("mismatch (-got +want):\n%s", diff)
  809. }
  810. })
  811. }