0006-embeddings.patch 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. From 235b6d876a74cb09abe26985fa89ebe5bfc9f562 Mon Sep 17 00:00:00 2001
  2. From: Gabe Goodhart <ghart@us.ibm.com>
  3. Date: Thu, 19 Sep 2024 17:06:17 -0600
  4. Subject: [PATCH] embeddings
  5. ---
  6. src/llama.cpp | 15 +++++++++------
  7. 1 file changed, 9 insertions(+), 6 deletions(-)
  8. diff --git a/src/llama.cpp b/src/llama.cpp
  9. index 1a8e0c51..e55ec3f8 100644
  10. --- a/src/llama.cpp
  11. +++ b/src/llama.cpp
  12. @@ -16516,7 +16516,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
  13. const auto n_embd = hparams.n_embd;
  14. // TODO: use a per-batch flag for logits presence instead
  15. - const bool has_logits = !cparams.embeddings;
  16. + const bool has_logits = cparams.causal_attn;
  17. const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
  18. const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
  19. @@ -16794,20 +16794,23 @@ static int llama_decode_internal(
  20. // no output
  21. res = nullptr;
  22. embd = nullptr;
  23. - } else if (cparams.embeddings) {
  24. - res = nullptr; // do not extract logits for embedding case
  25. - embd = nullptr;
  26. + }
  27. +
  28. + if (cparams.embeddings) {
  29. for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
  30. + embd = ggml_graph_node(gf, i);
  31. if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) {
  32. - embd = ggml_graph_node(gf, i);
  33. break;
  34. }
  35. }
  36. - GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
  37. } else {
  38. embd = nullptr; // do not extract embeddings when not needed
  39. GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
  40. }
  41. +
  42. + if (!cparams.causal_attn) {
  43. + res = nullptr; // do not extract logits when not needed
  44. + }
  45. // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
  46. ggml_backend_sched_alloc_graph(lctx.sched, gf);
  47. --
  48. 2.39.3 (Apple Git-146)