瀏覽代碼

add a /tokenize endpoint

Bruce MacDonald 1 年之前
父節點
當前提交
19ce10e49e
共有 2 個文件被更改,包括 91 次插入0 次删除
  1. 12 0
      api/types.go
  2. 79 0
      server/routes.go

+ 12 - 0
api/types.go

@@ -195,6 +195,18 @@ type EmbeddingResponse struct {
 	Embedding []float64 `json:"embedding"`
 }
 
+type TokenizeRequest struct {
+	Model     string    `json:"model"`
+	Prompt    string    `json:"prompt"`
+	KeepAlive *Duration `json:"keep_alive,omitempty"`
+
+	Options map[string]interface{} `json:"options"`
+}
+
+type TokenizeResponse struct {
+	Tokens []int `json:"tokens"`
+}
+
 // CreateRequest is the request passed to [Client.Create].
 type CreateRequest struct {
 	Model        string `json:"model"`

+ 79 - 0
server/routes.go

@@ -407,6 +407,84 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 	c.JSON(http.StatusOK, resp)
 }
 
+func (s *Server) TokenizeHandler(c *gin.Context) {
+	var req api.TokenizeRequest
+	err := c.ShouldBindJSON(&req)
+	switch {
+	case errors.Is(err, io.EOF):
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	case err != nil:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	if req.Model == "" {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
+		return
+	}
+
+	model, err := GetModel(req.Model)
+	if err != nil {
+		var pErr *fs.PathError
+		if errors.As(err, &pErr) {
+			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
+			return
+		}
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	opts, err := modelOptions(model, req.Options)
+	if err != nil {
+		if errors.Is(err, api.ErrInvalidOpts) {
+			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+			return
+		}
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	var sessionDuration time.Duration
+	if req.KeepAlive == nil {
+		sessionDuration = getDefaultSessionDuration()
+	} else {
+		sessionDuration = req.KeepAlive.Duration
+	}
+
+	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
+	var runner *runnerRef
+	select {
+	case runner = <-rCh:
+	case err = <-eCh:
+		if errors.Is(err, context.Canceled) {
+			c.JSON(499, gin.H{"error": "request canceled"})
+			return
+		}
+
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	// an empty request loads the model
+	if req.Prompt == "" {
+		c.JSON(http.StatusOK, api.TokenizeResponse{Tokens: []int{}})
+		return
+	}
+
+	tokens, err := runner.llama.Tokenize(c.Request.Context(), req.Prompt)
+	if err != nil {
+		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
+		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
+		return
+	}
+
+	resp := api.TokenizeResponse{
+		Tokens: tokens,
+	}
+	c.JSON(http.StatusOK, resp)
+}
+
 func (s *Server) PullModelHandler(c *gin.Context) {
 	var req api.PullRequest
 	err := c.ShouldBindJSON(&req)
@@ -967,6 +1045,7 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r.POST("/api/generate", s.GenerateHandler)
 	r.POST("/api/chat", s.ChatHandler)
 	r.POST("/api/embeddings", s.EmbeddingsHandler)
+	r.POST("/api/tokenize", s.TokenizeHandler)
 	r.POST("/api/create", s.CreateModelHandler)
 	r.POST("/api/push", s.PushModelHandler)
 	r.POST("/api/copy", s.CopyModelHandler)