Bläddra i källkod

runner.go: Enforce NUM_PARALLEL directly in the runner

NUM_PARALEL is currently enforced by the Ollama server process - it
will only issue requests to the runner if the maximum number of
concurrent requests has not been exceeded. Although this should
be sufficient, it is good for the runner to protect its own data
structures. Currently, if too many requests get through to the
runner, they will just get stuck and never return.

This may help with reports of Ollama hanging, though it is unclear
how it would actually occur.

Bug #7573
Jesse Gross 5 månader sedan
förälder
incheckning
17b386a891
1 ändrade filer med 47 tillägg och 20 borttagningar
  1. 47 20
      llama/runner/runner.go

+ 47 - 20
llama/runner/runner.go

@@ -20,6 +20,8 @@ import (
 	"time"
 	"time"
 	"unicode/utf8"
 	"unicode/utf8"
 
 
+	"golang.org/x/sync/semaphore"
+
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/llama"
 	"github.com/ollama/ollama/llama"
 )
 )
@@ -203,38 +205,51 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
 }
 }
 
 
 type Server struct {
 type Server struct {
+	// is the server ready to process requests?
+	// protects access to model and image
+	ready sync.WaitGroup
+
+	// loaded model
 	model *llama.Model
 	model *llama.Model
-	lc    *llama.Context
 
 
-	// required for image embeddings
+	// image model context for multi-modal models
 	image *ImageContext
 	image *ImageContext
 
 
+	// status for external health reporting - loading, ready to serve, etc.
+	status ServerStatus
+
+	// current progress on loading the model
+	progress float32
+
+	// number of simultaneous requests to handle
+	parallel int
+
+	// maximum number of elements in a batch (per sequence)
 	// TODO (jmorganca): make this n_batch
 	// TODO (jmorganca): make this n_batch
 	batchSize int
 	batchSize int
 
 
-	// parallel is the number of parallel requests to handle
-	parallel int
+	// protects access to everything below this line
+	// this is context state needed for decoding
+	mu sync.Mutex
+
+	// indicates that data is ready for processing
+	cond *sync.Cond
+
+	// decoding state
+	lc *llama.Context
 
 
-	// seqs is the list of parallel sequences being evaluated
-	// TODO (jmorganca): this can probably be moved into run()
+	// the list of simultaneous sequences being evaluated
 	seqs []*Sequence
 	seqs []*Sequence
 
 
+	// seqs can have a maximum of parallel entries, which
+	// is enfoced by seqSem
+	seqsSem *semaphore.Weighted
+
 	// KV cache
 	// KV cache
 	cache *InputCache
 	cache *InputCache
 
 
 	// next sequence for prompt processing to avoid starvation
 	// next sequence for prompt processing to avoid starvation
 	nextSeq int
 	nextSeq int
-
-	// is the server ready to process requests?
-	ready sync.WaitGroup
-
-	mu sync.Mutex
-
-	cond *sync.Cond
-
-	progress float32
-
-	status ServerStatus
 }
 }
 
 
 func (s *Server) allNil() bool {
 func (s *Server) allNil() bool {
@@ -616,8 +631,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 		return
 		return
 	}
 	}
 
 
-	// TODO (jmorganca): add to sequence queue instead of
-	// failing if a slot isn't available
+	// Ensure that a place to put the sequence is available
+	if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
+		slog.Error("Failed to acquire semaphore", "error", err)
+		return
+	}
+	defer s.seqsSem.Release(1)
+
 	s.mu.Lock()
 	s.mu.Lock()
 	for i, sq := range s.seqs {
 	for i, sq := range s.seqs {
 		if sq == nil {
 		if sq == nil {
@@ -700,7 +720,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 		return
 		return
 	}
 	}
 
 
-	// TODO (jessegross): Wait for a free slot instead of failing and blocking forever
+	// Ensure that a place to put the sequence is available
+	if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
+		slog.Error("Failed to acquire semaphore", "error", err)
+		return
+	}
+	defer s.seqsSem.Release(1)
+
 	s.mu.Lock()
 	s.mu.Lock()
 	for i, sq := range s.seqs {
 	for i, sq := range s.seqs {
 		if sq == nil {
 		if sq == nil {
@@ -855,6 +881,7 @@ func main() {
 		batchSize: *batchSize,
 		batchSize: *batchSize,
 		parallel:  *parallel,
 		parallel:  *parallel,
 		seqs:      make([]*Sequence, *parallel),
 		seqs:      make([]*Sequence, *parallel),
+		seqsSem:   semaphore.NewWeighted(int64(*parallel)),
 		status:    ServerStatusLoadingModel,
 		status:    ServerStatusLoadingModel,
 	}
 	}