06-embeddings.diff 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. diff --git a/src/llama.cpp b/src/llama.cpp
  2. index 88355971..d7db689b 100644
  3. --- a/src/llama.cpp
  4. +++ b/src/llama.cpp
  5. @@ -15906,7 +15906,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
  6. const auto n_embd = hparams.n_embd;
  7. // TODO: use a per-batch flag for logits presence instead
  8. - const bool has_logits = !cparams.embeddings;
  9. + const bool has_logits = cparams.causal_attn;
  10. const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
  11. const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
  12. @@ -16175,20 +16175,23 @@ static int llama_decode_internal(
  13. // no output
  14. res = nullptr;
  15. embd = nullptr;
  16. - } else if (cparams.embeddings) {
  17. - res = nullptr; // do not extract logits for embedding case
  18. - embd = nullptr;
  19. + }
  20. +
  21. + if (cparams.embeddings) {
  22. for (int i = gf->n_nodes - 1; i >= 0; --i) {
  23. - if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
  24. - embd = gf->nodes[i];
  25. + embd = gf->nodes[i];
  26. + if (strcmp(embd->name, "result_embd_pooled") == 0) {
  27. break;
  28. }
  29. }
  30. - GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
  31. } else {
  32. embd = nullptr; // do not extract embeddings when not needed
  33. GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
  34. }
  35. +
  36. + if (!cparams.causal_attn) {
  37. + res = nullptr; // do not extract logits when not needed
  38. + }
  39. // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
  40. ggml_backend_sched_alloc_graph(lctx.sched, gf);