|
@@ -1650,26 +1650,41 @@ struct llama_server_context
|
|
|
}
|
|
|
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
|
|
|
|
|
+ char buf[256];
|
|
|
+ llama_model_meta_val_str(model, "general.architecture", buf, 256);
|
|
|
+ bool gemma2 = strcmp(buf, "gemma2") == 0;
|
|
|
+
|
|
|
+ int32_t truncate_at = slot.n_ctx;
|
|
|
+
|
|
|
+ // truncate at 2/3 of the context length for gemma2 models
|
|
|
+ // as they do not support context shifts (from the sliding window implementation).
|
|
|
+ // this way, prompts that almost fit the context length can still generate a full
|
|
|
+ // response without a sudden stop from hitting the context limit
|
|
|
+ if (gemma2) {
|
|
|
+ truncate_at = 2 * slot.n_ctx / 3;
|
|
|
+ }
|
|
|
+
|
|
|
// if input prompt is too big, truncate it, if group attention self-extend is disabled
|
|
|
- if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
|
|
|
+ if (slot.ga_n == 1 && slot.n_prompt_tokens >= truncate_at)
|
|
|
{
|
|
|
const int n_left = slot.n_ctx - slot.params.n_keep;
|
|
|
- const int n_block_size = n_left / 2;
|
|
|
- const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
|
|
+ const int n_shift = n_left / 2;
|
|
|
+ const int n_erase = slot.n_prompt_tokens - slot.params.n_keep - n_shift;
|
|
|
|
|
|
std::vector<llama_token> new_tokens(
|
|
|
prompt_tokens.begin(),
|
|
|
prompt_tokens.begin() + slot.params.n_keep);
|
|
|
new_tokens.insert(
|
|
|
new_tokens.end(),
|
|
|
- prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
|
|
+ prompt_tokens.begin() + slot.params.n_keep + n_erase,
|
|
|
prompt_tokens.end());
|
|
|
|
|
|
- LOG_VERBOSE("input truncated", {
|
|
|
- {"n_ctx", slot.n_ctx},
|
|
|
- {"n_keep", slot.params.n_keep},
|
|
|
- {"n_left", n_left},
|
|
|
- {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
|
|
+ LOG_INFO("input truncated", {
|
|
|
+ {"n_ctx", slot.n_ctx},
|
|
|
+ {"n_keep", slot.params.n_keep},
|
|
|
+ {"n_left", n_left},
|
|
|
+ {"n_shift", n_shift},
|
|
|
+ {"n_erase", n_erase},
|
|
|
});
|
|
|
slot.truncated = true;
|
|
|
prompt_tokens = new_tokens;
|
|
@@ -1678,6 +1693,19 @@ struct llama_server_context
|
|
|
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
|
|
}
|
|
|
|
|
|
+ // Models with sliding window attention do not work with context shifts, so
|
|
|
+ // limit their prediction to the context length
|
|
|
+ if (gemma2) {
|
|
|
+ int32_t limit = slot.n_ctx - slot.n_prompt_tokens;
|
|
|
+ slot.n_predict = limit;
|
|
|
+ slot.params.n_predict = limit;
|
|
|
+ LOG_INFO("model does not support sliding window, limiting generation", {
|
|
|
+ {"n_ctx", slot.n_ctx},
|
|
|
+ {"n_prompt_tokens", slot.n_prompt_tokens},
|
|
|
+ {"n_predict", slot.n_predict}
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
if (!slot.params.cache_prompt)
|
|
|
{
|
|
|
llama_sampling_reset(slot.ctx_sampling);
|