|
@@ -407,6 +407,84 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
|
c.JSON(http.StatusOK, resp)
|
|
|
}
|
|
|
|
|
|
+func (s *Server) TokenizeHandler(c *gin.Context) {
|
|
|
+ var req api.TokenizeRequest
|
|
|
+ err := c.ShouldBindJSON(&req)
|
|
|
+ switch {
|
|
|
+ case errors.Is(err, io.EOF):
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
|
+ return
|
|
|
+ case err != nil:
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if req.Model == "" {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ model, err := GetModel(req.Model)
|
|
|
+ if err != nil {
|
|
|
+ var pErr *fs.PathError
|
|
|
+ if errors.As(err, &pErr) {
|
|
|
+ c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
|
+ return
|
|
|
+ }
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ opts, err := modelOptions(model, req.Options)
|
|
|
+ if err != nil {
|
|
|
+ if errors.Is(err, api.ErrInvalidOpts) {
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ var sessionDuration time.Duration
|
|
|
+ if req.KeepAlive == nil {
|
|
|
+ sessionDuration = getDefaultSessionDuration()
|
|
|
+ } else {
|
|
|
+ sessionDuration = req.KeepAlive.Duration
|
|
|
+ }
|
|
|
+
|
|
|
+ rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
|
|
+ var runner *runnerRef
|
|
|
+ select {
|
|
|
+ case runner = <-rCh:
|
|
|
+ case err = <-eCh:
|
|
|
+ if errors.Is(err, context.Canceled) {
|
|
|
+ c.JSON(499, gin.H{"error": "request canceled"})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // an empty request loads the model
|
|
|
+ if req.Prompt == "" {
|
|
|
+ c.JSON(http.StatusOK, api.TokenizeResponse{Tokens: []int{}})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ tokens, err := runner.llama.Tokenize(c.Request.Context(), req.Prompt)
|
|
|
+ if err != nil {
|
|
|
+ slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ resp := api.TokenizeResponse{
|
|
|
+ Tokens: tokens,
|
|
|
+ }
|
|
|
+ c.JSON(http.StatusOK, resp)
|
|
|
+}
|
|
|
+
|
|
|
func (s *Server) PullModelHandler(c *gin.Context) {
|
|
|
var req api.PullRequest
|
|
|
err := c.ShouldBindJSON(&req)
|
|
@@ -967,6 +1045,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
|
|
r.POST("/api/generate", s.GenerateHandler)
|
|
|
r.POST("/api/chat", s.ChatHandler)
|
|
|
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
|
|
+ r.POST("/api/tokenize", s.TokenizeHandler)
|
|
|
r.POST("/api/create", s.CreateModelHandler)
|
|
|
r.POST("/api/push", s.PushModelHandler)
|
|
|
r.POST("/api/copy", s.CopyModelHandler)
|