Bruce MacDonald 1 год назад
Родитель
Сommit
5b5cc9c9f1
2 измененных файлов с 85 добавлено и 31 удалено
  1. 11 0
      api/types.go
  2. 74 31
      server/routes.go

+ 11 - 0
api/types.go

@@ -42,6 +42,17 @@ type GenerateRequest struct {
 	Options map[string]interface{} `json:"options"`
 }
 
+type EmbeddingRequest struct {
+	Model  string `json:"model"`
+	Prompt string `json:"prompt"`
+
+	Options map[string]interface{} `json:"options"`
+}
+
+type EmbeddingResponse struct {
+	Embedding []float64 `json:"embedding"`
+}
+
 type CreateRequest struct {
 	Name string `json:"name"`
 	Path string `json:"path"`

+ 74 - 31
server/routes.go

@@ -38,35 +38,17 @@ var loaded struct {
 	options api.Options
 }
 
-func GenerateHandler(c *gin.Context) {
-	loaded.mu.Lock()
-	defer loaded.mu.Unlock()
-
-	checkpointStart := time.Now()
-
-	var req api.GenerateRequest
-	if err := c.ShouldBindJSON(&req); err != nil {
-		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-		return
-	}
-
-	model, err := GetModel(req.Model)
-	if err != nil {
-		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-		return
-	}
-
+// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
+func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
 	opts := api.DefaultOptions()
 	if err := opts.FromMap(model.Options); err != nil {
 		log.Printf("could not load model options: %v", err)
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
+		return err
 	}
 
-	if err := opts.FromMap(req.Options); err != nil {
+	if err := opts.FromMap(reqOpts); err != nil {
 		log.Printf("could not merge model options: %v", err)
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
+		return err
 	}
 
 	if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, opts) {
@@ -83,21 +65,18 @@ func GenerateHandler(c *gin.Context) {
 
 		llm, err := llama.New(model.ModelPath, opts)
 		if err != nil {
-			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-			return
+			return err
 		}
 
 		if opts.NumKeep < 0 {
 			promptWithSystem, err := model.Prompt(api.GenerateRequest{}, "")
 			if err != nil {
-				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-				return
+				return err
 			}
 
 			promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}}, "")
 			if err != nil {
-				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-				return
+				return err
 			}
 
 			tokensWithSystem := llm.Encode(promptWithSystem)
@@ -110,9 +89,8 @@ func GenerateHandler(c *gin.Context) {
 		loaded.digest = model.Digest
 		loaded.options = opts
 	}
-	sessionDuration := 5 * time.Minute
-
 	loaded.expireAt = time.Now().Add(sessionDuration)
+
 	if loaded.expireTimer == nil {
 		loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
 			loaded.mu.Lock()
@@ -132,6 +110,32 @@ func GenerateHandler(c *gin.Context) {
 		})
 	}
 	loaded.expireTimer.Reset(sessionDuration)
+	return nil
+}
+
+func GenerateHandler(c *gin.Context) {
+	loaded.mu.Lock()
+	defer loaded.mu.Unlock()
+
+	checkpointStart := time.Now()
+
+	var req api.GenerateRequest
+	if err := c.ShouldBindJSON(&req); err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	model, err := GetModel(req.Model)
+	if err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	sessionDuration := 5 * time.Minute
+	if err := load(model, req.Options, sessionDuration); err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
 
 	checkpointLoaded := time.Now()
 
@@ -181,6 +185,44 @@ func GenerateHandler(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
+func EmbeddingHandler(c *gin.Context) {
+	loaded.mu.Lock()
+	defer loaded.mu.Unlock()
+
+	var req api.EmbeddingRequest
+	if err := c.ShouldBindJSON(&req); err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	model, err := GetModel(req.Model)
+	if err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+	if err := load(model, req.Options, 5*time.Minute); err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	if !loaded.options.EmbeddingOnly {
+		c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"})
+		return
+	}
+
+	embedding, err := loaded.llm.Embedding(req.Prompt)
+	if err != nil {
+		log.Printf("embedding generation failed: %v", err)
+		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
+		return
+	}
+
+	resp := api.EmbeddingResponse{
+		Embedding: embedding,
+	}
+	c.JSON(http.StatusOK, resp)
+}
+
 func PullModelHandler(c *gin.Context) {
 	var req api.PullRequest
 	if err := c.ShouldBindJSON(&req); err != nil {
@@ -381,6 +423,7 @@ func Serve(ln net.Listener, extraOrigins []string) error {
 
 	r.POST("/api/pull", PullModelHandler)
 	r.POST("/api/generate", GenerateHandler)
+	r.POST("/api/embeddings", EmbeddingHandler)
 	r.POST("/api/create", CreateModelHandler)
 	r.POST("/api/push", PushModelHandler)
 	r.POST("/api/copy", CopyModelHandler)