|
@@ -0,0 +1,290 @@
|
|
|
|
+package server
|
|
|
|
+
|
|
|
|
+import (
|
|
|
|
+ "bytes"
|
|
|
|
+ "context"
|
|
|
|
+ "encoding/json"
|
|
|
|
+ "fmt"
|
|
|
|
+ "net/http"
|
|
|
|
+ "net/http/httptest"
|
|
|
|
+ "strings"
|
|
|
|
+ "testing"
|
|
|
|
+ "time"
|
|
|
|
+
|
|
|
|
+ "github.com/gin-gonic/gin"
|
|
|
|
+ "github.com/google/go-cmp/cmp"
|
|
|
|
+
|
|
|
|
+ "github.com/ollama/ollama/api"
|
|
|
|
+ "github.com/ollama/ollama/discover"
|
|
|
|
+ "github.com/ollama/ollama/llm"
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+func TestTokenize(t *testing.T) {
|
|
|
|
+ gin.SetMode(gin.TestMode)
|
|
|
|
+
|
|
|
|
+ mock := mockRunner{
|
|
|
|
+ CompletionResponse: llm.CompletionResponse{
|
|
|
|
+ Done: true,
|
|
|
|
+ DoneReason: "stop",
|
|
|
|
+ PromptEvalCount: 1,
|
|
|
|
+ PromptEvalDuration: 1,
|
|
|
|
+ EvalCount: 1,
|
|
|
|
+ EvalDuration: 1,
|
|
|
|
+ },
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ s := Server{
|
|
|
|
+ sched: &Scheduler{
|
|
|
|
+ pendingReqCh: make(chan *LlmRequest, 1),
|
|
|
|
+ finishedReqCh: make(chan *LlmRequest, 1),
|
|
|
|
+ expiredCh: make(chan *runnerRef, 1),
|
|
|
|
+ unloadedCh: make(chan any, 1),
|
|
|
|
+ loaded: make(map[string]*runnerRef),
|
|
|
|
+ newServerFn: newMockServer(&mock),
|
|
|
|
+ getGpuFn: discover.GetGPUInfo,
|
|
|
|
+ getCpuFn: discover.GetCPUInfo,
|
|
|
|
+ reschedDelay: 250 * time.Millisecond,
|
|
|
|
+ loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
|
|
|
|
+ // add small delay to simulate loading
|
|
|
|
+ time.Sleep(time.Millisecond)
|
|
|
|
+ req.successCh <- &runnerRef{
|
|
|
|
+ llama: &mock,
|
|
|
|
+ }
|
|
|
|
+ },
|
|
|
|
+ },
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ go s.sched.Run(context.TODO())
|
|
|
|
+
|
|
|
|
+ t.Run("missing body", func(t *testing.T) {
|
|
|
|
+ w := httptest.NewRecorder()
|
|
|
|
+ r := httptest.NewRequest(http.MethodPost, "/api/tokenize", nil)
|
|
|
|
+ s.TokenizeHandler(w, r)
|
|
|
|
+
|
|
|
|
+ if w.Code != http.StatusBadRequest {
|
|
|
|
+ t.Errorf("expected status 400, got %d", w.Code)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if diff := cmp.Diff(w.Body.String(), "missing request body\n"); diff != "" {
|
|
|
|
+ t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
|
|
+ }
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ t.Run("missing model", func(t *testing.T) {
|
|
|
|
+ w := httptest.NewRecorder()
|
|
|
|
+ r := httptest.NewRequest(http.MethodPost, "/api/tokenize", strings.NewReader("{}"))
|
|
|
|
+ s.TokenizeHandler(w, r)
|
|
|
|
+
|
|
|
|
+ if w.Code != http.StatusBadRequest {
|
|
|
|
+ t.Errorf("expected status 400, got %d", w.Code)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if diff := cmp.Diff(w.Body.String(), "missing `text` for tokenization\n"); diff != "" {
|
|
|
|
+ t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
|
|
+ }
|
|
|
|
+ })
|
|
|
|
+ t.Run("tokenize text", func(t *testing.T) {
|
|
|
|
+ // First create the model
|
|
|
|
+ w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
|
|
+ Model: "test",
|
|
|
|
+ Modelfile: fmt.Sprintf(`FROM %s`, createBinFile(t, llm.KV{
|
|
|
|
+ "general.architecture": "llama",
|
|
|
|
+ "llama.block_count": uint32(1),
|
|
|
|
+ "llama.context_length": uint32(8192),
|
|
|
|
+ "llama.embedding_length": uint32(4096),
|
|
|
|
+ "llama.attention.head_count": uint32(32),
|
|
|
|
+ "llama.attention.head_count_kv": uint32(8),
|
|
|
|
+ "tokenizer.ggml.tokens": []string{""},
|
|
|
|
+ "tokenizer.ggml.scores": []float32{0},
|
|
|
|
+ "tokenizer.ggml.token_type": []int32{0},
|
|
|
|
+ }, []llm.Tensor{
|
|
|
|
+ {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
|
|
+ {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
|
|
+ })),
|
|
|
|
+ })
|
|
|
|
+ if w.Code != http.StatusOK {
|
|
|
|
+ t.Fatalf("failed to create model: %d", w.Code)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Now test tokenization
|
|
|
|
+ body, err := json.Marshal(api.TokenizeRequest{
|
|
|
|
+ Model: "test",
|
|
|
|
+ Text: "Hello world how are you",
|
|
|
|
+ })
|
|
|
|
+ if err != nil {
|
|
|
|
+ t.Fatalf("failed to marshal request: %v", err)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ w = httptest.NewRecorder()
|
|
|
|
+ r := httptest.NewRequest(http.MethodPost, "/api/tokenize", bytes.NewReader(body))
|
|
|
|
+ r.Header.Set("Content-Type", "application/json")
|
|
|
|
+ s.TokenizeHandler(w, r)
|
|
|
|
+
|
|
|
|
+ if w.Code != http.StatusOK {
|
|
|
|
+ t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ var resp api.TokenizeResponse
|
|
|
|
+ if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
|
|
|
+ t.Errorf("failed to decode response: %v", err)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Our mock tokenizer creates sequential tokens based on word count
|
|
|
|
+ expected := []int{0, 1, 2, 3, 4}
|
|
|
|
+ if diff := cmp.Diff(resp.Tokens, expected); diff != "" {
|
|
|
|
+ t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
|
|
+ }
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ t.Run("tokenize empty text", func(t *testing.T) {
|
|
|
|
+ body, err := json.Marshal(api.TokenizeRequest{
|
|
|
|
+ Model: "test",
|
|
|
|
+ Text: "",
|
|
|
|
+ })
|
|
|
|
+ if err != nil {
|
|
|
|
+ t.Fatalf("failed to marshal request: %v", err)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ w := httptest.NewRecorder()
|
|
|
|
+ r := httptest.NewRequest(http.MethodPost, "/api/tokenize", bytes.NewReader(body))
|
|
|
|
+ r.Header.Set("Content-Type", "application/json")
|
|
|
|
+ s.TokenizeHandler(w, r)
|
|
|
|
+
|
|
|
|
+ if w.Code != http.StatusBadRequest {
|
|
|
|
+ t.Errorf("expected status 400, got %d", w.Code)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if diff := cmp.Diff(w.Body.String(), "missing `text` for tokenization\n"); diff != "" {
|
|
|
|
+ t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
|
|
+ }
|
|
|
|
+ })
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func TestDetokenize(t *testing.T) {
|
|
|
|
+ gin.SetMode(gin.TestMode)
|
|
|
|
+
|
|
|
|
+ mock := mockRunner{
|
|
|
|
+ CompletionResponse: llm.CompletionResponse{
|
|
|
|
+ Done: true,
|
|
|
|
+ DoneReason: "stop",
|
|
|
|
+ PromptEvalCount: 1,
|
|
|
|
+ PromptEvalDuration: 1,
|
|
|
|
+ EvalCount: 1,
|
|
|
|
+ EvalDuration: 1,
|
|
|
|
+ },
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ s := Server{
|
|
|
|
+ sched: &Scheduler{
|
|
|
|
+ pendingReqCh: make(chan *LlmRequest, 1),
|
|
|
|
+ finishedReqCh: make(chan *LlmRequest, 1),
|
|
|
|
+ expiredCh: make(chan *runnerRef, 1),
|
|
|
|
+ unloadedCh: make(chan any, 1),
|
|
|
|
+ loaded: make(map[string]*runnerRef),
|
|
|
|
+ newServerFn: newMockServer(&mock),
|
|
|
|
+ getGpuFn: discover.GetGPUInfo,
|
|
|
|
+ getCpuFn: discover.GetCPUInfo,
|
|
|
|
+ reschedDelay: 250 * time.Millisecond,
|
|
|
|
+ loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
|
|
|
|
+ // add small delay to simulate loading
|
|
|
|
+ time.Sleep(time.Millisecond)
|
|
|
|
+ req.successCh <- &runnerRef{
|
|
|
|
+ llama: &mock,
|
|
|
|
+ }
|
|
|
|
+ },
|
|
|
|
+ },
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ go s.sched.Run(context.TODO())
|
|
|
|
+
|
|
|
|
+ t.Run("detokenize tokens", func(t *testing.T) {
|
|
|
|
+ // Create the model first
|
|
|
|
+ w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
|
|
+ Model: "test",
|
|
|
|
+ Modelfile: fmt.Sprintf(`FROM %s`, createBinFile(t, llm.KV{
|
|
|
|
+ "general.architecture": "llama",
|
|
|
|
+ "llama.block_count": uint32(1),
|
|
|
|
+ "llama.context_length": uint32(8192),
|
|
|
|
+ "llama.embedding_length": uint32(4096),
|
|
|
|
+ "llama.attention.head_count": uint32(32),
|
|
|
|
+ "llama.attention.head_count_kv": uint32(8),
|
|
|
|
+ "tokenizer.ggml.tokens": []string{""},
|
|
|
|
+ "tokenizer.ggml.scores": []float32{0},
|
|
|
|
+ "tokenizer.ggml.token_type": []int32{0},
|
|
|
|
+ }, []llm.Tensor{
|
|
|
|
+ {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
|
|
+ {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
|
|
+ })),
|
|
|
|
+ Stream: &stream,
|
|
|
|
+ })
|
|
|
|
+ if w.Code != http.StatusOK {
|
|
|
|
+ t.Fatalf("failed to create model: %d - %s", w.Code, w.Body.String())
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ body, err := json.Marshal(api.DetokenizeRequest{
|
|
|
|
+ Model: "test",
|
|
|
|
+ Tokens: []int{0, 1, 2, 3, 4},
|
|
|
|
+ })
|
|
|
|
+ if err != nil {
|
|
|
|
+ t.Fatalf("failed to marshal request: %v", err)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ w = httptest.NewRecorder()
|
|
|
|
+ r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
|
|
|
|
+ r.Header.Set("Content-Type", "application/json")
|
|
|
|
+ s.DetokenizeHandler(w, r)
|
|
|
|
+
|
|
|
|
+ if w.Code != http.StatusOK {
|
|
|
|
+ t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ var resp api.DetokenizeResponse
|
|
|
|
+ if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
|
|
|
+ t.Errorf("failed to decode response: %v", err)
|
|
|
|
+ }
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ t.Run("detokenize empty tokens", func(t *testing.T) {
|
|
|
|
+ body, err := json.Marshal(api.DetokenizeRequest{
|
|
|
|
+ Model: "test",
|
|
|
|
+ })
|
|
|
|
+ if err != nil {
|
|
|
|
+ t.Fatalf("failed to marshal request: %v", err)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ w := httptest.NewRecorder()
|
|
|
|
+ r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
|
|
|
|
+ r.Header.Set("Content-Type", "application/json")
|
|
|
|
+ s.DetokenizeHandler(w, r)
|
|
|
|
+
|
|
|
|
+ if w.Code != http.StatusBadRequest {
|
|
|
|
+ t.Errorf("expected status 400, got %d", w.Code)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if diff := cmp.Diff(w.Body.String(), "missing tokens for detokenization\n"); diff != "" {
|
|
|
|
+ t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
|
|
+ }
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ t.Run("detokenize missing model", func(t *testing.T) {
|
|
|
|
+ body, err := json.Marshal(api.DetokenizeRequest{
|
|
|
|
+ Tokens: []int{0, 1, 2},
|
|
|
|
+ })
|
|
|
|
+ if err != nil {
|
|
|
|
+ t.Fatalf("failed to marshal request: %v", err)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ w := httptest.NewRecorder()
|
|
|
|
+ r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
|
|
|
|
+ r.Header.Set("Content-Type", "application/json")
|
|
|
|
+ s.DetokenizeHandler(w, r)
|
|
|
|
+
|
|
|
|
+ if w.Code != http.StatusNotFound {
|
|
|
|
+ t.Errorf("expected status 404, got %d", w.Code)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if diff := cmp.Diff(w.Body.String(), "model '' not found\n"); diff != "" {
|
|
|
|
+ t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
|
|
+ }
|
|
|
|
+ })
|
|
|
|
+}
|