0003-embeddings.patch 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
  2. From: Michael Yang <mxyng@pm.me>
  3. Date: Mon, 16 Sep 2024 15:53:14 -0700
  4. Subject: [PATCH] embeddings
  5. ---
  6. src/llama-context.cpp | 2 +-
  7. src/llama.cpp | 6 ++++--
  8. 2 files changed, 5 insertions(+), 3 deletions(-)
  9. diff --git a/src/llama-context.cpp b/src/llama-context.cpp
  10. index 671d2a81..47e79ed4 100644
  11. --- a/src/llama-context.cpp
  12. +++ b/src/llama-context.cpp
  13. @@ -479,7 +479,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
  14. const auto n_embd = hparams.n_embd;
  15. // TODO: use a per-batch flag for logits presence instead
  16. - const bool has_logits = !cparams.embeddings;
  17. + const bool has_logits = cparams.causal_attn;
  18. const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
  19. const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
  20. diff --git a/src/llama.cpp b/src/llama.cpp
  21. index 607f2786..ac85bfed 100644
  22. --- a/src/llama.cpp
  23. +++ b/src/llama.cpp
  24. @@ -8652,7 +8652,6 @@ static int llama_decode_impl(
  25. res = nullptr;
  26. embd = nullptr;
  27. } else if (cparams.embeddings) {
  28. - res = nullptr; // do not extract logits for embedding case
  29. embd = nullptr;
  30. for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
  31. if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) {
  32. @@ -8660,12 +8659,15 @@ static int llama_decode_impl(
  33. break;
  34. }
  35. }
  36. - GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
  37. } else {
  38. embd = nullptr; // do not extract embeddings when not needed
  39. GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
  40. }
  41. + if (!cparams.causal_attn) {
  42. + res = nullptr; // do not extract logits when not needed
  43. + }
  44. +
  45. // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
  46. ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);