浏览代码

server: fix `context`, `load_duration` and `total_duration` fields (#5676)

* server: fix `contet`, `load_duration` and `total_duration` fields

* Update server/routes.go
Jeffrey Morgan 9 月之前
父节点
当前提交
1ed0aa8fea
共有 1 个文件被更改,包括 46 次插入10 次删除
  1. 46 10
      server/routes.go

+ 46 - 10
server/routes.go

@@ -102,6 +102,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
 }
 }
 
 
 func (s *Server) GenerateHandler(c *gin.Context) {
 func (s *Server) GenerateHandler(c *gin.Context) {
+	checkpointStart := time.Now()
 	var req api.GenerateRequest
 	var req api.GenerateRequest
 	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
 	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -129,6 +130,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
+	checkpointLoaded := time.Now()
+
 	if req.Prompt == "" {
 	if req.Prompt == "" {
 		c.JSON(http.StatusOK, api.GenerateResponse{
 		c.JSON(http.StatusOK, api.GenerateResponse{
 			Model:      req.Model,
 			Model:      req.Model,
@@ -191,26 +194,48 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 
 
 	ch := make(chan any)
 	ch := make(chan any)
 	go func() {
 	go func() {
+		// TODO (jmorganca): avoid building the response twice both here and below
+		var sb strings.Builder
 		defer close(ch)
 		defer close(ch)
 		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
 		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
 			Prompt:  prompt,
 			Prompt:  prompt,
 			Images:  images,
 			Images:  images,
 			Format:  req.Format,
 			Format:  req.Format,
 			Options: opts,
 			Options: opts,
-		}, func(r llm.CompletionResponse) {
-			ch <- api.GenerateResponse{
+		}, func(cr llm.CompletionResponse) {
+			res := api.GenerateResponse{
 				Model:      req.Model,
 				Model:      req.Model,
 				CreatedAt:  time.Now().UTC(),
 				CreatedAt:  time.Now().UTC(),
-				Response:   r.Content,
-				Done:       r.Done,
-				DoneReason: r.DoneReason,
+				Response:   cr.Content,
+				Done:       cr.Done,
+				DoneReason: cr.DoneReason,
 				Metrics: api.Metrics{
 				Metrics: api.Metrics{
-					PromptEvalCount:    r.PromptEvalCount,
-					PromptEvalDuration: r.PromptEvalDuration,
-					EvalCount:          r.EvalCount,
-					EvalDuration:       r.EvalDuration,
+					PromptEvalCount:    cr.PromptEvalCount,
+					PromptEvalDuration: cr.PromptEvalDuration,
+					EvalCount:          cr.EvalCount,
+					EvalDuration:       cr.EvalDuration,
 				},
 				},
 			}
 			}
+
+			if _, err := sb.WriteString(cr.Content); err != nil {
+				ch <- gin.H{"error": err.Error()}
+			}
+
+			if cr.Done {
+				res.TotalDuration = time.Since(checkpointStart)
+				res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
+
+				if !req.Raw {
+					tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
+					if err != nil {
+						ch <- gin.H{"error": err.Error()}
+						return
+					}
+					res.Context = append(req.Context, tokens...)
+				}
+			}
+
+			ch <- res
 		}); err != nil {
 		}); err != nil {
 			ch <- gin.H{"error": err.Error()}
 			ch <- gin.H{"error": err.Error()}
 		}
 		}
@@ -1122,6 +1147,8 @@ func (s *Server) ProcessHandler(c *gin.Context) {
 }
 }
 
 
 func (s *Server) ChatHandler(c *gin.Context) {
 func (s *Server) ChatHandler(c *gin.Context) {
+	checkpointStart := time.Now()
+
 	var req api.ChatRequest
 	var req api.ChatRequest
 	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
 	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -1141,6 +1168,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
+	checkpointLoaded := time.Now()
+
 	if len(req.Messages) == 0 {
 	if len(req.Messages) == 0 {
 		c.JSON(http.StatusOK, api.ChatResponse{
 		c.JSON(http.StatusOK, api.ChatResponse{
 			Model:      req.Model,
 			Model:      req.Model,
@@ -1169,7 +1198,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 			Format:  req.Format,
 			Format:  req.Format,
 			Options: opts,
 			Options: opts,
 		}, func(r llm.CompletionResponse) {
 		}, func(r llm.CompletionResponse) {
-			ch <- api.ChatResponse{
+			res := api.ChatResponse{
 				Model:      req.Model,
 				Model:      req.Model,
 				CreatedAt:  time.Now().UTC(),
 				CreatedAt:  time.Now().UTC(),
 				Message:    api.Message{Role: "assistant", Content: r.Content},
 				Message:    api.Message{Role: "assistant", Content: r.Content},
@@ -1182,6 +1211,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
 					EvalDuration:       r.EvalDuration,
 					EvalDuration:       r.EvalDuration,
 				},
 				},
 			}
 			}
+
+			if r.Done {
+				res.TotalDuration = time.Since(checkpointStart)
+				res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
+			}
+
+			ch <- res
 		}); err != nil {
 		}); err != nil {
 			ch <- gin.H{"error": err.Error()}
 			ch <- gin.H{"error": err.Error()}
 		}
 		}