Browse Source

Merge pull request #462 from jmorganca/mxyng/rm-marshal-prompt

remove marshalPrompt which is no longer needed
Michael Yang 1 year ago
parent
commit
7b5aefb427
2 changed files with 38 additions and 82 deletions
  1. 36 81
      llm/ggml_llama.go
  2. 2 1
      server/routes.go

+ 36 - 81
llm/ggml_llama.go

@@ -286,8 +286,8 @@ func newLlama(model string, adapters []string, runner ModelRunner, opts api.Opti
 			runner.Path,
 			append(params, "--port", strconv.Itoa(port))...,
 		)
-		var stderr bytes.Buffer
-		cmd.Stderr = &stderr
+		cmd.Stdout = os.Stderr
+		cmd.Stderr = os.Stderr
 
 		llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}}
 
@@ -353,11 +353,6 @@ func (llm *llama) SetOptions(opts api.Options) {
 	llm.Options = opts
 }
 
-type Prediction struct {
-	Content string `json:"content"`
-	Stop    bool   `json:"stop"`
-}
-
 type GenerationSettings struct {
 	FrequencyPenalty float64       `json:"frequency_penalty"`
 	IgnoreEOS        bool          `json:"ignore_eos"`
@@ -385,31 +380,19 @@ type GenerationSettings struct {
 }
 
 type Timings struct {
-	PredictedMS         float64 `json:"predicted_ms"`
-	PredictedN          int     `json:"predicted_n"`
-	PredictedPerSecond  float64 `json:"predicted_per_second"`
-	PredictedPerTokenMS float64 `json:"predicted_per_token_ms"`
-	PromptMS            float64 `json:"prompt_ms"`
-	PromptN             int     `json:"prompt_n"`
-	PromptPerSecond     float64 `json:"prompt_per_second"`
-	PromptPerTokenMS    float64 `json:"prompt_per_token_ms"`
+	PredictedN  int     `json:"predicted_n"`
+	PredictedMS float64 `json:"predicted_ms"`
+	PromptN     int     `json:"prompt_n"`
+	PromptMS    float64 `json:"prompt_ms"`
 }
 
-type PredictComplete struct {
-	Content            string             `json:"content"`
-	GenerationSettings GenerationSettings `json:"generation_settings"`
-	Model              string             `json:"model"`
-	Prompt             string             `json:"prompt"`
-	Stop               bool               `json:"stop"`
-	StoppedEOS         bool               `json:"stopped_eos"`
-	StoppedLimit       bool               `json:"stopped_limit"`
-	StoppedWord        bool               `json:"stopped_word"`
-	StoppingWord       string             `json:"stopping_word"`
-	Timings            Timings            `json:"timings"`
-	TokensCached       int                `json:"tokens_cached"`
-	TokensEvaluated    int                `json:"tokens_evaluated"`
-	TokensPredicted    int                `json:"tokens_predicted"`
-	Truncated          bool               `json:"truncated"`
+type Prediction struct {
+	Content string `json:"content"`
+	Model   string `json:"model"`
+	Prompt  string `json:"prompt"`
+	Stop    bool   `json:"stop"`
+
+	Timings `json:"timings"`
 }
 
 type PredictRequest struct {
@@ -437,15 +420,19 @@ type PredictRequest struct {
 	Stop             []string        `json:"stop,omitempty"`
 }
 
-func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string, fn func(api.GenerateResponse)) error {
-	// we need to find the trimmed prompt context before predicting so that we can return it to the client
-	trimmedPrompt, err := llm.marshalPrompt(ctx, predictCtx, prompt)
+func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
+	prevConvo, err := llm.Decode(ctx, prevContext)
 	if err != nil {
-		return fmt.Errorf("marshaling prompt: %v", err)
+		return err
 	}
+
+	var nextContext strings.Builder
+	nextContext.WriteString(prevConvo)
+	nextContext.WriteString(prompt)
+
 	endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
 	predReq := PredictRequest{
-		Prompt:           trimmedPrompt,
+		Prompt:           nextContext.String(),
 		Stream:           true,
 		NPredict:         llm.NumPredict,
 		NKeep:            llm.NumKeep,
@@ -491,7 +478,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
 	}
 
 	scanner := bufio.NewScanner(resp.Body)
-	genCtx := trimmedPrompt // start with the trimmed prompt
 	for scanner.Scan() {
 		select {
 		case <-ctx.Done():
@@ -506,34 +492,31 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
 			// Read data from the server-side event stream
 			if strings.HasPrefix(line, "data: ") {
 				evt := line[6:]
-				var complete PredictComplete
-				if err := json.Unmarshal([]byte(evt), &complete); err != nil {
-					return fmt.Errorf("error unmarshaling llm complete response: %v", err)
+				var p Prediction
+				if err := json.Unmarshal([]byte(evt), &p); err != nil {
+					return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
 				}
 
-				if complete.Timings.PredictedMS > 0 {
-					genCtx += complete.Content
-					embd, err := llm.Encode(ctx, genCtx)
+				fn(api.GenerateResponse{Response: p.Content})
+				nextContext.WriteString(p.Content)
+
+				if p.Stop {
+					embd, err := llm.Encode(ctx, nextContext.String())
 					if err != nil {
 						return fmt.Errorf("encoding context: %v", err)
 					}
+
 					fn(api.GenerateResponse{
 						Done:               true,
 						Context:            embd,
-						PromptEvalCount:    int(complete.Timings.PromptN),
-						PromptEvalDuration: parseDurationMs(float64(complete.Timings.PromptMS)),
-						EvalCount:          int(complete.Timings.PredictedN),
-						EvalDuration:       parseDurationMs(float64(complete.Timings.PredictedMS)),
+						PromptEvalCount:    p.PromptN,
+						PromptEvalDuration: parseDurationMs(p.PromptMS),
+						EvalCount:          p.PredictedN,
+						EvalDuration:       parseDurationMs(p.PredictedMS),
 					})
-					return nil
-				}
 
-				var pred Prediction
-				if err := json.Unmarshal([]byte(evt), &pred); err != nil {
-					return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
+					return nil
 				}
-				genCtx += pred.Content
-				fn(api.GenerateResponse{Response: pred.Content})
 			}
 		}
 	}
@@ -545,34 +528,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
 	return nil
 }
 
-func (llm *llama) marshalPrompt(ctx context.Context, pCtx []int, prompt string) (string, error) {
-	pEncode, err := llm.Encode(ctx, prompt)
-	if err != nil {
-		return "", fmt.Errorf("encoding prompt context: %w", err)
-	}
-	tokens := append(pCtx, pEncode...)
-	if llm.NumKeep < 0 {
-		llm.NumKeep = len(tokens)
-	}
-
-	// min(llm.NumCtx - 4, llm.NumKeep)
-	if llm.NumCtx-4 < llm.NumKeep {
-		llm.NumKeep = llm.NumCtx - 4
-	}
-
-	if len(tokens) >= llm.NumCtx {
-		// truncate input
-		numLeft := (llm.NumCtx - llm.NumKeep) / 2
-		truncated := tokens[:llm.NumKeep]
-		erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
-		truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
-		tokens = truncated
-		log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
-	}
-
-	return llm.Decode(ctx, tokens)
-}
-
 type TokenizeRequest struct {
 	Content string `json:"content"`
 }

+ 2 - 1
server/routes.go

@@ -117,12 +117,13 @@ func load(ctx context.Context, model *Model, reqOpts map[string]interface{}, ses
 			if err != nil {
 				return err
 			}
+
 			tokensNoSystem, err := llmModel.Encode(ctx, promptNoSystem)
 			if err != nil {
 				return err
 			}
 
-			opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1
+			opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem)
 
 			llmModel.SetOptions(opts)
 		}