Browse Source

api: expose tokenize and detokenize endpoints

Yurzs 8 months ago
parent
commit
e60db349b7
3 changed files with 108 additions and 0 deletions
  1. 18 0
      api/client.go
  2. 38 0
      api/types.go
  3. 52 0
      server/routes.go

+ 18 - 0
api/client.go

@@ -360,6 +360,24 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
 	return &resp, nil
 }
 
+// Tokenize tokenizes a string.
+func (c *Client) Tokenize(ctx context.Context, req *TokenizeRequest) (*TokenizeResponse, error) {
+	var resp TokenizeResponse
+	if err := c.do(ctx, http.MethodPost, "/api/tokenize", req, &resp); err != nil {
+		return nil, err
+	}
+	return &resp, nil
+}
+
+// Detokenize detokenizes a string.
+func (c *Client) Detokenize(ctx context.Context, req *DetokenizeRequest) (*DetokenizeResponse, error) {
+	var resp DetokenizeResponse
+	if err := c.do(ctx, http.MethodPost, "/api/detokenize", req, &resp); err != nil {
+		return nil, err
+	}
+	return &resp, nil
+}
+
 // CreateBlob creates a blob from a file on the server. digest is the
 // expected SHA256 digest of the file, and r represents the file.
 func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {

+ 38 - 0
api/types.go

@@ -293,6 +293,44 @@ type EmbeddingResponse struct {
 	Embedding []float64 `json:"embedding"`
 }
 
+// TokenizeRequest is the request passed to [Client.Tokenize].
+type TokenizeRequest struct {
+	Model  string `json:"model"`
+	Prompt string `json:"prompt"`
+
+	// KeepAlive controls how long the model will stay loaded in memory following
+	// this request.
+	KeepAlive *Duration `json:"keep_alive,omitempty"`
+
+	// Options lists model-specific options.
+	Options map[string]interface{} `json:"options"`
+}
+
+// TokenizeResponse is the response from [Client.Tokenize].
+type TokenizeResponse struct {
+	Model  string `json:"model"`
+	Tokens []int  `json:"tokens"`
+}
+
+// DetokenizeRequest is the request passed to [Client.Detokenize].
+type DetokenizeRequest struct {
+	Model  string `json:"model"`
+	Tokens []int  `json:"tokens"`
+
+	// KeepAlive controls how long the model will stay loaded in memory following
+	// this request.
+	KeepAlive *Duration `json:"keep_alive,omitempty"`
+
+	// Options lists model-specific options.
+	Options map[string]interface{} `json:"options"`
+}
+
+// DetokenizeResponse is the response from [Client.Detokenize].
+type DetokenizeResponse struct {
+	Model string `json:"model"`
+	Text  string `json:"text"`
+}
+
 // CreateRequest is the request passed to [Client.Create].
 type CreateRequest struct {
 	Model     string `json:"model"`

+ 52 - 0
server/routes.go

@@ -548,6 +548,56 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 	c.JSON(http.StatusOK, resp)
 }
 
+func (s *Server) TokenizeHandler(c *gin.Context) {
+	var req api.TokenizeRequest
+	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	} else if err != nil {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
+	if err != nil {
+		handleScheduleError(c, req.Model, err)
+		return
+	}
+
+	tokens, err := r.Tokenize(c.Request.Context(), req.Prompt)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	c.JSON(http.StatusOK, api.TokenizeResponse{Model: req.Model, Tokens: tokens})
+}
+
+func (s *Server) DetokenizeHandler(c *gin.Context) {
+	var req api.DetokenizeRequest
+	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	} else if err != nil {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
+	if err != nil {
+		handleScheduleError(c, req.Model, err)
+		return
+	}
+
+	text, err := r.Detokenize(c.Request.Context(), req.Tokens)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
+	c.JSON(http.StatusOK, api.DetokenizeResponse{Model: req.Model, Text: text})
+}
+
 func (s *Server) PullHandler(c *gin.Context) {
 	var req api.PullRequest
 	err := c.ShouldBindJSON(&req)
@@ -1214,6 +1264,8 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r.POST("/api/chat", s.ChatHandler)
 	r.POST("/api/embed", s.EmbedHandler)
 	r.POST("/api/embeddings", s.EmbeddingsHandler)
+	r.POST("/api/tokenize", s.TokenizeHandler)
+	r.POST("/api/detokenize", s.DetokenizeHandler)
 	r.POST("/api/create", s.CreateHandler)
 	r.POST("/api/push", s.PushHandler)
 	r.POST("/api/copy", s.CopyHandler)