浏览代码

Merge pull request #463 from jmorganca/mxyng/fix-last-token

fix not forwarding last token
Michael Yang 1 年之前
父节点
当前提交
8dc68417e7
共有 1 个文件被更改,包括 23 次插入45 次删除
  1. 23 45
      llm/ggml_llama.go

+ 23 - 45
llm/ggml_llama.go

@@ -353,11 +353,6 @@ func (llm *llama) SetOptions(opts api.Options) {
 	llm.Options = opts
 	llm.Options = opts
 }
 }
 
 
-type Prediction struct {
-	Content string `json:"content"`
-	Stop    bool   `json:"stop"`
-}
-
 type GenerationSettings struct {
 type GenerationSettings struct {
 	FrequencyPenalty float64       `json:"frequency_penalty"`
 	FrequencyPenalty float64       `json:"frequency_penalty"`
 	IgnoreEOS        bool          `json:"ignore_eos"`
 	IgnoreEOS        bool          `json:"ignore_eos"`
@@ -385,31 +380,19 @@ type GenerationSettings struct {
 }
 }
 
 
 type Timings struct {
 type Timings struct {
-	PredictedMS         float64 `json:"predicted_ms"`
-	PredictedN          int     `json:"predicted_n"`
-	PredictedPerSecond  float64 `json:"predicted_per_second"`
-	PredictedPerTokenMS float64 `json:"predicted_per_token_ms"`
-	PromptMS            float64 `json:"prompt_ms"`
-	PromptN             int     `json:"prompt_n"`
-	PromptPerSecond     float64 `json:"prompt_per_second"`
-	PromptPerTokenMS    float64 `json:"prompt_per_token_ms"`
+	PredictedN  int     `json:"predicted_n"`
+	PredictedMS float64 `json:"predicted_ms"`
+	PromptN     int     `json:"prompt_n"`
+	PromptMS    float64 `json:"prompt_ms"`
 }
 }
 
 
-type PredictComplete struct {
-	Content            string             `json:"content"`
-	GenerationSettings GenerationSettings `json:"generation_settings"`
-	Model              string             `json:"model"`
-	Prompt             string             `json:"prompt"`
-	Stop               bool               `json:"stop"`
-	StoppedEOS         bool               `json:"stopped_eos"`
-	StoppedLimit       bool               `json:"stopped_limit"`
-	StoppedWord        bool               `json:"stopped_word"`
-	StoppingWord       string             `json:"stopping_word"`
-	Timings            Timings            `json:"timings"`
-	TokensCached       int                `json:"tokens_cached"`
-	TokensEvaluated    int                `json:"tokens_evaluated"`
-	TokensPredicted    int                `json:"tokens_predicted"`
-	Truncated          bool               `json:"truncated"`
+type Prediction struct {
+	Content string `json:"content"`
+	Model   string `json:"model"`
+	Prompt  string `json:"prompt"`
+	Stop    bool   `json:"stop"`
+
+	Timings `json:"timings"`
 }
 }
 
 
 type PredictRequest struct {
 type PredictRequest struct {
@@ -509,13 +492,15 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
 			// Read data from the server-side event stream
 			// Read data from the server-side event stream
 			if strings.HasPrefix(line, "data: ") {
 			if strings.HasPrefix(line, "data: ") {
 				evt := line[6:]
 				evt := line[6:]
-				var complete PredictComplete
-				if err := json.Unmarshal([]byte(evt), &complete); err != nil {
-					return fmt.Errorf("error unmarshaling llm complete response: %v", err)
+				var p Prediction
+				if err := json.Unmarshal([]byte(evt), &p); err != nil {
+					return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
 				}
 				}
 
 
-				if complete.Timings.PredictedMS > 0 {
-					nextContext.WriteString(complete.Content)
+				fn(api.GenerateResponse{Response: p.Content})
+				nextContext.WriteString(p.Content)
+
+				if p.Stop {
 					embd, err := llm.Encode(ctx, nextContext.String())
 					embd, err := llm.Encode(ctx, nextContext.String())
 					if err != nil {
 					if err != nil {
 						return fmt.Errorf("encoding context: %v", err)
 						return fmt.Errorf("encoding context: %v", err)
@@ -524,21 +509,14 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
 					fn(api.GenerateResponse{
 					fn(api.GenerateResponse{
 						Done:               true,
 						Done:               true,
 						Context:            embd,
 						Context:            embd,
-						PromptEvalCount:    int(complete.Timings.PromptN),
-						PromptEvalDuration: parseDurationMs(float64(complete.Timings.PromptMS)),
-						EvalCount:          int(complete.Timings.PredictedN),
-						EvalDuration:       parseDurationMs(float64(complete.Timings.PredictedMS)),
+						PromptEvalCount:    p.PromptN,
+						PromptEvalDuration: parseDurationMs(p.PromptMS),
+						EvalCount:          p.PredictedN,
+						EvalDuration:       parseDurationMs(p.PredictedMS),
 					})
 					})
-					return nil
-				}
 
 
-				var p Prediction
-				if err := json.Unmarshal([]byte(evt), &p); err != nil {
-					return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
+					return nil
 				}
 				}
-
-				fn(api.GenerateResponse{Response: p.Content})
-				nextContext.WriteString(p.Content)
 			}
 			}
 		}
 		}
 	}
 	}