浏览代码

Fix assert on small embedding inputs (#5491)

* Fix assert on small embedding inputs

* Update llm/patches/09-pooling.diff
Jeffrey Morgan 10 月之前
父节点
当前提交
e9188e971a
共有 1 个文件被更改,包括 60 次插入0 次删除
  1. 60 0
      llm/patches/09-pooling.diff

+ 60 - 0
llm/patches/09-pooling.diff

@@ -0,0 +1,60 @@
+diff --git a/llama.cpp b/llama.cpp
+index 61948751..61fe7b57 100644
+--- a/llama.cpp
++++ b/llama.cpp
+@@ -7591,14 +7591,14 @@ struct llm_build_context {
+     }
+ 
+     struct ggml_tensor * build_inp_mean() {
+-        lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
++        lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, cparams.n_seq_max);
+         cb(lctx.inp_mean, "inp_mean", -1);
+         ggml_set_input(lctx.inp_mean);
+         return lctx.inp_mean;
+     }
+ 
+     struct ggml_tensor * build_inp_cls() {
+-        lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
++        lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_seq_max);
+         cb(lctx.inp_cls, "inp_cls", -1);
+         ggml_set_input(lctx.inp_cls);
+         return lctx.inp_cls;
+@@ -12062,19 +12062,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
+         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
+ 
+         float * data = (float *) lctx.inp_mean->data;
+-        memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
++        memset(lctx.inp_mean->data, 0, n_tokens * cparams.n_seq_max * ggml_element_size(lctx.inp_mean));
+ 
+         std::vector<uint64_t> sum(n_tokens, 0);
+         for (int i = 0; i < n_tokens; ++i) {
+             const llama_seq_id seq_id = batch.seq_id[i][0];
+-
+-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
+-
+             sum[seq_id] += 1;
+         }
+ 
+-        std::vector<float> div(n_tokens, 0.0f);
+-        for (int i = 0; i < n_tokens; ++i) {
++        std::vector<float> div(cparams.n_seq_max, 0.0f);
++        for (uint32_t i = 0; i < cparams.n_seq_max; ++i) {
+             const uint64_t s = sum[i];
+             if (s > 0) {
+                 div[i] = 1.0f/float(s);
+@@ -12094,14 +12091,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
+         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
+ 
+         uint32_t * data = (uint32_t *) lctx.inp_cls->data;
+-        memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
++        memset(lctx.inp_cls->data, 0, cparams.n_seq_max * ggml_element_size(lctx.inp_cls));
+ 
+         for (int i = 0; i < n_tokens; ++i) {
+             const llama_seq_id seq_id = batch.seq_id[i][0];
+             const llama_pos    pos    = batch.pos[i];
+-
+-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
+-
+             if (pos == 0) {
+                 data[seq_id] = i;
+             }