|
@@ -442,68 +442,18 @@ func (llm *llama) SetOptions(opts api.Options) {
|
|
llm.Options = opts
|
|
llm.Options = opts
|
|
}
|
|
}
|
|
|
|
|
|
-type GenerationSettings struct {
|
|
|
|
- FrequencyPenalty float64 `json:"frequency_penalty"`
|
|
|
|
- IgnoreEOS bool `json:"ignore_eos"`
|
|
|
|
- LogitBias []interface{} `json:"logit_bias"`
|
|
|
|
- Mirostat int `json:"mirostat"`
|
|
|
|
- MirostatEta float64 `json:"mirostat_eta"`
|
|
|
|
- MirostatTau float64 `json:"mirostat_tau"`
|
|
|
|
- Model string `json:"model"`
|
|
|
|
- NCtx int `json:"n_ctx"`
|
|
|
|
- NKeep int `json:"n_keep"`
|
|
|
|
- NPredict int `json:"n_predict"`
|
|
|
|
- NProbs int `json:"n_probs"`
|
|
|
|
- PenalizeNl bool `json:"penalize_nl"`
|
|
|
|
- PresencePenalty float64 `json:"presence_penalty"`
|
|
|
|
- RepeatLastN int `json:"repeat_last_n"`
|
|
|
|
- RepeatPenalty float64 `json:"repeat_penalty"`
|
|
|
|
- Seed uint32 `json:"seed"`
|
|
|
|
- Stop []string `json:"stop"`
|
|
|
|
- Stream bool `json:"stream"`
|
|
|
|
- Temp float64 `json:"temp"`
|
|
|
|
- TfsZ float64 `json:"tfs_z"`
|
|
|
|
- TopK int `json:"top_k"`
|
|
|
|
- TopP float64 `json:"top_p"`
|
|
|
|
- TypicalP float64 `json:"typical_p"`
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-type Timings struct {
|
|
|
|
- PredictedN int `json:"predicted_n"`
|
|
|
|
- PredictedMS float64 `json:"predicted_ms"`
|
|
|
|
- PromptN int `json:"prompt_n"`
|
|
|
|
- PromptMS float64 `json:"prompt_ms"`
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-type Prediction struct {
|
|
|
|
|
|
+type prediction struct {
|
|
Content string `json:"content"`
|
|
Content string `json:"content"`
|
|
Model string `json:"model"`
|
|
Model string `json:"model"`
|
|
Prompt string `json:"prompt"`
|
|
Prompt string `json:"prompt"`
|
|
Stop bool `json:"stop"`
|
|
Stop bool `json:"stop"`
|
|
|
|
|
|
- Timings `json:"timings"`
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-type PredictRequest struct {
|
|
|
|
- Prompt string `json:"prompt"`
|
|
|
|
- Stream bool `json:"stream"`
|
|
|
|
- NPredict int `json:"n_predict"`
|
|
|
|
- NKeep int `json:"n_keep"`
|
|
|
|
- Temperature float32 `json:"temperature"`
|
|
|
|
- TopK int `json:"top_k"`
|
|
|
|
- TopP float32 `json:"top_p"`
|
|
|
|
- TfsZ float32 `json:"tfs_z"`
|
|
|
|
- TypicalP float32 `json:"typical_p"`
|
|
|
|
- RepeatLastN int `json:"repeat_last_n"`
|
|
|
|
- RepeatPenalty float32 `json:"repeat_penalty"`
|
|
|
|
- PresencePenalty float32 `json:"presence_penalty"`
|
|
|
|
- FrequencyPenalty float32 `json:"frequency_penalty"`
|
|
|
|
- Mirostat int `json:"mirostat"`
|
|
|
|
- MirostatTau float32 `json:"mirostat_tau"`
|
|
|
|
- MirostatEta float32 `json:"mirostat_eta"`
|
|
|
|
- PenalizeNl bool `json:"penalize_nl"`
|
|
|
|
- Seed int `json:"seed"`
|
|
|
|
- Stop []string `json:"stop,omitempty"`
|
|
|
|
|
|
+ Timings struct {
|
|
|
|
+ PredictedN int `json:"predicted_n"`
|
|
|
|
+ PredictedMS float64 `json:"predicted_ms"`
|
|
|
|
+ PromptN int `json:"prompt_n"`
|
|
|
|
+ PromptMS float64 `json:"prompt_ms"`
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
const maxBufferSize = 512 * format.KiloByte
|
|
const maxBufferSize = 512 * format.KiloByte
|
|
@@ -518,27 +468,26 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
|
nextContext.WriteString(prevConvo)
|
|
nextContext.WriteString(prevConvo)
|
|
nextContext.WriteString(prompt)
|
|
nextContext.WriteString(prompt)
|
|
|
|
|
|
- endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
|
|
|
|
- predReq := PredictRequest{
|
|
|
|
- Prompt: nextContext.String(),
|
|
|
|
- Stream: true,
|
|
|
|
- NPredict: llm.NumPredict,
|
|
|
|
- NKeep: llm.NumKeep,
|
|
|
|
- Temperature: llm.Temperature,
|
|
|
|
- TopK: llm.TopK,
|
|
|
|
- TopP: llm.TopP,
|
|
|
|
- TfsZ: llm.TFSZ,
|
|
|
|
- TypicalP: llm.TypicalP,
|
|
|
|
- RepeatLastN: llm.RepeatLastN,
|
|
|
|
- RepeatPenalty: llm.RepeatPenalty,
|
|
|
|
- PresencePenalty: llm.PresencePenalty,
|
|
|
|
- FrequencyPenalty: llm.FrequencyPenalty,
|
|
|
|
- Mirostat: llm.Mirostat,
|
|
|
|
- MirostatTau: llm.MirostatTau,
|
|
|
|
- MirostatEta: llm.MirostatEta,
|
|
|
|
- PenalizeNl: llm.PenalizeNewline,
|
|
|
|
- Seed: llm.Seed,
|
|
|
|
- Stop: llm.Stop,
|
|
|
|
|
|
+ request := map[string]any{
|
|
|
|
+ "prompt": nextContext.String(),
|
|
|
|
+ "stream": true,
|
|
|
|
+ "n_predict": llm.NumPredict,
|
|
|
|
+ "n_keep": llm.NumKeep,
|
|
|
|
+ "temperature": llm.Temperature,
|
|
|
|
+ "top_k": llm.TopK,
|
|
|
|
+ "top_p": llm.TopP,
|
|
|
|
+ "tfs_z": llm.TFSZ,
|
|
|
|
+ "typical_p": llm.TypicalP,
|
|
|
|
+ "repeat_last_n": llm.RepeatLastN,
|
|
|
|
+ "repeat_penalty": llm.RepeatPenalty,
|
|
|
|
+ "presence_penalty": llm.PresencePenalty,
|
|
|
|
+ "frequency_penalty": llm.FrequencyPenalty,
|
|
|
|
+ "mirostat": llm.Mirostat,
|
|
|
|
+ "mirostat_tau": llm.MirostatTau,
|
|
|
|
+ "mirostat_eta": llm.MirostatEta,
|
|
|
|
+ "penalize_nl": llm.PenalizeNewline,
|
|
|
|
+ "seed": llm.Seed,
|
|
|
|
+ "stop": llm.Stop,
|
|
}
|
|
}
|
|
|
|
|
|
// Handling JSON marshaling with special characters unescaped.
|
|
// Handling JSON marshaling with special characters unescaped.
|
|
@@ -546,10 +495,11 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
|
enc := json.NewEncoder(buffer)
|
|
enc := json.NewEncoder(buffer)
|
|
enc.SetEscapeHTML(false)
|
|
enc.SetEscapeHTML(false)
|
|
|
|
|
|
- if err := enc.Encode(predReq); err != nil {
|
|
|
|
|
|
+ if err := enc.Encode(request); err != nil {
|
|
return fmt.Errorf("failed to marshal data: %v", err)
|
|
return fmt.Errorf("failed to marshal data: %v", err)
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
|
if err != nil {
|
|
if err != nil {
|
|
return fmt.Errorf("error creating POST request: %v", err)
|
|
return fmt.Errorf("error creating POST request: %v", err)
|
|
@@ -589,7 +539,7 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
|
// Read data from the server-side event stream
|
|
// Read data from the server-side event stream
|
|
if strings.HasPrefix(line, "data: ") {
|
|
if strings.HasPrefix(line, "data: ") {
|
|
evt := line[6:]
|
|
evt := line[6:]
|
|
- var p Prediction
|
|
|
|
|
|
+ var p prediction
|
|
if err := json.Unmarshal([]byte(evt), &p); err != nil {
|
|
if err := json.Unmarshal([]byte(evt), &p); err != nil {
|
|
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
|
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
|
}
|
|
}
|
|
@@ -608,10 +558,10 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
|
fn(api.GenerateResponse{
|
|
fn(api.GenerateResponse{
|
|
Done: true,
|
|
Done: true,
|
|
Context: embd,
|
|
Context: embd,
|
|
- PromptEvalCount: p.PromptN,
|
|
|
|
- PromptEvalDuration: parseDurationMs(p.PromptMS),
|
|
|
|
- EvalCount: p.PredictedN,
|
|
|
|
- EvalDuration: parseDurationMs(p.PredictedMS),
|
|
|
|
|
|
+ PromptEvalCount: p.Timings.PromptN,
|
|
|
|
+ PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
|
|
|
+ EvalCount: p.Timings.PredictedN,
|
|
|
|
+ EvalDuration: parseDurationMs(p.Timings.PredictedMS),
|
|
})
|
|
})
|
|
|
|
|
|
return nil
|
|
return nil
|