Selaa lähdekoodia

fixed patches, llava

Josh Yan 8 kuukautta sitten
vanhempi
commit
e6802df906
2 muutettua tiedostoa jossa 70 lisäystä ja 40 poistoa
  1. 23 5
      llm/ext_server/server.cpp
  2. 47 35
      llm/patches/12-paligemma.diff

+ 23 - 5
llm/ext_server/server.cpp

@@ -1313,8 +1313,7 @@ struct llama_server_context
         return true;
     }
 
-    // for multiple images processing
-    bool ingest_images(server_slot &slot, int n_batch)
+    bool process_llava(server_slot &slot, int n_batch)
     {
         int image_idx = 0;
 
@@ -1391,6 +1390,20 @@ struct llama_server_context
         return true;
     }
 
+    // for multiple images processing based on model architecture
+    bool ingest_images(server_slot &slot, int n_batch)
+    {
+        switch (llama_get_architecture(model))
+        {
+        case 0:
+            return process_llava(slot, n_batch);
+        case 25:
+            return prepare_pali(slot, n_batch);
+        default:
+            return false;
+        }
+    }
+
     void request_cancel(int task_id)
     {
         task_server task;
@@ -1880,9 +1893,14 @@ struct llama_server_context
                         llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
                         slot_npast++;
                     }
-
-                    // if (has_images && !ingest_images(slot, n_batch))
-                    if (has_images && !prepare_pali(slot, n_batch))
+                    LOG_DEBUG("hi gpt params processing images", {
+                                                                     {"gpt_params.model", params.model.c_str()},
+                                                                     {"model alias", params.model_alias.c_str()},
+                                                                 });
+                    printf("gpt_params model is %s\n", params.model.c_str());
+                    printf("gpt_params model is %s\n", params.model.c_str());
+
+                    if (has_images && !ingest_images(slot, n_batch))
                     {
                         LOG_ERROR("failed processing images", {
                             {"slot_id", slot.id},

+ 47 - 35
llm/patches/12-paligemma.diff

@@ -1,78 +1,72 @@
 diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
-index 7cda5f10..50fbcf08 100644
+index 7cda5f10..671806fd 100644
 --- a/examples/llava/clip.cpp
 +++ b/examples/llava/clip.cpp
-@@ -709,9 +709,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
+@@ -708,11 +708,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
+         if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
              embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
              embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
- 
+-
 -            embeddings = ggml_gelu(ctx0, embeddings);
 -            embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
 -            embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
-+            // paligemma missing second linear layer
-+            if (model.mm_2_w) {
+-
++            if (model.mm_2_w)
++            {
 +                embeddings = ggml_gelu(ctx0, embeddings);
 +                embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
 +                embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
 +            }
- 
          } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
              embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
-@@ -2076,7 +2079,10 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
+             embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+@@ -2076,6 +2077,10 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
          return ctx->vision_model.mm_model_peg_0_b->ne[0];
      }
      if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
--        return ctx->vision_model.mm_2_b->ne[0];
-+        // paligemma missing second linear layer
-+        if (ctx->vision_model.mm_2_b == nullptr) {
++        if (ctx->vision_model.mm_2_b == nullptr)
++        {
 +            return ctx->vision_model.mm_0_b->ne[0];
 +        }
+         return ctx->vision_model.mm_2_b->ne[0];
      }
      if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
-         return ctx->vision_model.mm_3_b->ne[0];
 diff --git a/include/llama.h b/include/llama.h
-index f23355a6..7c6301bf 100644
+index f23355a6..e48da401 100644
 --- a/include/llama.h
 +++ b/include/llama.h
-@@ -444,6 +444,9 @@ extern "C" {
+@@ -444,6 +444,12 @@ extern "C" {
      // Frees all allocated memory
      LLAMA_API void llama_free(struct llama_context * ctx);
  
-+    // save image embeddings
++    // Sets image embeddings
 +    LLAMA_API void set_image_embeds(struct llama_context *ctx, float *data);
++
++    // Gets architecture
++    LLAMA_API int llama_get_architecture(struct llama_model *model);
 +
      LLAMA_API int64_t llama_time_us(void);
  
      LLAMA_API size_t llama_max_devices(void);
 diff --git a/src/llama.cpp b/src/llama.cpp
-index a7b1c9eb..b0a6bc27 100644
+index a7b1c9eb..ee067919 100644
 --- a/src/llama.cpp
 +++ b/src/llama.cpp
-@@ -2668,6 +2668,7 @@ struct llama_context {
+@@ -2710,6 +2710,8 @@ struct llama_context {
  
-     const struct llama_model & model;
+     bool logits_all = false;
  
 +    float *image_embeds = nullptr;
-     struct llama_cparams        cparams;
-     struct llama_sampling       sampling;
-     struct llama_kv_cache       kv_self;
-@@ -2751,6 +2752,10 @@ struct llama_context {
-     struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
- };
- 
-+void set_image_embeds(llama_context *ctx, float *data) {
-+    ctx->image_embeds = data;
-+}
 +
- struct llama_lora_weight {
-     struct ggml_tensor * a = nullptr;
-     struct ggml_tensor * b = nullptr;
-@@ -11599,6 +11604,15 @@ struct llm_build_context {
+     // embeddings output (2-dimensional array: [n_outputs][n_embd])
+     // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
+     size_t  embd_size = 0; // capacity (of floats) for embeddings
+@@ -11599,6 +11601,15 @@ struct llm_build_context {
  
          inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
  
-+        // set the image embeddings in the input tensor
-+        if (lctx.image_embeds) {
++        if (lctx.image_embeds)
++        {
 +            struct ggml_tensor *image_embeds = ggml_dup_tensor(ctx0, inpL);
 +            image_embeds->data = lctx.image_embeds;
 +            image_embeds->ne[1] = 256;
@@ -83,12 +77,30 @@ index a7b1c9eb..b0a6bc27 100644
          inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
          cb(inpL, "inp_scaled", -1);
  
-@@ -14589,7 +14603,7 @@ static int llama_decode_internal(
+@@ -14589,7 +14600,8 @@ static int llama_decode_internal(
          }
  
          // non-causal masks do not use the KV cache
 -        if (hparams.causal_attn) {
-+        if (hparams.causal_attn || lctx.image_embeds) {
++        if (hparams.causal_attn || lctx.image_embeds)
++        {
              llama_kv_cache_update(&lctx);
  
              // if we have enough unused cells before the current head ->
+@@ -16448,6 +16460,16 @@ void llama_free_model(struct llama_model * model) {
+     delete model;
+ }
+ 
++void set_image_embeds(llama_context *ctx, float *data)
++{
++    ctx->image_embeds = data;
++}
++
++int llama_get_architecture(llama_model *model)
++{
++    return model->arch;
++}
++
+ struct llama_context * llama_new_context_with_model(
+                  struct llama_model * model,
+         struct llama_context_params   params) {