Browse Source

retry on concurrent request failure (#1483)

- remove parallel
Bruce MacDonald 1 year ago
parent
commit
c0960e29b5
1 changed files with 98 additions and 77 deletions
  1. 98 77
      llm/llama.go

+ 98 - 77
llm/llama.go

@@ -412,10 +412,6 @@ func newLlama(model string, adapters, projectors []string, runners []ModelRunner
 		port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
 		port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
 		params := append(params, "--port", strconv.Itoa(port))
 		params := append(params, "--port", strconv.Itoa(port))
 
 
-		if runner.Type == "gguf" {
-			params = append(params, "--parallel", "2")
-		}
-
 		ctx, cancel := context.WithCancel(context.Background())
 		ctx, cancel := context.WithCancel(context.Background())
 		cmd := exec.CommandContext(
 		cmd := exec.CommandContext(
 			ctx,
 			ctx,
@@ -549,6 +545,8 @@ type prediction struct {
 }
 }
 
 
 const maxBufferSize = 512 * format.KiloByte
 const maxBufferSize = 512 * format.KiloByte
+const maxRetries = 3
+const retryDelay = 1 * time.Second
 
 
 type PredictOpts struct {
 type PredictOpts struct {
 	Prompt           string
 	Prompt           string
@@ -570,6 +568,11 @@ type PredictResult struct {
 	EvalDuration       time.Duration
 	EvalDuration       time.Duration
 }
 }
 
 
+// IsRetryable checks if the line matches a condition that can be retried
+func isRetryable(line []byte) bool {
+	return bytes.Contains(line, []byte("slot unavailable"))
+}
+
 func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
 func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
 	imageData := llm.ImageData
 	imageData := llm.ImageData
 	if len(predict.Images) > 0 {
 	if len(predict.Images) > 0 {
@@ -607,98 +610,116 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
 		request["grammar"] = jsonGrammar
 		request["grammar"] = jsonGrammar
 	}
 	}
 
 
-	// Handling JSON marshaling with special characters unescaped.
-	buffer := &bytes.Buffer{}
-	enc := json.NewEncoder(buffer)
-	enc.SetEscapeHTML(false)
+	for retries := 0; retries < maxRetries; retries++ {
+		if retries > 0 {
+			time.Sleep(retryDelay) // wait before retrying
+		}
 
 
-	if err := enc.Encode(request); err != nil {
-		return fmt.Errorf("failed to marshal data: %v", err)
-	}
+		// Handling JSON marshaling with special characters unescaped.
+		buffer := &bytes.Buffer{}
+		enc := json.NewEncoder(buffer)
+		enc.SetEscapeHTML(false)
 
 
-	endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
-	req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
-	if err != nil {
-		return fmt.Errorf("error creating POST request: %v", err)
-	}
-	req.Header.Set("Content-Type", "application/json")
+		if err := enc.Encode(request); err != nil {
+			return fmt.Errorf("failed to marshal data: %v", err)
+		}
 
 
-	resp, err := http.DefaultClient.Do(req)
-	if err != nil {
-		return fmt.Errorf("POST predict: %v", err)
-	}
-	defer resp.Body.Close()
+		endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
+		req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
+		if err != nil {
+			return fmt.Errorf("error creating POST request: %v", err)
+		}
+		req.Header.Set("Content-Type", "application/json")
 
 
-	if resp.StatusCode >= 400 {
-		bodyBytes, err := io.ReadAll(resp.Body)
+		resp, err := http.DefaultClient.Do(req)
 		if err != nil {
 		if err != nil {
-			return fmt.Errorf("failed reading llm error response: %w", err)
+			return fmt.Errorf("POST predict: %v", err)
 		}
 		}
-		log.Printf("llm predict error: %s", bodyBytes)
-		return fmt.Errorf("%s", bodyBytes)
-	}
+		defer resp.Body.Close()
 
 
-	scanner := bufio.NewScanner(resp.Body)
-	// increase the buffer size to avoid running out of space
-	buf := make([]byte, 0, maxBufferSize)
-	scanner.Buffer(buf, maxBufferSize)
-	for scanner.Scan() {
-		select {
-		case <-ctx.Done():
-			// This handles the request cancellation
-			return ctx.Err()
-		default:
-			line := scanner.Bytes()
-			if len(line) == 0 {
-				continue
+		if resp.StatusCode >= 400 {
+			bodyBytes, err := io.ReadAll(resp.Body)
+			if err != nil {
+				return fmt.Errorf("failed reading llm error response: %w", err)
 			}
 			}
+			log.Printf("llm predict error: %s", bodyBytes)
+			return fmt.Errorf("%s", bodyBytes)
+		}
 
 
-			evt, ok := bytes.CutPrefix(line, []byte("data: "))
-			if !ok {
-				return fmt.Errorf("error parsing llm response stream: %s", line)
-			}
+		scanner := bufio.NewScanner(resp.Body)
+		// increase the buffer size to avoid running out of space
+		buf := make([]byte, 0, maxBufferSize)
+		scanner.Buffer(buf, maxBufferSize)
 
 
-			var p prediction
-			if err := json.Unmarshal(evt, &p); err != nil {
-				return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
-			}
+		retryNeeded := false
+		for scanner.Scan() {
+			select {
+			case <-ctx.Done():
+				// This handles the request cancellation
+				return ctx.Err()
+			default:
+				line := scanner.Bytes()
+				if len(line) == 0 {
+					continue
+				}
 
 
-			if p.Content != "" {
-				fn(PredictResult{
-					CreatedAt: time.Now().UTC(),
-					Content:   p.Content,
-				})
-			}
+				if isRetryable(line) {
+					retryNeeded = true
+					break
+				}
 
 
-			if p.Stop {
-				fn(PredictResult{
-					CreatedAt:     time.Now().UTC(),
-					TotalDuration: time.Since(predict.CheckpointStart),
-
-					Done:               true,
-					PromptEvalCount:    p.Timings.PromptN,
-					PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
-					EvalCount:          p.Timings.PredictedN,
-					EvalDuration:       parseDurationMs(p.Timings.PredictedMS),
-				})
-				return nil
+				evt, ok := bytes.CutPrefix(line, []byte("data: "))
+				if !ok {
+					return fmt.Errorf("error parsing llm response stream: %s", line)
+				}
+
+				var p prediction
+				if err := json.Unmarshal(evt, &p); err != nil {
+					return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
+				}
+
+				if p.Content != "" {
+					fn(PredictResult{
+						CreatedAt: time.Now().UTC(),
+						Content:   p.Content,
+					})
+				}
+
+				if p.Stop {
+					fn(PredictResult{
+						CreatedAt:     time.Now().UTC(),
+						TotalDuration: time.Since(predict.CheckpointStart),
+
+						Done:               true,
+						PromptEvalCount:    p.Timings.PromptN,
+						PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
+						EvalCount:          p.Timings.PredictedN,
+						EvalDuration:       parseDurationMs(p.Timings.PredictedMS),
+					})
+					return nil
+				}
 			}
 			}
 		}
 		}
-	}
 
 
-	if err := scanner.Err(); err != nil {
-		if strings.Contains(err.Error(), "unexpected EOF") {
-			// this means the llama runner subprocess crashed
-			llm.Close()
-			if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
-				return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg)
+		if err := scanner.Err(); err != nil {
+			if strings.Contains(err.Error(), "unexpected EOF") {
+				// this means the llama runner subprocess crashed
+				llm.Close()
+				if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
+					return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg)
+				}
+				return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model")
 			}
 			}
-			return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model")
+			return fmt.Errorf("error reading llm response: %v", err)
+		}
+
+		if !retryNeeded {
+			return nil // success
 		}
 		}
-		return fmt.Errorf("error reading llm response: %v", err)
 	}
 	}
 
 
-	return nil
+	// should never reach here ideally
+	return fmt.Errorf("max retries exceeded")
 }
 }
 
 
 type TokenizeRequest struct {
 type TokenizeRequest struct {