Ver código fonte

Fix clearing kv cache between requests with the same prompt (#2186)

* Fix clearing kv cache between requests with the same prompt

* fix powershell script
Jeffrey Morgan 1 ano atrás
pai
commit
a64570dcae

+ 1 - 0
llm/dyn_ext_server.go

@@ -190,6 +190,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
 		"seed":              predict.Options.Seed,
 		"stop":              predict.Options.Stop,
 		"image_data":        imageData,
+		"cache_prompt":      true,
 	}
 
 	if predict.Format == "json" {

+ 11 - 0
llm/generate/gen_common.sh

@@ -61,6 +61,17 @@ apply_patches() {
     if ! grep ollama ${LLAMACPP_DIR}/examples/server/CMakeLists.txt; then
         echo 'include (../../../ext_server/CMakeLists.txt) # ollama' >>${LLAMACPP_DIR}/examples/server/CMakeLists.txt
     fi
+
+    # apply temporary patches until fix is upstream
+    for patch in ../patches/*.diff; do
+        for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do
+            (cd ${LLAMACPP_DIR}; git checkout ${file})
+        done
+    done
+    for patch in ../patches/*.diff; do
+        (cd ${LLAMACPP_DIR} && git apply ${patch})
+    done
+
     # Avoid duplicate main symbols when we link into the cgo binary
     sed -e 's/int main(/int __main(/g' <${LLAMACPP_DIR}/examples/server/server.cpp >${LLAMACPP_DIR}/examples/server/server.cpp.tmp &&
         mv ${LLAMACPP_DIR}/examples/server/server.cpp.tmp ${LLAMACPP_DIR}/examples/server/server.cpp

+ 23 - 0
llm/generate/gen_windows.ps1

@@ -40,6 +40,29 @@ function apply_patches {
     if (!(Select-String -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Pattern 'ollama')) {
         Add-Content -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Value 'include (../../../ext_server/CMakeLists.txt) # ollama'
     }
+
+    # Apply temporary patches until fix is upstream
+    $patches = Get-ChildItem "../patches/*.diff"
+    foreach ($patch in $patches) {
+        # Extract file paths from the patch file
+        $filePaths = Get-Content $patch.FullName | Where-Object { $_ -match '^\+\+\+ ' } | ForEach-Object {
+            $parts = $_ -split ' '
+            ($parts[1] -split '/', 2)[1]
+        }
+
+        # Checkout each file
+        foreach ($file in $filePaths) {
+            Set-Location -Path ${script:llamacppDir}
+            git checkout $file
+        }
+    }
+
+    # Apply each patch
+    foreach ($patch in $patches) {
+        Set-Location -Path ${script:llamacppDir}
+        git apply $patch.FullName
+    }
+
     # Avoid duplicate main symbols when we link into the cgo binary
     $content = Get-Content -Path "${script:llamacppDir}/examples/server/server.cpp"
     $content = $content -replace 'int main\(', 'int __main('

+ 30 - 0
llm/patches/01-cache.diff

@@ -0,0 +1,30 @@
+diff --git a/examples/server/server.cpp b/examples/server/server.cpp
+index 0462fbd2..4fa7b57f 100644
+--- a/examples/server/server.cpp
++++ b/examples/server/server.cpp
+@@ -1857,12 +1857,6 @@ struct llama_server_context
+                         LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
+                     }
+ 
+-                    LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
+-
+-                    llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
+-
+-                    slot.cache_tokens = prompt_tokens;
+-
+                     if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
+                     {
+                         // we have to evaluate at least 1 token to generate logits.
+@@ -1870,6 +1864,12 @@ struct llama_server_context
+                         slot.n_past--;
+                     }
+ 
++                    LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
++
++                    llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
++
++                    slot.cache_tokens = prompt_tokens;
++
+                     LOG_VERBOSE("prompt ingested", {
+                                                     {"n_past", slot.n_past},
+                                                     {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},