Ver código fonte

Merge pull request #813 from jmorganca/mxyng/llama

refactor llm/llama.go
Michael Yang 1 ano atrás
pai
commit
08b0e04f40
1 arquivos alterados com 38 adições e 90 exclusões
  1. 38 90
      llm/llama.go

+ 38 - 90
llm/llama.go

@@ -442,68 +442,18 @@ func (llm *llama) SetOptions(opts api.Options) {
 	llm.Options = opts
 }
 
-type GenerationSettings struct {
-	FrequencyPenalty float64       `json:"frequency_penalty"`
-	IgnoreEOS        bool          `json:"ignore_eos"`
-	LogitBias        []interface{} `json:"logit_bias"`
-	Mirostat         int           `json:"mirostat"`
-	MirostatEta      float64       `json:"mirostat_eta"`
-	MirostatTau      float64       `json:"mirostat_tau"`
-	Model            string        `json:"model"`
-	NCtx             int           `json:"n_ctx"`
-	NKeep            int           `json:"n_keep"`
-	NPredict         int           `json:"n_predict"`
-	NProbs           int           `json:"n_probs"`
-	PenalizeNl       bool          `json:"penalize_nl"`
-	PresencePenalty  float64       `json:"presence_penalty"`
-	RepeatLastN      int           `json:"repeat_last_n"`
-	RepeatPenalty    float64       `json:"repeat_penalty"`
-	Seed             uint32        `json:"seed"`
-	Stop             []string      `json:"stop"`
-	Stream           bool          `json:"stream"`
-	Temp             float64       `json:"temp"`
-	TfsZ             float64       `json:"tfs_z"`
-	TopK             int           `json:"top_k"`
-	TopP             float64       `json:"top_p"`
-	TypicalP         float64       `json:"typical_p"`
-}
-
-type Timings struct {
-	PredictedN  int     `json:"predicted_n"`
-	PredictedMS float64 `json:"predicted_ms"`
-	PromptN     int     `json:"prompt_n"`
-	PromptMS    float64 `json:"prompt_ms"`
-}
-
-type Prediction struct {
+type prediction struct {
 	Content string `json:"content"`
 	Model   string `json:"model"`
 	Prompt  string `json:"prompt"`
 	Stop    bool   `json:"stop"`
 
-	Timings `json:"timings"`
-}
-
-type PredictRequest struct {
-	Prompt           string   `json:"prompt"`
-	Stream           bool     `json:"stream"`
-	NPredict         int      `json:"n_predict"`
-	NKeep            int      `json:"n_keep"`
-	Temperature      float32  `json:"temperature"`
-	TopK             int      `json:"top_k"`
-	TopP             float32  `json:"top_p"`
-	TfsZ             float32  `json:"tfs_z"`
-	TypicalP         float32  `json:"typical_p"`
-	RepeatLastN      int      `json:"repeat_last_n"`
-	RepeatPenalty    float32  `json:"repeat_penalty"`
-	PresencePenalty  float32  `json:"presence_penalty"`
-	FrequencyPenalty float32  `json:"frequency_penalty"`
-	Mirostat         int      `json:"mirostat"`
-	MirostatTau      float32  `json:"mirostat_tau"`
-	MirostatEta      float32  `json:"mirostat_eta"`
-	PenalizeNl       bool     `json:"penalize_nl"`
-	Seed             int      `json:"seed"`
-	Stop             []string `json:"stop,omitempty"`
+	Timings struct {
+		PredictedN  int     `json:"predicted_n"`
+		PredictedMS float64 `json:"predicted_ms"`
+		PromptN     int     `json:"prompt_n"`
+		PromptMS    float64 `json:"prompt_ms"`
+	}
 }
 
 const maxBufferSize = 512 * format.KiloByte
@@ -518,27 +468,26 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
 	nextContext.WriteString(prevConvo)
 	nextContext.WriteString(prompt)
 
-	endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
-	predReq := PredictRequest{
-		Prompt:           nextContext.String(),
-		Stream:           true,
-		NPredict:         llm.NumPredict,
-		NKeep:            llm.NumKeep,
-		Temperature:      llm.Temperature,
-		TopK:             llm.TopK,
-		TopP:             llm.TopP,
-		TfsZ:             llm.TFSZ,
-		TypicalP:         llm.TypicalP,
-		RepeatLastN:      llm.RepeatLastN,
-		RepeatPenalty:    llm.RepeatPenalty,
-		PresencePenalty:  llm.PresencePenalty,
-		FrequencyPenalty: llm.FrequencyPenalty,
-		Mirostat:         llm.Mirostat,
-		MirostatTau:      llm.MirostatTau,
-		MirostatEta:      llm.MirostatEta,
-		PenalizeNl:       llm.PenalizeNewline,
-		Seed:             llm.Seed,
-		Stop:             llm.Stop,
+	request := map[string]any{
+		"prompt":            nextContext.String(),
+		"stream":            true,
+		"n_predict":         llm.NumPredict,
+		"n_keep":            llm.NumKeep,
+		"temperature":       llm.Temperature,
+		"top_k":             llm.TopK,
+		"top_p":             llm.TopP,
+		"tfs_z":             llm.TFSZ,
+		"typical_p":         llm.TypicalP,
+		"repeat_last_n":     llm.RepeatLastN,
+		"repeat_penalty":    llm.RepeatPenalty,
+		"presence_penalty":  llm.PresencePenalty,
+		"frequency_penalty": llm.FrequencyPenalty,
+		"mirostat":          llm.Mirostat,
+		"mirostat_tau":      llm.MirostatTau,
+		"mirostat_eta":      llm.MirostatEta,
+		"penalize_nl":       llm.PenalizeNewline,
+		"seed":              llm.Seed,
+		"stop":              llm.Stop,
 	}
 
 	// Handling JSON marshaling with special characters unescaped.
@@ -546,10 +495,11 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
 	enc := json.NewEncoder(buffer)
 	enc.SetEscapeHTML(false)
 
-	if err := enc.Encode(predReq); err != nil {
+	if err := enc.Encode(request); err != nil {
 		return fmt.Errorf("failed to marshal data: %v", err)
 	}
 
+	endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
 	req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
 	if err != nil {
 		return fmt.Errorf("error creating POST request: %v", err)
@@ -581,16 +531,14 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
 			// This handles the request cancellation
 			return ctx.Err()
 		default:
-			line := scanner.Text()
-			if line == "" {
+			line := scanner.Bytes()
+			if len(line) == 0 {
 				continue
 			}
 
-			// Read data from the server-side event stream
-			if strings.HasPrefix(line, "data: ") {
-				evt := line[6:]
-				var p Prediction
-				if err := json.Unmarshal([]byte(evt), &p); err != nil {
+			if evt, ok := bytes.CutPrefix(line, []byte("data: ")); ok {
+				var p prediction
+				if err := json.Unmarshal(evt, &p); err != nil {
 					return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
 				}
 
@@ -608,10 +556,10 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
 					fn(api.GenerateResponse{
 						Done:               true,
 						Context:            embd,
-						PromptEvalCount:    p.PromptN,
-						PromptEvalDuration: parseDurationMs(p.PromptMS),
-						EvalCount:          p.PredictedN,
-						EvalDuration:       parseDurationMs(p.PredictedMS),
+						PromptEvalCount:    p.Timings.PromptN,
+						PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
+						EvalCount:          p.Timings.PredictedN,
+						EvalDuration:       parseDurationMs(p.Timings.PredictedMS),
 					})
 
 					return nil