|
@@ -417,28 +417,25 @@ type Prediction struct {
|
|
}
|
|
}
|
|
|
|
|
|
type PredictRequest struct {
|
|
type PredictRequest struct {
|
|
- Stream bool `json:"stream"`
|
|
|
|
- NPredict int `json:"n_predict,omitempty"`
|
|
|
|
- TopK int `json:"top_k,omitempty"`
|
|
|
|
- TopP float32 `json:"top_p,omitempty"`
|
|
|
|
- TfsZ float32 `json:"tfs_z,omitempty"`
|
|
|
|
- TypicalP float32 `json:"typical_p,omitempty"`
|
|
|
|
- RepeatLastN int `json:"repeat_last_n,omitempty"`
|
|
|
|
- Temperature float32 `json:"temperature,omitempty"`
|
|
|
|
- RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
|
|
|
|
- PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
|
|
|
- FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
|
|
|
- Mirostat int `json:"mirostat,omitempty"`
|
|
|
|
- MirostatTau float32 `json:"mirostat_tau,omitempty"`
|
|
|
|
- MirostatEta float32 `json:"mirostat_eta,omitempty"`
|
|
|
|
- PenalizeNl bool `json:"penalize_nl,omitempty"`
|
|
|
|
- NKeep int `json:"n_keep,omitempty"`
|
|
|
|
- Seed int `json:"seed,omitempty"`
|
|
|
|
- Prompt string `json:"prompt,omitempty"`
|
|
|
|
- NProbs int `json:"n_probs,omitempty"`
|
|
|
|
- LogitBias map[int]float32 `json:"logit_bias,omitempty"`
|
|
|
|
- IgnoreEos bool `json:"ignore_eos,omitempty"`
|
|
|
|
- Stop []string `json:"stop,omitempty"`
|
|
|
|
|
|
+ Prompt string `json:"prompt"`
|
|
|
|
+ Stream bool `json:"stream"`
|
|
|
|
+ NPredict int `json:"n_predict"`
|
|
|
|
+ NKeep int `json:"n_keep"`
|
|
|
|
+ Temperature float32 `json:"temperature"`
|
|
|
|
+ TopK int `json:"top_k"`
|
|
|
|
+ TopP float32 `json:"top_p"`
|
|
|
|
+ TfsZ float32 `json:"tfs_z"`
|
|
|
|
+ TypicalP float32 `json:"typical_p"`
|
|
|
|
+ RepeatLastN int `json:"repeat_last_n"`
|
|
|
|
+ RepeatPenalty float32 `json:"repeat_penalty"`
|
|
|
|
+ PresencePenalty float32 `json:"presence_penalty"`
|
|
|
|
+ FrequencyPenalty float32 `json:"frequency_penalty"`
|
|
|
|
+ Mirostat int `json:"mirostat"`
|
|
|
|
+ MirostatTau float32 `json:"mirostat_tau"`
|
|
|
|
+ MirostatEta float32 `json:"mirostat_eta"`
|
|
|
|
+ PenalizeNl bool `json:"penalize_nl"`
|
|
|
|
+ Seed int `json:"seed"`
|
|
|
|
+ Stop []string `json:"stop,omitempty"`
|
|
}
|
|
}
|
|
|
|
|
|
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
|
|
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
|
|
@@ -470,8 +467,10 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
|
MirostatTau: llm.MirostatTau,
|
|
MirostatTau: llm.MirostatTau,
|
|
MirostatEta: llm.MirostatEta,
|
|
MirostatEta: llm.MirostatEta,
|
|
PenalizeNl: llm.PenalizeNewline,
|
|
PenalizeNl: llm.PenalizeNewline,
|
|
|
|
+ Seed: llm.Seed,
|
|
Stop: llm.Stop,
|
|
Stop: llm.Stop,
|
|
}
|
|
}
|
|
|
|
+
|
|
data, err := json.Marshal(predReq)
|
|
data, err := json.Marshal(predReq)
|
|
if err != nil {
|
|
if err != nil {
|
|
return fmt.Errorf("error marshaling data: %v", err)
|
|
return fmt.Errorf("error marshaling data: %v", err)
|