routes_generate_test.go 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "sync"
  10. "testing"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/google/go-cmp/cmp"
  14. "github.com/ollama/ollama/api"
  15. "github.com/ollama/ollama/discover"
  16. "github.com/ollama/ollama/llm"
  17. )
  18. type mockRunner struct {
  19. llm.LlamaServer
  20. // CompletionRequest is only valid until the next call to Completion
  21. llm.CompletionRequest
  22. llm.CompletionResponse
  23. CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
  24. }
  25. func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
  26. m.CompletionRequest = r
  27. if m.CompletionFn != nil {
  28. return m.CompletionFn(ctx, r, fn)
  29. }
  30. fn(m.CompletionResponse)
  31. return nil
  32. }
  33. func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
  34. for range strings.Fields(s) {
  35. tokens = append(tokens, len(tokens))
  36. }
  37. return
  38. }
  39. func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
  40. return func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
  41. return mock, nil
  42. }
  43. }
  44. func TestGenerateChat(t *testing.T) {
  45. gin.SetMode(gin.TestMode)
  46. mock := mockRunner{
  47. CompletionResponse: llm.CompletionResponse{
  48. Done: true,
  49. DoneReason: "stop",
  50. PromptEvalCount: 1,
  51. PromptEvalDuration: 1,
  52. EvalCount: 1,
  53. EvalDuration: 1,
  54. },
  55. }
  56. s := Server{
  57. sched: &Scheduler{
  58. pendingReqCh: make(chan *LlmRequest, 1),
  59. finishedReqCh: make(chan *LlmRequest, 1),
  60. expiredCh: make(chan *runnerRef, 1),
  61. unloadedCh: make(chan any, 1),
  62. loaded: make(map[string]*runnerRef),
  63. newServerFn: newMockServer(&mock),
  64. getGpuFn: discover.GetGPUInfo,
  65. getCpuFn: discover.GetCPUInfo,
  66. reschedDelay: 250 * time.Millisecond,
  67. loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
  68. // add small delay to simulate loading
  69. time.Sleep(time.Millisecond)
  70. req.successCh <- &runnerRef{
  71. llama: &mock,
  72. }
  73. },
  74. },
  75. }
  76. go s.sched.Run(context.TODO())
  77. _, digest := createBinFile(t, llm.KV{
  78. "general.architecture": "llama",
  79. "llama.block_count": uint32(1),
  80. "llama.context_length": uint32(8192),
  81. "llama.embedding_length": uint32(4096),
  82. "llama.attention.head_count": uint32(32),
  83. "llama.attention.head_count_kv": uint32(8),
  84. "tokenizer.ggml.tokens": []string{""},
  85. "tokenizer.ggml.scores": []float32{0},
  86. "tokenizer.ggml.token_type": []int32{0},
  87. }, []llm.Tensor{
  88. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  89. {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  90. {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  91. {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  92. {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  93. {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  94. {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  95. {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  96. {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  97. {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  98. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  99. })
  100. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  101. Model: "test",
  102. Files: map[string]string{"file.gguf": digest},
  103. Template: `
  104. {{- if .Tools }}
  105. {{ .Tools }}
  106. {{ end }}
  107. {{- range .Messages }}
  108. {{- .Role }}: {{ .Content }}
  109. {{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
  110. {{- end }}
  111. {{ end }}`,
  112. Stream: &stream,
  113. })
  114. if w.Code != http.StatusOK {
  115. t.Fatalf("expected status 200, got %d", w.Code)
  116. }
  117. t.Run("missing body", func(t *testing.T) {
  118. w := createRequest(t, s.ChatHandler, nil)
  119. if w.Code != http.StatusBadRequest {
  120. t.Errorf("expected status 400, got %d", w.Code)
  121. }
  122. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  123. t.Errorf("mismatch (-got +want):\n%s", diff)
  124. }
  125. })
  126. t.Run("missing model", func(t *testing.T) {
  127. w := createRequest(t, s.ChatHandler, api.ChatRequest{})
  128. if w.Code != http.StatusBadRequest {
  129. t.Errorf("expected status 400, got %d", w.Code)
  130. }
  131. if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
  132. t.Errorf("mismatch (-got +want):\n%s", diff)
  133. }
  134. })
  135. t.Run("missing capabilities chat", func(t *testing.T) {
  136. _, digest := createBinFile(t, llm.KV{
  137. "general.architecture": "bert",
  138. "bert.pooling_type": uint32(0),
  139. }, []llm.Tensor{})
  140. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  141. Model: "bert",
  142. Files: map[string]string{"bert.gguf": digest},
  143. Stream: &stream,
  144. })
  145. if w.Code != http.StatusOK {
  146. t.Fatalf("expected status 200, got %d", w.Code)
  147. }
  148. w = createRequest(t, s.ChatHandler, api.ChatRequest{
  149. Model: "bert",
  150. })
  151. if w.Code != http.StatusBadRequest {
  152. t.Errorf("expected status 400, got %d", w.Code)
  153. }
  154. if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
  155. t.Errorf("mismatch (-got +want):\n%s", diff)
  156. }
  157. })
  158. t.Run("load model", func(t *testing.T) {
  159. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  160. Model: "test",
  161. })
  162. if w.Code != http.StatusOK {
  163. t.Errorf("expected status 200, got %d", w.Code)
  164. }
  165. var actual api.ChatResponse
  166. if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
  167. t.Fatal(err)
  168. }
  169. if actual.Model != "test" {
  170. t.Errorf("expected model test, got %s", actual.Model)
  171. }
  172. if !actual.Done {
  173. t.Errorf("expected done true, got false")
  174. }
  175. if actual.DoneReason != "load" {
  176. t.Errorf("expected done reason load, got %s", actual.DoneReason)
  177. }
  178. })
  179. checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
  180. t.Helper()
  181. var actual api.ChatResponse
  182. if err := json.NewDecoder(body).Decode(&actual); err != nil {
  183. t.Fatal(err)
  184. }
  185. if actual.Model != model {
  186. t.Errorf("expected model test, got %s", actual.Model)
  187. }
  188. if !actual.Done {
  189. t.Errorf("expected done false, got true")
  190. }
  191. if actual.DoneReason != "stop" {
  192. t.Errorf("expected done reason stop, got %s", actual.DoneReason)
  193. }
  194. if diff := cmp.Diff(actual.Message, api.Message{
  195. Role: "assistant",
  196. Content: content,
  197. }); diff != "" {
  198. t.Errorf("mismatch (-got +want):\n%s", diff)
  199. }
  200. if actual.PromptEvalCount == 0 {
  201. t.Errorf("expected prompt eval count > 0, got 0")
  202. }
  203. if actual.PromptEvalDuration == 0 {
  204. t.Errorf("expected prompt eval duration > 0, got 0")
  205. }
  206. if actual.EvalCount == 0 {
  207. t.Errorf("expected eval count > 0, got 0")
  208. }
  209. if actual.EvalDuration == 0 {
  210. t.Errorf("expected eval duration > 0, got 0")
  211. }
  212. if actual.LoadDuration == 0 {
  213. t.Errorf("expected load duration > 0, got 0")
  214. }
  215. if actual.TotalDuration == 0 {
  216. t.Errorf("expected total duration > 0, got 0")
  217. }
  218. }
  219. mock.CompletionResponse.Content = "Hi!"
  220. t.Run("messages", func(t *testing.T) {
  221. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  222. Model: "test",
  223. Messages: []api.Message{
  224. {Role: "user", Content: "Hello!"},
  225. },
  226. Stream: &stream,
  227. })
  228. if w.Code != http.StatusOK {
  229. t.Errorf("expected status 200, got %d", w.Code)
  230. }
  231. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
  232. t.Errorf("mismatch (-got +want):\n%s", diff)
  233. }
  234. checkChatResponse(t, w.Body, "test", "Hi!")
  235. })
  236. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  237. Model: "test-system",
  238. From: "test",
  239. System: "You are a helpful assistant.",
  240. })
  241. if w.Code != http.StatusOK {
  242. t.Fatalf("expected status 200, got %d", w.Code)
  243. }
  244. t.Run("messages with model system", func(t *testing.T) {
  245. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  246. Model: "test-system",
  247. Messages: []api.Message{
  248. {Role: "user", Content: "Hello!"},
  249. },
  250. Stream: &stream,
  251. })
  252. if w.Code != http.StatusOK {
  253. t.Errorf("expected status 200, got %d", w.Code)
  254. }
  255. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
  256. t.Errorf("mismatch (-got +want):\n%s", diff)
  257. }
  258. checkChatResponse(t, w.Body, "test-system", "Hi!")
  259. })
  260. mock.CompletionResponse.Content = "Abra kadabra!"
  261. t.Run("messages with system", func(t *testing.T) {
  262. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  263. Model: "test-system",
  264. Messages: []api.Message{
  265. {Role: "system", Content: "You can perform magic tricks."},
  266. {Role: "user", Content: "Hello!"},
  267. },
  268. Stream: &stream,
  269. })
  270. if w.Code != http.StatusOK {
  271. t.Errorf("expected status 200, got %d", w.Code)
  272. }
  273. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
  274. t.Errorf("mismatch (-got +want):\n%s", diff)
  275. }
  276. checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
  277. })
  278. t.Run("messages with interleaved system", func(t *testing.T) {
  279. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  280. Model: "test-system",
  281. Messages: []api.Message{
  282. {Role: "user", Content: "Hello!"},
  283. {Role: "assistant", Content: "I can help you with that."},
  284. {Role: "system", Content: "You can perform magic tricks."},
  285. {Role: "user", Content: "Help me write tests."},
  286. },
  287. Stream: &stream,
  288. })
  289. if w.Code != http.StatusOK {
  290. t.Errorf("expected status 200, got %d", w.Code)
  291. }
  292. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
  293. t.Errorf("mismatch (-got +want):\n%s", diff)
  294. }
  295. checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
  296. })
  297. t.Run("messages with tools (non-streaming)", func(t *testing.T) {
  298. if w.Code != http.StatusOK {
  299. t.Fatalf("failed to create test-system model: %d", w.Code)
  300. }
  301. tools := []api.Tool{
  302. {
  303. Type: "function",
  304. Function: api.ToolFunction{
  305. Name: "get_weather",
  306. Description: "Get the current weather",
  307. Parameters: struct {
  308. Type string `json:"type"`
  309. Required []string `json:"required"`
  310. Properties map[string]struct {
  311. Type string `json:"type"`
  312. Description string `json:"description"`
  313. Enum []string `json:"enum,omitempty"`
  314. } `json:"properties"`
  315. }{
  316. Type: "object",
  317. Required: []string{"location"},
  318. Properties: map[string]struct {
  319. Type string `json:"type"`
  320. Description string `json:"description"`
  321. Enum []string `json:"enum,omitempty"`
  322. }{
  323. "location": {
  324. Type: "string",
  325. Description: "The city and state",
  326. },
  327. "unit": {
  328. Type: "string",
  329. Enum: []string{"celsius", "fahrenheit"},
  330. },
  331. },
  332. },
  333. },
  334. },
  335. }
  336. mock.CompletionResponse = llm.CompletionResponse{
  337. Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
  338. Done: true,
  339. DoneReason: "done",
  340. PromptEvalCount: 1,
  341. PromptEvalDuration: 1,
  342. EvalCount: 1,
  343. EvalDuration: 1,
  344. }
  345. streamRequest := true
  346. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  347. Model: "test-system",
  348. Messages: []api.Message{
  349. {Role: "user", Content: "What's the weather in Seattle?"},
  350. },
  351. Tools: tools,
  352. Stream: &streamRequest,
  353. })
  354. if w.Code != http.StatusOK {
  355. var errResp struct {
  356. Error string `json:"error"`
  357. }
  358. if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
  359. t.Logf("Failed to decode error response: %v", err)
  360. } else {
  361. t.Logf("Error response: %s", errResp.Error)
  362. }
  363. }
  364. if w.Code != http.StatusOK {
  365. t.Errorf("expected status 200, got %d", w.Code)
  366. }
  367. var resp api.ChatResponse
  368. if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
  369. t.Fatal(err)
  370. }
  371. if resp.Message.ToolCalls == nil {
  372. t.Error("expected tool calls, got nil")
  373. }
  374. expectedToolCall := api.ToolCall{
  375. Function: api.ToolCallFunction{
  376. Name: "get_weather",
  377. Arguments: api.ToolCallFunctionArguments{
  378. "location": "Seattle, WA",
  379. "unit": "celsius",
  380. },
  381. },
  382. }
  383. if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
  384. t.Errorf("tool call mismatch (-got +want):\n%s", diff)
  385. }
  386. })
  387. t.Run("messages with tools (streaming)", func(t *testing.T) {
  388. tools := []api.Tool{
  389. {
  390. Type: "function",
  391. Function: api.ToolFunction{
  392. Name: "get_weather",
  393. Description: "Get the current weather",
  394. Parameters: struct {
  395. Type string `json:"type"`
  396. Required []string `json:"required"`
  397. Properties map[string]struct {
  398. Type string `json:"type"`
  399. Description string `json:"description"`
  400. Enum []string `json:"enum,omitempty"`
  401. } `json:"properties"`
  402. }{
  403. Type: "object",
  404. Required: []string{"location"},
  405. Properties: map[string]struct {
  406. Type string `json:"type"`
  407. Description string `json:"description"`
  408. Enum []string `json:"enum,omitempty"`
  409. }{
  410. "location": {
  411. Type: "string",
  412. Description: "The city and state",
  413. },
  414. "unit": {
  415. Type: "string",
  416. Enum: []string{"celsius", "fahrenheit"},
  417. },
  418. },
  419. },
  420. },
  421. },
  422. }
  423. // Simulate streaming response with multiple chunks
  424. var wg sync.WaitGroup
  425. wg.Add(1)
  426. mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
  427. defer wg.Done()
  428. // Send chunks with small delays to simulate streaming
  429. responses := []llm.CompletionResponse{
  430. {
  431. Content: `{"name":"get_`,
  432. Done: false,
  433. PromptEvalCount: 1,
  434. PromptEvalDuration: 1,
  435. },
  436. {
  437. Content: `weather","arguments":{"location":"Seattle`,
  438. Done: false,
  439. PromptEvalCount: 2,
  440. PromptEvalDuration: 1,
  441. },
  442. {
  443. Content: `, WA","unit":"celsius"}}`,
  444. Done: true,
  445. DoneReason: "tool_call",
  446. PromptEvalCount: 3,
  447. PromptEvalDuration: 1,
  448. },
  449. }
  450. for _, resp := range responses {
  451. select {
  452. case <-ctx.Done():
  453. return ctx.Err()
  454. default:
  455. fn(resp)
  456. time.Sleep(10 * time.Millisecond) // Small delay between chunks
  457. }
  458. }
  459. return nil
  460. }
  461. w := createRequest(t, s.ChatHandler, api.ChatRequest{
  462. Model: "test-system",
  463. Messages: []api.Message{
  464. {Role: "user", Content: "What's the weather in Seattle?"},
  465. },
  466. Tools: tools,
  467. Stream: &stream,
  468. })
  469. wg.Wait()
  470. if w.Code != http.StatusOK {
  471. t.Errorf("expected status 200, got %d", w.Code)
  472. }
  473. // Read and validate the streamed responses
  474. decoder := json.NewDecoder(w.Body)
  475. var finalToolCall api.ToolCall
  476. for {
  477. var resp api.ChatResponse
  478. if err := decoder.Decode(&resp); err == io.EOF {
  479. break
  480. } else if err != nil {
  481. t.Fatal(err)
  482. }
  483. if resp.Done {
  484. if len(resp.Message.ToolCalls) != 1 {
  485. t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
  486. }
  487. finalToolCall = resp.Message.ToolCalls[0]
  488. }
  489. }
  490. expectedToolCall := api.ToolCall{
  491. Function: api.ToolCallFunction{
  492. Name: "get_weather",
  493. Arguments: api.ToolCallFunctionArguments{
  494. "location": "Seattle, WA",
  495. "unit": "celsius",
  496. },
  497. },
  498. }
  499. if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
  500. t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
  501. }
  502. })
  503. }
  504. func TestGenerate(t *testing.T) {
  505. gin.SetMode(gin.TestMode)
  506. mock := mockRunner{
  507. CompletionResponse: llm.CompletionResponse{
  508. Done: true,
  509. DoneReason: "stop",
  510. PromptEvalCount: 1,
  511. PromptEvalDuration: 1,
  512. EvalCount: 1,
  513. EvalDuration: 1,
  514. },
  515. }
  516. s := Server{
  517. sched: &Scheduler{
  518. pendingReqCh: make(chan *LlmRequest, 1),
  519. finishedReqCh: make(chan *LlmRequest, 1),
  520. expiredCh: make(chan *runnerRef, 1),
  521. unloadedCh: make(chan any, 1),
  522. loaded: make(map[string]*runnerRef),
  523. newServerFn: newMockServer(&mock),
  524. getGpuFn: discover.GetGPUInfo,
  525. getCpuFn: discover.GetCPUInfo,
  526. reschedDelay: 250 * time.Millisecond,
  527. loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
  528. // add small delay to simulate loading
  529. time.Sleep(time.Millisecond)
  530. req.successCh <- &runnerRef{
  531. llama: &mock,
  532. }
  533. },
  534. },
  535. }
  536. go s.sched.Run(context.TODO())
  537. _, digest := createBinFile(t, llm.KV{
  538. "general.architecture": "llama",
  539. "llama.block_count": uint32(1),
  540. "llama.context_length": uint32(8192),
  541. "llama.embedding_length": uint32(4096),
  542. "llama.attention.head_count": uint32(32),
  543. "llama.attention.head_count_kv": uint32(8),
  544. "tokenizer.ggml.tokens": []string{""},
  545. "tokenizer.ggml.scores": []float32{0},
  546. "tokenizer.ggml.token_type": []int32{0},
  547. }, []llm.Tensor{
  548. {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  549. {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  550. {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  551. {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  552. {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  553. {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  554. {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  555. {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  556. {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  557. {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  558. {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
  559. })
  560. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  561. Model: "test",
  562. Files: map[string]string{"file.gguf": digest},
  563. Template: `
  564. {{- if .System }}System: {{ .System }} {{ end }}
  565. {{- if .Prompt }}User: {{ .Prompt }} {{ end }}
  566. {{- if .Response }}Assistant: {{ .Response }} {{ end }}
  567. `,
  568. Stream: &stream,
  569. })
  570. if w.Code != http.StatusOK {
  571. t.Fatalf("expected status 200, got %d", w.Code)
  572. }
  573. t.Run("missing body", func(t *testing.T) {
  574. w := createRequest(t, s.GenerateHandler, nil)
  575. if w.Code != http.StatusNotFound {
  576. t.Errorf("expected status 404, got %d", w.Code)
  577. }
  578. if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
  579. t.Errorf("mismatch (-got +want):\n%s", diff)
  580. }
  581. })
  582. t.Run("missing model", func(t *testing.T) {
  583. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
  584. if w.Code != http.StatusNotFound {
  585. t.Errorf("expected status 404, got %d", w.Code)
  586. }
  587. if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
  588. t.Errorf("mismatch (-got +want):\n%s", diff)
  589. }
  590. })
  591. t.Run("missing capabilities generate", func(t *testing.T) {
  592. _, digest := createBinFile(t, llm.KV{
  593. "general.architecture": "bert",
  594. "bert.pooling_type": uint32(0),
  595. }, []llm.Tensor{})
  596. w := createRequest(t, s.CreateHandler, api.CreateRequest{
  597. Model: "bert",
  598. Files: map[string]string{"file.gguf": digest},
  599. Stream: &stream,
  600. })
  601. if w.Code != http.StatusOK {
  602. t.Fatalf("expected status 200, got %d", w.Code)
  603. }
  604. w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
  605. Model: "bert",
  606. })
  607. if w.Code != http.StatusBadRequest {
  608. t.Errorf("expected status 400, got %d", w.Code)
  609. }
  610. if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
  611. t.Errorf("mismatch (-got +want):\n%s", diff)
  612. }
  613. })
  614. t.Run("missing capabilities suffix", func(t *testing.T) {
  615. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  616. Model: "test",
  617. Prompt: "def add(",
  618. Suffix: " return c",
  619. })
  620. if w.Code != http.StatusBadRequest {
  621. t.Errorf("expected status 400, got %d", w.Code)
  622. }
  623. if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support insert"}`); diff != "" {
  624. t.Errorf("mismatch (-got +want):\n%s", diff)
  625. }
  626. })
  627. t.Run("load model", func(t *testing.T) {
  628. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  629. Model: "test",
  630. })
  631. if w.Code != http.StatusOK {
  632. t.Errorf("expected status 200, got %d", w.Code)
  633. }
  634. var actual api.GenerateResponse
  635. if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
  636. t.Fatal(err)
  637. }
  638. if actual.Model != "test" {
  639. t.Errorf("expected model test, got %s", actual.Model)
  640. }
  641. if !actual.Done {
  642. t.Errorf("expected done true, got false")
  643. }
  644. if actual.DoneReason != "load" {
  645. t.Errorf("expected done reason load, got %s", actual.DoneReason)
  646. }
  647. })
  648. checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
  649. t.Helper()
  650. var actual api.GenerateResponse
  651. if err := json.NewDecoder(body).Decode(&actual); err != nil {
  652. t.Fatal(err)
  653. }
  654. if actual.Model != model {
  655. t.Errorf("expected model test, got %s", actual.Model)
  656. }
  657. if !actual.Done {
  658. t.Errorf("expected done false, got true")
  659. }
  660. if actual.DoneReason != "stop" {
  661. t.Errorf("expected done reason stop, got %s", actual.DoneReason)
  662. }
  663. if actual.Response != content {
  664. t.Errorf("expected response %s, got %s", content, actual.Response)
  665. }
  666. if actual.Context == nil {
  667. t.Errorf("expected context not nil")
  668. }
  669. if actual.PromptEvalCount == 0 {
  670. t.Errorf("expected prompt eval count > 0, got 0")
  671. }
  672. if actual.PromptEvalDuration == 0 {
  673. t.Errorf("expected prompt eval duration > 0, got 0")
  674. }
  675. if actual.EvalCount == 0 {
  676. t.Errorf("expected eval count > 0, got 0")
  677. }
  678. if actual.EvalDuration == 0 {
  679. t.Errorf("expected eval duration > 0, got 0")
  680. }
  681. if actual.LoadDuration == 0 {
  682. t.Errorf("expected load duration > 0, got 0")
  683. }
  684. if actual.TotalDuration == 0 {
  685. t.Errorf("expected total duration > 0, got 0")
  686. }
  687. }
  688. mock.CompletionResponse.Content = "Hi!"
  689. t.Run("prompt", func(t *testing.T) {
  690. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  691. Model: "test",
  692. Prompt: "Hello!",
  693. Stream: &stream,
  694. })
  695. if w.Code != http.StatusOK {
  696. t.Errorf("expected status 200, got %d", w.Code)
  697. }
  698. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
  699. t.Errorf("mismatch (-got +want):\n%s", diff)
  700. }
  701. checkGenerateResponse(t, w.Body, "test", "Hi!")
  702. })
  703. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  704. Model: "test-system",
  705. From: "test",
  706. System: "You are a helpful assistant.",
  707. })
  708. if w.Code != http.StatusOK {
  709. t.Fatalf("expected status 200, got %d", w.Code)
  710. }
  711. t.Run("prompt with model system", func(t *testing.T) {
  712. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  713. Model: "test-system",
  714. Prompt: "Hello!",
  715. Stream: &stream,
  716. })
  717. if w.Code != http.StatusOK {
  718. t.Errorf("expected status 200, got %d", w.Code)
  719. }
  720. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
  721. t.Errorf("mismatch (-got +want):\n%s", diff)
  722. }
  723. checkGenerateResponse(t, w.Body, "test-system", "Hi!")
  724. })
  725. mock.CompletionResponse.Content = "Abra kadabra!"
  726. t.Run("prompt with system", func(t *testing.T) {
  727. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  728. Model: "test-system",
  729. Prompt: "Hello!",
  730. System: "You can perform magic tricks.",
  731. Stream: &stream,
  732. })
  733. if w.Code != http.StatusOK {
  734. t.Errorf("expected status 200, got %d", w.Code)
  735. }
  736. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
  737. t.Errorf("mismatch (-got +want):\n%s", diff)
  738. }
  739. checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
  740. })
  741. t.Run("prompt with template", func(t *testing.T) {
  742. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  743. Model: "test-system",
  744. Prompt: "Help me write tests.",
  745. System: "You can perform magic tricks.",
  746. Template: `{{- if .System }}{{ .System }} {{ end }}
  747. {{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
  748. {{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
  749. Stream: &stream,
  750. })
  751. if w.Code != http.StatusOK {
  752. t.Errorf("expected status 200, got %d", w.Code)
  753. }
  754. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
  755. t.Errorf("mismatch (-got +want):\n%s", diff)
  756. }
  757. checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
  758. })
  759. w = createRequest(t, s.CreateHandler, api.CreateRequest{
  760. Model: "test-suffix",
  761. Template: `{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
  762. {{- else }}{{ .Prompt }}
  763. {{- end }}`,
  764. From: "test",
  765. })
  766. if w.Code != http.StatusOK {
  767. t.Fatalf("expected status 200, got %d", w.Code)
  768. }
  769. t.Run("prompt with suffix", func(t *testing.T) {
  770. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  771. Model: "test-suffix",
  772. Prompt: "def add(",
  773. Suffix: " return c",
  774. })
  775. if w.Code != http.StatusOK {
  776. t.Errorf("expected status 200, got %d", w.Code)
  777. }
  778. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
  779. t.Errorf("mismatch (-got +want):\n%s", diff)
  780. }
  781. })
  782. t.Run("prompt without suffix", func(t *testing.T) {
  783. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  784. Model: "test-suffix",
  785. Prompt: "def add(",
  786. })
  787. if w.Code != http.StatusOK {
  788. t.Errorf("expected status 200, got %d", w.Code)
  789. }
  790. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
  791. t.Errorf("mismatch (-got +want):\n%s", diff)
  792. }
  793. })
  794. t.Run("raw", func(t *testing.T) {
  795. w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
  796. Model: "test-system",
  797. Prompt: "Help me write tests.",
  798. Raw: true,
  799. Stream: &stream,
  800. })
  801. if w.Code != http.StatusOK {
  802. t.Errorf("expected status 200, got %d", w.Code)
  803. }
  804. if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
  805. t.Errorf("mismatch (-got +want):\n%s", diff)
  806. }
  807. })
  808. }