|
@@ -338,7 +338,7 @@ type ServerStatus int
|
|
|
|
|
|
const ( // iota is reset to 0
|
|
|
ServerStatusReady ServerStatus = iota
|
|
|
- ServerStatusNoSlotsAvaialble
|
|
|
+ ServerStatusNoSlotsAvailable
|
|
|
ServerStatusLoadingModel
|
|
|
ServerStatusNotResponding
|
|
|
ServerStatusError
|
|
@@ -348,7 +348,7 @@ func (s ServerStatus) ToString() string {
|
|
|
switch s {
|
|
|
case ServerStatusReady:
|
|
|
return "llm server ready"
|
|
|
- case ServerStatusNoSlotsAvaialble:
|
|
|
+ case ServerStatusNoSlotsAvailable:
|
|
|
return "llm busy - no slots available"
|
|
|
case ServerStatusLoadingModel:
|
|
|
return "llm server loading model"
|
|
@@ -405,7 +405,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
|
|
case "ok":
|
|
|
return ServerStatusReady, nil
|
|
|
case "no slot available":
|
|
|
- return ServerStatusNoSlotsAvaialble, nil
|
|
|
+ return ServerStatusNoSlotsAvailable, nil
|
|
|
case "loading model":
|
|
|
return ServerStatusLoadingModel, nil
|
|
|
default:
|
|
@@ -413,6 +413,29 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// getServerStatusRetry will retry if ServerStatusNoSlotsAvailable is received
|
|
|
+func (s *llmServer) getServerStatusRetry(ctx context.Context) (ServerStatus, error) {
|
|
|
+ var retries int
|
|
|
+ for {
|
|
|
+ status, err := s.getServerStatus(ctx)
|
|
|
+ if err != nil {
|
|
|
+ return status, err
|
|
|
+ }
|
|
|
+
|
|
|
+ if status == ServerStatusNoSlotsAvailable {
|
|
|
+ if retries >= 10 {
|
|
|
+ return status, fmt.Errorf("no slots available after %d retries", retries)
|
|
|
+ }
|
|
|
+
|
|
|
+ time.Sleep(5 * time.Millisecond)
|
|
|
+ retries++
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ return status, nil
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func (s *llmServer) Ping(ctx context.Context) error {
|
|
|
_, err := s.getServerStatus(ctx)
|
|
|
if err != nil {
|
|
@@ -510,7 +533,6 @@ ws ::= ([ \t\n] ws)?
|
|
|
`
|
|
|
|
|
|
const maxBufferSize = 512 * format.KiloByte
|
|
|
-const maxRetries = 3
|
|
|
|
|
|
type ImageData struct {
|
|
|
Data []byte `json:"data"`
|
|
@@ -586,7 +608,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|
|
}
|
|
|
|
|
|
// Make sure the server is ready
|
|
|
- status, err := s.getServerStatus(ctx)
|
|
|
+ status, err := s.getServerStatusRetry(ctx)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
} else if status != ServerStatusReady {
|
|
@@ -600,133 +622,113 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- retryDelay := 100 * time.Microsecond
|
|
|
- for retries := 0; retries < maxRetries; retries++ {
|
|
|
- if retries > 0 {
|
|
|
- time.Sleep(retryDelay) // wait before retrying
|
|
|
- retryDelay *= 2 // exponential backoff
|
|
|
- }
|
|
|
+ // Handling JSON marshaling with special characters unescaped.
|
|
|
+ buffer := &bytes.Buffer{}
|
|
|
+ enc := json.NewEncoder(buffer)
|
|
|
+ enc.SetEscapeHTML(false)
|
|
|
|
|
|
- // Handling JSON marshaling with special characters unescaped.
|
|
|
- buffer := &bytes.Buffer{}
|
|
|
- enc := json.NewEncoder(buffer)
|
|
|
- enc.SetEscapeHTML(false)
|
|
|
+ if err := enc.Encode(request); err != nil {
|
|
|
+ return fmt.Errorf("failed to marshal data: %v", err)
|
|
|
+ }
|
|
|
|
|
|
- if err := enc.Encode(request); err != nil {
|
|
|
- return fmt.Errorf("failed to marshal data: %v", err)
|
|
|
- }
|
|
|
+ endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
|
|
+ serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("error creating POST request: %v", err)
|
|
|
+ }
|
|
|
+ serverReq.Header.Set("Content-Type", "application/json")
|
|
|
|
|
|
- endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
|
|
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
|
|
- if err != nil {
|
|
|
- return fmt.Errorf("error creating POST request: %v", err)
|
|
|
- }
|
|
|
- req.Header.Set("Content-Type", "application/json")
|
|
|
+ res, err := http.DefaultClient.Do(serverReq)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("POST predict: %v", err)
|
|
|
+ }
|
|
|
+ defer res.Body.Close()
|
|
|
|
|
|
- resp, err := http.DefaultClient.Do(req)
|
|
|
+ if res.StatusCode >= 400 {
|
|
|
+ bodyBytes, err := io.ReadAll(res.Body)
|
|
|
if err != nil {
|
|
|
- return fmt.Errorf("POST predict: %v", err)
|
|
|
- }
|
|
|
- defer resp.Body.Close()
|
|
|
-
|
|
|
- if resp.StatusCode >= 400 {
|
|
|
- bodyBytes, err := io.ReadAll(resp.Body)
|
|
|
- if err != nil {
|
|
|
- return fmt.Errorf("failed reading llm error response: %w", err)
|
|
|
- }
|
|
|
- log.Printf("llm predict error: %s", bodyBytes)
|
|
|
- return fmt.Errorf("%s", bodyBytes)
|
|
|
+ return fmt.Errorf("failed reading llm error response: %w", err)
|
|
|
}
|
|
|
+ log.Printf("llm predict error: %s", bodyBytes)
|
|
|
+ return fmt.Errorf("%s", bodyBytes)
|
|
|
+ }
|
|
|
|
|
|
- scanner := bufio.NewScanner(resp.Body)
|
|
|
- buf := make([]byte, 0, maxBufferSize)
|
|
|
- scanner.Buffer(buf, maxBufferSize)
|
|
|
-
|
|
|
- retryNeeded := false
|
|
|
- // keep track of the last token generated, this is used to abort if the model starts looping
|
|
|
- var lastToken string
|
|
|
- var tokenRepeat int
|
|
|
+ scanner := bufio.NewScanner(res.Body)
|
|
|
+ buf := make([]byte, 0, maxBufferSize)
|
|
|
+ scanner.Buffer(buf, maxBufferSize)
|
|
|
|
|
|
- for scanner.Scan() {
|
|
|
- select {
|
|
|
- case <-ctx.Done():
|
|
|
- // This handles the request cancellation
|
|
|
- return ctx.Err()
|
|
|
- default:
|
|
|
- line := scanner.Bytes()
|
|
|
- if len(line) == 0 {
|
|
|
- continue
|
|
|
- }
|
|
|
+ // keep track of the last token generated, this is used to abort if the model starts looping
|
|
|
+ var lastToken string
|
|
|
+ var tokenRepeat int
|
|
|
|
|
|
- // try again on slot unavailable
|
|
|
- if bytes.Contains(line, []byte("slot unavailable")) {
|
|
|
- retryNeeded = true
|
|
|
- break
|
|
|
- }
|
|
|
+ for scanner.Scan() {
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ // This handles the request cancellation
|
|
|
+ return ctx.Err()
|
|
|
+ default:
|
|
|
+ line := scanner.Bytes()
|
|
|
+ if len(line) == 0 {
|
|
|
+ continue
|
|
|
+ }
|
|
|
|
|
|
- evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
|
|
- if !ok {
|
|
|
- return fmt.Errorf("error parsing llm response stream: %s", line)
|
|
|
- }
|
|
|
+ evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
|
|
+ if !ok {
|
|
|
+ return fmt.Errorf("error parsing llm response stream: %s", line)
|
|
|
+ }
|
|
|
|
|
|
- var c completion
|
|
|
- if err := json.Unmarshal(evt, &c); err != nil {
|
|
|
- return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
|
|
- }
|
|
|
+ var c completion
|
|
|
+ if err := json.Unmarshal(evt, &c); err != nil {
|
|
|
+ return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
|
|
+ }
|
|
|
|
|
|
- switch {
|
|
|
- case strings.TrimSpace(c.Content) == lastToken:
|
|
|
- tokenRepeat++
|
|
|
- default:
|
|
|
- lastToken = strings.TrimSpace(c.Content)
|
|
|
- tokenRepeat = 0
|
|
|
- }
|
|
|
+ switch {
|
|
|
+ case strings.TrimSpace(c.Content) == lastToken:
|
|
|
+ tokenRepeat++
|
|
|
+ default:
|
|
|
+ lastToken = strings.TrimSpace(c.Content)
|
|
|
+ tokenRepeat = 0
|
|
|
+ }
|
|
|
|
|
|
- // 30 picked as an arbitrary max token repeat limit, modify as needed
|
|
|
- if tokenRepeat > 30 {
|
|
|
- slog.Debug("prediction aborted, token repeat limit reached")
|
|
|
- return ctx.Err()
|
|
|
- }
|
|
|
+ // 30 picked as an arbitrary max token repeat limit, modify as needed
|
|
|
+ if tokenRepeat > 30 {
|
|
|
+ slog.Debug("prediction aborted, token repeat limit reached")
|
|
|
+ return ctx.Err()
|
|
|
+ }
|
|
|
|
|
|
- if c.Content != "" {
|
|
|
- fn(CompletionResponse{
|
|
|
- Content: c.Content,
|
|
|
- })
|
|
|
- }
|
|
|
+ if c.Content != "" {
|
|
|
+ fn(CompletionResponse{
|
|
|
+ Content: c.Content,
|
|
|
+ })
|
|
|
+ }
|
|
|
|
|
|
- if c.Stop {
|
|
|
- fn(CompletionResponse{
|
|
|
- Done: true,
|
|
|
- PromptEvalCount: c.Timings.PromptN,
|
|
|
- PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
|
|
- EvalCount: c.Timings.PredictedN,
|
|
|
- EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
|
|
- })
|
|
|
- return nil
|
|
|
- }
|
|
|
+ if c.Stop {
|
|
|
+ fn(CompletionResponse{
|
|
|
+ Done: true,
|
|
|
+ PromptEvalCount: c.Timings.PromptN,
|
|
|
+ PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
|
|
+ EvalCount: c.Timings.PredictedN,
|
|
|
+ EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
|
|
+ })
|
|
|
+ return nil
|
|
|
}
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- if err := scanner.Err(); err != nil {
|
|
|
- if strings.Contains(err.Error(), "unexpected EOF") {
|
|
|
- s.Close()
|
|
|
- msg := ""
|
|
|
- if s.status != nil && s.status.LastErrMsg != "" {
|
|
|
- msg = s.status.LastErrMsg
|
|
|
- }
|
|
|
-
|
|
|
- return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
|
|
|
+ if err := scanner.Err(); err != nil {
|
|
|
+ if strings.Contains(err.Error(), "unexpected EOF") {
|
|
|
+ s.Close()
|
|
|
+ msg := ""
|
|
|
+ if s.status != nil && s.status.LastErrMsg != "" {
|
|
|
+ msg = s.status.LastErrMsg
|
|
|
}
|
|
|
- return fmt.Errorf("error reading llm response: %v", err)
|
|
|
+ return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
|
|
|
}
|
|
|
|
|
|
- if !retryNeeded {
|
|
|
- return nil // success
|
|
|
- }
|
|
|
+ return fmt.Errorf("error reading llm response: %v", err)
|
|
|
}
|
|
|
|
|
|
- // should never reach here ideally
|
|
|
- return fmt.Errorf("max retries exceeded")
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
type EmbeddingRequest struct {
|
|
@@ -743,8 +745,9 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
|
|
return nil, err
|
|
|
}
|
|
|
defer s.sem.Release(1)
|
|
|
+
|
|
|
// Make sure the server is ready
|
|
|
- status, err := s.getServerStatus(ctx)
|
|
|
+ status, err := s.getServerStatusRetry(ctx)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
} else if status != ServerStatusReady {
|
|
@@ -799,7 +802,7 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error)
|
|
|
status, err := s.getServerStatus(ctx)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
- } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
|
|
|
+ } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
|
|
|
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
|
|
}
|
|
|
|
|
@@ -851,7 +854,7 @@ func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error
|
|
|
status, err := s.getServerStatus(ctx)
|
|
|
if err != nil {
|
|
|
return "", err
|
|
|
- } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
|
|
|
+ } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
|
|
|
return "", fmt.Errorf("unexpected server status: %s", status.ToString())
|
|
|
}
|
|
|
|