Browse Source

add tests

ParthSareen 4 months ago
parent
commit
f0a5f7994b
3 changed files with 299 additions and 2 deletions
  1. 1 2
      server/routes.go
  2. 8 0
      server/routes_generate_test.go
  3. 290 0
      server/routes_tokenization_test.go

+ 1 - 2
server/routes.go

@@ -564,9 +564,8 @@ func (s *Server) TokenizeHandler(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	slog.Info("tokenize request", "text", req.Text, "tokens", req.Text)
 	if req.Text == "" {
-		http.Error(w, "missing text for tokenization", http.StatusBadRequest)
+		http.Error(w, "missing `text` for tokenization", http.StatusBadRequest)
 		return
 	}
 

+ 8 - 0
server/routes_generate_test.go

@@ -46,6 +46,14 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error
 	return
 }
 
+func (mockRunner) Detokenize(_ context.Context, tokens []int) (string, error) {
+	var strs []string
+	for _, t := range tokens {
+		strs = append(strs, fmt.Sprint(t))
+	}
+	return strings.Join(strs, " "), nil
+}
+
 func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
 	return func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
 		return mock, nil

+ 290 - 0
server/routes_tokenization_test.go

@@ -0,0 +1,290 @@
+package server
+
+import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"fmt"
+	"net/http"
+	"net/http/httptest"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/gin-gonic/gin"
+	"github.com/google/go-cmp/cmp"
+
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/discover"
+	"github.com/ollama/ollama/llm"
+)
+
+func TestTokenize(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
+	mock := mockRunner{
+		CompletionResponse: llm.CompletionResponse{
+			Done:               true,
+			DoneReason:         "stop",
+			PromptEvalCount:    1,
+			PromptEvalDuration: 1,
+			EvalCount:          1,
+			EvalDuration:       1,
+		},
+	}
+
+	s := Server{
+		sched: &Scheduler{
+			pendingReqCh:  make(chan *LlmRequest, 1),
+			finishedReqCh: make(chan *LlmRequest, 1),
+			expiredCh:     make(chan *runnerRef, 1),
+			unloadedCh:    make(chan any, 1),
+			loaded:        make(map[string]*runnerRef),
+			newServerFn:   newMockServer(&mock),
+			getGpuFn:      discover.GetGPUInfo,
+			getCpuFn:      discover.GetCPUInfo,
+			reschedDelay:  250 * time.Millisecond,
+			loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
+				// add small delay to simulate loading
+				time.Sleep(time.Millisecond)
+				req.successCh <- &runnerRef{
+					llama: &mock,
+				}
+			},
+		},
+	}
+
+	go s.sched.Run(context.TODO())
+
+	t.Run("missing body", func(t *testing.T) {
+		w := httptest.NewRecorder()
+		r := httptest.NewRequest(http.MethodPost, "/api/tokenize", nil)
+		s.TokenizeHandler(w, r)
+
+		if w.Code != http.StatusBadRequest {
+			t.Errorf("expected status 400, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), "missing request body\n"); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("missing model", func(t *testing.T) {
+		w := httptest.NewRecorder()
+		r := httptest.NewRequest(http.MethodPost, "/api/tokenize", strings.NewReader("{}"))
+		s.TokenizeHandler(w, r)
+
+		if w.Code != http.StatusBadRequest {
+			t.Errorf("expected status 400, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), "missing `text` for tokenization\n"); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+	t.Run("tokenize text", func(t *testing.T) {
+		// First create the model
+		w := createRequest(t, s.CreateHandler, api.CreateRequest{
+			Model: "test",
+			Modelfile: fmt.Sprintf(`FROM %s`, createBinFile(t, llm.KV{
+				"general.architecture":          "llama",
+				"llama.block_count":             uint32(1),
+				"llama.context_length":          uint32(8192),
+				"llama.embedding_length":        uint32(4096),
+				"llama.attention.head_count":    uint32(32),
+				"llama.attention.head_count_kv": uint32(8),
+				"tokenizer.ggml.tokens":         []string{""},
+				"tokenizer.ggml.scores":         []float32{0},
+				"tokenizer.ggml.token_type":     []int32{0},
+			}, []llm.Tensor{
+				{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+				{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			})),
+		})
+		if w.Code != http.StatusOK {
+			t.Fatalf("failed to create model: %d", w.Code)
+		}
+
+		// Now test tokenization
+		body, err := json.Marshal(api.TokenizeRequest{
+			Model: "test",
+			Text:  "Hello world how are you",
+		})
+		if err != nil {
+			t.Fatalf("failed to marshal request: %v", err)
+		}
+
+		w = httptest.NewRecorder()
+		r := httptest.NewRequest(http.MethodPost, "/api/tokenize", bytes.NewReader(body))
+		r.Header.Set("Content-Type", "application/json")
+		s.TokenizeHandler(w, r)
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
+		}
+
+		var resp api.TokenizeResponse
+		if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
+			t.Errorf("failed to decode response: %v", err)
+		}
+
+		// Our mock tokenizer creates sequential tokens based on word count
+		expected := []int{0, 1, 2, 3, 4}
+		if diff := cmp.Diff(resp.Tokens, expected); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("tokenize empty text", func(t *testing.T) {
+		body, err := json.Marshal(api.TokenizeRequest{
+			Model: "test",
+			Text:  "",
+		})
+		if err != nil {
+			t.Fatalf("failed to marshal request: %v", err)
+		}
+
+		w := httptest.NewRecorder()
+		r := httptest.NewRequest(http.MethodPost, "/api/tokenize", bytes.NewReader(body))
+		r.Header.Set("Content-Type", "application/json")
+		s.TokenizeHandler(w, r)
+
+		if w.Code != http.StatusBadRequest {
+			t.Errorf("expected status 400, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), "missing `text` for tokenization\n"); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+}
+
+func TestDetokenize(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
+	mock := mockRunner{
+		CompletionResponse: llm.CompletionResponse{
+			Done:               true,
+			DoneReason:         "stop",
+			PromptEvalCount:    1,
+			PromptEvalDuration: 1,
+			EvalCount:          1,
+			EvalDuration:       1,
+		},
+	}
+
+	s := Server{
+		sched: &Scheduler{
+			pendingReqCh:  make(chan *LlmRequest, 1),
+			finishedReqCh: make(chan *LlmRequest, 1),
+			expiredCh:     make(chan *runnerRef, 1),
+			unloadedCh:    make(chan any, 1),
+			loaded:        make(map[string]*runnerRef),
+			newServerFn:   newMockServer(&mock),
+			getGpuFn:      discover.GetGPUInfo,
+			getCpuFn:      discover.GetCPUInfo,
+			reschedDelay:  250 * time.Millisecond,
+			loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
+				// add small delay to simulate loading
+				time.Sleep(time.Millisecond)
+				req.successCh <- &runnerRef{
+					llama: &mock,
+				}
+			},
+		},
+	}
+
+	go s.sched.Run(context.TODO())
+
+	t.Run("detokenize tokens", func(t *testing.T) {
+		// Create the model first
+		w := createRequest(t, s.CreateHandler, api.CreateRequest{
+			Model: "test",
+			Modelfile: fmt.Sprintf(`FROM %s`, createBinFile(t, llm.KV{
+				"general.architecture":          "llama",
+				"llama.block_count":             uint32(1),
+				"llama.context_length":          uint32(8192),
+				"llama.embedding_length":        uint32(4096),
+				"llama.attention.head_count":    uint32(32),
+				"llama.attention.head_count_kv": uint32(8),
+				"tokenizer.ggml.tokens":         []string{""},
+				"tokenizer.ggml.scores":         []float32{0},
+				"tokenizer.ggml.token_type":     []int32{0},
+			}, []llm.Tensor{
+				{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+				{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+			})),
+			Stream: &stream,
+		})
+		if w.Code != http.StatusOK {
+			t.Fatalf("failed to create model: %d - %s", w.Code, w.Body.String())
+		}
+
+		body, err := json.Marshal(api.DetokenizeRequest{
+			Model:  "test",
+			Tokens: []int{0, 1, 2, 3, 4},
+		})
+		if err != nil {
+			t.Fatalf("failed to marshal request: %v", err)
+		}
+
+		w = httptest.NewRecorder()
+		r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
+		r.Header.Set("Content-Type", "application/json")
+		s.DetokenizeHandler(w, r)
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
+		}
+
+		var resp api.DetokenizeResponse
+		if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
+			t.Errorf("failed to decode response: %v", err)
+		}
+	})
+
+	t.Run("detokenize empty tokens", func(t *testing.T) {
+		body, err := json.Marshal(api.DetokenizeRequest{
+			Model: "test",
+		})
+		if err != nil {
+			t.Fatalf("failed to marshal request: %v", err)
+		}
+
+		w := httptest.NewRecorder()
+		r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
+		r.Header.Set("Content-Type", "application/json")
+		s.DetokenizeHandler(w, r)
+
+		if w.Code != http.StatusBadRequest {
+			t.Errorf("expected status 400, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), "missing tokens for detokenization\n"); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("detokenize missing model", func(t *testing.T) {
+		body, err := json.Marshal(api.DetokenizeRequest{
+			Tokens: []int{0, 1, 2},
+		})
+		if err != nil {
+			t.Fatalf("failed to marshal request: %v", err)
+		}
+
+		w := httptest.NewRecorder()
+		r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
+		r.Header.Set("Content-Type", "application/json")
+		s.DetokenizeHandler(w, r)
+
+		if w.Code != http.StatusNotFound {
+			t.Errorf("expected status 404, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), "model '' not found\n"); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+}