123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962 |
- package server
- import (
- "bytes"
- "context"
- "encoding/json"
- "io"
- "net/http"
- "strings"
- "sync"
- "testing"
- "time"
- "github.com/gin-gonic/gin"
- "github.com/google/go-cmp/cmp"
- "github.com/ollama/ollama/api"
- "github.com/ollama/ollama/discover"
- "github.com/ollama/ollama/fs/ggml"
- "github.com/ollama/ollama/llm"
- )
- type mockRunner struct {
- llm.LlamaServer
- // CompletionRequest is only valid until the next call to Completion
- llm.CompletionRequest
- llm.CompletionResponse
- CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
- }
- func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
- m.CompletionRequest = r
- if m.CompletionFn != nil {
- return m.CompletionFn(ctx, r, fn)
- }
- fn(m.CompletionResponse)
- return nil
- }
- func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
- for range strings.Fields(s) {
- tokens = append(tokens, len(tokens))
- }
- return
- }
- func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
- return func(_ discover.GpuInfoList, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
- return mock, nil
- }
- }
- func TestGenerateChat(t *testing.T) {
- gin.SetMode(gin.TestMode)
- mock := mockRunner{
- CompletionResponse: llm.CompletionResponse{
- Done: true,
- DoneReason: "stop",
- PromptEvalCount: 1,
- PromptEvalDuration: 1,
- EvalCount: 1,
- EvalDuration: 1,
- },
- }
- s := Server{
- sched: &Scheduler{
- pendingReqCh: make(chan *LlmRequest, 1),
- finishedReqCh: make(chan *LlmRequest, 1),
- expiredCh: make(chan *runnerRef, 1),
- unloadedCh: make(chan any, 1),
- loaded: make(map[string]*runnerRef),
- newServerFn: newMockServer(&mock),
- getGpuFn: discover.GetGPUInfo,
- getCpuFn: discover.GetCPUInfo,
- reschedDelay: 250 * time.Millisecond,
- loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
- // add small delay to simulate loading
- time.Sleep(time.Millisecond)
- req.successCh <- &runnerRef{
- llama: &mock,
- }
- },
- },
- }
- go s.sched.Run(context.TODO())
- _, digest := createBinFile(t, ggml.KV{
- "general.architecture": "llama",
- "llama.block_count": uint32(1),
- "llama.context_length": uint32(8192),
- "llama.embedding_length": uint32(4096),
- "llama.attention.head_count": uint32(32),
- "llama.attention.head_count_kv": uint32(8),
- "tokenizer.ggml.tokens": []string{""},
- "tokenizer.ggml.scores": []float32{0},
- "tokenizer.ggml.token_type": []int32{0},
- }, []ggml.Tensor{
- {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- })
- w := createRequest(t, s.CreateHandler, api.CreateRequest{
- Model: "test",
- Files: map[string]string{"file.gguf": digest},
- Template: `
- {{- if .Tools }}
- {{ .Tools }}
- {{ end }}
- {{- range .Messages }}
- {{- .Role }}: {{ .Content }}
- {{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
- {{- end }}
- {{ end }}`,
- Stream: &stream,
- })
- if w.Code != http.StatusOK {
- t.Fatalf("expected status 200, got %d", w.Code)
- }
- t.Run("missing body", func(t *testing.T) {
- w := createRequest(t, s.ChatHandler, nil)
- if w.Code != http.StatusBadRequest {
- t.Errorf("expected status 400, got %d", w.Code)
- }
- if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- t.Run("missing model", func(t *testing.T) {
- w := createRequest(t, s.ChatHandler, api.ChatRequest{})
- if w.Code != http.StatusBadRequest {
- t.Errorf("expected status 400, got %d", w.Code)
- }
- if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- t.Run("missing capabilities chat", func(t *testing.T) {
- _, digest := createBinFile(t, ggml.KV{
- "general.architecture": "bert",
- "bert.pooling_type": uint32(0),
- }, []ggml.Tensor{})
- w := createRequest(t, s.CreateHandler, api.CreateRequest{
- Model: "bert",
- Files: map[string]string{"bert.gguf": digest},
- Stream: &stream,
- })
- if w.Code != http.StatusOK {
- t.Fatalf("expected status 200, got %d", w.Code)
- }
- w = createRequest(t, s.ChatHandler, api.ChatRequest{
- Model: "bert",
- })
- if w.Code != http.StatusBadRequest {
- t.Errorf("expected status 400, got %d", w.Code)
- }
- if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- t.Run("load model", func(t *testing.T) {
- w := createRequest(t, s.ChatHandler, api.ChatRequest{
- Model: "test",
- })
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d", w.Code)
- }
- var actual api.ChatResponse
- if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
- t.Fatal(err)
- }
- if actual.Model != "test" {
- t.Errorf("expected model test, got %s", actual.Model)
- }
- if !actual.Done {
- t.Errorf("expected done true, got false")
- }
- if actual.DoneReason != "load" {
- t.Errorf("expected done reason load, got %s", actual.DoneReason)
- }
- })
- checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
- t.Helper()
- var actual api.ChatResponse
- if err := json.NewDecoder(body).Decode(&actual); err != nil {
- t.Fatal(err)
- }
- if actual.Model != model {
- t.Errorf("expected model test, got %s", actual.Model)
- }
- if !actual.Done {
- t.Errorf("expected done false, got true")
- }
- if actual.DoneReason != "stop" {
- t.Errorf("expected done reason stop, got %s", actual.DoneReason)
- }
- if diff := cmp.Diff(actual.Message, api.Message{
- Role: "assistant",
- Content: content,
- }); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- if actual.PromptEvalCount == 0 {
- t.Errorf("expected prompt eval count > 0, got 0")
- }
- if actual.PromptEvalDuration == 0 {
- t.Errorf("expected prompt eval duration > 0, got 0")
- }
- if actual.EvalCount == 0 {
- t.Errorf("expected eval count > 0, got 0")
- }
- if actual.EvalDuration == 0 {
- t.Errorf("expected eval duration > 0, got 0")
- }
- if actual.LoadDuration == 0 {
- t.Errorf("expected load duration > 0, got 0")
- }
- if actual.TotalDuration == 0 {
- t.Errorf("expected total duration > 0, got 0")
- }
- }
- mock.CompletionResponse.Content = "Hi!"
- t.Run("messages", func(t *testing.T) {
- w := createRequest(t, s.ChatHandler, api.ChatRequest{
- Model: "test",
- Messages: []api.Message{
- {Role: "user", Content: "Hello!"},
- },
- Stream: &stream,
- })
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d", w.Code)
- }
- if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- checkChatResponse(t, w.Body, "test", "Hi!")
- })
- w = createRequest(t, s.CreateHandler, api.CreateRequest{
- Model: "test-system",
- From: "test",
- System: "You are a helpful assistant.",
- })
- if w.Code != http.StatusOK {
- t.Fatalf("expected status 200, got %d", w.Code)
- }
- t.Run("messages with model system", func(t *testing.T) {
- w := createRequest(t, s.ChatHandler, api.ChatRequest{
- Model: "test-system",
- Messages: []api.Message{
- {Role: "user", Content: "Hello!"},
- },
- Stream: &stream,
- })
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d", w.Code)
- }
- if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- checkChatResponse(t, w.Body, "test-system", "Hi!")
- })
- mock.CompletionResponse.Content = "Abra kadabra!"
- t.Run("messages with system", func(t *testing.T) {
- w := createRequest(t, s.ChatHandler, api.ChatRequest{
- Model: "test-system",
- Messages: []api.Message{
- {Role: "system", Content: "You can perform magic tricks."},
- {Role: "user", Content: "Hello!"},
- },
- Stream: &stream,
- })
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d", w.Code)
- }
- if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
- })
- t.Run("messages with interleaved system", func(t *testing.T) {
- w := createRequest(t, s.ChatHandler, api.ChatRequest{
- Model: "test-system",
- Messages: []api.Message{
- {Role: "user", Content: "Hello!"},
- {Role: "assistant", Content: "I can help you with that."},
- {Role: "system", Content: "You can perform magic tricks."},
- {Role: "user", Content: "Help me write tests."},
- },
- Stream: &stream,
- })
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d", w.Code)
- }
- if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
- })
- t.Run("messages with tools (non-streaming)", func(t *testing.T) {
- if w.Code != http.StatusOK {
- t.Fatalf("failed to create test-system model: %d", w.Code)
- }
- tools := []api.Tool{
- {
- Type: "function",
- Function: api.ToolFunction{
- Name: "get_weather",
- Description: "Get the current weather",
- Parameters: struct {
- Type string `json:"type"`
- Required []string `json:"required"`
- Properties map[string]struct {
- Type string `json:"type"`
- Description string `json:"description"`
- Enum []string `json:"enum,omitempty"`
- } `json:"properties"`
- }{
- Type: "object",
- Required: []string{"location"},
- Properties: map[string]struct {
- Type string `json:"type"`
- Description string `json:"description"`
- Enum []string `json:"enum,omitempty"`
- }{
- "location": {
- Type: "string",
- Description: "The city and state",
- },
- "unit": {
- Type: "string",
- Enum: []string{"celsius", "fahrenheit"},
- },
- },
- },
- },
- },
- }
- mock.CompletionResponse = llm.CompletionResponse{
- Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
- Done: true,
- DoneReason: "done",
- PromptEvalCount: 1,
- PromptEvalDuration: 1,
- EvalCount: 1,
- EvalDuration: 1,
- }
- streamRequest := true
- w := createRequest(t, s.ChatHandler, api.ChatRequest{
- Model: "test-system",
- Messages: []api.Message{
- {Role: "user", Content: "What's the weather in Seattle?"},
- },
- Tools: tools,
- Stream: &streamRequest,
- })
- if w.Code != http.StatusOK {
- var errResp struct {
- Error string `json:"error"`
- }
- if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
- t.Logf("Failed to decode error response: %v", err)
- } else {
- t.Logf("Error response: %s", errResp.Error)
- }
- }
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d", w.Code)
- }
- var resp api.ChatResponse
- if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
- t.Fatal(err)
- }
- if resp.Message.ToolCalls == nil {
- t.Error("expected tool calls, got nil")
- }
- expectedToolCall := api.ToolCall{
- Function: api.ToolCallFunction{
- Name: "get_weather",
- Arguments: api.ToolCallFunctionArguments{
- "location": "Seattle, WA",
- "unit": "celsius",
- },
- },
- }
- if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
- t.Errorf("tool call mismatch (-got +want):\n%s", diff)
- }
- })
- t.Run("messages with tools (streaming)", func(t *testing.T) {
- tools := []api.Tool{
- {
- Type: "function",
- Function: api.ToolFunction{
- Name: "get_weather",
- Description: "Get the current weather",
- Parameters: struct {
- Type string `json:"type"`
- Required []string `json:"required"`
- Properties map[string]struct {
- Type string `json:"type"`
- Description string `json:"description"`
- Enum []string `json:"enum,omitempty"`
- } `json:"properties"`
- }{
- Type: "object",
- Required: []string{"location"},
- Properties: map[string]struct {
- Type string `json:"type"`
- Description string `json:"description"`
- Enum []string `json:"enum,omitempty"`
- }{
- "location": {
- Type: "string",
- Description: "The city and state",
- },
- "unit": {
- Type: "string",
- Enum: []string{"celsius", "fahrenheit"},
- },
- },
- },
- },
- },
- }
- // Simulate streaming response with multiple chunks
- var wg sync.WaitGroup
- wg.Add(1)
- mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
- defer wg.Done()
- // Send chunks with small delays to simulate streaming
- responses := []llm.CompletionResponse{
- {
- Content: `{"name":"get_`,
- Done: false,
- PromptEvalCount: 1,
- PromptEvalDuration: 1,
- },
- {
- Content: `weather","arguments":{"location":"Seattle`,
- Done: false,
- PromptEvalCount: 2,
- PromptEvalDuration: 1,
- },
- {
- Content: `, WA","unit":"celsius"}}`,
- Done: true,
- DoneReason: "tool_call",
- PromptEvalCount: 3,
- PromptEvalDuration: 1,
- },
- }
- for _, resp := range responses {
- select {
- case <-ctx.Done():
- return ctx.Err()
- default:
- fn(resp)
- time.Sleep(10 * time.Millisecond) // Small delay between chunks
- }
- }
- return nil
- }
- w := createRequest(t, s.ChatHandler, api.ChatRequest{
- Model: "test-system",
- Messages: []api.Message{
- {Role: "user", Content: "What's the weather in Seattle?"},
- },
- Tools: tools,
- Stream: &stream,
- })
- wg.Wait()
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d", w.Code)
- }
- // Read and validate the streamed responses
- decoder := json.NewDecoder(w.Body)
- var finalToolCall api.ToolCall
- for {
- var resp api.ChatResponse
- if err := decoder.Decode(&resp); err == io.EOF {
- break
- } else if err != nil {
- t.Fatal(err)
- }
- if resp.Done {
- if len(resp.Message.ToolCalls) != 1 {
- t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
- }
- finalToolCall = resp.Message.ToolCalls[0]
- }
- }
- expectedToolCall := api.ToolCall{
- Function: api.ToolCallFunction{
- Name: "get_weather",
- Arguments: api.ToolCallFunctionArguments{
- "location": "Seattle, WA",
- "unit": "celsius",
- },
- },
- }
- if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
- t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
- }
- })
- }
- func TestGenerate(t *testing.T) {
- gin.SetMode(gin.TestMode)
- mock := mockRunner{
- CompletionResponse: llm.CompletionResponse{
- Done: true,
- DoneReason: "stop",
- PromptEvalCount: 1,
- PromptEvalDuration: 1,
- EvalCount: 1,
- EvalDuration: 1,
- },
- }
- s := Server{
- sched: &Scheduler{
- pendingReqCh: make(chan *LlmRequest, 1),
- finishedReqCh: make(chan *LlmRequest, 1),
- expiredCh: make(chan *runnerRef, 1),
- unloadedCh: make(chan any, 1),
- loaded: make(map[string]*runnerRef),
- newServerFn: newMockServer(&mock),
- getGpuFn: discover.GetGPUInfo,
- getCpuFn: discover.GetCPUInfo,
- reschedDelay: 250 * time.Millisecond,
- loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
- // add small delay to simulate loading
- time.Sleep(time.Millisecond)
- req.successCh <- &runnerRef{
- llama: &mock,
- }
- },
- },
- }
- go s.sched.Run(context.TODO())
- _, digest := createBinFile(t, ggml.KV{
- "general.architecture": "llama",
- "llama.block_count": uint32(1),
- "llama.context_length": uint32(8192),
- "llama.embedding_length": uint32(4096),
- "llama.attention.head_count": uint32(32),
- "llama.attention.head_count_kv": uint32(8),
- "tokenizer.ggml.tokens": []string{""},
- "tokenizer.ggml.scores": []float32{0},
- "tokenizer.ggml.token_type": []int32{0},
- }, []ggml.Tensor{
- {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- })
- w := createRequest(t, s.CreateHandler, api.CreateRequest{
- Model: "test",
- Files: map[string]string{"file.gguf": digest},
- Template: `
- {{- if .System }}System: {{ .System }} {{ end }}
- {{- if .Prompt }}User: {{ .Prompt }} {{ end }}
- {{- if .Response }}Assistant: {{ .Response }} {{ end }}
- `,
- Stream: &stream,
- })
- if w.Code != http.StatusOK {
- t.Fatalf("expected status 200, got %d", w.Code)
- }
- t.Run("missing body", func(t *testing.T) {
- w := createRequest(t, s.GenerateHandler, nil)
- if w.Code != http.StatusNotFound {
- t.Errorf("expected status 404, got %d", w.Code)
- }
- if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- t.Run("missing model", func(t *testing.T) {
- w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
- if w.Code != http.StatusNotFound {
- t.Errorf("expected status 404, got %d", w.Code)
- }
- if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- t.Run("missing capabilities generate", func(t *testing.T) {
- _, digest := createBinFile(t, ggml.KV{
- "general.architecture": "bert",
- "bert.pooling_type": uint32(0),
- }, []ggml.Tensor{})
- w := createRequest(t, s.CreateHandler, api.CreateRequest{
- Model: "bert",
- Files: map[string]string{"file.gguf": digest},
- Stream: &stream,
- })
- if w.Code != http.StatusOK {
- t.Fatalf("expected status 200, got %d", w.Code)
- }
- w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
- Model: "bert",
- })
- if w.Code != http.StatusBadRequest {
- t.Errorf("expected status 400, got %d", w.Code)
- }
- if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- t.Run("missing capabilities suffix", func(t *testing.T) {
- w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
- Model: "test",
- Prompt: "def add(",
- Suffix: " return c",
- })
- if w.Code != http.StatusBadRequest {
- t.Errorf("expected status 400, got %d", w.Code)
- }
- if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support insert"}`); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- t.Run("load model", func(t *testing.T) {
- w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
- Model: "test",
- })
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d", w.Code)
- }
- var actual api.GenerateResponse
- if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
- t.Fatal(err)
- }
- if actual.Model != "test" {
- t.Errorf("expected model test, got %s", actual.Model)
- }
- if !actual.Done {
- t.Errorf("expected done true, got false")
- }
- if actual.DoneReason != "load" {
- t.Errorf("expected done reason load, got %s", actual.DoneReason)
- }
- })
- checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
- t.Helper()
- var actual api.GenerateResponse
- if err := json.NewDecoder(body).Decode(&actual); err != nil {
- t.Fatal(err)
- }
- if actual.Model != model {
- t.Errorf("expected model test, got %s", actual.Model)
- }
- if !actual.Done {
- t.Errorf("expected done false, got true")
- }
- if actual.DoneReason != "stop" {
- t.Errorf("expected done reason stop, got %s", actual.DoneReason)
- }
- if actual.Response != content {
- t.Errorf("expected response %s, got %s", content, actual.Response)
- }
- if actual.Context == nil {
- t.Errorf("expected context not nil")
- }
- if actual.PromptEvalCount == 0 {
- t.Errorf("expected prompt eval count > 0, got 0")
- }
- if actual.PromptEvalDuration == 0 {
- t.Errorf("expected prompt eval duration > 0, got 0")
- }
- if actual.EvalCount == 0 {
- t.Errorf("expected eval count > 0, got 0")
- }
- if actual.EvalDuration == 0 {
- t.Errorf("expected eval duration > 0, got 0")
- }
- if actual.LoadDuration == 0 {
- t.Errorf("expected load duration > 0, got 0")
- }
- if actual.TotalDuration == 0 {
- t.Errorf("expected total duration > 0, got 0")
- }
- }
- mock.CompletionResponse.Content = "Hi!"
- t.Run("prompt", func(t *testing.T) {
- w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
- Model: "test",
- Prompt: "Hello!",
- Stream: &stream,
- })
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d", w.Code)
- }
- if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- checkGenerateResponse(t, w.Body, "test", "Hi!")
- })
- w = createRequest(t, s.CreateHandler, api.CreateRequest{
- Model: "test-system",
- From: "test",
- System: "You are a helpful assistant.",
- })
- if w.Code != http.StatusOK {
- t.Fatalf("expected status 200, got %d", w.Code)
- }
- t.Run("prompt with model system", func(t *testing.T) {
- w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
- Model: "test-system",
- Prompt: "Hello!",
- Stream: &stream,
- })
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d", w.Code)
- }
- if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- checkGenerateResponse(t, w.Body, "test-system", "Hi!")
- })
- mock.CompletionResponse.Content = "Abra kadabra!"
- t.Run("prompt with system", func(t *testing.T) {
- w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
- Model: "test-system",
- Prompt: "Hello!",
- System: "You can perform magic tricks.",
- Stream: &stream,
- })
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d", w.Code)
- }
- if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
- })
- t.Run("prompt with template", func(t *testing.T) {
- w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
- Model: "test-system",
- Prompt: "Help me write tests.",
- System: "You can perform magic tricks.",
- Template: `{{- if .System }}{{ .System }} {{ end }}
- {{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
- {{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
- Stream: &stream,
- })
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d", w.Code)
- }
- if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
- })
- w = createRequest(t, s.CreateHandler, api.CreateRequest{
- Model: "test-suffix",
- Template: `{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
- {{- else }}{{ .Prompt }}
- {{- end }}`,
- From: "test",
- })
- if w.Code != http.StatusOK {
- t.Fatalf("expected status 200, got %d", w.Code)
- }
- t.Run("prompt with suffix", func(t *testing.T) {
- w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
- Model: "test-suffix",
- Prompt: "def add(",
- Suffix: " return c",
- })
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d", w.Code)
- }
- if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- t.Run("prompt without suffix", func(t *testing.T) {
- w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
- Model: "test-suffix",
- Prompt: "def add(",
- })
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d", w.Code)
- }
- if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- t.Run("raw", func(t *testing.T) {
- w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
- Model: "test-system",
- Prompt: "Help me write tests.",
- Raw: true,
- Stream: &stream,
- })
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d", w.Code)
- }
- if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- }
|