ParthSareen 4 месяцев назад
Родитель
Сommit
1e545ea7a0
2 измененных файлов с 67 добавлено и 44 удалено
  1. 63 0
      server/model_loader.go
  2. 4 44
      server/routes.go

+ 63 - 0
server/model_loader.go

@@ -0,0 +1,63 @@
+package server
+
+import (
+	"fmt"
+	"sync"
+
+	"github.com/ollama/ollama/llama"
+	"github.com/ollama/ollama/types/model"
+)
+
+type loadedModel struct {
+	model     *llama.Model
+	modelPath string
+}
+
+// modelCache stores loaded models keyed by their full path and params hash
+var modelCache sync.Map // map[string]*loadedModel
+
+func LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
+	modelName := model.ParseName(name)
+	if !modelName.IsValid() {
+		return nil, fmt.Errorf("invalid model name: %s", modelName)
+	}
+
+	modelPath, err := GetModel(modelName.String())
+	if err != nil {
+		return nil, fmt.Errorf("model not found: %s", modelName)
+	}
+
+	// Create cache key from model path and params hash
+	cacheKey := fmt.Sprintf("%s-%+v", modelPath.ModelPath, params)
+	if cached, ok := modelCache.Load(cacheKey); ok {
+		return cached.(*loadedModel), nil
+	}
+
+	// Evict existing model if any
+	evictExistingModel()
+
+	model, err := llama.LoadModelFromFile(modelPath.ModelPath, params)
+	if err != nil {
+		return nil, fmt.Errorf("failed to load model: %v", err)
+	}
+
+	loaded := &loadedModel{
+		model:     model,
+		modelPath: modelPath.ModelPath,
+	}
+	modelCache.Store(cacheKey, loaded)
+
+	return loaded, nil
+}
+
+// evictExistingModel removes any currently loaded model from the cache
+// Currently only supports a single model in cache at a time
+// TODO: Add proper cache eviction policy (LRU/size/TTL based)
+func evictExistingModel() {
+	modelCache.Range(func(key, value any) bool {
+		if cached, ok := modelCache.LoadAndDelete(key); ok {
+			llama.FreeModel(cached.(*loadedModel).model)
+		}
+		return true
+	})
+}

+ 4 - 44
server/routes.go

@@ -575,36 +575,16 @@ func (s *Server) TokenizeHandler(w http.ResponseWriter, r *http.Request) {
 		return
 		return
 	}
 	}
 
 
-	name := model.ParseName(req.Model)
-	if !name.IsValid() {
-		http.Error(w, fmt.Sprintf("model name `%q` is invalid", req.Model), http.StatusBadRequest)
-		return
-	}
-	name, err := getExistingName(name)
-	if err != nil {
-		http.Error(w, fmt.Sprintf("model `%s` not found", req.Model), http.StatusNotFound)
-		return
-	}
-
-	// Get local model path
-	modelPath, err := GetModel(name.String())
-	if err != nil {
-		http.Error(w, fmt.Sprintf("model `%s` not found", req.Model), http.StatusNotFound)
-		return
-	}
-
-	model, err := llama.LoadModelFromFile(modelPath.ModelPath, llama.ModelParams{
+	loadedModel, err := LoadModel(req.Model, llama.ModelParams{
 		VocabOnly: true,
 		VocabOnly: true,
-		UseMmap:   true,
 	})
 	})
 	if err != nil {
 	if err != nil {
 		http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError)
 		http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError)
 		return
 		return
 	}
 	}
-	defer llama.FreeModel(model)
 
 
 	// Tokenize the text
 	// Tokenize the text
-	tokens, err := model.Tokenize(req.Text, false, true)
+	tokens, err := loadedModel.model.Tokenize(req.Text, false, true)
 	if err != nil {
 	if err != nil {
 		http.Error(w, fmt.Sprintf("failed to tokenize text: %v", err), http.StatusInternalServerError)
 		http.Error(w, fmt.Sprintf("failed to tokenize text: %v", err), http.StatusInternalServerError)
 		return
 		return
@@ -645,37 +625,17 @@ func (s *Server) DetokenizeHandler(w http.ResponseWriter, r *http.Request) {
 		return
 		return
 	}
 	}
 
 
-	name := model.ParseName(req.Model)
-	if !name.IsValid() {
-		http.Error(w, fmt.Sprintf("model name `%q` is invalid", req.Model), http.StatusBadRequest)
-		return
-	}
-	name, err := getExistingName(name)
-	if err != nil {
-		http.Error(w, fmt.Sprintf("model `%s` not found", req.Model), http.StatusNotFound)
-		return
-	}
-
-	// Get local model path
-	modelPath, err := GetModel(name.String())
-	if err != nil {
-		http.Error(w, fmt.Sprintf("model `%s` not found", req.Model), http.StatusNotFound)
-		return
-	}
-
-	model, err := llama.LoadModelFromFile(modelPath.ModelPath, llama.ModelParams{
+	loadedModel, err := LoadModel(req.Model, llama.ModelParams{
 		VocabOnly: true,
 		VocabOnly: true,
-		UseMmap:   true,
 	})
 	})
 	if err != nil {
 	if err != nil {
 		http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError)
 		http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError)
 		return
 		return
 	}
 	}
-	defer llama.FreeModel(model)
 
 
 	var text string
 	var text string
 	for _, token := range req.Tokens {
 	for _, token := range req.Tokens {
-		text += model.TokenToPiece(token)
+		text += loadedModel.model.TokenToPiece(token)
 	}
 	}
 
 
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")