123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- diff --git a/src/llama.cpp b/src/llama.cpp
- index 1fe2b9f7..a43312a7 100644
- --- a/src/llama.cpp
- +++ b/src/llama.cpp
- @@ -13689,7 +13689,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
- const auto n_embd = hparams.n_embd;
-
- // TODO: use a per-batch flag for logits presence instead
- - const bool has_logits = !cparams.embeddings;
- + const bool has_logits = cparams.causal_attn || lctx.image_embeds;
- const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
-
- const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
- @@ -13959,17 +13959,25 @@ static int llama_decode_internal(
- // no output
- res = nullptr;
- embd = nullptr;
- - } else if (cparams.embeddings) {
- - res = nullptr; // do not extract logits for embedding case
- - embd = gf->nodes[gf->n_nodes - 1];
- - if (strcmp(embd->name, "result_embd_pooled") != 0) {
- - embd = gf->nodes[gf->n_nodes - 2];
- + }
- +
- + if (cparams.embeddings) {
- + for (int i = gf->n_nodes - 1; i >= 0; --i) {
- + embd = gf->nodes[i];
- + if (strcmp(embd->name, "result_embd_pooled") == 0) {
- + break;
- + }
- }
- GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
- - } else {
- + } else {
- embd = nullptr; // do not extract embeddings when not needed
- GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
- }
- +
- + if (!cparams.causal_attn && !has_image_embeds) {
- + res = nullptr; // do not extract logits when not needed
- + }
- +
- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
-
- ggml_backend_sched_alloc_graph(lctx.sched, gf);
|