浏览代码

runner: Release semaphore and improve error messages on failures

If we have an error after creating a new sequence but before
finding a slot for it, we return without releasing the semaphore.
This reduces our parallel sequences and eventually leads to deadlock.

In practice this should never happen because once we have acquired
the semaphore, we should always be able to find a slot. However, the
code is clearly not correct.
Jesse Gross 1 月之前
父节点
当前提交
97e569475e
共有 2 个文件被更改,包括 9 次插入3 次删除
  1. 6 2
      runner/llamarunner/runner.go
  2. 3 1
      runner/ollamarunner/runner.go

+ 6 - 2
runner/llamarunner/runner.go

@@ -599,7 +599,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 		if errors.Is(err, context.Canceled) {
 		if errors.Is(err, context.Canceled) {
 			slog.Info("aborting completion request due to client closing the connection")
 			slog.Info("aborting completion request due to client closing the connection")
 		} else {
 		} else {
-			slog.Error("Failed to acquire semaphore", "error", err)
+			http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
 		}
 		}
 		return
 		return
 	}
 	}
@@ -611,6 +611,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
 			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
 			if err != nil {
 			if err != nil {
 				s.mu.Unlock()
 				s.mu.Unlock()
+				s.seqsSem.Release(1)
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
 				return
 				return
 			}
 			}
@@ -626,6 +627,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	s.mu.Unlock()
 	s.mu.Unlock()
 
 
 	if !found {
 	if !found {
+		s.seqsSem.Release(1)
 		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
 		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
 		return
 		return
 	}
 	}
@@ -691,7 +693,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 		if errors.Is(err, context.Canceled) {
 		if errors.Is(err, context.Canceled) {
 			slog.Info("aborting embeddings request due to client closing the connection")
 			slog.Info("aborting embeddings request due to client closing the connection")
 		} else {
 		} else {
-			slog.Error("Failed to acquire semaphore", "error", err)
+			http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
 		}
 		}
 		return
 		return
 	}
 	}
@@ -703,6 +705,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
 			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
 			if err != nil {
 			if err != nil {
 				s.mu.Unlock()
 				s.mu.Unlock()
+				s.seqsSem.Release(1)
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
 				return
 				return
 			}
 			}
@@ -715,6 +718,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 	s.mu.Unlock()
 	s.mu.Unlock()
 
 
 	if !found {
 	if !found {
+		s.seqsSem.Release(1)
 		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
 		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
 		return
 		return
 	}
 	}

+ 3 - 1
runner/ollamarunner/runner.go

@@ -588,7 +588,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 		if errors.Is(err, context.Canceled) {
 		if errors.Is(err, context.Canceled) {
 			slog.Info("aborting completion request due to client closing the connection")
 			slog.Info("aborting completion request due to client closing the connection")
 		} else {
 		} else {
-			slog.Error("Failed to acquire semaphore", "error", err)
+			http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
 		}
 		}
 		return
 		return
 	}
 	}
@@ -600,6 +600,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
 			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
 			if err != nil {
 			if err != nil {
 				s.mu.Unlock()
 				s.mu.Unlock()
+				s.seqsSem.Release(1)
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
 				return
 				return
 			}
 			}
@@ -613,6 +614,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	s.mu.Unlock()
 	s.mu.Unlock()
 
 
 	if !found {
 	if !found {
+		s.seqsSem.Release(1)
 		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
 		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
 		return
 		return
 	}
 	}