Browse Source

Fix `no slots available` error with concurrent requests (#4160)

Jeffrey Morgan 1 year ago
parent
commit
ed740a2504
1 changed files with 115 additions and 112 deletions
  1. 115 112
      llm/server.go

+ 115 - 112
llm/server.go

@@ -338,7 +338,7 @@ type ServerStatus int
 
 
 const ( // iota is reset to 0
 const ( // iota is reset to 0
 	ServerStatusReady ServerStatus = iota
 	ServerStatusReady ServerStatus = iota
-	ServerStatusNoSlotsAvaialble
+	ServerStatusNoSlotsAvailable
 	ServerStatusLoadingModel
 	ServerStatusLoadingModel
 	ServerStatusNotResponding
 	ServerStatusNotResponding
 	ServerStatusError
 	ServerStatusError
@@ -348,7 +348,7 @@ func (s ServerStatus) ToString() string {
 	switch s {
 	switch s {
 	case ServerStatusReady:
 	case ServerStatusReady:
 		return "llm server ready"
 		return "llm server ready"
-	case ServerStatusNoSlotsAvaialble:
+	case ServerStatusNoSlotsAvailable:
 		return "llm busy - no slots available"
 		return "llm busy - no slots available"
 	case ServerStatusLoadingModel:
 	case ServerStatusLoadingModel:
 		return "llm server loading model"
 		return "llm server loading model"
@@ -405,7 +405,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
 	case "ok":
 	case "ok":
 		return ServerStatusReady, nil
 		return ServerStatusReady, nil
 	case "no slot available":
 	case "no slot available":
-		return ServerStatusNoSlotsAvaialble, nil
+		return ServerStatusNoSlotsAvailable, nil
 	case "loading model":
 	case "loading model":
 		return ServerStatusLoadingModel, nil
 		return ServerStatusLoadingModel, nil
 	default:
 	default:
@@ -413,6 +413,29 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
 	}
 	}
 }
 }
 
 
+// getServerStatusRetry will retry if ServerStatusNoSlotsAvailable is received
+func (s *llmServer) getServerStatusRetry(ctx context.Context) (ServerStatus, error) {
+	var retries int
+	for {
+		status, err := s.getServerStatus(ctx)
+		if err != nil {
+			return status, err
+		}
+
+		if status == ServerStatusNoSlotsAvailable {
+			if retries >= 10 {
+				return status, fmt.Errorf("no slots available after %d retries", retries)
+			}
+
+			time.Sleep(5 * time.Millisecond)
+			retries++
+			continue
+		}
+
+		return status, nil
+	}
+}
+
 func (s *llmServer) Ping(ctx context.Context) error {
 func (s *llmServer) Ping(ctx context.Context) error {
 	_, err := s.getServerStatus(ctx)
 	_, err := s.getServerStatus(ctx)
 	if err != nil {
 	if err != nil {
@@ -510,7 +533,6 @@ ws ::= ([ \t\n] ws)?
 `
 `
 
 
 const maxBufferSize = 512 * format.KiloByte
 const maxBufferSize = 512 * format.KiloByte
-const maxRetries = 3
 
 
 type ImageData struct {
 type ImageData struct {
 	Data []byte `json:"data"`
 	Data []byte `json:"data"`
@@ -586,7 +608,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 	}
 	}
 
 
 	// Make sure the server is ready
 	// Make sure the server is ready
-	status, err := s.getServerStatus(ctx)
+	status, err := s.getServerStatusRetry(ctx)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	} else if status != ServerStatusReady {
 	} else if status != ServerStatusReady {
@@ -600,133 +622,113 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 		}
 		}
 	}
 	}
 
 
-	retryDelay := 100 * time.Microsecond
-	for retries := 0; retries < maxRetries; retries++ {
-		if retries > 0 {
-			time.Sleep(retryDelay) // wait before retrying
-			retryDelay *= 2        // exponential backoff
-		}
+	// Handling JSON marshaling with special characters unescaped.
+	buffer := &bytes.Buffer{}
+	enc := json.NewEncoder(buffer)
+	enc.SetEscapeHTML(false)
 
 
-		// Handling JSON marshaling with special characters unescaped.
-		buffer := &bytes.Buffer{}
-		enc := json.NewEncoder(buffer)
-		enc.SetEscapeHTML(false)
+	if err := enc.Encode(request); err != nil {
+		return fmt.Errorf("failed to marshal data: %v", err)
+	}
 
 
-		if err := enc.Encode(request); err != nil {
-			return fmt.Errorf("failed to marshal data: %v", err)
-		}
+	endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
+	serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
+	if err != nil {
+		return fmt.Errorf("error creating POST request: %v", err)
+	}
+	serverReq.Header.Set("Content-Type", "application/json")
 
 
-		endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
-		req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
-		if err != nil {
-			return fmt.Errorf("error creating POST request: %v", err)
-		}
-		req.Header.Set("Content-Type", "application/json")
+	res, err := http.DefaultClient.Do(serverReq)
+	if err != nil {
+		return fmt.Errorf("POST predict: %v", err)
+	}
+	defer res.Body.Close()
 
 
-		resp, err := http.DefaultClient.Do(req)
+	if res.StatusCode >= 400 {
+		bodyBytes, err := io.ReadAll(res.Body)
 		if err != nil {
 		if err != nil {
-			return fmt.Errorf("POST predict: %v", err)
-		}
-		defer resp.Body.Close()
-
-		if resp.StatusCode >= 400 {
-			bodyBytes, err := io.ReadAll(resp.Body)
-			if err != nil {
-				return fmt.Errorf("failed reading llm error response: %w", err)
-			}
-			log.Printf("llm predict error: %s", bodyBytes)
-			return fmt.Errorf("%s", bodyBytes)
+			return fmt.Errorf("failed reading llm error response: %w", err)
 		}
 		}
+		log.Printf("llm predict error: %s", bodyBytes)
+		return fmt.Errorf("%s", bodyBytes)
+	}
 
 
-		scanner := bufio.NewScanner(resp.Body)
-		buf := make([]byte, 0, maxBufferSize)
-		scanner.Buffer(buf, maxBufferSize)
-
-		retryNeeded := false
-		// keep track of the last token generated, this is used to abort if the model starts looping
-		var lastToken string
-		var tokenRepeat int
+	scanner := bufio.NewScanner(res.Body)
+	buf := make([]byte, 0, maxBufferSize)
+	scanner.Buffer(buf, maxBufferSize)
 
 
-		for scanner.Scan() {
-			select {
-			case <-ctx.Done():
-				// This handles the request cancellation
-				return ctx.Err()
-			default:
-				line := scanner.Bytes()
-				if len(line) == 0 {
-					continue
-				}
+	// keep track of the last token generated, this is used to abort if the model starts looping
+	var lastToken string
+	var tokenRepeat int
 
 
-				// try again on slot unavailable
-				if bytes.Contains(line, []byte("slot unavailable")) {
-					retryNeeded = true
-					break
-				}
+	for scanner.Scan() {
+		select {
+		case <-ctx.Done():
+			// This handles the request cancellation
+			return ctx.Err()
+		default:
+			line := scanner.Bytes()
+			if len(line) == 0 {
+				continue
+			}
 
 
-				evt, ok := bytes.CutPrefix(line, []byte("data: "))
-				if !ok {
-					return fmt.Errorf("error parsing llm response stream: %s", line)
-				}
+			evt, ok := bytes.CutPrefix(line, []byte("data: "))
+			if !ok {
+				return fmt.Errorf("error parsing llm response stream: %s", line)
+			}
 
 
-				var c completion
-				if err := json.Unmarshal(evt, &c); err != nil {
-					return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
-				}
+			var c completion
+			if err := json.Unmarshal(evt, &c); err != nil {
+				return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
+			}
 
 
-				switch {
-				case strings.TrimSpace(c.Content) == lastToken:
-					tokenRepeat++
-				default:
-					lastToken = strings.TrimSpace(c.Content)
-					tokenRepeat = 0
-				}
+			switch {
+			case strings.TrimSpace(c.Content) == lastToken:
+				tokenRepeat++
+			default:
+				lastToken = strings.TrimSpace(c.Content)
+				tokenRepeat = 0
+			}
 
 
-				// 30 picked as an arbitrary max token repeat limit, modify as needed
-				if tokenRepeat > 30 {
-					slog.Debug("prediction aborted, token repeat limit reached")
-					return ctx.Err()
-				}
+			// 30 picked as an arbitrary max token repeat limit, modify as needed
+			if tokenRepeat > 30 {
+				slog.Debug("prediction aborted, token repeat limit reached")
+				return ctx.Err()
+			}
 
 
-				if c.Content != "" {
-					fn(CompletionResponse{
-						Content: c.Content,
-					})
-				}
+			if c.Content != "" {
+				fn(CompletionResponse{
+					Content: c.Content,
+				})
+			}
 
 
-				if c.Stop {
-					fn(CompletionResponse{
-						Done:               true,
-						PromptEvalCount:    c.Timings.PromptN,
-						PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
-						EvalCount:          c.Timings.PredictedN,
-						EvalDuration:       parseDurationMs(c.Timings.PredictedMS),
-					})
-					return nil
-				}
+			if c.Stop {
+				fn(CompletionResponse{
+					Done:               true,
+					PromptEvalCount:    c.Timings.PromptN,
+					PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
+					EvalCount:          c.Timings.PredictedN,
+					EvalDuration:       parseDurationMs(c.Timings.PredictedMS),
+				})
+				return nil
 			}
 			}
 		}
 		}
+	}
 
 
-		if err := scanner.Err(); err != nil {
-			if strings.Contains(err.Error(), "unexpected EOF") {
-				s.Close()
-				msg := ""
-				if s.status != nil && s.status.LastErrMsg != "" {
-					msg = s.status.LastErrMsg
-				}
-
-				return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
+	if err := scanner.Err(); err != nil {
+		if strings.Contains(err.Error(), "unexpected EOF") {
+			s.Close()
+			msg := ""
+			if s.status != nil && s.status.LastErrMsg != "" {
+				msg = s.status.LastErrMsg
 			}
 			}
-			return fmt.Errorf("error reading llm response: %v", err)
+			return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
 		}
 		}
 
 
-		if !retryNeeded {
-			return nil // success
-		}
+		return fmt.Errorf("error reading llm response: %v", err)
 	}
 	}
 
 
-	// should never reach here ideally
-	return fmt.Errorf("max retries exceeded")
+	return nil
 }
 }
 
 
 type EmbeddingRequest struct {
 type EmbeddingRequest struct {
@@ -743,8 +745,9 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
 		return nil, err
 		return nil, err
 	}
 	}
 	defer s.sem.Release(1)
 	defer s.sem.Release(1)
+
 	// Make sure the server is ready
 	// Make sure the server is ready
-	status, err := s.getServerStatus(ctx)
+	status, err := s.getServerStatusRetry(ctx)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	} else if status != ServerStatusReady {
 	} else if status != ServerStatusReady {
@@ -799,7 +802,7 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error)
 	status, err := s.getServerStatus(ctx)
 	status, err := s.getServerStatus(ctx)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
-	} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
+	} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
 		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 	}
 
 
@@ -851,7 +854,7 @@ func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error
 	status, err := s.getServerStatus(ctx)
 	status, err := s.getServerStatus(ctx)
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
-	} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
+	} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
 		return "", fmt.Errorf("unexpected server status: %s", status.ToString())
 		return "", fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 	}