|
@@ -286,8 +286,8 @@ func newLlama(model string, adapters []string, runner ModelRunner, opts api.Opti
|
|
runner.Path,
|
|
runner.Path,
|
|
append(params, "--port", strconv.Itoa(port))...,
|
|
append(params, "--port", strconv.Itoa(port))...,
|
|
)
|
|
)
|
|
- var stderr bytes.Buffer
|
|
|
|
- cmd.Stderr = &stderr
|
|
|
|
|
|
+ cmd.Stdout = os.Stderr
|
|
|
|
+ cmd.Stderr = os.Stderr
|
|
|
|
|
|
llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}}
|
|
llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}}
|
|
|
|
|
|
@@ -353,11 +353,6 @@ func (llm *llama) SetOptions(opts api.Options) {
|
|
llm.Options = opts
|
|
llm.Options = opts
|
|
}
|
|
}
|
|
|
|
|
|
-type Prediction struct {
|
|
|
|
- Content string `json:"content"`
|
|
|
|
- Stop bool `json:"stop"`
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
type GenerationSettings struct {
|
|
type GenerationSettings struct {
|
|
FrequencyPenalty float64 `json:"frequency_penalty"`
|
|
FrequencyPenalty float64 `json:"frequency_penalty"`
|
|
IgnoreEOS bool `json:"ignore_eos"`
|
|
IgnoreEOS bool `json:"ignore_eos"`
|
|
@@ -385,31 +380,19 @@ type GenerationSettings struct {
|
|
}
|
|
}
|
|
|
|
|
|
type Timings struct {
|
|
type Timings struct {
|
|
- PredictedMS float64 `json:"predicted_ms"`
|
|
|
|
- PredictedN int `json:"predicted_n"`
|
|
|
|
- PredictedPerSecond float64 `json:"predicted_per_second"`
|
|
|
|
- PredictedPerTokenMS float64 `json:"predicted_per_token_ms"`
|
|
|
|
- PromptMS float64 `json:"prompt_ms"`
|
|
|
|
- PromptN int `json:"prompt_n"`
|
|
|
|
- PromptPerSecond float64 `json:"prompt_per_second"`
|
|
|
|
- PromptPerTokenMS float64 `json:"prompt_per_token_ms"`
|
|
|
|
|
|
+ PredictedN int `json:"predicted_n"`
|
|
|
|
+ PredictedMS float64 `json:"predicted_ms"`
|
|
|
|
+ PromptN int `json:"prompt_n"`
|
|
|
|
+ PromptMS float64 `json:"prompt_ms"`
|
|
}
|
|
}
|
|
|
|
|
|
-type PredictComplete struct {
|
|
|
|
- Content string `json:"content"`
|
|
|
|
- GenerationSettings GenerationSettings `json:"generation_settings"`
|
|
|
|
- Model string `json:"model"`
|
|
|
|
- Prompt string `json:"prompt"`
|
|
|
|
- Stop bool `json:"stop"`
|
|
|
|
- StoppedEOS bool `json:"stopped_eos"`
|
|
|
|
- StoppedLimit bool `json:"stopped_limit"`
|
|
|
|
- StoppedWord bool `json:"stopped_word"`
|
|
|
|
- StoppingWord string `json:"stopping_word"`
|
|
|
|
- Timings Timings `json:"timings"`
|
|
|
|
- TokensCached int `json:"tokens_cached"`
|
|
|
|
- TokensEvaluated int `json:"tokens_evaluated"`
|
|
|
|
- TokensPredicted int `json:"tokens_predicted"`
|
|
|
|
- Truncated bool `json:"truncated"`
|
|
|
|
|
|
+type Prediction struct {
|
|
|
|
+ Content string `json:"content"`
|
|
|
|
+ Model string `json:"model"`
|
|
|
|
+ Prompt string `json:"prompt"`
|
|
|
|
+ Stop bool `json:"stop"`
|
|
|
|
+
|
|
|
|
+ Timings `json:"timings"`
|
|
}
|
|
}
|
|
|
|
|
|
type PredictRequest struct {
|
|
type PredictRequest struct {
|
|
@@ -437,15 +420,19 @@ type PredictRequest struct {
|
|
Stop []string `json:"stop,omitempty"`
|
|
Stop []string `json:"stop,omitempty"`
|
|
}
|
|
}
|
|
|
|
|
|
-func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string, fn func(api.GenerateResponse)) error {
|
|
|
|
- // we need to find the trimmed prompt context before predicting so that we can return it to the client
|
|
|
|
- trimmedPrompt, err := llm.marshalPrompt(ctx, predictCtx, prompt)
|
|
|
|
|
|
+func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
|
|
|
|
+ prevConvo, err := llm.Decode(ctx, prevContext)
|
|
if err != nil {
|
|
if err != nil {
|
|
- return fmt.Errorf("marshaling prompt: %v", err)
|
|
|
|
|
|
+ return err
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+ var nextContext strings.Builder
|
|
|
|
+ nextContext.WriteString(prevConvo)
|
|
|
|
+ nextContext.WriteString(prompt)
|
|
|
|
+
|
|
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
|
|
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
|
|
predReq := PredictRequest{
|
|
predReq := PredictRequest{
|
|
- Prompt: trimmedPrompt,
|
|
|
|
|
|
+ Prompt: nextContext.String(),
|
|
Stream: true,
|
|
Stream: true,
|
|
NPredict: llm.NumPredict,
|
|
NPredict: llm.NumPredict,
|
|
NKeep: llm.NumKeep,
|
|
NKeep: llm.NumKeep,
|
|
@@ -491,7 +478,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
|
|
}
|
|
}
|
|
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
- genCtx := trimmedPrompt // start with the trimmed prompt
|
|
|
|
for scanner.Scan() {
|
|
for scanner.Scan() {
|
|
select {
|
|
select {
|
|
case <-ctx.Done():
|
|
case <-ctx.Done():
|
|
@@ -506,34 +492,31 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
|
|
// Read data from the server-side event stream
|
|
// Read data from the server-side event stream
|
|
if strings.HasPrefix(line, "data: ") {
|
|
if strings.HasPrefix(line, "data: ") {
|
|
evt := line[6:]
|
|
evt := line[6:]
|
|
- var complete PredictComplete
|
|
|
|
- if err := json.Unmarshal([]byte(evt), &complete); err != nil {
|
|
|
|
- return fmt.Errorf("error unmarshaling llm complete response: %v", err)
|
|
|
|
|
|
+ var p Prediction
|
|
|
|
+ if err := json.Unmarshal([]byte(evt), &p); err != nil {
|
|
|
|
+ return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
|
}
|
|
}
|
|
|
|
|
|
- if complete.Timings.PredictedMS > 0 {
|
|
|
|
- genCtx += complete.Content
|
|
|
|
- embd, err := llm.Encode(ctx, genCtx)
|
|
|
|
|
|
+ fn(api.GenerateResponse{Response: p.Content})
|
|
|
|
+ nextContext.WriteString(p.Content)
|
|
|
|
+
|
|
|
|
+ if p.Stop {
|
|
|
|
+ embd, err := llm.Encode(ctx, nextContext.String())
|
|
if err != nil {
|
|
if err != nil {
|
|
return fmt.Errorf("encoding context: %v", err)
|
|
return fmt.Errorf("encoding context: %v", err)
|
|
}
|
|
}
|
|
|
|
+
|
|
fn(api.GenerateResponse{
|
|
fn(api.GenerateResponse{
|
|
Done: true,
|
|
Done: true,
|
|
Context: embd,
|
|
Context: embd,
|
|
- PromptEvalCount: int(complete.Timings.PromptN),
|
|
|
|
- PromptEvalDuration: parseDurationMs(float64(complete.Timings.PromptMS)),
|
|
|
|
- EvalCount: int(complete.Timings.PredictedN),
|
|
|
|
- EvalDuration: parseDurationMs(float64(complete.Timings.PredictedMS)),
|
|
|
|
|
|
+ PromptEvalCount: p.PromptN,
|
|
|
|
+ PromptEvalDuration: parseDurationMs(p.PromptMS),
|
|
|
|
+ EvalCount: p.PredictedN,
|
|
|
|
+ EvalDuration: parseDurationMs(p.PredictedMS),
|
|
})
|
|
})
|
|
- return nil
|
|
|
|
- }
|
|
|
|
|
|
|
|
- var pred Prediction
|
|
|
|
- if err := json.Unmarshal([]byte(evt), &pred); err != nil {
|
|
|
|
- return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
|
|
|
|
|
+ return nil
|
|
}
|
|
}
|
|
- genCtx += pred.Content
|
|
|
|
- fn(api.GenerateResponse{Response: pred.Content})
|
|
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -545,34 +528,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
-func (llm *llama) marshalPrompt(ctx context.Context, pCtx []int, prompt string) (string, error) {
|
|
|
|
- pEncode, err := llm.Encode(ctx, prompt)
|
|
|
|
- if err != nil {
|
|
|
|
- return "", fmt.Errorf("encoding prompt context: %w", err)
|
|
|
|
- }
|
|
|
|
- tokens := append(pCtx, pEncode...)
|
|
|
|
- if llm.NumKeep < 0 {
|
|
|
|
- llm.NumKeep = len(tokens)
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- // min(llm.NumCtx - 4, llm.NumKeep)
|
|
|
|
- if llm.NumCtx-4 < llm.NumKeep {
|
|
|
|
- llm.NumKeep = llm.NumCtx - 4
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if len(tokens) >= llm.NumCtx {
|
|
|
|
- // truncate input
|
|
|
|
- numLeft := (llm.NumCtx - llm.NumKeep) / 2
|
|
|
|
- truncated := tokens[:llm.NumKeep]
|
|
|
|
- erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
|
|
|
|
- truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
|
|
|
|
- tokens = truncated
|
|
|
|
- log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- return llm.Decode(ctx, tokens)
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
type TokenizeRequest struct {
|
|
type TokenizeRequest struct {
|
|
Content string `json:"content"`
|
|
Content string `json:"content"`
|
|
}
|
|
}
|