浏览代码

llm: reserve required number of slots for embeddings (#6219)

Jeffrey Morgan 8 月之前
父节点
当前提交
de4fc29773
共有 1 个文件被更改,包括 12 次插入7 次删除
  1. 12 7
      llm/server.go

+ 12 - 7
llm/server.go

@@ -44,11 +44,12 @@ type LlamaServer interface {
 
 // llmServer is an instance of the llama.cpp server
 type llmServer struct {
-	port    int
-	cmd     *exec.Cmd
-	done    chan error // Channel to signal when the process exits
-	status  *StatusWriter
-	options api.Options
+	port        int
+	cmd         *exec.Cmd
+	done        chan error // Channel to signal when the process exits
+	status      *StatusWriter
+	options     api.Options
+	numParallel int
 
 	estimate    MemoryEstimate
 	totalLayers uint64
@@ -343,6 +344,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 			status:      NewStatusWriter(os.Stderr),
 			options:     opts,
 			estimate:    estimate,
+			numParallel: numParallel,
 			sem:         semaphore.NewWeighted(int64(numParallel)),
 			totalLayers: ggml.KV().BlockCount() + 1,
 			gpus:        gpus,
@@ -890,11 +892,14 @@ type EmbedResponse struct {
 }
 
 func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) {
-	if err := s.sem.Acquire(ctx, 1); err != nil {
+	// each input will use a slot, so we need to acquire the semaphore for
+	// the number of inputs up to numParallel
+	slots := int64(min(len(input), s.numParallel))
+	if err := s.sem.Acquire(ctx, slots); err != nil {
 		slog.Error("Failed to acquire semaphore", "error", err)
 		return nil, err
 	}
-	defer s.sem.Release(1)
+	defer s.sem.Release(slots)
 
 	// Make sure the server is ready
 	status, err := s.getServerStatusRetry(ctx)