浏览代码

restore model load duration on generate response (#1524)

* restore model load duration on generate response

- set model load duration on generate and chat done response
- calculate createAt time when response created

* remove checkpoints predict opts

* Update routes.go
Bruce MacDonald 1 年之前
父节点
当前提交
6ee8c80199
共有 2 个文件被更改,包括 27 次插入36 次删除
  1. 4 13
      llm/llama.go
  2. 23 23
      server/routes.go

+ 4 - 13
llm/llama.go

@@ -548,17 +548,12 @@ const maxBufferSize = 512 * format.KiloByte
 const maxRetries = 6
 
 type PredictOpts struct {
-	Prompt           string
-	Format           string
-	Images           []api.ImageData
-	CheckpointStart  time.Time
-	CheckpointLoaded time.Time
+	Prompt string
+	Format string
+	Images []api.ImageData
 }
 
 type PredictResult struct {
-	CreatedAt          time.Time
-	TotalDuration      time.Duration
-	LoadDuration       time.Duration
 	Content            string
 	Done               bool
 	PromptEvalCount    int
@@ -681,16 +676,12 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
 
 				if p.Content != "" {
 					fn(PredictResult{
-						CreatedAt: time.Now().UTC(),
-						Content:   p.Content,
+						Content: p.Content,
 					})
 				}
 
 				if p.Stop {
 					fn(PredictResult{
-						CreatedAt:     time.Now().UTC(),
-						TotalDuration: time.Since(predict.CheckpointStart),
-
 						Done:               true,
 						PromptEvalCount:    p.Timings.PromptN,
 						PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),

+ 23 - 23
server/routes.go

@@ -261,12 +261,10 @@ func GenerateHandler(c *gin.Context) {
 
 			resp := api.GenerateResponse{
 				Model:     req.Model,
-				CreatedAt: r.CreatedAt,
+				CreatedAt: time.Now().UTC(),
 				Done:      r.Done,
 				Response:  r.Content,
 				Metrics: api.Metrics{
-					TotalDuration:      r.TotalDuration,
-					LoadDuration:       r.LoadDuration,
 					PromptEvalCount:    r.PromptEvalCount,
 					PromptEvalDuration: r.PromptEvalDuration,
 					EvalCount:          r.EvalCount,
@@ -274,13 +272,18 @@ func GenerateHandler(c *gin.Context) {
 				},
 			}
 
-			if r.Done && !req.Raw {
-				embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
-				if err != nil {
-					ch <- gin.H{"error": err.Error()}
-					return
+			if r.Done {
+				resp.TotalDuration = time.Since(checkpointStart)
+				resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
+
+				if !req.Raw {
+					embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
+					if err != nil {
+						ch <- gin.H{"error": err.Error()}
+						return
+					}
+					resp.Context = embd
 				}
-				resp.Context = embd
 			}
 
 			ch <- resp
@@ -288,11 +291,9 @@ func GenerateHandler(c *gin.Context) {
 
 		// Start prediction
 		predictReq := llm.PredictOpts{
-			Prompt:           prompt,
-			Format:           req.Format,
-			CheckpointStart:  checkpointStart,
-			CheckpointLoaded: checkpointLoaded,
-			Images:           req.Images,
+			Prompt: prompt,
+			Format: req.Format,
+			Images: req.Images,
 		}
 		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
@@ -1012,11 +1013,9 @@ func ChatHandler(c *gin.Context) {
 
 			resp := api.ChatResponse{
 				Model:     req.Model,
-				CreatedAt: r.CreatedAt,
+				CreatedAt: time.Now().UTC(),
 				Done:      r.Done,
 				Metrics: api.Metrics{
-					TotalDuration:      r.TotalDuration,
-					LoadDuration:       r.LoadDuration,
 					PromptEvalCount:    r.PromptEvalCount,
 					PromptEvalDuration: r.PromptEvalDuration,
 					EvalCount:          r.EvalCount,
@@ -1024,7 +1023,10 @@ func ChatHandler(c *gin.Context) {
 				},
 			}
 
-			if !r.Done {
+			if r.Done {
+				resp.TotalDuration = time.Since(checkpointStart)
+				resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
+			} else {
 				resp.Message = &api.Message{Role: "assistant", Content: r.Content}
 			}
 
@@ -1033,11 +1035,9 @@ func ChatHandler(c *gin.Context) {
 
 		// Start prediction
 		predictReq := llm.PredictOpts{
-			Prompt:           prompt,
-			Format:           req.Format,
-			CheckpointStart:  checkpointStart,
-			CheckpointLoaded: checkpointLoaded,
-			Images:           images,
+			Prompt: prompt,
+			Format: req.Format,
+			Images: images,
 		}
 		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}