|
@@ -324,7 +324,11 @@ func (s *Server) run(ctx context.Context) {
|
|
|
case <-ctx.Done():
|
|
|
return
|
|
|
default:
|
|
|
- s.processBatch(tokenBatch, embedBatch)
|
|
|
+ err := s.processBatch(tokenBatch, embedBatch)
|
|
|
+ if err != nil {
|
|
|
+ panic(err)
|
|
|
+ }
|
|
|
+
|
|
|
tokenBatch.Clear()
|
|
|
embedBatch.Clear()
|
|
|
}
|
|
@@ -338,7 +342,7 @@ func (s *Server) run(ctx context.Context) {
|
|
|
// these should instead be handled by the handlers
|
|
|
// it should only be responsible for accepting tokens or embeddings and
|
|
|
// processing batches as fast as possible
|
|
|
-func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) {
|
|
|
+func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) error {
|
|
|
s.mu.Lock()
|
|
|
for s.allNil() {
|
|
|
s.cond.Wait() // Wait until an item is added
|
|
@@ -357,14 +361,6 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- // If an error occurred during the processing of a previous batch then we may have emptied the inputs
|
|
|
- // without adding a new one. In this case, end the sequence rather than infinite looping.
|
|
|
- if len(seq.inputs) == 0 {
|
|
|
- slog.Error("removing sequence due to no input tokens", "index", seqIdx, "cache id", seq.cache.Id)
|
|
|
- s.removeSequence(seqIdx, "error")
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
// if past the num predict limit
|
|
|
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
|
|
s.removeSequence(seqIdx, "limit")
|
|
@@ -419,7 +415,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
|
}
|
|
|
|
|
|
if batch == nil || batch.NumTokens() == 0 {
|
|
|
- return
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
s.lc.SetCrossAttention(crossAttention)
|
|
@@ -432,8 +428,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
|
err = s.lc.Decode(batch)
|
|
|
}
|
|
|
if err != nil {
|
|
|
- slog.Error("failed to decode batch", "error", err)
|
|
|
- return
|
|
|
+ return fmt.Errorf("failed to decode batch: %w", err)
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -531,6 +526,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
|
s.removeSequence(i, "connection")
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
// TODO (jmorganca): use structs from the api package to avoid duplication
|