|
@@ -412,10 +412,6 @@ func newLlama(model string, adapters, projectors []string, runners []ModelRunner
|
|
|
port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
|
|
params := append(params, "--port", strconv.Itoa(port))
|
|
|
|
|
|
- if runner.Type == "gguf" {
|
|
|
- params = append(params, "--parallel", "2")
|
|
|
- }
|
|
|
-
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
cmd := exec.CommandContext(
|
|
|
ctx,
|
|
@@ -549,6 +545,8 @@ type prediction struct {
|
|
|
}
|
|
|
|
|
|
const maxBufferSize = 512 * format.KiloByte
|
|
|
+const maxRetries = 3
|
|
|
+const retryDelay = 1 * time.Second
|
|
|
|
|
|
type PredictOpts struct {
|
|
|
Prompt string
|
|
@@ -570,6 +568,11 @@ type PredictResult struct {
|
|
|
EvalDuration time.Duration
|
|
|
}
|
|
|
|
|
|
+// IsRetryable checks if the line matches a condition that can be retried
|
|
|
+func isRetryable(line []byte) bool {
|
|
|
+ return bytes.Contains(line, []byte("slot unavailable"))
|
|
|
+}
|
|
|
+
|
|
|
func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
|
|
imageData := llm.ImageData
|
|
|
if len(predict.Images) > 0 {
|
|
@@ -607,98 +610,116 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
|
|
|
request["grammar"] = jsonGrammar
|
|
|
}
|
|
|
|
|
|
- // Handling JSON marshaling with special characters unescaped.
|
|
|
- buffer := &bytes.Buffer{}
|
|
|
- enc := json.NewEncoder(buffer)
|
|
|
- enc.SetEscapeHTML(false)
|
|
|
+ for retries := 0; retries < maxRetries; retries++ {
|
|
|
+ if retries > 0 {
|
|
|
+ time.Sleep(retryDelay) // wait before retrying
|
|
|
+ }
|
|
|
|
|
|
- if err := enc.Encode(request); err != nil {
|
|
|
- return fmt.Errorf("failed to marshal data: %v", err)
|
|
|
- }
|
|
|
+ // Handling JSON marshaling with special characters unescaped.
|
|
|
+ buffer := &bytes.Buffer{}
|
|
|
+ enc := json.NewEncoder(buffer)
|
|
|
+ enc.SetEscapeHTML(false)
|
|
|
|
|
|
- endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
|
|
|
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
|
|
- if err != nil {
|
|
|
- return fmt.Errorf("error creating POST request: %v", err)
|
|
|
- }
|
|
|
- req.Header.Set("Content-Type", "application/json")
|
|
|
+ if err := enc.Encode(request); err != nil {
|
|
|
+ return fmt.Errorf("failed to marshal data: %v", err)
|
|
|
+ }
|
|
|
|
|
|
- resp, err := http.DefaultClient.Do(req)
|
|
|
- if err != nil {
|
|
|
- return fmt.Errorf("POST predict: %v", err)
|
|
|
- }
|
|
|
- defer resp.Body.Close()
|
|
|
+ endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
|
|
|
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("error creating POST request: %v", err)
|
|
|
+ }
|
|
|
+ req.Header.Set("Content-Type", "application/json")
|
|
|
|
|
|
- if resp.StatusCode >= 400 {
|
|
|
- bodyBytes, err := io.ReadAll(resp.Body)
|
|
|
+ resp, err := http.DefaultClient.Do(req)
|
|
|
if err != nil {
|
|
|
- return fmt.Errorf("failed reading llm error response: %w", err)
|
|
|
+ return fmt.Errorf("POST predict: %v", err)
|
|
|
}
|
|
|
- log.Printf("llm predict error: %s", bodyBytes)
|
|
|
- return fmt.Errorf("%s", bodyBytes)
|
|
|
- }
|
|
|
+ defer resp.Body.Close()
|
|
|
|
|
|
- scanner := bufio.NewScanner(resp.Body)
|
|
|
- // increase the buffer size to avoid running out of space
|
|
|
- buf := make([]byte, 0, maxBufferSize)
|
|
|
- scanner.Buffer(buf, maxBufferSize)
|
|
|
- for scanner.Scan() {
|
|
|
- select {
|
|
|
- case <-ctx.Done():
|
|
|
- // This handles the request cancellation
|
|
|
- return ctx.Err()
|
|
|
- default:
|
|
|
- line := scanner.Bytes()
|
|
|
- if len(line) == 0 {
|
|
|
- continue
|
|
|
+ if resp.StatusCode >= 400 {
|
|
|
+ bodyBytes, err := io.ReadAll(resp.Body)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("failed reading llm error response: %w", err)
|
|
|
}
|
|
|
+ log.Printf("llm predict error: %s", bodyBytes)
|
|
|
+ return fmt.Errorf("%s", bodyBytes)
|
|
|
+ }
|
|
|
|
|
|
- evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
|
|
- if !ok {
|
|
|
- return fmt.Errorf("error parsing llm response stream: %s", line)
|
|
|
- }
|
|
|
+ scanner := bufio.NewScanner(resp.Body)
|
|
|
+ // increase the buffer size to avoid running out of space
|
|
|
+ buf := make([]byte, 0, maxBufferSize)
|
|
|
+ scanner.Buffer(buf, maxBufferSize)
|
|
|
|
|
|
- var p prediction
|
|
|
- if err := json.Unmarshal(evt, &p); err != nil {
|
|
|
- return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
|
|
- }
|
|
|
+ retryNeeded := false
|
|
|
+ for scanner.Scan() {
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ // This handles the request cancellation
|
|
|
+ return ctx.Err()
|
|
|
+ default:
|
|
|
+ line := scanner.Bytes()
|
|
|
+ if len(line) == 0 {
|
|
|
+ continue
|
|
|
+ }
|
|
|
|
|
|
- if p.Content != "" {
|
|
|
- fn(PredictResult{
|
|
|
- CreatedAt: time.Now().UTC(),
|
|
|
- Content: p.Content,
|
|
|
- })
|
|
|
- }
|
|
|
+ if isRetryable(line) {
|
|
|
+ retryNeeded = true
|
|
|
+ break
|
|
|
+ }
|
|
|
|
|
|
- if p.Stop {
|
|
|
- fn(PredictResult{
|
|
|
- CreatedAt: time.Now().UTC(),
|
|
|
- TotalDuration: time.Since(predict.CheckpointStart),
|
|
|
-
|
|
|
- Done: true,
|
|
|
- PromptEvalCount: p.Timings.PromptN,
|
|
|
- PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
|
|
- EvalCount: p.Timings.PredictedN,
|
|
|
- EvalDuration: parseDurationMs(p.Timings.PredictedMS),
|
|
|
- })
|
|
|
- return nil
|
|
|
+ evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
|
|
+ if !ok {
|
|
|
+ return fmt.Errorf("error parsing llm response stream: %s", line)
|
|
|
+ }
|
|
|
+
|
|
|
+ var p prediction
|
|
|
+ if err := json.Unmarshal(evt, &p); err != nil {
|
|
|
+ return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if p.Content != "" {
|
|
|
+ fn(PredictResult{
|
|
|
+ CreatedAt: time.Now().UTC(),
|
|
|
+ Content: p.Content,
|
|
|
+ })
|
|
|
+ }
|
|
|
+
|
|
|
+ if p.Stop {
|
|
|
+ fn(PredictResult{
|
|
|
+ CreatedAt: time.Now().UTC(),
|
|
|
+ TotalDuration: time.Since(predict.CheckpointStart),
|
|
|
+
|
|
|
+ Done: true,
|
|
|
+ PromptEvalCount: p.Timings.PromptN,
|
|
|
+ PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
|
|
+ EvalCount: p.Timings.PredictedN,
|
|
|
+ EvalDuration: parseDurationMs(p.Timings.PredictedMS),
|
|
|
+ })
|
|
|
+ return nil
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
- }
|
|
|
|
|
|
- if err := scanner.Err(); err != nil {
|
|
|
- if strings.Contains(err.Error(), "unexpected EOF") {
|
|
|
- // this means the llama runner subprocess crashed
|
|
|
- llm.Close()
|
|
|
- if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
|
|
|
- return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg)
|
|
|
+ if err := scanner.Err(); err != nil {
|
|
|
+ if strings.Contains(err.Error(), "unexpected EOF") {
|
|
|
+ // this means the llama runner subprocess crashed
|
|
|
+ llm.Close()
|
|
|
+ if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
|
|
|
+ return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg)
|
|
|
+ }
|
|
|
+ return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model")
|
|
|
}
|
|
|
- return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model")
|
|
|
+ return fmt.Errorf("error reading llm response: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if !retryNeeded {
|
|
|
+ return nil // success
|
|
|
}
|
|
|
- return fmt.Errorf("error reading llm response: %v", err)
|
|
|
}
|
|
|
|
|
|
- return nil
|
|
|
+ // should never reach here ideally
|
|
|
+ return fmt.Errorf("max retries exceeded")
|
|
|
}
|
|
|
|
|
|
type TokenizeRequest struct {
|