ParthSareen 4 months ago
parent
commit
11acb85ff3
3 changed files with 159 additions and 10 deletions
  1. 12 8
      server/model_loader.go
  2. 3 2
      server/routes.go
  3. 144 0
      server/routes_tokenize_test.go

+ 12 - 8
server/model_loader.go

@@ -9,14 +9,18 @@ import (
 )
 
 type loadedModel struct {
-	model     *llama.Model
+	model     llama.Model
 	modelPath string
 }
 
+type modelLoader struct {
+	cache sync.Map
+}
+
 // modelCache stores loaded models keyed by their full path and params hash
 var modelCache sync.Map // map[string]*loadedModel
 
-func LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
+func (ml *modelLoader) LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
 	modelName := model.ParseName(name)
 	if !modelName.IsValid() {
 		return nil, fmt.Errorf("invalid model name: %s", modelName)
@@ -34,7 +38,7 @@ func LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
 	}
 
 	// Evict existing model if any
-	evictExistingModel()
+	ml.evictExistingModel()
 
 	model, err := llama.LoadModelFromFile(modelPath.ModelPath, params)
 	if err != nil {
@@ -42,7 +46,7 @@ func LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
 	}
 
 	loaded := &loadedModel{
-		model:     model,
+		model:     *model,
 		modelPath: modelPath.ModelPath,
 	}
 	modelCache.Store(cacheKey, loaded)
@@ -53,10 +57,10 @@ func LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
 // evictExistingModel removes any currently loaded model from the cache
 // Currently only supports a single model in cache at a time
 // TODO: Add proper cache eviction policy (LRU/size/TTL based)
-func evictExistingModel() {
-	modelCache.Range(func(key, value any) bool {
-		if cached, ok := modelCache.LoadAndDelete(key); ok {
-			llama.FreeModel(cached.(*loadedModel).model)
+func (ml *modelLoader) evictExistingModel() {
+	ml.cache.Range(func(key, value any) bool {
+		if cached, ok := ml.cache.LoadAndDelete(key); ok {
+			llama.FreeModel(&cached.(*loadedModel).model)
 		}
 		return true
 	})

+ 3 - 2
server/routes.go

@@ -47,6 +47,7 @@ var mode string = gin.DebugMode
 type Server struct {
 	addr  net.Addr
 	sched *Scheduler
+	ml    modelLoader
 }
 
 func init() {
@@ -575,7 +576,7 @@ func (s *Server) TokenizeHandler(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	loadedModel, err := LoadModel(req.Model, llama.ModelParams{
+	loadedModel, err := s.ml.LoadModel(req.Model, llama.ModelParams{
 		VocabOnly: true,
 	})
 	if err != nil {
@@ -625,7 +626,7 @@ func (s *Server) DetokenizeHandler(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	loadedModel, err := LoadModel(req.Model, llama.ModelParams{
+	loadedModel, err := s.ml.LoadModel(req.Model, llama.ModelParams{
 		VocabOnly: true,
 	})
 	if err != nil {

+ 144 - 0
server/routes_tokenize_test.go

@@ -0,0 +1,144 @@
+package server
+
+import (
+	"encoding/json"
+	"fmt"
+	"net/http"
+	"testing"
+	"time"
+
+	"github.com/gin-gonic/gin"
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/discover"
+	"github.com/ollama/ollama/llama"
+	"github.com/ollama/ollama/llm"
+)
+
+type mockModelLoader struct {
+	LoadModelFn func(string, llama.ModelParams) (*loadedModel, error)
+}
+
+func (ml *mockModelLoader) LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
+	if ml.LoadModelFn != nil {
+		return ml.LoadModelFn(name, params)
+	}
+
+	return &loadedModel{
+		model: mockModel{},
+	}, nil
+}
+
+type mockModel struct {
+	llama.Model
+	TokenizeFn     func(text string, addBos bool, addEos bool) ([]int, error)
+	TokenToPieceFn func(token int) string
+}
+
+func (m *mockModel) Tokenize(text string, addBos bool, addEos bool) ([]int, error) {
+	return []int{1, 2, 3}, nil
+}
+
+func (m *mockModel) TokenToPiece(token int) string {
+	return fmt.Sprint(token)
+}
+
+func TestTokenizeHandler(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
+	mockLoader := mockModelLoader{
+		LoadModelFn: func(name string, params llama.ModelParams) (*loadedModel, error) {
+			return &loadedModel{
+				model: mockModel{},
+			}, nil
+		},
+	}
+
+	s := Server{
+		sched: &Scheduler{
+			pendingReqCh:  make(chan *LlmRequest, 1),
+			finishedReqCh: make(chan *LlmRequest, 1),
+			expiredCh:     make(chan *runnerRef, 1),
+			unloadedCh:    make(chan any, 1),
+			loaded:        make(map[string]*runnerRef),
+			newServerFn:   newMockServer(&mockRunner{}),
+			getGpuFn:      discover.GetGPUInfo,
+			getCpuFn:      discover.GetCPUInfo,
+			reschedDelay:  250 * time.Millisecond,
+			loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
+				time.Sleep(time.Millisecond)
+				req.successCh <- &runnerRef{
+					llama: &mockRunner{},
+				}
+			},
+		},
+		ml: mockLoader,
+	}
+
+	t.Run("method not allowed", func(t *testing.T) {
+		w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil)
+		if w.Code != http.StatusMethodNotAllowed {
+			t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
+		}
+	})
+
+	t.Run("missing body", func(t *testing.T) {
+		w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil)
+		if w.Code != http.StatusBadRequest {
+			t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
+		}
+	})
+
+	t.Run("missing text", func(t *testing.T) {
+		w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
+			Model: "test",
+		})
+		if w.Code != http.StatusBadRequest {
+			t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
+		}
+	})
+
+	t.Run("missing model", func(t *testing.T) {
+		w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
+			Text: "test text",
+		})
+		if w.Code != http.StatusBadRequest {
+			t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
+		}
+	})
+
+	t.Run("model not found", func(t *testing.T) {
+		w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
+			Model: "nonexistent",
+			Text:  "test text",
+		})
+		if w.Code != http.StatusInternalServerError {
+			t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
+		}
+	})
+
+	t.Run("successful tokenization", func(t *testing.T) {
+		w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
+			Model: "test",
+			Text:  "test text",
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
+		}
+
+		var resp api.TokenizeResponse
+		if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
+			t.Fatal(err)
+		}
+
+		expectedTokens := []int{0, 1}
+		if len(resp.Tokens) != len(expectedTokens) {
+			t.Errorf("expected %d tokens, got %d", len(expectedTokens), len(resp.Tokens))
+		}
+		for i, token := range resp.Tokens {
+			if token != expectedTokens[i] {
+				t.Errorf("expected token %d at position %d, got %d", expectedTokens[i], i, token)
+			}
+		}
+	})
+}