瀏覽代碼

Fix parallel requests

Daniel Hiltgen 9 月之前
父節點
當前提交
f97ee8c506
共有 3 個文件被更改,包括 10 次插入17 次删除
  1. 7 15
      llama/runner/runner.go
  2. 1 1
      llm/generate/gen_linux.sh
  3. 2 1
      llm/generate/gen_windows.ps1

+ 7 - 15
llama/runner/runner.go

@@ -23,6 +23,9 @@ type Sequence struct {
 	// number of tokens evaluated
 	nPast int
 
+	// batch index
+	iBatch int
+
 	// number of tokens predicted so far
 	numPredicted int
 
@@ -122,6 +125,7 @@ func (s *Server) allNil() bool {
 }
 
 func (s *Server) run(ctx context.Context) {
+	// TODO - should this be n_ctx / parallel like the old server.cpp setup?
 	batch := llama.NewBatch(s.batchSize, 0, s.parallel)
 	defer batch.Free()
 
@@ -141,8 +145,6 @@ func (s *Server) run(ctx context.Context) {
 			}
 			s.mu.Unlock()
 
-			// prepare the batch
-			ibatch := make([]int, s.parallel)
 			for i, seq := range s.seqs {
 				if seq == nil {
 					continue
@@ -164,14 +166,10 @@ func (s *Server) run(ctx context.Context) {
 					if j > s.batchSize {
 						break
 					}
-
 					batch.Add(t, seq.nPast, []int{i}, !seq.prompt())
 					seq.nPast++
-
-					if seq.prompt() {
-						ibatch[i] = batch.NumTokens() + 1
-					}
 				}
+				seq.iBatch = batch.NumTokens() - 1
 			}
 
 			err := s.lc.Decode(batch)
@@ -186,12 +184,6 @@ func (s *Server) run(ctx context.Context) {
 
 				// don't sample prompt processing
 				if seq.prompt() {
-					if len(seq.tokens) < s.batchSize {
-						seq.tokens = []int{}
-					} else {
-						seq.tokens = seq.tokens[s.batchSize:]
-					}
-
 					continue
 				}
 
@@ -199,7 +191,7 @@ func (s *Server) run(ctx context.Context) {
 				if seq.embeddingOnly {
 					embd := s.lc.GetEmbeddingsSeq(i)
 					if embd == nil {
-						embd = s.lc.GetEmbeddingsIth(ibatch[i])
+						embd = s.lc.GetEmbeddingsIth(seq.iBatch)
 					}
 
 					seq.embedding <- embd
@@ -212,7 +204,7 @@ func (s *Server) run(ctx context.Context) {
 				// sample a token
 				// logits := s.lc.GetLogitsIth(ibatch[i])
 				// token := s.lc.SampleTokenGreedy(logits)
-				token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
+				token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch)
 
 				seq.samplingCtx.Accept(s.lc, token, true)
 				piece := s.model.TokenToPiece(token)

+ 1 - 1
llm/generate/gen_linux.sh

@@ -65,7 +65,7 @@ if [ -z "${OLLAMA_SKIP_STATIC_GENERATE}" -o "${OLLAMA_CPU_TARGET}" = "static" ];
     # Static build for linking into the Go binary
     init_vars
     CMAKE_TARGETS="--target llama --target ggml"
-    CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
+    CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off -DGGML_OPENMP=off ${CMAKE_DEFS}"
     BUILD_DIR="../build/linux/${ARCH}_static"
     echo "Building static library"
     build

+ 2 - 1
llm/generate/gen_windows.ps1

@@ -200,7 +200,8 @@ function build_static() {
             "-DLLAMA_AVX2=off",
             "-DLLAMA_AVX512=off",
             "-DLLAMA_F16C=off",
-            "-DLLAMA_FMA=off")
+            "-DLLAMA_FMA=off",
+            "-DGGML_OPENMP=off")
         $script:buildDir="../build/windows/${script:ARCH}_static"
         write-host "Building static library"
         build