浏览代码

remove marshalPrompt which is no longer needed

Michael Yang 1 年之前
父节点
当前提交
5d3f314b0b
共有 1 个文件被更改,包括 19 次插入42 次删除
  1. 19 42
      llm/ggml_llama.go

+ 19 - 42
llm/ggml_llama.go

@@ -286,8 +286,8 @@ func newLlama(model string, adapters []string, runner ModelRunner, opts api.Opti
 			runner.Path,
 			runner.Path,
 			append(params, "--port", strconv.Itoa(port))...,
 			append(params, "--port", strconv.Itoa(port))...,
 		)
 		)
-		var stderr bytes.Buffer
-		cmd.Stderr = &stderr
+		cmd.Stdout = os.Stderr
+		cmd.Stderr = os.Stderr
 
 
 		llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}}
 		llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}}
 
 
@@ -437,15 +437,19 @@ type PredictRequest struct {
 	Stop             []string        `json:"stop,omitempty"`
 	Stop             []string        `json:"stop,omitempty"`
 }
 }
 
 
-func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string, fn func(api.GenerateResponse)) error {
-	// we need to find the trimmed prompt context before predicting so that we can return it to the client
-	trimmedPrompt, err := llm.marshalPrompt(ctx, predictCtx, prompt)
+func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
+	prevConvo, err := llm.Decode(ctx, prevContext)
 	if err != nil {
 	if err != nil {
-		return fmt.Errorf("marshaling prompt: %v", err)
+		return err
 	}
 	}
+
+	var nextContext strings.Builder
+	nextContext.WriteString(prevConvo)
+	nextContext.WriteString(prompt)
+
 	endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
 	endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
 	predReq := PredictRequest{
 	predReq := PredictRequest{
-		Prompt:           trimmedPrompt,
+		Prompt:           nextContext.String(),
 		Stream:           true,
 		Stream:           true,
 		NPredict:         llm.NumPredict,
 		NPredict:         llm.NumPredict,
 		NKeep:            llm.NumKeep,
 		NKeep:            llm.NumKeep,
@@ -491,7 +495,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
 	}
 	}
 
 
 	scanner := bufio.NewScanner(resp.Body)
 	scanner := bufio.NewScanner(resp.Body)
-	genCtx := trimmedPrompt // start with the trimmed prompt
 	for scanner.Scan() {
 	for scanner.Scan() {
 		select {
 		select {
 		case <-ctx.Done():
 		case <-ctx.Done():
@@ -512,11 +515,12 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
 				}
 				}
 
 
 				if complete.Timings.PredictedMS > 0 {
 				if complete.Timings.PredictedMS > 0 {
-					genCtx += complete.Content
-					embd, err := llm.Encode(ctx, genCtx)
+					nextContext.WriteString(complete.Content)
+					embd, err := llm.Encode(ctx, nextContext.String())
 					if err != nil {
 					if err != nil {
 						return fmt.Errorf("encoding context: %v", err)
 						return fmt.Errorf("encoding context: %v", err)
 					}
 					}
+
 					fn(api.GenerateResponse{
 					fn(api.GenerateResponse{
 						Done:               true,
 						Done:               true,
 						Context:            embd,
 						Context:            embd,
@@ -528,12 +532,13 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
 					return nil
 					return nil
 				}
 				}
 
 
-				var pred Prediction
-				if err := json.Unmarshal([]byte(evt), &pred); err != nil {
+				var p Prediction
+				if err := json.Unmarshal([]byte(evt), &p); err != nil {
 					return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
 					return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
 				}
 				}
-				genCtx += pred.Content
-				fn(api.GenerateResponse{Response: pred.Content})
+
+				fn(api.GenerateResponse{Response: p.Content})
+				nextContext.WriteString(p.Content)
 			}
 			}
 		}
 		}
 	}
 	}
@@ -545,34 +550,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
 	return nil
 	return nil
 }
 }
 
 
-func (llm *llama) marshalPrompt(ctx context.Context, pCtx []int, prompt string) (string, error) {
-	pEncode, err := llm.Encode(ctx, prompt)
-	if err != nil {
-		return "", fmt.Errorf("encoding prompt context: %w", err)
-	}
-	tokens := append(pCtx, pEncode...)
-	if llm.NumKeep < 0 {
-		llm.NumKeep = len(tokens)
-	}
-
-	// min(llm.NumCtx - 4, llm.NumKeep)
-	if llm.NumCtx-4 < llm.NumKeep {
-		llm.NumKeep = llm.NumCtx - 4
-	}
-
-	if len(tokens) >= llm.NumCtx {
-		// truncate input
-		numLeft := (llm.NumCtx - llm.NumKeep) / 2
-		truncated := tokens[:llm.NumKeep]
-		erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
-		truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
-		tokens = truncated
-		log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
-	}
-
-	return llm.Decode(ctx, tokens)
-}
-
 type TokenizeRequest struct {
 type TokenizeRequest struct {
 	Content string `json:"content"`
 	Content string `json:"content"`
 }
 }