123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- package server
- import (
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "time"
- "github.com/gin-gonic/gin"
- "github.com/google/go-cmp/cmp"
- "github.com/ollama/ollama/api"
- "github.com/ollama/ollama/discover"
- "github.com/ollama/ollama/llm"
- )
- func TestTokenize(t *testing.T) {
- gin.SetMode(gin.TestMode)
- mock := mockRunner{
- CompletionResponse: llm.CompletionResponse{
- Done: true,
- DoneReason: "stop",
- PromptEvalCount: 1,
- PromptEvalDuration: 1,
- EvalCount: 1,
- EvalDuration: 1,
- },
- }
- s := Server{
- sched: &Scheduler{
- pendingReqCh: make(chan *LlmRequest, 1),
- finishedReqCh: make(chan *LlmRequest, 1),
- expiredCh: make(chan *runnerRef, 1),
- unloadedCh: make(chan any, 1),
- loaded: make(map[string]*runnerRef),
- newServerFn: newMockServer(&mock),
- getGpuFn: discover.GetGPUInfo,
- getCpuFn: discover.GetCPUInfo,
- reschedDelay: 250 * time.Millisecond,
- loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
- // add small delay to simulate loading
- time.Sleep(time.Millisecond)
- req.successCh <- &runnerRef{
- llama: &mock,
- }
- },
- },
- }
- go s.sched.Run(context.TODO())
- t.Run("missing body", func(t *testing.T) {
- w := httptest.NewRecorder()
- r := httptest.NewRequest(http.MethodPost, "/api/tokenize", nil)
- s.TokenizeHandler(w, r)
- if w.Code != http.StatusBadRequest {
- t.Errorf("expected status 400, got %d", w.Code)
- }
- if diff := cmp.Diff(w.Body.String(), "missing request body\n"); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- t.Run("missing model", func(t *testing.T) {
- w := httptest.NewRecorder()
- r := httptest.NewRequest(http.MethodPost, "/api/tokenize", strings.NewReader("{}"))
- s.TokenizeHandler(w, r)
- if w.Code != http.StatusBadRequest {
- t.Errorf("expected status 400, got %d", w.Code)
- }
- if diff := cmp.Diff(w.Body.String(), "missing `text` for tokenization\n"); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- t.Run("tokenize text", func(t *testing.T) {
- // First create the model
- w := createRequest(t, s.CreateHandler, api.CreateRequest{
- Model: "test",
- Modelfile: fmt.Sprintf(`FROM %s`, createBinFile(t, llm.KV{
- "general.architecture": "llama",
- "llama.block_count": uint32(1),
- "llama.context_length": uint32(8192),
- "llama.embedding_length": uint32(4096),
- "llama.attention.head_count": uint32(32),
- "llama.attention.head_count_kv": uint32(8),
- "tokenizer.ggml.tokens": []string{""},
- "tokenizer.ggml.scores": []float32{0},
- "tokenizer.ggml.token_type": []int32{0},
- }, []llm.Tensor{
- {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- })),
- })
- if w.Code != http.StatusOK {
- t.Fatalf("failed to create model: %d", w.Code)
- }
- // Now test tokenization
- body, err := json.Marshal(api.TokenizeRequest{
- Model: "test",
- Text: "Hello world how are you",
- })
- if err != nil {
- t.Fatalf("failed to marshal request: %v", err)
- }
- w = httptest.NewRecorder()
- r := httptest.NewRequest(http.MethodPost, "/api/tokenize", bytes.NewReader(body))
- r.Header.Set("Content-Type", "application/json")
- s.TokenizeHandler(w, r)
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
- }
- var resp api.TokenizeResponse
- if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
- t.Errorf("failed to decode response: %v", err)
- }
- // Our mock tokenizer creates sequential tokens based on word count
- expected := []int{0, 1, 2, 3, 4}
- if diff := cmp.Diff(resp.Tokens, expected); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- t.Run("tokenize empty text", func(t *testing.T) {
- body, err := json.Marshal(api.TokenizeRequest{
- Model: "test",
- Text: "",
- })
- if err != nil {
- t.Fatalf("failed to marshal request: %v", err)
- }
- w := httptest.NewRecorder()
- r := httptest.NewRequest(http.MethodPost, "/api/tokenize", bytes.NewReader(body))
- r.Header.Set("Content-Type", "application/json")
- s.TokenizeHandler(w, r)
- if w.Code != http.StatusBadRequest {
- t.Errorf("expected status 400, got %d", w.Code)
- }
- if diff := cmp.Diff(w.Body.String(), "missing `text` for tokenization\n"); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- }
- func TestDetokenize(t *testing.T) {
- gin.SetMode(gin.TestMode)
- mock := mockRunner{
- CompletionResponse: llm.CompletionResponse{
- Done: true,
- DoneReason: "stop",
- PromptEvalCount: 1,
- PromptEvalDuration: 1,
- EvalCount: 1,
- EvalDuration: 1,
- },
- }
- s := Server{
- sched: &Scheduler{
- pendingReqCh: make(chan *LlmRequest, 1),
- finishedReqCh: make(chan *LlmRequest, 1),
- expiredCh: make(chan *runnerRef, 1),
- unloadedCh: make(chan any, 1),
- loaded: make(map[string]*runnerRef),
- newServerFn: newMockServer(&mock),
- getGpuFn: discover.GetGPUInfo,
- getCpuFn: discover.GetCPUInfo,
- reschedDelay: 250 * time.Millisecond,
- loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
- // add small delay to simulate loading
- time.Sleep(time.Millisecond)
- req.successCh <- &runnerRef{
- llama: &mock,
- }
- },
- },
- }
- go s.sched.Run(context.TODO())
- t.Run("detokenize tokens", func(t *testing.T) {
- // Create the model first
- w := createRequest(t, s.CreateHandler, api.CreateRequest{
- Model: "test",
- Modelfile: fmt.Sprintf(`FROM %s`, createBinFile(t, llm.KV{
- "general.architecture": "llama",
- "llama.block_count": uint32(1),
- "llama.context_length": uint32(8192),
- "llama.embedding_length": uint32(4096),
- "llama.attention.head_count": uint32(32),
- "llama.attention.head_count_kv": uint32(8),
- "tokenizer.ggml.tokens": []string{""},
- "tokenizer.ggml.scores": []float32{0},
- "tokenizer.ggml.token_type": []int32{0},
- }, []llm.Tensor{
- {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
- })),
- Stream: &stream,
- })
- if w.Code != http.StatusOK {
- t.Fatalf("failed to create model: %d - %s", w.Code, w.Body.String())
- }
- body, err := json.Marshal(api.DetokenizeRequest{
- Model: "test",
- Tokens: []int{0, 1, 2, 3, 4},
- })
- if err != nil {
- t.Fatalf("failed to marshal request: %v", err)
- }
- w = httptest.NewRecorder()
- r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
- r.Header.Set("Content-Type", "application/json")
- s.DetokenizeHandler(w, r)
- if w.Code != http.StatusOK {
- t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
- }
- var resp api.DetokenizeResponse
- if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
- t.Errorf("failed to decode response: %v", err)
- }
- })
- t.Run("detokenize empty tokens", func(t *testing.T) {
- body, err := json.Marshal(api.DetokenizeRequest{
- Model: "test",
- })
- if err != nil {
- t.Fatalf("failed to marshal request: %v", err)
- }
- w := httptest.NewRecorder()
- r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
- r.Header.Set("Content-Type", "application/json")
- s.DetokenizeHandler(w, r)
- if w.Code != http.StatusBadRequest {
- t.Errorf("expected status 400, got %d", w.Code)
- }
- if diff := cmp.Diff(w.Body.String(), "missing tokens for detokenization\n"); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- t.Run("detokenize missing model", func(t *testing.T) {
- body, err := json.Marshal(api.DetokenizeRequest{
- Tokens: []int{0, 1, 2},
- })
- if err != nil {
- t.Fatalf("failed to marshal request: %v", err)
- }
- w := httptest.NewRecorder()
- r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
- r.Header.Set("Content-Type", "application/json")
- s.DetokenizeHandler(w, r)
- if w.Code != http.StatusNotFound {
- t.Errorf("expected status 404, got %d", w.Code)
- }
- if diff := cmp.Diff(w.Body.String(), "model '' not found\n"); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- }
|