06-embeddings.diff 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. diff --git a/src/llama.cpp b/src/llama.cpp
  2. index 1fe2b9f7..a43312a7 100644
  3. --- a/src/llama.cpp
  4. +++ b/src/llama.cpp
  5. @@ -13689,7 +13689,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
  6. const auto n_embd = hparams.n_embd;
  7. // TODO: use a per-batch flag for logits presence instead
  8. - const bool has_logits = !cparams.embeddings;
  9. + const bool has_logits = cparams.causal_attn || lctx.image_embeds;
  10. const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
  11. const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
  12. @@ -13959,17 +13959,25 @@ static int llama_decode_internal(
  13. // no output
  14. res = nullptr;
  15. embd = nullptr;
  16. - } else if (cparams.embeddings) {
  17. - res = nullptr; // do not extract logits for embedding case
  18. - embd = gf->nodes[gf->n_nodes - 1];
  19. - if (strcmp(embd->name, "result_embd_pooled") != 0) {
  20. - embd = gf->nodes[gf->n_nodes - 2];
  21. + }
  22. +
  23. + if (cparams.embeddings) {
  24. + for (int i = gf->n_nodes - 1; i >= 0; --i) {
  25. + embd = gf->nodes[i];
  26. + if (strcmp(embd->name, "result_embd_pooled") == 0) {
  27. + break;
  28. + }
  29. }
  30. GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
  31. - } else {
  32. + } else {
  33. embd = nullptr; // do not extract embeddings when not needed
  34. GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
  35. }
  36. +
  37. + if (!cparams.causal_attn && !has_image_embeds) {
  38. + res = nullptr; // do not extract logits when not needed
  39. + }
  40. +
  41. // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
  42. ggml_backend_sched_alloc_graph(lctx.sched, gf);