Browse Source

llm: add server entrypoint for mllama

jmorganca 7 months ago
parent
commit
d0c8ce5ea4
2 changed files with 28 additions and 74 deletions
  1. 12 0
      llm/ext_server/server.cpp
  2. 16 74
      llm/patches/0009-mllama.patch

+ 12 - 0
llm/ext_server/server.cpp

@@ -1032,6 +1032,18 @@ struct llama_server_context
 
     bool process_images(server_slot &slot) const
     {
+        // Set cross attention state for mllama models
+        // TODO (jmorganca): this should be provided via the API
+        // TODO (jmorganca): generalize this beyond mllama models
+        char arch_str[256];
+        llama_model_meta_val_str(model, "general.architecture", arch_str, 256);
+        if (strcmp(arch_str, "mllama") == 0) {
+            // TODO (jmorganca): this should be passed in via the llama_decode api
+            // or similar, maybe using the llama_batch struct
+            // llama_reset_cross_attn_state(ctx);
+            // llama_set_cross_attn_state(ctx, (float*)cross_attn_state);
+        }
+
         for (slot_image &img : slot.images)
         {
             if (!img.request_encode_image)

+ 16 - 74
llm/patches/0009-mllama.patch

@@ -1,4 +1,4 @@
-From c2db1ad0fc86de189959b628021a970511e9c6f9 Mon Sep 17 00:00:00 2001
+From 9935fbbf26ad4d9ca7735ec6ba4c0a206c0c8329 Mon Sep 17 00:00:00 2001
 From: jmorganca <jmorganca@gmail.com>
 Date: Tue, 24 Sep 2024 11:53:40 -0700
 Subject: [PATCH] add mllama support
@@ -13,8 +13,8 @@ kv cache once per run
 remaining is to implement the cross attention mask
 ---
  include/llama.h |   5 +
- src/llama.cpp   | 514 ++++++++++++++++++++++++++++++++++++++++++++++--
- 2 files changed, 499 insertions(+), 20 deletions(-)
+ src/llama.cpp   | 470 ++++++++++++++++++++++++++++++++++++++++++++++--
+ 2 files changed, 461 insertions(+), 14 deletions(-)
 
 diff --git a/include/llama.h b/include/llama.h
 index bfc37e88..94ce82a4 100644
@@ -33,7 +33,7 @@ index bfc37e88..94ce82a4 100644
      LLAMA_API void llama_free(struct llama_context * ctx);
  
 diff --git a/src/llama.cpp b/src/llama.cpp
-index b7771f53..75bbc226 100644
+index b7771f53..72a57a38 100644
 --- a/src/llama.cpp
 +++ b/src/llama.cpp
 @@ -170,6 +170,7 @@ static std::string format(const char * fmt, ...) {
@@ -193,25 +193,6 @@ index b7771f53..75bbc226 100644
  };
  
  // very similar to llama_batch,
-@@ -2684,12 +2749,12 @@ struct llama_ubatch {
-     uint32_t n_seq_tokens; // tokens per sequence
-     uint32_t n_seqs;
- 
--    llama_token  *  token;    // [n_tokens]
--    float        *  embd;     // [n_embd, n_tokens]
--    llama_pos    *  pos;      // [n_tokens]
--    int32_t      *  n_seq_id; // [n_seqs]
--    llama_seq_id ** seq_id;   // [n_seqs]
--    int8_t       *  output;   // [n_tokens]
-+    llama_token  *  token;             // [n_tokens]
-+    float        *  embd;              // [n_embd, n_tokens]
-+    llama_pos    *  pos;               // [n_tokens]
-+    int32_t      *  n_seq_id;          // [n_seqs]
-+    llama_seq_id ** seq_id;            // [n_seqs]
-+    int8_t       *  output;            // [n_tokens]
- };
- 
- struct llama_kv_cell {
 @@ -3268,6 +3333,10 @@ struct llama_context {
      // host buffer for the model output (logits and embeddings)
      ggml_backend_buffer_t buf_output = nullptr;
@@ -404,48 +385,7 @@ index b7771f53..75bbc226 100644
  
      // note: storing RoPE-ed version of K in the KV cache
      ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
-@@ -9625,6 +9788,40 @@ static struct ggml_tensor * llm_build_rwkv6_channel_mix(
-     return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
- }
- 
-+
-+static void show_tensor(std::string name, ggml_tensor *t) {
-+    LLAMA_LOG_INFO("%s [%lld, %lld]\n", name.c_str(), t->ne[0], t->ne[1]);
-+
-+    int cols = int(t->ne[0]);
-+    int rows = int(t->ne[1]);
-+
-+    for(int r=0; r<3; r++) {
-+        for(int c=0; c<3; c++) {
-+            float v = ggml_get_f32_nd(t, c, r, 0, 0);
-+            LLAMA_LOG_INFO("%11.8f ", v);
-+        }
-+        LLAMA_LOG_INFO("... ");
-+        for(int c=0; c<3; c++) {
-+            float v = ggml_get_f32_nd(t, cols-3+c, r, 0, 0);
-+            LLAMA_LOG_INFO("%11.8f ", v);
-+        }
-+        LLAMA_LOG_INFO("\n");
-+    }
-+    LLAMA_LOG_INFO(" ...\n");
-+    for(int r=0; r<3; r++) {
-+        for(int c=0; c<3; c++) {
-+            float v = ggml_get_f32_nd(t, c, rows-3+r, 0, 0);
-+            LLAMA_LOG_INFO("%11.8f ", v);
-+        }
-+        LLAMA_LOG_INFO("... ");
-+        for(int c=0; c<3; c++) {
-+            float v = ggml_get_f32_nd(t, cols-3+c, rows-3+r, 0, 0);
-+            LLAMA_LOG_INFO("%11.8f ", v);
-+        }
-+        LLAMA_LOG_INFO("\n");
-+    }
-+}
-+
- struct llm_build_context {
-     const llama_model    & model;
-           llama_context  & lctx;
-@@ -9743,6 +9940,7 @@ struct llm_build_context {
+@@ -9743,6 +9906,7 @@ struct llm_build_context {
          lctx.inp_pos_bucket    = nullptr;
          lctx.inp_embd_enc      = nullptr;
          lctx.inp_KQ_mask_cross = nullptr;
@@ -453,7 +393,7 @@ index b7771f53..75bbc226 100644
      }
  
      void free() {
-@@ -10158,6 +10356,253 @@ struct llm_build_context {
+@@ -10158,6 +10322,253 @@ struct llm_build_context {
                  LLM_NORM_RMS, cb, -1);
          cb(cur, "result_norm", -1);
  
@@ -707,7 +647,7 @@ index b7771f53..75bbc226 100644
          // lm_head
          cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
          cb(cur, "result_output", -1);
-@@ -15493,6 +15938,10 @@ static struct ggml_cgraph * llama_build_graph(
+@@ -15493,6 +15904,10 @@ static struct ggml_cgraph * llama_build_graph(
              {
                  result = llm.build_llama();
              } break;
@@ -718,7 +658,7 @@ index b7771f53..75bbc226 100644
          case LLM_ARCH_BAICHUAN:
              {
                  result = llm.build_baichuan();
-@@ -15736,7 +16185,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
+@@ -15736,7 +16151,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
  
      if (batch.token) {
          const int64_t n_tokens = batch.n_tokens;
@@ -726,7 +666,7 @@ index b7771f53..75bbc226 100644
          ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
      }
  
-@@ -16123,6 +16571,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
+@@ -16123,6 +16537,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
              }
          }
      }
@@ -734,13 +674,15 @@ index b7771f53..75bbc226 100644
 +    // TODO (jmorganca): this might copy a lot of data on every request of a
 +    // single generation even though it doesn't change, so we should
 +    // find a way to not set this more than one time per image
-+    if (lctx.cross_attn_state && lctx.inp_cross_attn_state->buffer) {
++    if (lctx.cross_attn_state &&
++        lctx.inp_cross_attn_state &&
++        lctx.inp_cross_attn_state->buffer) {
 +        ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state));
 +    }
  }
  
  // Make sure enough space is available for outputs.
-@@ -16430,6 +16885,10 @@ static int llama_decode_internal(
+@@ -16430,6 +16853,10 @@ static int llama_decode_internal(
  
          llama_set_inputs(lctx, ubatch);
  
@@ -751,7 +693,7 @@ index b7771f53..75bbc226 100644
          llama_graph_compute(lctx, gf, n_threads, threadpool);
  
          // update the kv ring buffer
-@@ -17586,7 +18045,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
+@@ -17586,7 +18013,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
          if (llama_model_has_encoder(&model)) {
              n_attn_layer *= 3;
          }
@@ -762,7 +704,7 @@ index b7771f53..75bbc226 100644
      }
  
      size_t total_size_org = 0;
-@@ -18681,6 +19142,18 @@ struct llama_context * llama_new_context_with_model(
+@@ -18681,6 +19110,18 @@ struct llama_context * llama_new_context_with_model(
      return ctx;
  }
  
@@ -781,7 +723,7 @@ index b7771f53..75bbc226 100644
  void llama_free(struct llama_context * ctx) {
      delete ctx;
  }
-@@ -18731,6 +19204,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
+@@ -18731,6 +19172,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
  
          // use what we call a normal RoPE, operating on pairs of consecutive head values
          case LLM_ARCH_LLAMA: