瀏覽代碼

Do not shift context for sliding window models (#5368)

* Do not shift context for sliding window models

* truncate prompt > 2/3 tokens

* only target gemma2
Jeffrey Morgan 10 月之前
父節點
當前提交
1f4f46800c
共有 1 個文件被更改,包括 37 次插入9 次删除
  1. 37 9
      llm/ext_server/server.cpp

+ 37 - 9
llm/ext_server/server.cpp

@@ -1650,26 +1650,41 @@ struct llama_server_context
                     }
                     slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
 
+                    char buf[256];
+                    llama_model_meta_val_str(model, "general.architecture", buf, 256);
+                    bool gemma2 = strcmp(buf, "gemma2") == 0;
+
+                    int32_t truncate_at = slot.n_ctx;
+
+                    // truncate at 2/3 of the context length for gemma2 models
+                    // as they do not support context shifts (from the sliding window implementation).
+                    // this way, prompts that almost fit the context length can still generate a full
+                    // response without a sudden stop from hitting the context limit
+                    if (gemma2) {
+                        truncate_at = 2 * slot.n_ctx / 3;
+                    }
+
                     // if input prompt is too big, truncate it, if group attention self-extend is disabled
-                    if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
+                    if (slot.ga_n == 1 && slot.n_prompt_tokens >= truncate_at)
                     {
                         const int n_left = slot.n_ctx - slot.params.n_keep;
-                        const int n_block_size = n_left / 2;
-                        const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
+                        const int n_shift = n_left / 2;
+                        const int n_erase = slot.n_prompt_tokens - slot.params.n_keep - n_shift;
 
                         std::vector<llama_token> new_tokens(
                             prompt_tokens.begin(),
                             prompt_tokens.begin() + slot.params.n_keep);
                         new_tokens.insert(
                             new_tokens.end(),
-                            prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
+                            prompt_tokens.begin() + slot.params.n_keep + n_erase,
                             prompt_tokens.end());
 
-                        LOG_VERBOSE("input truncated", {
-                            {"n_ctx",      slot.n_ctx},
-                            {"n_keep",     slot.params.n_keep},
-                            {"n_left",     n_left},
-                            {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
+                        LOG_INFO("input truncated", {
+                            {"n_ctx",        slot.n_ctx},
+                            {"n_keep",       slot.params.n_keep},
+                            {"n_left",       n_left},
+                            {"n_shift",      n_shift},
+                            {"n_erase",      n_erase},
                         });
                         slot.truncated = true;
                         prompt_tokens = new_tokens;
@@ -1678,6 +1693,19 @@ struct llama_server_context
                         GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
                     }
 
+                    // Models with sliding window attention do not work with context shifts, so
+                    // limit their prediction to the context length
+                    if (gemma2) {
+                        int32_t limit = slot.n_ctx - slot.n_prompt_tokens;
+                        slot.n_predict = limit;
+                        slot.params.n_predict = limit;
+                        LOG_INFO("model does not support sliding window, limiting generation", {
+                            {"n_ctx", slot.n_ctx},
+                            {"n_prompt_tokens", slot.n_prompt_tokens},
+                            {"n_predict", slot.n_predict}
+                        });
+                    }
+
                     if (!slot.params.cache_prompt)
                     {
                         llama_sampling_reset(slot.ctx_sampling);