routes_generate_test.go 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "sync"
  10. "testing"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/google/go-cmp/cmp"
  14. "github.com/ollama/ollama/api"
  15. "github.com/ollama/ollama/discover"
  16. "github.com/ollama/ollama/fs/ggml"
  17. "github.com/ollama/ollama/llm"
  18. )
  19. type mockRunner struct {
  20. llm.LlamaServer
  21. // CompletionRequest is only valid until the next call to Completion
  22. llm.CompletionRequest
  23. llm.CompletionResponse
  24. CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
  25. }
  26. func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
  27. m.CompletionRequest = r
  28. if m.CompletionFn != nil {
  29. return m.CompletionFn(ctx, r, fn)
  30. }
  31. fn(m.CompletionResponse)
  32. return nil
  33. }
  34. func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
  35. for range strings.Fields(s) {
  36. tokens = append(tokens, len(tokens))
  37. }
  38. return
  39. }
  40. func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
  41. return func(_ discover.GpuInfoList, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
  42. return mock, nil
  43. }
  44. }
  45. func TestGenerateChat(t *testing.T) {
  46. gin.SetMode(gin.TestMode)
  47. mock := mockRunner{
  48. CompletionResponse: llm.CompletionResponse{
  49. Done: true,
  50. DoneReason: "stop",
  51. PromptEvalCount: 1,
  52. PromptEvalDuration: 1,
  53. EvalCount: 1,
  54. EvalDuration: 1,
  55. },
  56. }
  57. s := Server{
  58. sched: &Scheduler{
  59. pendingReqCh: make(chan *LlmRequest, 1),
  60. finishedReqCh: make(chan *LlmRequest, 1),
  61. expiredCh: make(chan *runnerRef, 1),
  62. unloadedCh: make(chan any, 1),
  63. loaded: make(map[string]*runnerRef),
  64. newServerFn: newMockServer(&mock),
  65. getGpuFn: discover.GetGPUInfo,
  66. getCpuFn: discover.GetCPUInfo,
  67. reschedDelay: 250 * time.Millisecond,
  68. loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
  69. // add small delay to simulate loading
  70. time.Sleep(time.Millisecond)
  71. req.successCh <- &runnerRef{
  72. llama: &mock,
  73. }
  74. },
  75. },
  76. }
  77. go s.sched.Run(context.TODO())
  78. _, digest := createBinFile(t, ggml.KV{
  79. "general.architecture": "llama",
  80. "llama.block_count": uint32(1),
  81. "llama.context_length": uint32(8192),
  82. "llama.embedding_length": uint32(4096),
  83. "llama.attention.head_count": uint32(32),
  84. "llama.attention.head_count_kv": uint32(8),
  85. "tokenizer.ggml.tokens": []string{""},
  86. "tokenizer.ggml.scores": []float32{0},
  87. "tokenizer.ggml.token_type": []int32{0},
  88. }, []ggml.Tensor{
  89. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  90. {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  91. {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  92. {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  93. {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  94. {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  95. {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  96. {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  97. {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  98. {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  99. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  100. })
  101. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  102. Model: "test",
  103. Files: map[string]string{"file.gguf": digest},
  104. Template: `
  105. {{- if .Tools }}
  106. {{ .Tools }}
  107. {{ end }}
  108. {{- range .Messages }}
  109. {{- .Role }}: {{ .Content }}
  110. {{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
  111. {{- end }}
  112. {{ end }}`,
  113. Stream: &stream,
  114. })
  115. if w.Code != http.StatusOK {
  116. t.Fatalf("expected status 200, got %d", w.Code)
  117. }
  118. t.Run("missing body", func(t *testing.T) {
  119. w := createRequest(t, s.ChatHandler, nil)
  120. if w.Code != http.StatusBadRequest {
  121. t.Errorf("expected status 400, got %d", w.Code)
  122. }
  123. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  124. t.Errorf("mismatch (-got +want):\n%s", diff)
  125. }
  126. })
  127. t.Run("missing model", func(t *testing.T) {
  128. w := createRequest(t, s.ChatHandler, api.ChatRequest{})
  129. if w.Code != http.StatusBadRequest {
  130. t.Errorf("expected status 400, got %d", w.Code)
  131. }
  132. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  133. t.Errorf("mismatch (-got +want):\n%s", diff)
  134. }
  135. })
  136. t.Run("missing capabilities chat", func(t *testing.T) {
  137. _, digest := createBinFile(t, ggml.KV{
  138. "general.architecture": "bert",
  139. "bert.pooling_type": uint32(0),
  140. }, []ggml.Tensor{})
  141. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  142. Model: "bert",
  143. Files: map[string]string{"bert.gguf": digest},
  144. Stream: &stream,
  145. })
  146. if w.Code != http.StatusOK {
  147. t.Fatalf("expected status 200, got %d", w.Code)
  148. }
  149. w = createRequest(t, s.ChatHandler, api.ChatRequest{
  150. Model: "bert",
  151. })
  152. if w.Code != http.StatusBadRequest {
  153. t.Errorf("expected status 400, got %d", w.Code)
  154. }
  155. if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
  156. t.Errorf("mismatch (-got +want):\n%s", diff)
  157. }
  158. })
  159. t.Run("load model", func(t *testing.T) {
  160. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  161. Model: "test",
  162. })
  163. if w.Code != http.StatusOK {
  164. t.Errorf("expected status 200, got %d", w.Code)
  165. }
  166. var actual api.ChatResponse
  167. if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
  168. t.Fatal(err)
  169. }
  170. if actual.Model != "test" {
  171. t.Errorf("expected model test, got %s", actual.Model)
  172. }
  173. if !actual.Done {
  174. t.Errorf("expected done true, got false")
  175. }
  176. if actual.DoneReason != "load" {
  177. t.Errorf("expected done reason load, got %s", actual.DoneReason)
  178. }
  179. })
  180. checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
  181. t.Helper()
  182. var actual api.ChatResponse
  183. if err := json.NewDecoder(body).Decode(&actual); err != nil {
  184. t.Fatal(err)
  185. }
  186. if actual.Model != model {
  187. t.Errorf("expected model test, got %s", actual.Model)
  188. }
  189. if !actual.Done {
  190. t.Errorf("expected done false, got true")
  191. }
  192. if actual.DoneReason != "stop" {
  193. t.Errorf("expected done reason stop, got %s", actual.DoneReason)
  194. }
  195. if diff := cmp.Diff(actual.Message, api.Message{
  196. Role: "assistant",
  197. Content: content,
  198. }); diff != "" {
  199. t.Errorf("mismatch (-got +want):\n%s", diff)
  200. }
  201. if actual.PromptEvalCount == 0 {
  202. t.Errorf("expected prompt eval count > 0, got 0")
  203. }
  204. if actual.PromptEvalDuration == 0 {
  205. t.Errorf("expected prompt eval duration > 0, got 0")
  206. }
  207. if actual.EvalCount == 0 {
  208. t.Errorf("expected eval count > 0, got 0")
  209. }
  210. if actual.EvalDuration == 0 {
  211. t.Errorf("expected eval duration > 0, got 0")
  212. }
  213. if actual.LoadDuration == 0 {
  214. t.Errorf("expected load duration > 0, got 0")
  215. }
  216. if actual.TotalDuration == 0 {
  217. t.Errorf("expected total duration > 0, got 0")
  218. }
  219. }
  220. mock.CompletionResponse.Content = "Hi!"
  221. t.Run("messages", func(t *testing.T) {
  222. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  223. Model: "test",
  224. Messages: []api.Message{
  225. {Role: "user", Content: "Hello!"},
  226. },
  227. Stream: &stream,
  228. })
  229. if w.Code != http.StatusOK {
  230. t.Errorf("expected status 200, got %d", w.Code)
  231. }
  232. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
  233. t.Errorf("mismatch (-got +want):\n%s", diff)
  234. }
  235. checkChatResponse(t, w.Body, "test", "Hi!")
  236. })
  237. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  238. Model: "test-system",
  239. From: "test",
  240. System: "You are a helpful assistant.",
  241. })
  242. if w.Code != http.StatusOK {
  243. t.Fatalf("expected status 200, got %d", w.Code)
  244. }
  245. t.Run("messages with model system", func(t *testing.T) {
  246. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  247. Model: "test-system",
  248. Messages: []api.Message{
  249. {Role: "user", Content: "Hello!"},
  250. },
  251. Stream: &stream,
  252. })
  253. if w.Code != http.StatusOK {
  254. t.Errorf("expected status 200, got %d", w.Code)
  255. }
  256. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
  257. t.Errorf("mismatch (-got +want):\n%s", diff)
  258. }
  259. checkChatResponse(t, w.Body, "test-system", "Hi!")
  260. })
  261. mock.CompletionResponse.Content = "Abra kadabra!"
  262. t.Run("messages with system", func(t *testing.T) {
  263. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  264. Model: "test-system",
  265. Messages: []api.Message{
  266. {Role: "system", Content: "You can perform magic tricks."},
  267. {Role: "user", Content: "Hello!"},
  268. },
  269. Stream: &stream,
  270. })
  271. if w.Code != http.StatusOK {
  272. t.Errorf("expected status 200, got %d", w.Code)
  273. }
  274. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
  275. t.Errorf("mismatch (-got +want):\n%s", diff)
  276. }
  277. checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
  278. })
  279. t.Run("messages with interleaved system", func(t *testing.T) {
  280. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  281. Model: "test-system",
  282. Messages: []api.Message{
  283. {Role: "user", Content: "Hello!"},
  284. {Role: "assistant", Content: "I can help you with that."},
  285. {Role: "system", Content: "You can perform magic tricks."},
  286. {Role: "user", Content: "Help me write tests."},
  287. },
  288. Stream: &stream,
  289. })
  290. if w.Code != http.StatusOK {
  291. t.Errorf("expected status 200, got %d", w.Code)
  292. }
  293. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
  294. t.Errorf("mismatch (-got +want):\n%s", diff)
  295. }
  296. checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
  297. })
  298. t.Run("messages with tools (non-streaming)", func(t *testing.T) {
  299. if w.Code != http.StatusOK {
  300. t.Fatalf("failed to create test-system model: %d", w.Code)
  301. }
  302. tools := []api.Tool{
  303. {
  304. Type: "function",
  305. Function: api.ToolFunction{
  306. Name: "get_weather",
  307. Description: "Get the current weather",
  308. Parameters: struct {
  309. Type string `json:"type"`
  310. Required []string `json:"required"`
  311. Properties map[string]struct {
  312. Type string `json:"type"`
  313. Description string `json:"description"`
  314. Enum []string `json:"enum,omitempty"`
  315. } `json:"properties"`
  316. }{
  317. Type: "object",
  318. Required: []string{"location"},
  319. Properties: map[string]struct {
  320. Type string `json:"type"`
  321. Description string `json:"description"`
  322. Enum []string `json:"enum,omitempty"`
  323. }{
  324. "location": {
  325. Type: "string",
  326. Description: "The city and state",
  327. },
  328. "unit": {
  329. Type: "string",
  330. Enum: []string{"celsius", "fahrenheit"},
  331. },
  332. },
  333. },
  334. },
  335. },
  336. }
  337. mock.CompletionResponse = llm.CompletionResponse{
  338. Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
  339. Done: true,
  340. DoneReason: "done",
  341. PromptEvalCount: 1,
  342. PromptEvalDuration: 1,
  343. EvalCount: 1,
  344. EvalDuration: 1,
  345. }
  346. streamRequest := true
  347. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  348. Model: "test-system",
  349. Messages: []api.Message{
  350. {Role: "user", Content: "What's the weather in Seattle?"},
  351. },
  352. Tools: tools,
  353. Stream: &streamRequest,
  354. })
  355. if w.Code != http.StatusOK {
  356. var errResp struct {
  357. Error string `json:"error"`
  358. }
  359. if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
  360. t.Logf("Failed to decode error response: %v", err)
  361. } else {
  362. t.Logf("Error response: %s", errResp.Error)
  363. }
  364. }
  365. if w.Code != http.StatusOK {
  366. t.Errorf("expected status 200, got %d", w.Code)
  367. }
  368. var resp api.ChatResponse
  369. if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
  370. t.Fatal(err)
  371. }
  372. if resp.Message.ToolCalls == nil {
  373. t.Error("expected tool calls, got nil")
  374. }
  375. expectedToolCall := api.ToolCall{
  376. Function: api.ToolCallFunction{
  377. Name: "get_weather",
  378. Arguments: api.ToolCallFunctionArguments{
  379. "location": "Seattle, WA",
  380. "unit": "celsius",
  381. },
  382. },
  383. }
  384. if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
  385. t.Errorf("tool call mismatch (-got +want):\n%s", diff)
  386. }
  387. })
  388. t.Run("messages with tools (streaming)", func(t *testing.T) {
  389. tools := []api.Tool{
  390. {
  391. Type: "function",
  392. Function: api.ToolFunction{
  393. Name: "get_weather",
  394. Description: "Get the current weather",
  395. Parameters: struct {
  396. Type string `json:"type"`
  397. Required []string `json:"required"`
  398. Properties map[string]struct {
  399. Type string `json:"type"`
  400. Description string `json:"description"`
  401. Enum []string `json:"enum,omitempty"`
  402. } `json:"properties"`
  403. }{
  404. Type: "object",
  405. Required: []string{"location"},
  406. Properties: map[string]struct {
  407. Type string `json:"type"`
  408. Description string `json:"description"`
  409. Enum []string `json:"enum,omitempty"`
  410. }{
  411. "location": {
  412. Type: "string",
  413. Description: "The city and state",
  414. },
  415. "unit": {
  416. Type: "string",
  417. Enum: []string{"celsius", "fahrenheit"},
  418. },
  419. },
  420. },
  421. },
  422. },
  423. }
  424. // Simulate streaming response with multiple chunks
  425. var wg sync.WaitGroup
  426. wg.Add(1)
  427. mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
  428. defer wg.Done()
  429. // Send chunks with small delays to simulate streaming
  430. responses := []llm.CompletionResponse{
  431. {
  432. Content: `{"name":"get_`,
  433. Done: false,
  434. PromptEvalCount: 1,
  435. PromptEvalDuration: 1,
  436. },
  437. {
  438. Content: `weather","arguments":{"location":"Seattle`,
  439. Done: false,
  440. PromptEvalCount: 2,
  441. PromptEvalDuration: 1,
  442. },
  443. {
  444. Content: `, WA","unit":"celsius"}}`,
  445. Done: true,
  446. DoneReason: "tool_call",
  447. PromptEvalCount: 3,
  448. PromptEvalDuration: 1,
  449. },
  450. }
  451. for _, resp := range responses {
  452. select {
  453. case <-ctx.Done():
  454. return ctx.Err()
  455. default:
  456. fn(resp)
  457. time.Sleep(10 * time.Millisecond) // Small delay between chunks
  458. }
  459. }
  460. return nil
  461. }
  462. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  463. Model: "test-system",
  464. Messages: []api.Message{
  465. {Role: "user", Content: "What's the weather in Seattle?"},
  466. },
  467. Tools: tools,
  468. Stream: &stream,
  469. })
  470. wg.Wait()
  471. if w.Code != http.StatusOK {
  472. t.Errorf("expected status 200, got %d", w.Code)
  473. }
  474. // Read and validate the streamed responses
  475. decoder := json.NewDecoder(w.Body)
  476. var finalToolCall api.ToolCall
  477. for {
  478. var resp api.ChatResponse
  479. if err := decoder.Decode(&resp); err == io.EOF {
  480. break
  481. } else if err != nil {
  482. t.Fatal(err)
  483. }
  484. if resp.Done {
  485. if len(resp.Message.ToolCalls) != 1 {
  486. t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
  487. }
  488. finalToolCall = resp.Message.ToolCalls[0]
  489. }
  490. }
  491. expectedToolCall := api.ToolCall{
  492. Function: api.ToolCallFunction{
  493. Name: "get_weather",
  494. Arguments: api.ToolCallFunctionArguments{
  495. "location": "Seattle, WA",
  496. "unit": "celsius",
  497. },
  498. },
  499. }
  500. if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
  501. t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
  502. }
  503. })
  504. }
  505. func TestGenerate(t *testing.T) {
  506. gin.SetMode(gin.TestMode)
  507. mock := mockRunner{
  508. CompletionResponse: llm.CompletionResponse{
  509. Done: true,
  510. DoneReason: "stop",
  511. PromptEvalCount: 1,
  512. PromptEvalDuration: 1,
  513. EvalCount: 1,
  514. EvalDuration: 1,
  515. },
  516. }
  517. s := Server{
  518. sched: &Scheduler{
  519. pendingReqCh: make(chan *LlmRequest, 1),
  520. finishedReqCh: make(chan *LlmRequest, 1),
  521. expiredCh: make(chan *runnerRef, 1),
  522. unloadedCh: make(chan any, 1),
  523. loaded: make(map[string]*runnerRef),
  524. newServerFn: newMockServer(&mock),
  525. getGpuFn: discover.GetGPUInfo,
  526. getCpuFn: discover.GetCPUInfo,
  527. reschedDelay: 250 * time.Millisecond,
  528. loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
  529. // add small delay to simulate loading
  530. time.Sleep(time.Millisecond)
  531. req.successCh <- &runnerRef{
  532. llama: &mock,
  533. }
  534. },
  535. },
  536. }
  537. go s.sched.Run(context.TODO())
  538. _, digest := createBinFile(t, ggml.KV{
  539. "general.architecture": "llama",
  540. "llama.block_count": uint32(1),
  541. "llama.context_length": uint32(8192),
  542. "llama.embedding_length": uint32(4096),
  543. "llama.attention.head_count": uint32(32),
  544. "llama.attention.head_count_kv": uint32(8),
  545. "tokenizer.ggml.tokens": []string{""},
  546. "tokenizer.ggml.scores": []float32{0},
  547. "tokenizer.ggml.token_type": []int32{0},
  548. }, []ggml.Tensor{
  549. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  550. {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  551. {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  552. {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  553. {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  554. {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  555. {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  556. {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  557. {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  558. {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  559. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  560. })
  561. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  562. Model: "test",
  563. Files: map[string]string{"file.gguf": digest},
  564. Template: `
  565. {{- if .System }}System: {{ .System }} {{ end }}
  566. {{- if .Prompt }}User: {{ .Prompt }} {{ end }}
  567. {{- if .Response }}Assistant: {{ .Response }} {{ end }}
  568. `,
  569. Stream: &stream,
  570. })
  571. if w.Code != http.StatusOK {
  572. t.Fatalf("expected status 200, got %d", w.Code)
  573. }
  574. t.Run("missing body", func(t *testing.T) {
  575. w := createRequest(t, s.GenerateHandler, nil)
  576. if w.Code != http.StatusNotFound {
  577. t.Errorf("expected status 404, got %d", w.Code)
  578. }
  579. if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
  580. t.Errorf("mismatch (-got +want):\n%s", diff)
  581. }
  582. })
  583. t.Run("missing model", func(t *testing.T) {
  584. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
  585. if w.Code != http.StatusNotFound {
  586. t.Errorf("expected status 404, got %d", w.Code)
  587. }
  588. if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
  589. t.Errorf("mismatch (-got +want):\n%s", diff)
  590. }
  591. })
  592. t.Run("missing capabilities generate", func(t *testing.T) {
  593. _, digest := createBinFile(t, ggml.KV{
  594. "general.architecture": "bert",
  595. "bert.pooling_type": uint32(0),
  596. }, []ggml.Tensor{})
  597. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  598. Model: "bert",
  599. Files: map[string]string{"file.gguf": digest},
  600. Stream: &stream,
  601. })
  602. if w.Code != http.StatusOK {
  603. t.Fatalf("expected status 200, got %d", w.Code)
  604. }
  605. w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
  606. Model: "bert",
  607. })
  608. if w.Code != http.StatusBadRequest {
  609. t.Errorf("expected status 400, got %d", w.Code)
  610. }
  611. if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
  612. t.Errorf("mismatch (-got +want):\n%s", diff)
  613. }
  614. })
  615. t.Run("missing capabilities suffix", func(t *testing.T) {
  616. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  617. Model: "test",
  618. Prompt: "def add(",
  619. Suffix: " return c",
  620. })
  621. if w.Code != http.StatusBadRequest {
  622. t.Errorf("expected status 400, got %d", w.Code)
  623. }
  624. if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support insert"}`); diff != "" {
  625. t.Errorf("mismatch (-got +want):\n%s", diff)
  626. }
  627. })
  628. t.Run("load model", func(t *testing.T) {
  629. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  630. Model: "test",
  631. })
  632. if w.Code != http.StatusOK {
  633. t.Errorf("expected status 200, got %d", w.Code)
  634. }
  635. var actual api.GenerateResponse
  636. if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
  637. t.Fatal(err)
  638. }
  639. if actual.Model != "test" {
  640. t.Errorf("expected model test, got %s", actual.Model)
  641. }
  642. if !actual.Done {
  643. t.Errorf("expected done true, got false")
  644. }
  645. if actual.DoneReason != "load" {
  646. t.Errorf("expected done reason load, got %s", actual.DoneReason)
  647. }
  648. })
  649. checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
  650. t.Helper()
  651. var actual api.GenerateResponse
  652. if err := json.NewDecoder(body).Decode(&actual); err != nil {
  653. t.Fatal(err)
  654. }
  655. if actual.Model != model {
  656. t.Errorf("expected model test, got %s", actual.Model)
  657. }
  658. if !actual.Done {
  659. t.Errorf("expected done false, got true")
  660. }
  661. if actual.DoneReason != "stop" {
  662. t.Errorf("expected done reason stop, got %s", actual.DoneReason)
  663. }
  664. if actual.Response != content {
  665. t.Errorf("expected response %s, got %s", content, actual.Response)
  666. }
  667. if actual.Context == nil {
  668. t.Errorf("expected context not nil")
  669. }
  670. if actual.PromptEvalCount == 0 {
  671. t.Errorf("expected prompt eval count > 0, got 0")
  672. }
  673. if actual.PromptEvalDuration == 0 {
  674. t.Errorf("expected prompt eval duration > 0, got 0")
  675. }
  676. if actual.EvalCount == 0 {
  677. t.Errorf("expected eval count > 0, got 0")
  678. }
  679. if actual.EvalDuration == 0 {
  680. t.Errorf("expected eval duration > 0, got 0")
  681. }
  682. if actual.LoadDuration == 0 {
  683. t.Errorf("expected load duration > 0, got 0")
  684. }
  685. if actual.TotalDuration == 0 {
  686. t.Errorf("expected total duration > 0, got 0")
  687. }
  688. }
  689. mock.CompletionResponse.Content = "Hi!"
  690. t.Run("prompt", func(t *testing.T) {
  691. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  692. Model: "test",
  693. Prompt: "Hello!",
  694. Stream: &stream,
  695. })
  696. if w.Code != http.StatusOK {
  697. t.Errorf("expected status 200, got %d", w.Code)
  698. }
  699. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
  700. t.Errorf("mismatch (-got +want):\n%s", diff)
  701. }
  702. checkGenerateResponse(t, w.Body, "test", "Hi!")
  703. })
  704. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  705. Model: "test-system",
  706. From: "test",
  707. System: "You are a helpful assistant.",
  708. })
  709. if w.Code != http.StatusOK {
  710. t.Fatalf("expected status 200, got %d", w.Code)
  711. }
  712. t.Run("prompt with model system", func(t *testing.T) {
  713. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  714. Model: "test-system",
  715. Prompt: "Hello!",
  716. Stream: &stream,
  717. })
  718. if w.Code != http.StatusOK {
  719. t.Errorf("expected status 200, got %d", w.Code)
  720. }
  721. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
  722. t.Errorf("mismatch (-got +want):\n%s", diff)
  723. }
  724. checkGenerateResponse(t, w.Body, "test-system", "Hi!")
  725. })
  726. mock.CompletionResponse.Content = "Abra kadabra!"
  727. t.Run("prompt with system", func(t *testing.T) {
  728. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  729. Model: "test-system",
  730. Prompt: "Hello!",
  731. System: "You can perform magic tricks.",
  732. Stream: &stream,
  733. })
  734. if w.Code != http.StatusOK {
  735. t.Errorf("expected status 200, got %d", w.Code)
  736. }
  737. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
  738. t.Errorf("mismatch (-got +want):\n%s", diff)
  739. }
  740. checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
  741. })
  742. t.Run("prompt with template", func(t *testing.T) {
  743. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  744. Model: "test-system",
  745. Prompt: "Help me write tests.",
  746. System: "You can perform magic tricks.",
  747. Template: `{{- if .System }}{{ .System }} {{ end }}
  748. {{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
  749. {{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
  750. Stream: &stream,
  751. })
  752. if w.Code != http.StatusOK {
  753. t.Errorf("expected status 200, got %d", w.Code)
  754. }
  755. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
  756. t.Errorf("mismatch (-got +want):\n%s", diff)
  757. }
  758. checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
  759. })
  760. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  761. Model: "test-suffix",
  762. Template: `{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
  763. {{- else }}{{ .Prompt }}
  764. {{- end }}`,
  765. From: "test",
  766. })
  767. if w.Code != http.StatusOK {
  768. t.Fatalf("expected status 200, got %d", w.Code)
  769. }
  770. t.Run("prompt with suffix", func(t *testing.T) {
  771. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  772. Model: "test-suffix",
  773. Prompt: "def add(",
  774. Suffix: " return c",
  775. })
  776. if w.Code != http.StatusOK {
  777. t.Errorf("expected status 200, got %d", w.Code)
  778. }
  779. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
  780. t.Errorf("mismatch (-got +want):\n%s", diff)
  781. }
  782. })
  783. t.Run("prompt without suffix", func(t *testing.T) {
  784. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  785. Model: "test-suffix",
  786. Prompt: "def add(",
  787. })
  788. if w.Code != http.StatusOK {
  789. t.Errorf("expected status 200, got %d", w.Code)
  790. }
  791. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
  792. t.Errorf("mismatch (-got +want):\n%s", diff)
  793. }
  794. })
  795. t.Run("raw", func(t *testing.T) {
  796. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  797. Model: "test-system",
  798. Prompt: "Help me write tests.",
  799. Raw: true,
  800. Stream: &stream,
  801. })
  802. if w.Code != http.StatusOK {
  803. t.Errorf("expected status 200, got %d", w.Code)
  804. }
  805. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
  806. t.Errorf("mismatch (-got +want):\n%s", diff)
  807. }
  808. })
  809. }