Browse Source

Update sync with latest llama.cpp layout, and run against b3485

Daniel Hiltgen 9 months ago
parent
commit
41bf8d9932
100 changed files with 5278 additions and 1576 deletions
  1. 1 1
      llama/build-info.cpp
  2. 24 15
      llama/clip.cpp
  3. 1 1
      llama/clip.h
  4. 140 368
      llama/common.cpp
  5. 64 20
      llama/common.h
  6. 2219 0
      llama/ggml-aarch64.c
  7. 65 0
      llama/ggml-aarch64.h
  8. 97 46
      llama/ggml-alloc.c
  9. 1 1
      llama/ggml-alloc.h
  10. 21 9
      llama/ggml-backend-impl.h
  11. 303 163
      llama/ggml-backend.c
  12. 22 17
      llama/ggml-backend.h
  13. 37 9
      llama/ggml-common.h
  14. 195 151
      llama/ggml-cuda.cu
  15. 4 1
      llama/ggml-cuda.h
  16. 1 1
      llama/ggml-cuda/acc.cu
  17. 1 1
      llama/ggml-cuda/acc.cuh
  18. 1 79
      llama/ggml-cuda/alibi.cu
  19. 1 53
      llama/ggml-cuda/alibi.cuh
  20. 1 1
      llama/ggml-cuda/arange.cu
  21. 1 1
      llama/ggml-cuda/arange.cuh
  22. 3 2
      llama/ggml-cuda/argsort.cu
  23. 1 1
      llama/ggml-cuda/argsort.cuh
  24. 2 2
      llama/ggml-cuda/binbcast.cu
  25. 1 1
      llama/ggml-cuda/binbcast.cuh
  26. 1 1
      llama/ggml-cuda/clamp.cu
  27. 1 1
      llama/ggml-cuda/clamp.cuh
  28. 273 67
      llama/ggml-cuda/common.cuh
  29. 1 1
      llama/ggml-cuda/concat.cu
  30. 1 1
      llama/ggml-cuda/concat.cuh
  31. 113 0
      llama/ggml-cuda/conv-transpose-1d.cu
  32. 31 0
      llama/ggml-cuda/conv-transpose-1d.cuh
  33. 1 1
      llama/ggml-cuda/convert.cu
  34. 1 1
      llama/ggml-cuda/convert.cuh
  35. 3 4
      llama/ggml-cuda/cpy.cu
  36. 1 1
      llama/ggml-cuda/cpy.cuh
  37. 1 1
      llama/ggml-cuda/dequantize.cuh
  38. 1 1
      llama/ggml-cuda/diagmask.cu
  39. 1 1
      llama/ggml-cuda/diagmask.cuh
  40. 2 2
      llama/ggml-cuda/dmmv.cu
  41. 1 1
      llama/ggml-cuda/dmmv.cuh
  42. 30 70
      llama/ggml-cuda/fattn-common.cuh
  43. 3 3
      llama/ggml-cuda/fattn-tile-f16.cu
  44. 1 1
      llama/ggml-cuda/fattn-tile-f16.cuh
  45. 2 2
      llama/ggml-cuda/fattn-tile-f32.cu
  46. 1 1
      llama/ggml-cuda/fattn-tile-f32.cuh
  47. 2 2
      llama/ggml-cuda/fattn-vec-f16.cuh
  48. 2 2
      llama/ggml-cuda/fattn-vec-f32.cuh
  49. 4 4
      llama/ggml-cuda/fattn-wmma-f16.cuh
  50. 6 6
      llama/ggml-cuda/fattn.cu
  51. 1 1
      llama/ggml-cuda/fattn.cuh
  52. 2 3
      llama/ggml-cuda/getrows.cu
  53. 1 1
      llama/ggml-cuda/getrows.cuh
  54. 1 1
      llama/ggml-cuda/im2col.cu
  55. 1 1
      llama/ggml-cuda/im2col.cuh
  56. 247 0
      llama/ggml-cuda/mma.cuh
  57. 79 16
      llama/ggml-cuda/mmq.cu
  58. 1033 364
      llama/ggml-cuda/mmq.cuh
  59. 21 15
      llama/ggml-cuda/mmvq.cu
  60. 3 1
      llama/ggml-cuda/mmvq.cuh
  61. 1 1
      llama/ggml-cuda/norm.cu
  62. 1 1
      llama/ggml-cuda/norm.cuh
  63. 1 1
      llama/ggml-cuda/pad.cu
  64. 1 1
      llama/ggml-cuda/pad.cuh
  65. 1 1
      llama/ggml-cuda/pool2d.cu
  66. 1 1
      llama/ggml-cuda/pool2d.cuh
  67. 135 11
      llama/ggml-cuda/quantize.cu
  68. 22 3
      llama/ggml-cuda/quantize.cuh
  69. 3 3
      llama/ggml-cuda/rope.cu
  70. 1 1
      llama/ggml-cuda/rope.cuh
  71. 1 1
      llama/ggml-cuda/scale.cu
  72. 1 1
      llama/ggml-cuda/scale.cuh
  73. 2 1
      llama/ggml-cuda/softmax.cu
  74. 1 1
      llama/ggml-cuda/softmax.cuh
  75. 1 1
      llama/ggml-cuda/sumrows.cu
  76. 1 1
      llama/ggml-cuda/sumrows.cuh
  77. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu
  78. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu
  79. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu
  80. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu
  81. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu
  82. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu
  83. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu
  84. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu
  85. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu
  86. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu
  87. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu
  88. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu
  89. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu
  90. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu
  91. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu
  92. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu
  93. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu
  94. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu
  95. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu
  96. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu
  97. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu
  98. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu
  99. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu
  100. 1 1
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu

+ 1 - 1
llama/build-info.cpp

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 24 - 15
llama/clip.cpp

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -42,6 +42,10 @@
 #include "ggml-metal.h"
 #endif
 
+#ifdef GGML_USE_CANN
+#include "ggml-cann.h"
+#endif
+
 #define STB_IMAGE_IMPLEMENTATION
 #include "stb_image.h"
 
@@ -891,7 +895,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
             embeddings = peg_0;
         }
         else {
-            GGML_ASSERT(false);
+            GGML_ABORT("fatal error");
         }
     }
 
@@ -1027,6 +1031,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
     LOG_TEE("%s: CLIP using Metal backend\n", __func__);
 #endif
 
+#ifdef GGML_USE_CANN
+    new_clip->backend = ggml_backend_cann_init(0);
+    LOG_TEE("%s: CLIP using CANN backend\n", __func__);
+#endif
+
 
     if (!new_clip->backend) {
         new_clip->backend = ggml_backend_cpu_init();
@@ -1147,20 +1156,20 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
             }
             if (n < 32)
                 hparams.image_grid_pinpoints[n] = 0;
-        } catch (std::runtime_error & e) {
+        } catch (std::runtime_error & /*e*/) {
             hparams.image_grid_pinpoints[0]=0;
         }
 
         try {
             int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE);
             strcpy(hparams.mm_patch_merge_type, gguf_get_val_str(ctx, idx));
-        } catch (std::runtime_error & e) {
+        } catch (std::runtime_error & /*e*/) {
             strcpy(hparams.mm_patch_merge_type, "flat");
         }
 
         try {
             hparams.image_crop_resolution = get_u32(ctx, KEY_IMAGE_CROP_RESOLUTION); // llava-1.6
-        } catch(const std::exception& e) {
+        } catch(const std::exception& /*e*/) {
             hparams.image_crop_resolution = hparams.image_size;
         }
 
@@ -1199,7 +1208,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
         try {
             vision_model.class_embedding  = get_tensor(new_clip->ctx_data, TN_CLASS_EMBD);
             new_clip->has_class_embedding = true;
-        } catch (const std::exception& e) {
+        } catch (const std::exception& /*e*/) {
             new_clip->has_class_embedding = false;
         }
 
@@ -1207,7 +1216,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
             vision_model.pre_ln_w  = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight"));
             vision_model.pre_ln_b  = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias"));
             new_clip->has_pre_norm = true;
-        } catch (std::exception & e) {
+        } catch (std::exception & /*e*/) {
             new_clip->has_pre_norm = false;
         }
 
@@ -1215,21 +1224,21 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
             vision_model.post_ln_w  = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight"));
             vision_model.post_ln_b  = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "bias"));
             new_clip->has_post_norm = true;
-        } catch (std::exception & e) {
+        } catch (std::exception & /*e*/) {
             new_clip->has_post_norm = false;
         }
 
         try {
             vision_model.patch_bias = get_tensor(new_clip->ctx_data, TN_PATCH_BIAS);
             new_clip->has_patch_bias = true;
-        } catch (std::exception & e) {
+        } catch (std::exception & /*e*/) {
             new_clip->has_patch_bias = false;
         }
 
         try {
             vision_model.patch_embeddings    = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
             vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
-        } catch(const std::exception& e) {
+        } catch(const std::exception& /*e*/) {
             LOG_TEE("%s: failed to load vision model tensors\n", __func__);
         }
 
@@ -1241,26 +1250,26 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
                 // Yi-type llava
                 vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 1, "weight"));
                 vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 1, "bias"));
-            } catch (std::runtime_error & e) {  }
+            } catch (std::runtime_error & /*e*/) { }
             try {
                 // missing in Yi-type llava
                 vision_model.mm_2_w              = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
                 vision_model.mm_2_b              = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
-            } catch (std::runtime_error & e) {  }
+            } catch (std::runtime_error & /*e*/) { }
             try {
                 // Yi-type llava
                 vision_model.mm_3_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 3, "weight"));
                 vision_model.mm_3_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 3, "bias"));
-            } catch (std::runtime_error & e) {  }
+            } catch (std::runtime_error & /*e*/) { }
             try {
                 // Yi-type llava
                 vision_model.mm_4_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "weight"));
                 vision_model.mm_4_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "bias"));
-            } catch (std::runtime_error & e) {  }
+            } catch (std::runtime_error & /*e*/) { }
             try {
                 vision_model.image_newline = get_tensor(new_clip->ctx_data, TN_IMAGE_NEWLINE);
                 // LOG_TEE("%s: image_newline tensor (llava-1.6) found\n", __func__);
-            } catch (std::runtime_error & e) {  }
+            } catch (std::runtime_error & /*e*/) { }
         } else if (new_clip->proj_type == PROJECTOR_TYPE_LDP) {
             // MobileVLM projection
             vision_model.mm_model_mlp_1_w               = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 1, "weight"));

+ 1 - 1
llama/clip.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

File diff suppressed because it is too large
+ 140 - 368
llama/common.cpp


+ 64 - 20
llama/common.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -78,6 +78,12 @@ int32_t cpu_get_num_math();
 // CLI argument parsing
 //
 
+// dimensionality reduction methods, used by cvector-generator
+enum dimre_method {
+    DIMRE_METHOD_PCA,
+    DIMRE_METHOD_MEAN,
+};
+
 struct gpt_params {
     uint32_t seed                 = LLAMA_DEFAULT_SEED; // RNG seed
 
@@ -99,7 +105,6 @@ struct gpt_params {
     int32_t n_gpu_layers_draft    =    -1; // number of layers to store in VRAM for the draft model (-1 - use default)
     int32_t main_gpu              =     0; // the GPU that is used for scratch and small tensors
     float   tensor_split[128]     =   {0}; // how split tensors should be distributed across GPUs
-    int32_t n_beams               =     0; // if non-zero then use beam search of given width.
     int32_t grp_attn_n            =     1; // group-attention factor
     int32_t grp_attn_w            =   512; // group-attention width
     int32_t n_print               =    -1; // print token count every n tokens (-1 = disabled)
@@ -120,6 +125,7 @@ struct gpt_params {
     enum llama_split_mode        split_mode        = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
     enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
     enum llama_pooling_type      pooling_type      = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
+    enum llama_attention_type    attention_type    = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
 
     // // sampling parameters
     struct llama_sampling_params sparams;
@@ -128,6 +134,7 @@ struct gpt_params {
     std::string model_draft          = ""; // draft model for speculative decoding
     std::string model_alias          = "unknown"; // model alias
     std::string model_url            = ""; // model url to download
+    std::string hf_token             = ""; // HF token
     std::string hf_repo              = ""; // HF repo
     std::string hf_file              = ""; // HF file
     std::string prompt               = "";
@@ -147,7 +154,6 @@ struct gpt_params {
 
     // TODO: avoid tuple, use struct
     std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
-    std::string lora_base  = "";                              // base model path for the lora adapter
 
     std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale
 
@@ -179,7 +185,6 @@ struct gpt_params {
     bool prompt_cache_all  = false; // save user input and generations to prompt cache
     bool prompt_cache_ro   = false; // open the prompt cache read-only and do not update it
 
-    bool embedding         = false; // get only sentence embedding
     bool escape            = true;  // escape "\n", "\r", "\t", "\'", "\"", and "\\"
     bool multiline_input   = false; // reverse the usage of `\`
     bool simple_io         = false; // improves compatibility with subprocesses and limited consoles
@@ -206,6 +211,12 @@ struct gpt_params {
     std::string mmproj = "";        // path to multimodal projector
     std::vector<std::string> image; // path to image file(s)
 
+    // embedding
+    bool embedding         = false; // get only sentence embedding
+    int32_t embd_normalize = 2;     // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
+    std::string embd_out   = "";    // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
+    std::string embd_sep   = "\n";  // separator of embendings
+
     // server params
     int32_t port           = 8080;         // server listens on this network port
     int32_t timeout_read   = 600;          // http read timeout in seconds
@@ -216,6 +227,7 @@ struct gpt_params {
     std::string public_path   = "";
     std::string chat_template = "";
     std::string system_prompt = "";
+    bool enable_chat_template = true;
 
     std::vector<std::string> api_keys;
 
@@ -229,6 +241,8 @@ struct gpt_params {
 
     std::string slot_save_path;
 
+    float slot_prompt_similarity = 0.5f;
+
     // batched-bench params
     bool is_pp_shared = false;
 
@@ -256,8 +270,21 @@ struct gpt_params {
 
     bool process_output = false; // collect data for the output tensor
     bool compute_ppl    = true;  // whether to compute perplexity
+
+    // cvector-generator params
+    int n_pca_batch = 100;
+    int n_pca_iterations = 1000;
+    dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
+    std::string cvector_outfile       = "control_vector.gguf";
+    std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
+    std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
+
+    bool spm_infill = false; // suffix/prefix/middle pattern for infill
+
+    std::string lora_outfile = "ggml-lora-merged-f16.gguf";
 };
 
+void gpt_params_handle_hf_token(gpt_params & params);
 void gpt_params_handle_model_default(gpt_params & params);
 
 bool gpt_params_parse_ex   (int argc, char ** argv, gpt_params & params);
@@ -301,6 +328,7 @@ bool fs_validate_filename(const std::string & filename);
 bool fs_create_directory_with_parents(const std::string & path);
 
 std::string fs_get_cache_directory();
+std::string fs_get_cache_file(const std::string & filename);
 
 //
 // Model utils
@@ -312,8 +340,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
 struct llama_model_params   llama_model_params_from_gpt_params  (const gpt_params & params);
 struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
 
-struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params);
-struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params);
+struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params);
+struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params);
 
 // Batch utils
 
@@ -351,21 +379,13 @@ std::string llama_token_to_piece(
                        llama_token   token,
                        bool          special = true);
 
-// TODO: these should be moved in llama.h C-style API under single `llama_detokenize` function
-//       that takes into account the tokenizer type and decides how to handle the leading space
-//
-// detokenizes a vector of tokens into a string
-// should work similar to Python's `tokenizer.decode`
-// removes the leading space from the first non-BOS token
-std::string llama_detokenize_spm(
-                         llama_context * ctx,
-        const std::vector<llama_token> & tokens);
-
 // detokenizes a vector of tokens into a string
 // should work similar to Python's `tokenizer.decode`
-std::string llama_detokenize_bpe(
+// optionally renders special/control tokens
+std::string llama_detokenize(
                          llama_context * ctx,
-        const std::vector<llama_token> & tokens);
+        const std::vector<llama_token> & tokens,
+                                  bool   special = true);
 
 // Uses the value from the model metadata if possible, otherwise
 // defaults to true when model type is SPM, otherwise false.
@@ -375,9 +395,34 @@ bool llama_should_add_bos_token(const llama_model * model);
 // Chat template utils
 //
 
+// same with llama_chat_message, but uses std::string
+struct llama_chat_msg {
+    std::string role;
+    std::string content;
+};
+
 // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
 bool llama_chat_verify_template(const std::string & tmpl);
 
+// CPP wrapper for llama_chat_apply_template
+// If the built-in template is not supported, we default to chatml
+// If the custom "tmpl" is not supported, we throw an error
+std::string llama_chat_apply_template(const struct llama_model * model,
+        const std::string & tmpl,
+        const std::vector<llama_chat_msg> & chat,
+        bool add_ass);
+
+// Format single message, while taking into account the position of that message in chat history
+std::string llama_chat_format_single(const struct llama_model * model,
+        const std::string & tmpl,
+        const std::vector<llama_chat_msg> & past_msg,
+        const llama_chat_msg & new_msg,
+        bool add_ass);
+
+// Returns an example of formatted chat
+std::string llama_chat_format_example(const struct llama_model * model,
+        const std::string & tmpl);
+
 //
 // KV cache utils
 //
@@ -392,7 +437,7 @@ void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_siz
 // Embedding utils
 //
 
-void llama_embd_normalize(const float * inp, float * out, int n);
+void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
 
 float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n);
 
@@ -436,4 +481,3 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha
 void yaml_dump_non_result_info(
     FILE * stream, const gpt_params & params, const llama_context * lctx,
     const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
-

+ 2219 - 0
llama/ggml-aarch64.c

@@ -0,0 +1,2219 @@
+/**
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+#include "ggml-quants.h"
+#include "ggml-impl.h"
+
+#include <math.h>
+#include <string.h>
+#include <assert.h>
+#include <float.h>
+#include <stdlib.h> // for qsort
+#include <stdio.h>  // for GGML_ASSERT
+
+#include "ggml-aarch64.h"
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic ignored "-Woverlength-strings"
+#endif
+
+#define UNUSED GGML_UNUSED
+
+// Functions to create the interleaved data layout formats
+
+// interleave 4 block_q4_0s in blocks of blck_size_interleave
+// returns an interleaved block_q4_0x4
+// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks
+// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave
+//
+// - in                  : an array of block_q4_0 pointers
+// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of
+//                         blck_size_interleave bytes
+// - xor_mask            : the mask to convert the nibbles in block_q4_0 quants bytes
+//                         from bias offset form to pure sign form (this saves subtract
+//                         operations durin unpacking)
+//
+static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
+    block_q4_0x4 out;
+
+    for (int i = 0; i < 4; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    for (int i = 0; i < QK4_0 * 2; i++) {
+        int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave;
+        int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave;
+        src_offset += (i % blck_size_interleave);
+
+        out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
+    }
+
+    return out;
+}
+
+// interleave 8 block_q4_0s in blocks of blck_size_interleave
+// returns an interleaved block_q4_0x8
+// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
+// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
+static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
+    block_q4_0x8 out;
+
+    for (int i = 0; i < 8; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    for (int i = 0; i < QK4_0 * 4; i++) {
+        int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave;
+        int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave;
+        src_offset += (i % blck_size_interleave);
+
+        out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
+    }
+
+    return out;
+}
+
+void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(QK8_0 == 32);
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t srcv[4][8];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
+            for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
+
+            for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
+            for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
+            for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
+
+            const float amax = vmaxvq_f32(amaxv[0]);
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < 8; j++) {
+            float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]);
+            int32x4_t vi = vcvtnq_s32_f32(v);
+            y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[1][j], id[1]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0);
+            y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1);
+            y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2);
+            y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[2][j], id[2]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0);
+            y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1);
+            y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2);
+            y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[3][j], id[3]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0);
+            y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1);
+            y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2);
+            y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3);
+        }
+    }
+#else
+    // scalar
+    const int blck_size_interleave = 4;
+    float srcv[4][QK8_0];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            float amax = 0.0f; // absolute max
+
+            for (int j = 0; j < QK8_0; j++) {
+                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
+                amax = MAX(amax, fabsf(srcv[row_iter][j]));
+            }
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < QK8_0 * 4; j++) {
+            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
+            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
+            src_offset += (j % blck_size_interleave);
+
+            float x0 = srcv[src_id][src_offset] * id[src_id];
+            y[i].qs[j] = roundf(x0);
+        }
+    }
+#endif
+}
+
+void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(QK8_0 == 32);
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t srcv[4][8];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
+            for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
+
+            for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
+            for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
+            for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
+
+            const float amax = vmaxvq_f32(amaxv[0]);
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < 4; j++) {
+            float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]);
+            int32x4_t vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3);
+            v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[1][2 * j], id[1]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3);
+            v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[2][2 * j], id[2]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3);
+            v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[3][2 * j], id[3]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3);
+            v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3);
+        }
+    }
+#else
+    // scalar
+    const int blck_size_interleave = 8;
+    float srcv[4][QK8_0];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            float amax = 0.0f; // absolute max
+
+            for (int j = 0; j < QK8_0; j++) {
+                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
+                amax = MAX(amax, fabsf(srcv[row_iter][j]));
+            }
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < QK8_0 * 4; j++) {
+            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
+            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
+            src_offset += (j % blck_size_interleave);
+
+            float x0 = srcv[src_id][src_offset] * id[src_id];
+            y[i].qs[j] = roundf(x0);
+        }
+    }
+#endif
+}
+
+void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
+    assert(nrow == 4);
+    UNUSED(nrow);
+    if (blck_size_interleave == 4) {
+        quantize_q8_0_4x4(x, vy, n_per_row);
+    } else if (blck_size_interleave == 8) {
+        quantize_q8_0_4x8(x, vy, n_per_row);
+    } else {
+        assert(false);
+    }
+}
+
+static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) {
+    assert(n_per_row % QK4_0 == 0);
+    const int nb = n_per_row / QK4_0;
+
+    void * out_ptr = NULL;
+    if (nrows_interleaved == 8) {
+        out_ptr = (block_q4_0x8 *) dst;
+    }
+    else if (nrows_interleaved == 4) {
+        out_ptr = (block_q4_0x4 *) dst;
+    }
+    assert(nrows_interleaved <= 8);
+    block_q4_0 dst_tmp[8];
+
+    for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) {
+
+        for (int64_t x = 0; x < nb; x++) {
+
+            for (int i  = 0; i < nrows_interleaved; i++ ) {
+                quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0);
+            }
+
+            if (nrows_interleaved == 8) {
+                *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88);
+                out_ptr = (block_q4_0x8 *) out_ptr + 1;
+            }
+            else if (nrows_interleaved == 4) {
+                *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88);
+                out_ptr = (block_q4_0x4 *) out_ptr + 1;
+            }
+        }
+    }
+
+    return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0));
+}
+
+size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    if (!quant_weights) {
+        return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4);
+    }
+    else {
+        assert(false);
+        return 0;
+    }
+}
+
+size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    if (!quant_weights) {
+        return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8);
+    }
+    else {
+        assert(false);
+        return 0;
+    }
+}
+
+size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    if (!quant_weights) {
+        return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
+    }
+    else {
+        assert(false);
+        return 0;
+    }
+}
+
+void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE)
+    if (svcntw() == 8) {
+        GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) &&
+                    "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
+    }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) &&
+                "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance");
+#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+    const void * b_ptr = vx;
+    const void * a_ptr = vy;
+    float * res_ptr = s;
+
+    __asm__ __volatile__(
+        "movi v31.16b, #0x4\n"
+        "movi v30.16b, #0xf0\n"
+        "add %x[b_ptr], %x[b_ptr], #0x8\n"
+        "1:"  // Column loop
+        "add x22, %x[a_ptr], #0x2\n"
+        "movi v29.16b, #0x0\n"
+        "mov x21, %x[nb]\n"
+        "2:"  // Block loop
+        "ldr q28, [%x[b_ptr], #0x0]\n"
+        "ldr q27, [x22, #0x0]\n"
+        "movi v26.4s, #0x0\n"
+        "sub x20, x22, #0x2\n"
+        "ldr q25, [x22, #0x10]\n"
+        "ldr q24, [%x[b_ptr], #0x10]\n"
+        "sub x21, x21, #0x1\n"
+        "add x22, x22, #0x22\n"
+        "ldr q23, [%x[b_ptr], #0x20]\n"
+        "ldr q22, [%x[b_ptr], #0x30]\n"
+        "ld1r { v21.8h }, [x20]\n"
+        "ldr q20, [%x[b_ptr], #-0x8]\n"
+        "sshl v16.16b, v28.16b, v31.16b\n"
+        "and v28.16b, v28.16b, v30.16b\n"
+        "sshl v19.16b, v24.16b, v31.16b\n"
+        "and v24.16b, v24.16b, v30.16b\n"
+        "add %x[b_ptr], %x[b_ptr], #0x48\n"
+        "sshl v18.16b, v23.16b, v31.16b\n"
+        "and v23.16b, v23.16b, v30.16b\n"
+        ".inst 0x4f9be21a  // sdot v26.4s, v16.16b, v27.4b[0]\n"
+        "sshl v17.16b, v22.16b, v31.16b\n"
+        "and v22.16b, v22.16b, v30.16b\n"
+        "fcvtl v21.4s, v21.4h\n"
+        "fcvtl v16.4s, v20.4h\n"
+        ".inst 0x4f99e39a  // sdot v26.4s, v28.16b, v25.4b[0]\n"
+        "fmul v16.4s, v16.4s, v21.4s\n"
+        ".inst 0x4fbbe27a  // sdot v26.4s, v19.16b, v27.4b[1]\n"
+        ".inst 0x4fb9e31a  // sdot v26.4s, v24.16b, v25.4b[1]\n"
+        ".inst 0x4f9bea5a  // sdot v26.4s, v18.16b, v27.4b[2]\n"
+        ".inst 0x4f99eafa  // sdot v26.4s, v23.16b, v25.4b[2]\n"
+        ".inst 0x4fbbea3a  // sdot v26.4s, v17.16b, v27.4b[3]\n"
+        ".inst 0x4fb9eada  // sdot v26.4s, v22.16b, v25.4b[3]\n"
+        "scvtf v26.4s, v26.4s, #0x4\n"
+        "fmla v29.4s, v26.4s, v16.4s\n"
+        "cbnz x21, 2b\n"
+        "sub %x[nc], %x[nc], #0x4\n"
+        "str q29, [%x[res_ptr], #0x0]\n"
+        "add %x[res_ptr], %x[res_ptr], #0x10\n"
+        "cbnz %x[nc], 1b\n"
+        : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
+        : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
+        : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"
+    );
+#else
+    float sumf[4];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+                    }
+                    sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+#endif
+}
+
+void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE)
+    if (svcntw() == 8) {
+        GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) &&
+                    "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
+    }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+    const void * b_ptr = vx;
+    const void * a_ptr = vy;
+    float * res_ptr = s;
+
+    __asm__ __volatile__(
+        "movi v2.16b, #0x4\n"
+        "movi v1.16b, #0xf0\n"
+        "add %x[b_ptr], %x[b_ptr], #0x8\n"
+        "1:"  // Column loop
+        "add x23, %x[a_ptr], #0x2\n"
+        "movi v0.16b, #0x0\n"
+        "mov x22, %x[nb]\n"
+        "2:"  // Block loop
+        "ldr q31, [%x[b_ptr], #0x0]\n"
+        "ldr q30, [%x[b_ptr], #0x10]\n"
+        "mov x21, x23\n"
+        "movi v29.4s, #0x0\n"
+        "ldr q28, [%x[b_ptr], #0x20]\n"
+        "ldr q27, [%x[b_ptr], #0x30]\n"
+        "movi v26.4s, #0x0\n"
+        "sub x20, x23, #0x2\n"
+        "ld1r { v25.8h }, [x20]\n"
+        "ldr q24, [%x[b_ptr], #-0x8]\n"
+        "sub x22, x22, #0x1\n"
+        "add x23, x23, #0x22\n"
+        "ld1r { v23.2d }, [x21], #0x8\n"
+        "sshl v22.16b, v31.16b, v2.16b\n"
+        "sshl v16.16b, v30.16b, v2.16b\n"
+        "add %x[b_ptr], %x[b_ptr], #0x48\n"
+        "ld1r { v21.2d }, [x21], #0x8\n"
+        "sshl v20.16b, v28.16b, v2.16b\n"
+        "sshl v19.16b, v27.16b, v2.16b\n"
+        "ld1r { v18.2d }, [x21], #0x8\n"
+        "ld1r { v17.2d }, [x21], #0x8\n"
+        "and v31.16b, v31.16b, v1.16b\n"
+        "and v30.16b, v30.16b, v1.16b\n"
+        ".inst 0x4e9796dd  // sdot v29.4s, v22.16b, v23.16b\n"
+        ".inst 0x4e97961a  // sdot v26.4s, v16.16b, v23.16b\n"
+        "and v28.16b, v28.16b, v1.16b\n"
+        "and v27.16b, v27.16b, v1.16b\n"
+        "fcvtl v25.4s, v25.4h\n"
+        "fcvtl v16.4s, v24.4h\n"
+        ".inst 0x4e95969d  // sdot v29.4s, v20.16b, v21.16b\n"
+        ".inst 0x4e95967a  // sdot v26.4s, v19.16b, v21.16b\n"
+        "fmul v16.4s, v16.4s, v25.4s\n"
+        ".inst 0x4e9297fd  // sdot v29.4s, v31.16b, v18.16b\n"
+        ".inst 0x4e9297da  // sdot v26.4s, v30.16b, v18.16b\n"
+        ".inst 0x4e91979d  // sdot v29.4s, v28.16b, v17.16b\n"
+        ".inst 0x4e91977a  // sdot v26.4s, v27.16b, v17.16b\n"
+        "addp v29.4s, v29.4s, v26.4s\n"
+        "scvtf v29.4s, v29.4s, #0x4\n"
+        "fmla v0.4s, v29.4s, v16.4s\n"
+        "cbnz x22, 2b\n"
+        "sub %x[nc], %x[nc], #0x4\n"
+        "str q0, [%x[res_ptr], #0x0]\n"
+        "add %x[res_ptr], %x[res_ptr], #0x10\n"
+        "cbnz %x[nc], 1b\n"
+        : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
+        : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
+        : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
+    );
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+    GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
+                "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
+                "performance");
+#else
+    float sumf[4];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+                    }
+                    sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+#endif
+}
+
+void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+    if (svcntw() == 8) {
+        const void * b_ptr = vx;
+        const void * a_ptr = vy;
+        float * res_ptr = s;
+
+        __asm__ __volatile__(
+            "ptrue p0.b\n"
+            "add %x[b_ptr], %x[b_ptr], #0x10\n"
+            "1:"  // Column loop
+            "add x22, %x[a_ptr], #0x2\n"
+            "mov z31.b, #0x0\n"
+            "mov x21, %x[nb]\n"
+            "2:"  // Block loop
+            "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n"
+            "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n"
+            "mov z28.s, #0x0\n"
+            "mov z27.s, #0x0\n"
+            "ld1rd { z26.d }, p0/Z, [x22]\n"
+            "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n"
+            "sub x20, x22, #0x2\n"
+            "sub x21, x21, #0x1\n"
+            "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n"
+            "ld1rd { z23.d }, p0/Z, [x22, #8]\n"
+            "lsl z22.b, z30.b, #0x4\n"
+            "lsl z16.b, z29.b, #0x4\n"
+            "and z30.b, z30.b, #0xf0\n"
+            "and z29.b, z29.b, #0xf0\n"
+            "ld1rd { z21.d }, p0/Z, [x22, #16]\n"
+            "ld1rd { z20.d }, p0/Z, [x22, #24]\n"
+            "lsl z19.b, z25.b, #0x4\n"
+            "and z25.b, z25.b, #0xf0\n"
+            "ld1rh { z17.h }, p0/Z, [x20]\n"
+            "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n"
+            "sdot z28.s, z22.b, z26.b\n"
+            "sdot z27.s, z16.b, z26.b\n"
+            "lsl z16.b, z24.b, #0x4\n"
+            "add x22, x22, #0x22\n"
+            "and z24.b, z24.b, #0xf0\n"
+            "add %x[b_ptr], %x[b_ptr], #0x90\n"
+            "fcvt z17.s, p0/m, z17.h\n"
+            "fcvt z18.s, p0/m, z18.h\n"
+            "sdot z28.s, z19.b, z23.b\n"
+            "sdot z27.s, z16.b, z23.b\n"
+            "fmul z18.s, z18.s, z17.s\n"
+            "sdot z28.s, z30.b, z21.b\n"
+            "sdot z27.s, z29.b, z21.b\n"
+            "sdot z28.s, z25.b, z20.b\n"
+            "sdot z27.s, z24.b, z20.b\n"
+            "uzp1 z17.s, z28.s, z27.s\n"
+            "uzp2 z16.s, z28.s, z27.s\n"
+            "add z17.s, z17.s, z16.s\n"
+            "asr z17.s, z17.s, #0x4\n"
+            "scvtf z17.s, p0/m, z17.s\n"
+            "fmla z31.s, p0/M, z17.s, z18.s\n"
+            "cbnz x21, 2b\n"
+            "sub %x[nc], %x[nc], #0x8\n"
+            "st1w { z31.s }, p0, [%x[res_ptr]]\n"
+            "add %x[res_ptr], %x[res_ptr], #0x20\n"
+            "cbnz %x[nc], 1b\n"
+            : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
+            : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
+            : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+        );
+        return;
+    }
+    else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+        GGML_ASSERT((ggml_cpu_has_sve() && (svcntw() == 8)) &&
+                    "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal "
+                    "performance");
+    }
+    else if (ggml_cpu_has_neon()) {
+        GGML_ASSERT(((ggml_cpu_has_sve() && (svcntw() == 8)) || ggml_cpu_has_matmul_int8()) &&
+                    "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 "
+                    "quantization format for optimal performance");
+    }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    GGML_ASSERT(ggml_cpu_has_sve() &&
+                "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance");
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+    GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
+                "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
+                "performance");
+#else
+    float sumf[8];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+                    }
+                    sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+#endif
+}
+
+void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+    if (svcntw() == 8) {
+        GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) &&
+                    "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
+    }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) &&
+                "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance");
+#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+    const void * b_ptr = vx;
+    const void * a_ptr = vy;
+    float * res_ptr = s;
+    size_t res_stride = bs * sizeof(float);
+
+    __asm__ __volatile__(
+        "mov x10, %x[nr]\n"
+        "mov x9, #0x88\n"
+        "cmp x10, #0x10\n"
+        "mul x9, %x[nb], x9\n"
+        "blt 4f\n"
+        "1:"  // Row loop
+        "add x28, %x[b_ptr], #0x8\n"
+        "mov x27, %x[nc]\n"
+        "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
+        "2:"  // Column loop
+        "add x25, %x[a_ptr], #0x8\n"
+        "movi v15.16b, #0x0\n"
+        "movi v19.16b, #0x0\n"
+        "mov x24, %x[nb]\n"
+        "add x23, x25, x9\n"
+        "movi v18.16b, #0x0\n"
+        "movi v14.16b, #0x0\n"
+        "add x22, x23, x9\n"
+        "movi v11.16b, #0x0\n"
+        "movi v13.16b, #0x0\n"
+        "add x21, x22, x9\n"
+        "movi v23.16b, #0x0\n"
+        "movi v16.16b, #0x0\n"
+        "movi v25.16b, #0x0\n"
+        "movi v7.16b, #0x0\n"
+        "movi v0.16b, #0x0\n"
+        "movi v4.16b, #0x0\n"
+        "movi v5.16b, #0x0\n"
+        "movi v21.16b, #0x0\n"
+        "movi v8.16b, #0x0\n"
+        "movi v1.16b, #0x0\n"
+        "3:"  // Block loop
+        "ldr q3, [x28, #0x0]\n"
+        "ldr q31, [x25, #0x0]\n"
+        "movi v28.16b, #0x4\n"
+        "movi v10.4s, #0x0\n"
+        "ldr q22, [x28, #0x10]\n"
+        "ldr q6, [x25, #0x10]\n"
+        "movi v29.4s, #0x0\n"
+        "movi v9.4s, #0x0\n"
+        "ldr q27, [x28, #0x20]\n"
+        "ldr q30, [x28, #0x30]\n"
+        "movi v20.4s, #0x0\n"
+        "movi v24.16b, #0xf0\n"
+        "ldr d2, [x25, #-0x8]\n"
+        "ldr d26, [x23, #-0x8]\n"
+        "sshl v12.16b, v3.16b, v28.16b\n"
+        "sub x20, x28, #0x8\n"
+        "ldr d17, [x20, #0x0]\n"
+        "and v3.16b, v3.16b, v24.16b\n"
+        "subs x24, x24, #0x1\n"
+        "add x28, x28, #0x48\n"
+        ".inst 0x4f9fe18a  // sdot v10.4s, v12.16b, v31.4b[0]\n"
+        ".inst 0x4fbfe19d  // sdot v29.4s, v12.16b, v31.4b[1]\n"
+        ".inst 0x4f9fe989  // sdot v9.4s, v12.16b, v31.4b[2]\n"
+        ".inst 0x4fbfe994  // sdot v20.4s, v12.16b, v31.4b[3]\n"
+        "sshl v31.16b, v22.16b, v28.16b\n"
+        "and v22.16b, v22.16b, v24.16b\n"
+        "fcvtl v17.4s, v17.4h\n"
+        "fcvtl v2.4s, v2.4h\n"
+        "fcvtl v26.4s, v26.4h\n"
+        ".inst 0x4f86e3ea  // sdot v10.4s, v31.16b, v6.4b[0]\n"
+        ".inst 0x4fa6e3fd  // sdot v29.4s, v31.16b, v6.4b[1]\n"
+        ".inst 0x4f86ebe9  // sdot v9.4s, v31.16b, v6.4b[2]\n"
+        ".inst 0x4fa6ebf4  // sdot v20.4s, v31.16b, v6.4b[3]\n"
+        "sshl v6.16b, v27.16b, v28.16b\n"
+        "sshl v28.16b, v30.16b, v28.16b\n"
+        "and v27.16b, v27.16b, v24.16b\n"
+        "and v30.16b, v30.16b, v24.16b\n"
+        "ldr q24, [x25, #0x20]\n"
+        ".inst 0x4f98e0ca  // sdot v10.4s, v6.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\n"
+        ".inst 0x4f98e8c9  // sdot v9.4s, v6.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e8d4  // sdot v20.4s, v6.16b, v24.4b[3]\n"
+        "ldr q24, [x25, #0x30]\n"
+        ".inst 0x4f98e38a  // sdot v10.4s, v28.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e39d  // sdot v29.4s, v28.16b, v24.4b[1]\n"
+        ".inst 0x4f98eb89  // sdot v9.4s, v28.16b, v24.4b[2]\n"
+        ".inst 0x4fb8eb94  // sdot v20.4s, v28.16b, v24.4b[3]\n"
+        "ldr q24, [x25, #0x40]\n"
+        ".inst 0x4f98e06a  // sdot v10.4s, v3.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\n"
+        ".inst 0x4f98e869  // sdot v9.4s, v3.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e874  // sdot v20.4s, v3.16b, v24.4b[3]\n"
+        "ldr q24, [x25, #0x50]\n"
+        ".inst 0x4f98e2ca  // sdot v10.4s, v22.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e2dd  // sdot v29.4s, v22.16b, v24.4b[1]\n"
+        ".inst 0x4f98eac9  // sdot v9.4s, v22.16b, v24.4b[2]\n"
+        ".inst 0x4fb8ead4  // sdot v20.4s, v22.16b, v24.4b[3]\n"
+        "ldr q24, [x25, #0x60]\n"
+        ".inst 0x4f98e36a  // sdot v10.4s, v27.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\n"
+        ".inst 0x4f98eb69  // sdot v9.4s, v27.16b, v24.4b[2]\n"
+        ".inst 0x4fb8eb74  // sdot v20.4s, v27.16b, v24.4b[3]\n"
+        "ldr q24, [x25, #0x70]\n"
+        "add x25, x25, #0x88\n"
+        ".inst 0x4f98e3ca  // sdot v10.4s, v30.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e3dd  // sdot v29.4s, v30.16b, v24.4b[1]\n"
+        ".inst 0x4f98ebc9  // sdot v9.4s, v30.16b, v24.4b[2]\n"
+        ".inst 0x4fb8ebd4  // sdot v20.4s, v30.16b, v24.4b[3]\n"
+        "fmul v24.4s, v17.4s, v2.s[0]\n"
+        "scvtf v10.4s, v10.4s, #0x4\n"
+        "scvtf v29.4s, v29.4s, #0x4\n"
+        "scvtf v9.4s, v9.4s, #0x4\n"
+        "scvtf v20.4s, v20.4s, #0x4\n"
+        "fmla v15.4s, v10.4s, v24.4s\n"
+        "ldr q24, [x23, #0x0]\n"
+        "fmul v10.4s, v17.4s, v2.s[1]\n"
+        "fmla v19.4s, v29.4s, v10.4s\n"
+        "ldr q10, [x23, #0x10]\n"
+        "fmul v29.4s, v17.4s, v2.s[2]\n"
+        "fmul v2.4s, v17.4s, v2.s[3]\n"
+        "fmla v18.4s, v9.4s, v29.4s\n"
+        "movi v9.4s, #0x0\n"
+        "movi v29.4s, #0x0\n"
+        ".inst 0x4f98e189  // sdot v9.4s, v12.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e19d  // sdot v29.4s, v12.16b, v24.4b[1]\n"
+        "fmla v14.4s, v20.4s, v2.4s\n"
+        "movi v20.4s, #0x0\n"
+        "movi v2.4s, #0x0\n"
+        ".inst 0x4f98e994  // sdot v20.4s, v12.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e982  // sdot v2.4s, v12.16b, v24.4b[3]\n"
+        "ldr q24, [x23, #0x20]\n"
+        ".inst 0x4f8ae3e9  // sdot v9.4s, v31.16b, v10.4b[0]\n"
+        ".inst 0x4faae3fd  // sdot v29.4s, v31.16b, v10.4b[1]\n"
+        ".inst 0x4f8aebf4  // sdot v20.4s, v31.16b, v10.4b[2]\n"
+        ".inst 0x4faaebe2  // sdot v2.4s, v31.16b, v10.4b[3]\n"
+        "ldr q10, [x23, #0x30]\n"
+        ".inst 0x4f98e0c9  // sdot v9.4s, v6.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\n"
+        ".inst 0x4f98e8d4  // sdot v20.4s, v6.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e8c2  // sdot v2.4s, v6.16b, v24.4b[3]\n"
+        "ldr q24, [x23, #0x40]\n"
+        ".inst 0x4f8ae389  // sdot v9.4s, v28.16b, v10.4b[0]\n"
+        ".inst 0x4faae39d  // sdot v29.4s, v28.16b, v10.4b[1]\n"
+        ".inst 0x4f8aeb94  // sdot v20.4s, v28.16b, v10.4b[2]\n"
+        ".inst 0x4faaeb82  // sdot v2.4s, v28.16b, v10.4b[3]\n"
+        "ldr q10, [x23, #0x50]\n"
+        ".inst 0x4f98e069  // sdot v9.4s, v3.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\n"
+        ".inst 0x4f98e874  // sdot v20.4s, v3.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e862  // sdot v2.4s, v3.16b, v24.4b[3]\n"
+        "ldr q24, [x23, #0x60]\n"
+        ".inst 0x4f8ae2c9  // sdot v9.4s, v22.16b, v10.4b[0]\n"
+        ".inst 0x4faae2dd  // sdot v29.4s, v22.16b, v10.4b[1]\n"
+        ".inst 0x4f8aead4  // sdot v20.4s, v22.16b, v10.4b[2]\n"
+        ".inst 0x4faaeac2  // sdot v2.4s, v22.16b, v10.4b[3]\n"
+        "ldr q10, [x23, #0x70]\n"
+        "add x23, x23, #0x88\n"
+        ".inst 0x4f98e369  // sdot v9.4s, v27.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\n"
+        ".inst 0x4f98eb74  // sdot v20.4s, v27.16b, v24.4b[2]\n"
+        ".inst 0x4fb8eb62  // sdot v2.4s, v27.16b, v24.4b[3]\n"
+        "ldr q24, [x22, #0x0]\n"
+        ".inst 0x4f8ae3c9  // sdot v9.4s, v30.16b, v10.4b[0]\n"
+        ".inst 0x4faae3dd  // sdot v29.4s, v30.16b, v10.4b[1]\n"
+        ".inst 0x4f8aebd4  // sdot v20.4s, v30.16b, v10.4b[2]\n"
+        ".inst 0x4faaebc2  // sdot v2.4s, v30.16b, v10.4b[3]\n"
+        "fmul v10.4s, v17.4s, v26.s[0]\n"
+        "scvtf v9.4s, v9.4s, #0x4\n"
+        "scvtf v29.4s, v29.4s, #0x4\n"
+        "scvtf v20.4s, v20.4s, #0x4\n"
+        "scvtf v2.4s, v2.4s, #0x4\n"
+        "fmla v11.4s, v9.4s, v10.4s\n"
+        "ldr q9, [x22, #0x10]\n"
+        "fmul v10.4s, v17.4s, v26.s[1]\n"
+        "fmla v13.4s, v29.4s, v10.4s\n"
+        "ldr d29, [x22, #-0x8]\n"
+        "fmul v10.4s, v17.4s, v26.s[2]\n"
+        "fmul v26.4s, v17.4s, v26.s[3]\n"
+        "fcvtl v29.4s, v29.4h\n"
+        "fmla v23.4s, v20.4s, v10.4s\n"
+        "movi v20.4s, #0x0\n"
+        "movi v10.4s, #0x0\n"
+        "fmla v16.4s, v2.4s, v26.4s\n"
+        "movi v26.4s, #0x0\n"
+        "movi v2.4s, #0x0\n"
+        ".inst 0x4f98e194  // sdot v20.4s, v12.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e18a  // sdot v10.4s, v12.16b, v24.4b[1]\n"
+        ".inst 0x4f98e99a  // sdot v26.4s, v12.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e982  // sdot v2.4s, v12.16b, v24.4b[3]\n"
+        "ldr q24, [x22, #0x20]\n"
+        ".inst 0x4f89e3f4  // sdot v20.4s, v31.16b, v9.4b[0]\n"
+        ".inst 0x4fa9e3ea  // sdot v10.4s, v31.16b, v9.4b[1]\n"
+        ".inst 0x4f89ebfa  // sdot v26.4s, v31.16b, v9.4b[2]\n"
+        ".inst 0x4fa9ebe2  // sdot v2.4s, v31.16b, v9.4b[3]\n"
+        "ldr q9, [x22, #0x30]\n"
+        ".inst 0x4f98e0d4  // sdot v20.4s, v6.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e0ca  // sdot v10.4s, v6.16b, v24.4b[1]\n"
+        ".inst 0x4f98e8da  // sdot v26.4s, v6.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e8c2  // sdot v2.4s, v6.16b, v24.4b[3]\n"
+        "ldr q24, [x22, #0x40]\n"
+        ".inst 0x4f89e394  // sdot v20.4s, v28.16b, v9.4b[0]\n"
+        ".inst 0x4fa9e38a  // sdot v10.4s, v28.16b, v9.4b[1]\n"
+        ".inst 0x4f89eb9a  // sdot v26.4s, v28.16b, v9.4b[2]\n"
+        ".inst 0x4fa9eb82  // sdot v2.4s, v28.16b, v9.4b[3]\n"
+        "ldr q9, [x22, #0x50]\n"
+        ".inst 0x4f98e074  // sdot v20.4s, v3.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e06a  // sdot v10.4s, v3.16b, v24.4b[1]\n"
+        ".inst 0x4f98e87a  // sdot v26.4s, v3.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e862  // sdot v2.4s, v3.16b, v24.4b[3]\n"
+        "ldr q24, [x22, #0x60]\n"
+        ".inst 0x4f89e2d4  // sdot v20.4s, v22.16b, v9.4b[0]\n"
+        ".inst 0x4fa9e2ca  // sdot v10.4s, v22.16b, v9.4b[1]\n"
+        ".inst 0x4f89eada  // sdot v26.4s, v22.16b, v9.4b[2]\n"
+        ".inst 0x4fa9eac2  // sdot v2.4s, v22.16b, v9.4b[3]\n"
+        "ldr q9, [x22, #0x70]\n"
+        "add x22, x22, #0x88\n"
+        ".inst 0x4f98e374  // sdot v20.4s, v27.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e36a  // sdot v10.4s, v27.16b, v24.4b[1]\n"
+        ".inst 0x4f98eb7a  // sdot v26.4s, v27.16b, v24.4b[2]\n"
+        ".inst 0x4fb8eb62  // sdot v2.4s, v27.16b, v24.4b[3]\n"
+        "ldr q24, [x21, #0x0]\n"
+        ".inst 0x4f89e3d4  // sdot v20.4s, v30.16b, v9.4b[0]\n"
+        ".inst 0x4fa9e3ca  // sdot v10.4s, v30.16b, v9.4b[1]\n"
+        ".inst 0x4f89ebda  // sdot v26.4s, v30.16b, v9.4b[2]\n"
+        ".inst 0x4fa9ebc2  // sdot v2.4s, v30.16b, v9.4b[3]\n"
+        "fmul v9.4s, v17.4s, v29.s[0]\n"
+        "scvtf v20.4s, v20.4s, #0x4\n"
+        "scvtf v10.4s, v10.4s, #0x4\n"
+        "scvtf v26.4s, v26.4s, #0x4\n"
+        "scvtf v2.4s, v2.4s, #0x4\n"
+        "fmla v25.4s, v20.4s, v9.4s\n"
+        "ldr q9, [x21, #0x10]\n"
+        "fmul v20.4s, v17.4s, v29.s[1]\n"
+        "fmla v7.4s, v10.4s, v20.4s\n"
+        "ldr d20, [x21, #-0x8]\n"
+        "fmul v10.4s, v17.4s, v29.s[2]\n"
+        "fmul v29.4s, v17.4s, v29.s[3]\n"
+        "fcvtl v20.4s, v20.4h\n"
+        "fmla v0.4s, v26.4s, v10.4s\n"
+        "movi v26.4s, #0x0\n"
+        "movi v10.4s, #0x0\n"
+        "fmla v4.4s, v2.4s, v29.4s\n"
+        "movi v2.4s, #0x0\n"
+        "movi v29.4s, #0x0\n"
+        ".inst 0x4f98e19a  // sdot v26.4s, v12.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e18a  // sdot v10.4s, v12.16b, v24.4b[1]\n"
+        ".inst 0x4f98e982  // sdot v2.4s, v12.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e99d  // sdot v29.4s, v12.16b, v24.4b[3]\n"
+        "ldr q12, [x21, #0x20]\n"
+        "fmul v24.4s, v17.4s, v20.s[0]\n"
+        ".inst 0x4f89e3fa  // sdot v26.4s, v31.16b, v9.4b[0]\n"
+        ".inst 0x4fa9e3ea  // sdot v10.4s, v31.16b, v9.4b[1]\n"
+        ".inst 0x4f89ebe2  // sdot v2.4s, v31.16b, v9.4b[2]\n"
+        ".inst 0x4fa9ebfd  // sdot v29.4s, v31.16b, v9.4b[3]\n"
+        "ldr q9, [x21, #0x30]\n"
+        "fmul v31.4s, v17.4s, v20.s[1]\n"
+        ".inst 0x4f8ce0da  // sdot v26.4s, v6.16b, v12.4b[0]\n"
+        ".inst 0x4face0ca  // sdot v10.4s, v6.16b, v12.4b[1]\n"
+        ".inst 0x4f8ce8c2  // sdot v2.4s, v6.16b, v12.4b[2]\n"
+        ".inst 0x4face8dd  // sdot v29.4s, v6.16b, v12.4b[3]\n"
+        "ldr q12, [x21, #0x40]\n"
+        "fmul v6.4s, v17.4s, v20.s[2]\n"
+        "fmul v20.4s, v17.4s, v20.s[3]\n"
+        ".inst 0x4f89e39a  // sdot v26.4s, v28.16b, v9.4b[0]\n"
+        ".inst 0x4fa9e38a  // sdot v10.4s, v28.16b, v9.4b[1]\n"
+        ".inst 0x4f89eb82  // sdot v2.4s, v28.16b, v9.4b[2]\n"
+        ".inst 0x4fa9eb9d  // sdot v29.4s, v28.16b, v9.4b[3]\n"
+        "ldr q9, [x21, #0x50]\n"
+        ".inst 0x4f8ce07a  // sdot v26.4s, v3.16b, v12.4b[0]\n"
+        ".inst 0x4face06a  // sdot v10.4s, v3.16b, v12.4b[1]\n"
+        ".inst 0x4f8ce862  // sdot v2.4s, v3.16b, v12.4b[2]\n"
+        ".inst 0x4face87d  // sdot v29.4s, v3.16b, v12.4b[3]\n"
+        "ldr q12, [x21, #0x60]\n"
+        ".inst 0x4f89e2da  // sdot v26.4s, v22.16b, v9.4b[0]\n"
+        ".inst 0x4fa9e2ca  // sdot v10.4s, v22.16b, v9.4b[1]\n"
+        ".inst 0x4f89eac2  // sdot v2.4s, v22.16b, v9.4b[2]\n"
+        ".inst 0x4fa9eadd  // sdot v29.4s, v22.16b, v9.4b[3]\n"
+        "ldr q17, [x21, #0x70]\n"
+        "add x21, x21, #0x88\n"
+        ".inst 0x4f8ce37a  // sdot v26.4s, v27.16b, v12.4b[0]\n"
+        ".inst 0x4face36a  // sdot v10.4s, v27.16b, v12.4b[1]\n"
+        ".inst 0x4f8ceb62  // sdot v2.4s, v27.16b, v12.4b[2]\n"
+        ".inst 0x4faceb7d  // sdot v29.4s, v27.16b, v12.4b[3]\n"
+        ".inst 0x4f91e3da  // sdot v26.4s, v30.16b, v17.4b[0]\n"
+        ".inst 0x4fb1e3ca  // sdot v10.4s, v30.16b, v17.4b[1]\n"
+        ".inst 0x4f91ebc2  // sdot v2.4s, v30.16b, v17.4b[2]\n"
+        ".inst 0x4fb1ebdd  // sdot v29.4s, v30.16b, v17.4b[3]\n"
+        "scvtf v26.4s, v26.4s, #0x4\n"
+        "scvtf v10.4s, v10.4s, #0x4\n"
+        "fmla v5.4s, v26.4s, v24.4s\n"
+        "scvtf v2.4s, v2.4s, #0x4\n"
+        "scvtf v29.4s, v29.4s, #0x4\n"
+        "fmla v21.4s, v10.4s, v31.4s\n"
+        "fmla v8.4s, v2.4s, v6.4s\n"
+        "fmla v1.4s, v29.4s, v20.4s\n"
+        "bgt 3b\n"
+        "mov x20, %x[res_ptr]\n"
+        "subs x27, x27, #0x4\n"
+        "add %x[res_ptr], %x[res_ptr], #0x10\n"
+        "str q15, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q19, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q18, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q14, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q11, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q13, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q23, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q16, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q25, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q7, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q0, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q4, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q5, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q21, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q8, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q1, [x20, #0x0]\n"
+        "bne 2b\n"
+        "mov x20, #0x4\n"
+        "sub x10, x10, #0x10\n"
+        "cmp x10, #0x10\n"
+        "mov %x[res_ptr], x26\n"
+        "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
+        "bge 1b\n"
+        "4:"  // Row loop skip
+        "cbz x10, 9f\n"
+        "5:"  // Row tail: Row loop
+        "add x24, %x[b_ptr], #0x8\n"
+        "mov x23, %x[nc]\n"
+        "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
+        "6:"  // Row tail: Column loop
+        "movi v15.16b, #0x0\n"
+        "movi v19.16b, #0x0\n"
+        "add x25, %x[a_ptr], #0x8\n"
+        "mov x21, %x[nb]\n"
+        "movi v18.16b, #0x0\n"
+        "movi v14.16b, #0x0\n"
+        "7:"  // Row tail: Block loop
+        "ldr q7, [x24, #0x0]\n"
+        "ldr q5, [x25, #0x0]\n"
+        "movi v9.16b, #0x4\n"
+        "movi v4.4s, #0x0\n"
+        "ldr q3, [x24, #0x10]\n"
+        "ldr q2, [x25, #0x10]\n"
+        "movi v1.4s, #0x0\n"
+        "movi v0.4s, #0x0\n"
+        "ldr q13, [x24, #0x20]\n"
+        "ldr q31, [x25, #0x20]\n"
+        "movi v30.4s, #0x0\n"
+        "movi v29.16b, #0xf0\n"
+        "ldr q28, [x24, #0x30]\n"
+        "ldr q27, [x25, #0x30]\n"
+        "sshl v20.16b, v7.16b, v9.16b\n"
+        "sub x20, x24, #0x8\n"
+        "ldr q26, [x25, #0x40]\n"
+        "ldr q25, [x25, #0x50]\n"
+        "sshl v17.16b, v3.16b, v9.16b\n"
+        "and v7.16b, v7.16b, v29.16b\n"
+        "ldr q24, [x25, #0x60]\n"
+        "ldr q16, [x25, #0x70]\n"
+        "sshl v22.16b, v13.16b, v9.16b\n"
+        "and v3.16b, v3.16b, v29.16b\n"
+        "ldr d21, [x20, #0x0]\n"
+        "ldr d12, [x25, #-0x8]\n"
+        ".inst 0x4f85e284  // sdot v4.4s, v20.16b, v5.4b[0]\n"
+        ".inst 0x4fa5e281  // sdot v1.4s, v20.16b, v5.4b[1]\n"
+        ".inst 0x4f85ea80  // sdot v0.4s, v20.16b, v5.4b[2]\n"
+        ".inst 0x4fa5ea9e  // sdot v30.4s, v20.16b, v5.4b[3]\n"
+        "sshl v9.16b, v28.16b, v9.16b\n"
+        "subs x21, x21, #0x1\n"
+        "and v13.16b, v13.16b, v29.16b\n"
+        "and v28.16b, v28.16b, v29.16b\n"
+        "add x25, x25, #0x88\n"
+        "add x24, x24, #0x48\n"
+        "fcvtl v21.4s, v21.4h\n"
+        "fcvtl v12.4s, v12.4h\n"
+        ".inst 0x4f82e224  // sdot v4.4s, v17.16b, v2.4b[0]\n"
+        ".inst 0x4fa2e221  // sdot v1.4s, v17.16b, v2.4b[1]\n"
+        ".inst 0x4f82ea20  // sdot v0.4s, v17.16b, v2.4b[2]\n"
+        ".inst 0x4fa2ea3e  // sdot v30.4s, v17.16b, v2.4b[3]\n"
+        "fmul v11.4s, v21.4s, v12.s[0]\n"
+        "fmul v23.4s, v21.4s, v12.s[1]\n"
+        "fmul v17.4s, v21.4s, v12.s[2]\n"
+        ".inst 0x4f9fe2c4  // sdot v4.4s, v22.16b, v31.4b[0]\n"
+        "fmul v6.4s, v21.4s, v12.s[3]\n"
+        ".inst 0x4fbfe2c1  // sdot v1.4s, v22.16b, v31.4b[1]\n"
+        ".inst 0x4f9feac0  // sdot v0.4s, v22.16b, v31.4b[2]\n"
+        ".inst 0x4fbfeade  // sdot v30.4s, v22.16b, v31.4b[3]\n"
+        ".inst 0x4f9be124  // sdot v4.4s, v9.16b, v27.4b[0]\n"
+        ".inst 0x4fbbe121  // sdot v1.4s, v9.16b, v27.4b[1]\n"
+        ".inst 0x4f9be920  // sdot v0.4s, v9.16b, v27.4b[2]\n"
+        ".inst 0x4fbbe93e  // sdot v30.4s, v9.16b, v27.4b[3]\n"
+        ".inst 0x4f9ae0e4  // sdot v4.4s, v7.16b, v26.4b[0]\n"
+        ".inst 0x4fbae0e1  // sdot v1.4s, v7.16b, v26.4b[1]\n"
+        ".inst 0x4f9ae8e0  // sdot v0.4s, v7.16b, v26.4b[2]\n"
+        ".inst 0x4fbae8fe  // sdot v30.4s, v7.16b, v26.4b[3]\n"
+        ".inst 0x4f99e064  // sdot v4.4s, v3.16b, v25.4b[0]\n"
+        ".inst 0x4fb9e061  // sdot v1.4s, v3.16b, v25.4b[1]\n"
+        ".inst 0x4f99e860  // sdot v0.4s, v3.16b, v25.4b[2]\n"
+        ".inst 0x4fb9e87e  // sdot v30.4s, v3.16b, v25.4b[3]\n"
+        ".inst 0x4f98e1a4  // sdot v4.4s, v13.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e1a1  // sdot v1.4s, v13.16b, v24.4b[1]\n"
+        ".inst 0x4f98e9a0  // sdot v0.4s, v13.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e9be  // sdot v30.4s, v13.16b, v24.4b[3]\n"
+        ".inst 0x4f90e384  // sdot v4.4s, v28.16b, v16.4b[0]\n"
+        ".inst 0x4fb0e381  // sdot v1.4s, v28.16b, v16.4b[1]\n"
+        ".inst 0x4f90eb80  // sdot v0.4s, v28.16b, v16.4b[2]\n"
+        ".inst 0x4fb0eb9e  // sdot v30.4s, v28.16b, v16.4b[3]\n"
+        "scvtf v4.4s, v4.4s, #0x4\n"
+        "scvtf v1.4s, v1.4s, #0x4\n"
+        "scvtf v0.4s, v0.4s, #0x4\n"
+        "fmla v15.4s, v4.4s, v11.4s\n"
+        "scvtf v30.4s, v30.4s, #0x4\n"
+        "fmla v19.4s, v1.4s, v23.4s\n"
+        "fmla v18.4s, v0.4s, v17.4s\n"
+        "fmla v14.4s, v30.4s, v6.4s\n"
+        "bgt 7b\n"
+        "mov x20, %x[res_ptr]\n"
+        "cmp x10, #0x1\n"
+        "str q15, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "ble 8f\n"
+        "cmp x10, #0x2\n"
+        "str q19, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "ble 8f\n"
+        "cmp x10, #0x3\n"
+        "str q18, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "ble 8f\n"
+        "str q14, [x20, #0x0]\n"
+        "8:"  // Row tail: Accumulator store skip
+        "subs x23, x23, #0x4\n"
+        "add %x[res_ptr], %x[res_ptr], #0x10\n"
+        "bne 6b\n"
+        "subs x10, x10, #0x4\n"
+        "add %x[a_ptr], %x[a_ptr], x9\n"
+        "mov %x[res_ptr], x22\n"
+        "bgt 5b\n"
+        "9:"  // Row tail: Row loop skip
+        : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+        : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+        : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+    );
+#else
+    float sumf[4][4];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+                            }
+                            sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+#endif
+}
+
+void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+    if (svcntw() == 8) {
+        GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) &&
+                    "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
+    }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+    const void * b_ptr = vx;
+    const void * a_ptr = vy;
+    float * res_ptr = s;
+    size_t res_stride = bs * sizeof(float);
+
+    __asm__ __volatile__(
+        "mov x10, %x[nr]\n"
+        "mov x9, #0x88\n"
+        "cmp x10, #0x10\n"
+        "mul x9, %x[nb], x9\n"
+        "blt 4f\n"
+        "1:"  // Row loop
+        "add x28, %x[b_ptr], #0x8\n"
+        "mov x27, %x[nc]\n"
+        "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
+        "2:"  // Column loop
+        "add x25, %x[a_ptr], #0x8\n"
+        "movi v2.16b, #0x0\n"
+        "movi v10.16b, #0x0\n"
+        "mov x24, %x[nb]\n"
+        "add x23, x25, x9\n"
+        "movi v12.16b, #0x0\n"
+        "movi v28.16b, #0x0\n"
+        "add x22, x23, x9\n"
+        "movi v11.16b, #0x0\n"
+        "movi v13.16b, #0x0\n"
+        "add x21, x22, x9\n"
+        "movi v22.16b, #0x0\n"
+        "movi v23.16b, #0x0\n"
+        "movi v25.16b, #0x0\n"
+        "movi v5.16b, #0x0\n"
+        "movi v7.16b, #0x0\n"
+        "movi v4.16b, #0x0\n"
+        "movi v6.16b, #0x0\n"
+        "movi v30.16b, #0x0\n"
+        "movi v24.16b, #0x0\n"
+        "movi v14.16b, #0x0\n"
+        "3:"  // Block loop
+        "ldr q21, [x28, #0x0]\n"
+        "ldr q16, [x28, #0x10]\n"
+        "movi v1.16b, #0x4\n"
+        "movi v19.4s, #0x0\n"
+        "ldr q27, [x25, #0x0]\n"
+        "ldr q15, [x25, #0x10]\n"
+        "movi v26.4s, #0x0\n"
+        "movi v18.4s, #0x0\n"
+        "ldr q29, [x28, #0x20]\n"
+        "ldr q3, [x28, #0x30]\n"
+        "movi v17.4s, #0x0\n"
+        "movi v0.16b, #0xf0\n"
+        "ldr d20, [x25, #-0x8]\n"
+        "ldr d9, [x23, #-0x8]\n"
+        "sshl v8.16b, v21.16b, v1.16b\n"
+        "sshl v31.16b, v16.16b, v1.16b\n"
+        "and v21.16b, v21.16b, v0.16b\n"
+        "and v16.16b, v16.16b, v0.16b\n"
+        "sub x20, x28, #0x8\n"
+        "subs x24, x24, #0x1\n"
+        "add x28, x28, #0x48\n"
+        ".inst 0x4e88a773  // smmla v19.4s, v27.16b, v8.16b\n"
+        ".inst 0x4e9fa77a  // smmla v26.4s, v27.16b, v31.16b\n"
+        "ldr q27, [x25, #0x20]\n"
+        ".inst 0x4e88a5f2  // smmla v18.4s, v15.16b, v8.16b\n"
+        ".inst 0x4e9fa5f1  // smmla v17.4s, v15.16b, v31.16b\n"
+        "sshl v15.16b, v29.16b, v1.16b\n"
+        "sshl v1.16b, v3.16b, v1.16b\n"
+        "and v29.16b, v29.16b, v0.16b\n"
+        "and v3.16b, v3.16b, v0.16b\n"
+        "ldr q0, [x25, #0x30]\n"
+        "fcvtl v20.4s, v20.4h\n"
+        ".inst 0x4e8fa773  // smmla v19.4s, v27.16b, v15.16b\n"
+        "fcvtl v9.4s, v9.4h\n"
+        ".inst 0x4e81a77a  // smmla v26.4s, v27.16b, v1.16b\n"
+        "ldr q27, [x25, #0x40]\n"
+        ".inst 0x4e8fa412  // smmla v18.4s, v0.16b, v15.16b\n"
+        ".inst 0x4e81a411  // smmla v17.4s, v0.16b, v1.16b\n"
+        "ldr q0, [x25, #0x50]\n"
+        ".inst 0x4e95a773  // smmla v19.4s, v27.16b, v21.16b\n"
+        ".inst 0x4e90a77a  // smmla v26.4s, v27.16b, v16.16b\n"
+        "ldr q27, [x25, #0x60]\n"
+        ".inst 0x4e95a412  // smmla v18.4s, v0.16b, v21.16b\n"
+        ".inst 0x4e90a411  // smmla v17.4s, v0.16b, v16.16b\n"
+        "ldr q0, [x25, #0x70]\n"
+        "add x25, x25, #0x88\n"
+        ".inst 0x4e9da773  // smmla v19.4s, v27.16b, v29.16b\n"
+        ".inst 0x4e83a77a  // smmla v26.4s, v27.16b, v3.16b\n"
+        "ldr d27, [x20, #0x0]\n"
+        ".inst 0x4e9da412  // smmla v18.4s, v0.16b, v29.16b\n"
+        ".inst 0x4e83a411  // smmla v17.4s, v0.16b, v3.16b\n"
+        "fcvtl v27.4s, v27.4h\n"
+        "uzp1 v0.2d, v19.2d, v26.2d\n"
+        "uzp2 v26.2d, v19.2d, v26.2d\n"
+        "fmul v19.4s, v27.4s, v20.s[0]\n"
+        "scvtf v0.4s, v0.4s, #0x4\n"
+        "scvtf v26.4s, v26.4s, #0x4\n"
+        "fmla v2.4s, v0.4s, v19.4s\n"
+        "ldr q19, [x23, #0x0]\n"
+        "uzp1 v0.2d, v18.2d, v17.2d\n"
+        "uzp2 v18.2d, v18.2d, v17.2d\n"
+        "fmul v17.4s, v27.4s, v20.s[1]\n"
+        "scvtf v0.4s, v0.4s, #0x4\n"
+        "scvtf v18.4s, v18.4s, #0x4\n"
+        "fmla v10.4s, v26.4s, v17.4s\n"
+        "ldr q17, [x23, #0x10]\n"
+        "fmul v26.4s, v27.4s, v20.s[2]\n"
+        "fmul v20.4s, v27.4s, v20.s[3]\n"
+        "fmla v12.4s, v0.4s, v26.4s\n"
+        "ldr d0, [x22, #-0x8]\n"
+        "ldr d26, [x21, #-0x8]\n"
+        "fcvtl v0.4s, v0.4h\n"
+        "fmla v28.4s, v18.4s, v20.4s\n"
+        "movi v20.4s, #0x0\n"
+        "movi v18.4s, #0x0\n"
+        ".inst 0x4e88a674  // smmla v20.4s, v19.16b, v8.16b\n"
+        ".inst 0x4e9fa672  // smmla v18.4s, v19.16b, v31.16b\n"
+        "ldr q19, [x23, #0x20]\n"
+        "fcvtl v26.4s, v26.4h\n"
+        ".inst 0x4e8fa674  // smmla v20.4s, v19.16b, v15.16b\n"
+        ".inst 0x4e81a672  // smmla v18.4s, v19.16b, v1.16b\n"
+        "ldr q19, [x23, #0x40]\n"
+        ".inst 0x4e95a674  // smmla v20.4s, v19.16b, v21.16b\n"
+        ".inst 0x4e90a672  // smmla v18.4s, v19.16b, v16.16b\n"
+        "ldr q19, [x23, #0x60]\n"
+        ".inst 0x4e9da674  // smmla v20.4s, v19.16b, v29.16b\n"
+        ".inst 0x4e83a672  // smmla v18.4s, v19.16b, v3.16b\n"
+        "uzp1 v19.2d, v20.2d, v18.2d\n"
+        "scvtf v19.4s, v19.4s, #0x4\n"
+        "uzp2 v20.2d, v20.2d, v18.2d\n"
+        "fmul v18.4s, v27.4s, v9.s[0]\n"
+        "scvtf v20.4s, v20.4s, #0x4\n"
+        "fmla v11.4s, v19.4s, v18.4s\n"
+        "ldr q18, [x22, #0x0]\n"
+        "fmul v19.4s, v27.4s, v9.s[1]\n"
+        "fmla v13.4s, v20.4s, v19.4s\n"
+        "movi v19.4s, #0x0\n"
+        "movi v20.4s, #0x0\n"
+        ".inst 0x4e88a633  // smmla v19.4s, v17.16b, v8.16b\n"
+        ".inst 0x4e9fa634  // smmla v20.4s, v17.16b, v31.16b\n"
+        "ldr q17, [x23, #0x30]\n"
+        ".inst 0x4e8fa633  // smmla v19.4s, v17.16b, v15.16b\n"
+        ".inst 0x4e81a634  // smmla v20.4s, v17.16b, v1.16b\n"
+        "ldr q17, [x23, #0x50]\n"
+        ".inst 0x4e95a633  // smmla v19.4s, v17.16b, v21.16b\n"
+        ".inst 0x4e90a634  // smmla v20.4s, v17.16b, v16.16b\n"
+        "ldr q17, [x23, #0x70]\n"
+        "add x23, x23, #0x88\n"
+        ".inst 0x4e9da633  // smmla v19.4s, v17.16b, v29.16b\n"
+        ".inst 0x4e83a634  // smmla v20.4s, v17.16b, v3.16b\n"
+        "uzp1 v17.2d, v19.2d, v20.2d\n"
+        "scvtf v17.4s, v17.4s, #0x4\n"
+        "uzp2 v20.2d, v19.2d, v20.2d\n"
+        "fmul v19.4s, v27.4s, v9.s[2]\n"
+        "fmul v9.4s, v27.4s, v9.s[3]\n"
+        "scvtf v20.4s, v20.4s, #0x4\n"
+        "fmla v22.4s, v17.4s, v19.4s\n"
+        "ldr q17, [x22, #0x10]\n"
+        "movi v19.4s, #0x0\n"
+        ".inst 0x4e88a653  // smmla v19.4s, v18.16b, v8.16b\n"
+        "fmla v23.4s, v20.4s, v9.4s\n"
+        "movi v20.4s, #0x0\n"
+        "movi v9.4s, #0x0\n"
+        ".inst 0x4e9fa654  // smmla v20.4s, v18.16b, v31.16b\n"
+        "ldr q18, [x22, #0x20]\n"
+        ".inst 0x4e88a629  // smmla v9.4s, v17.16b, v8.16b\n"
+        ".inst 0x4e8fa653  // smmla v19.4s, v18.16b, v15.16b\n"
+        ".inst 0x4e81a654  // smmla v20.4s, v18.16b, v1.16b\n"
+        "ldr q18, [x22, #0x40]\n"
+        ".inst 0x4e95a653  // smmla v19.4s, v18.16b, v21.16b\n"
+        ".inst 0x4e90a654  // smmla v20.4s, v18.16b, v16.16b\n"
+        "ldr q18, [x22, #0x60]\n"
+        ".inst 0x4e9da653  // smmla v19.4s, v18.16b, v29.16b\n"
+        ".inst 0x4e83a654  // smmla v20.4s, v18.16b, v3.16b\n"
+        "movi v18.4s, #0x0\n"
+        ".inst 0x4e9fa632  // smmla v18.4s, v17.16b, v31.16b\n"
+        "ldr q17, [x22, #0x30]\n"
+        ".inst 0x4e8fa629  // smmla v9.4s, v17.16b, v15.16b\n"
+        ".inst 0x4e81a632  // smmla v18.4s, v17.16b, v1.16b\n"
+        "ldr q17, [x22, #0x50]\n"
+        ".inst 0x4e95a629  // smmla v9.4s, v17.16b, v21.16b\n"
+        ".inst 0x4e90a632  // smmla v18.4s, v17.16b, v16.16b\n"
+        "ldr q17, [x22, #0x70]\n"
+        "add x22, x22, #0x88\n"
+        ".inst 0x4e9da629  // smmla v9.4s, v17.16b, v29.16b\n"
+        ".inst 0x4e83a632  // smmla v18.4s, v17.16b, v3.16b\n"
+        "uzp1 v17.2d, v19.2d, v20.2d\n"
+        "uzp2 v20.2d, v19.2d, v20.2d\n"
+        "fmul v19.4s, v27.4s, v0.s[0]\n"
+        "scvtf v17.4s, v17.4s, #0x4\n"
+        "scvtf v20.4s, v20.4s, #0x4\n"
+        "fmla v25.4s, v17.4s, v19.4s\n"
+        "ldr q19, [x21, #0x0]\n"
+        "fmul v17.4s, v27.4s, v0.s[1]\n"
+        "fmla v5.4s, v20.4s, v17.4s\n"
+        "ldr q17, [x21, #0x10]\n"
+        "uzp1 v20.2d, v9.2d, v18.2d\n"
+        "uzp2 v9.2d, v9.2d, v18.2d\n"
+        "fmul v18.4s, v27.4s, v0.s[2]\n"
+        "fmul v0.4s, v27.4s, v0.s[3]\n"
+        "scvtf v20.4s, v20.4s, #0x4\n"
+        "scvtf v9.4s, v9.4s, #0x4\n"
+        "fmla v7.4s, v20.4s, v18.4s\n"
+        "movi v20.4s, #0x0\n"
+        "movi v18.4s, #0x0\n"
+        ".inst 0x4e88a674  // smmla v20.4s, v19.16b, v8.16b\n"
+        ".inst 0x4e9fa672  // smmla v18.4s, v19.16b, v31.16b\n"
+        "ldr q19, [x21, #0x20]\n"
+        "fmla v4.4s, v9.4s, v0.4s\n"
+        "movi v9.4s, #0x0\n"
+        "movi v0.4s, #0x0\n"
+        ".inst 0x4e88a629  // smmla v9.4s, v17.16b, v8.16b\n"
+        "fmul v8.4s, v27.4s, v26.s[0]\n"
+        ".inst 0x4e9fa620  // smmla v0.4s, v17.16b, v31.16b\n"
+        "ldr q17, [x21, #0x30]\n"
+        ".inst 0x4e8fa674  // smmla v20.4s, v19.16b, v15.16b\n"
+        "fmul v31.4s, v27.4s, v26.s[1]\n"
+        ".inst 0x4e81a672  // smmla v18.4s, v19.16b, v1.16b\n"
+        "ldr q19, [x21, #0x40]\n"
+        ".inst 0x4e8fa629  // smmla v9.4s, v17.16b, v15.16b\n"
+        "fmul v15.4s, v27.4s, v26.s[2]\n"
+        "fmul v27.4s, v27.4s, v26.s[3]\n"
+        ".inst 0x4e81a620  // smmla v0.4s, v17.16b, v1.16b\n"
+        "ldr q1, [x21, #0x50]\n"
+        ".inst 0x4e95a674  // smmla v20.4s, v19.16b, v21.16b\n"
+        ".inst 0x4e90a672  // smmla v18.4s, v19.16b, v16.16b\n"
+        "ldr q26, [x21, #0x60]\n"
+        ".inst 0x4e95a429  // smmla v9.4s, v1.16b, v21.16b\n"
+        ".inst 0x4e90a420  // smmla v0.4s, v1.16b, v16.16b\n"
+        "ldr q21, [x21, #0x70]\n"
+        "add x21, x21, #0x88\n"
+        ".inst 0x4e9da754  // smmla v20.4s, v26.16b, v29.16b\n"
+        ".inst 0x4e83a752  // smmla v18.4s, v26.16b, v3.16b\n"
+        ".inst 0x4e9da6a9  // smmla v9.4s, v21.16b, v29.16b\n"
+        ".inst 0x4e83a6a0  // smmla v0.4s, v21.16b, v3.16b\n"
+        "uzp1 v29.2d, v20.2d, v18.2d\n"
+        "uzp2 v21.2d, v20.2d, v18.2d\n"
+        "scvtf v29.4s, v29.4s, #0x4\n"
+        "uzp1 v18.2d, v9.2d, v0.2d\n"
+        "uzp2 v16.2d, v9.2d, v0.2d\n"
+        "scvtf v21.4s, v21.4s, #0x4\n"
+        "fmla v6.4s, v29.4s, v8.4s\n"
+        "scvtf v18.4s, v18.4s, #0x4\n"
+        "scvtf v16.4s, v16.4s, #0x4\n"
+        "fmla v30.4s, v21.4s, v31.4s\n"
+        "fmla v24.4s, v18.4s, v15.4s\n"
+        "fmla v14.4s, v16.4s, v27.4s\n"
+        "bgt 3b\n"
+        "mov x20, %x[res_ptr]\n"
+        "subs x27, x27, #0x4\n"
+        "add %x[res_ptr], %x[res_ptr], #0x10\n"
+        "str q2, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q10, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q12, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q28, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q11, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q13, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q22, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q23, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q25, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q5, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q7, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q4, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q6, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q30, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q24, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q14, [x20, #0x0]\n"
+        "bne 2b\n"
+        "mov x20, #0x4\n"
+        "sub x10, x10, #0x10\n"
+        "cmp x10, #0x10\n"
+        "mov %x[res_ptr], x26\n"
+        "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
+        "bge 1b\n"
+        "4:"  // Row loop skip
+        "cbz x10, 9f\n"
+        "5:"  // Row tail: Row loop
+        "add x24, %x[b_ptr], #0x8\n"
+        "mov x23, %x[nc]\n"
+        "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
+        "6:"  // Row tail: Column loop
+        "movi v2.16b, #0x0\n"
+        "movi v10.16b, #0x0\n"
+        "add x25, %x[a_ptr], #0x8\n"
+        "mov x21, %x[nb]\n"
+        "movi v12.16b, #0x0\n"
+        "movi v28.16b, #0x0\n"
+        "7:"  // Row tail: Block loop
+        "ldr q6, [x24, #0x0]\n"
+        "ldr q5, [x24, #0x10]\n"
+        "movi v17.16b, #0x4\n"
+        "movi v8.4s, #0x0\n"
+        "ldr q4, [x25, #0x0]\n"
+        "ldr q13, [x25, #0x10]\n"
+        "movi v27.4s, #0x0\n"
+        "movi v0.4s, #0x0\n"
+        "ldr q31, [x24, #0x20]\n"
+        "ldr q14, [x24, #0x30]\n"
+        "movi v29.4s, #0x0\n"
+        "movi v22.16b, #0xf0\n"
+        "ldr q11, [x25, #0x20]\n"
+        "ldr q23, [x25, #0x30]\n"
+        "sshl v21.16b, v6.16b, v17.16b\n"
+        "sshl v16.16b, v5.16b, v17.16b\n"
+        "ldr q20, [x25, #0x40]\n"
+        "ldr q26, [x25, #0x50]\n"
+        "and v6.16b, v6.16b, v22.16b\n"
+        "and v5.16b, v5.16b, v22.16b\n"
+        "ldr q25, [x25, #0x60]\n"
+        "ldr q3, [x25, #0x70]\n"
+        "sshl v19.16b, v31.16b, v17.16b\n"
+        "sshl v18.16b, v14.16b, v17.16b\n"
+        "ldr d17, [x25, #-0x8]\n"
+        ".inst 0x4e95a488  // smmla v8.4s, v4.16b, v21.16b\n"
+        ".inst 0x4e90a49b  // smmla v27.4s, v4.16b, v16.16b\n"
+        "and v31.16b, v31.16b, v22.16b\n"
+        ".inst 0x4e95a5a0  // smmla v0.4s, v13.16b, v21.16b\n"
+        ".inst 0x4e90a5bd  // smmla v29.4s, v13.16b, v16.16b\n"
+        "and v14.16b, v14.16b, v22.16b\n"
+        "sub x20, x24, #0x8\n"
+        "ldr d16, [x20, #0x0]\n"
+        "subs x21, x21, #0x1\n"
+        "add x25, x25, #0x88\n"
+        "fcvtl v17.4s, v17.4h\n"
+        "add x24, x24, #0x48\n"
+        ".inst 0x4e93a568  // smmla v8.4s, v11.16b, v19.16b\n"
+        ".inst 0x4e92a57b  // smmla v27.4s, v11.16b, v18.16b\n"
+        ".inst 0x4e93a6e0  // smmla v0.4s, v23.16b, v19.16b\n"
+        ".inst 0x4e92a6fd  // smmla v29.4s, v23.16b, v18.16b\n"
+        "fcvtl v16.4s, v16.4h\n"
+        ".inst 0x4e86a688  // smmla v8.4s, v20.16b, v6.16b\n"
+        ".inst 0x4e85a69b  // smmla v27.4s, v20.16b, v5.16b\n"
+        "fmul v23.4s, v16.4s, v17.s[0]\n"
+        "fmul v21.4s, v16.4s, v17.s[1]\n"
+        "fmul v1.4s, v16.4s, v17.s[2]\n"
+        "fmul v20.4s, v16.4s, v17.s[3]\n"
+        ".inst 0x4e86a740  // smmla v0.4s, v26.16b, v6.16b\n"
+        ".inst 0x4e85a75d  // smmla v29.4s, v26.16b, v5.16b\n"
+        ".inst 0x4e9fa728  // smmla v8.4s, v25.16b, v31.16b\n"
+        ".inst 0x4e8ea73b  // smmla v27.4s, v25.16b, v14.16b\n"
+        ".inst 0x4e9fa460  // smmla v0.4s, v3.16b, v31.16b\n"
+        ".inst 0x4e8ea47d  // smmla v29.4s, v3.16b, v14.16b\n"
+        "uzp1 v19.2d, v8.2d, v27.2d\n"
+        "uzp2 v18.2d, v8.2d, v27.2d\n"
+        "scvtf v19.4s, v19.4s, #0x4\n"
+        "uzp1 v17.2d, v0.2d, v29.2d\n"
+        "uzp2 v16.2d, v0.2d, v29.2d\n"
+        "scvtf v18.4s, v18.4s, #0x4\n"
+        "fmla v2.4s, v19.4s, v23.4s\n"
+        "scvtf v17.4s, v17.4s, #0x4\n"
+        "scvtf v16.4s, v16.4s, #0x4\n"
+        "fmla v10.4s, v18.4s, v21.4s\n"
+        "fmla v12.4s, v17.4s, v1.4s\n"
+        "fmla v28.4s, v16.4s, v20.4s\n"
+        "bgt 7b\n"
+        "mov x20, %x[res_ptr]\n"
+        "cmp x10, #0x1\n"
+        "str q2, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "ble 8f\n"
+        "cmp x10, #0x2\n"
+        "str q10, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "ble 8f\n"
+        "cmp x10, #0x3\n"
+        "str q12, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "ble 8f\n"
+        "str q28, [x20, #0x0]\n"
+        "8:"  // Row tail: Accumulator store skip
+        "subs x23, x23, #0x4\n"
+        "add %x[res_ptr], %x[res_ptr], #0x10\n"
+        "bne 6b\n"
+        "subs x10, x10, #0x4\n"
+        "add %x[a_ptr], %x[a_ptr], x9\n"
+        "mov %x[res_ptr], x22\n"
+        "bgt 5b\n"
+        "9:"  // Row tail: Row loop skip
+        : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+        : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+        : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+    );
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+    GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
+                "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
+                "performance");
+#else
+    float sumf[4][4];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+                            }
+                            sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+#endif
+}
+
+void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+    if (svcntw() == 8) {
+        const void * b_ptr = vx;
+        const void * a_ptr = vy;
+        float * res_ptr = s;
+        size_t res_stride = bs * sizeof(float);
+
+        __asm__ __volatile__(
+            "mov x20, #0x4\n"
+            "mov x13, %x[nr]\n"
+            "mov z28.s, #-0x4\n"
+            "mov x12, #0x88\n"
+            "ptrue p1.b\n"
+            "whilelt p0.s, XZR, x20\n"
+            "cmp x13, #0x10\n"
+            "mul x12, %x[nb], x12\n"
+            "blt 4f\n"
+            "1:"  // Row loop
+            "add x11, %x[b_ptr], #0x10\n"
+            "mov x10, %x[nc]\n"
+            "add x9, %x[res_ptr], %x[res_stride], LSL #4\n"
+            "2:"  // Column loop
+            "add x28, %x[a_ptr], #0x8\n"
+            "mov z24.b, #0x0\n"
+            "mov z15.b, #0x0\n"
+            "mov x27, %x[nb]\n"
+            "add x26, x28, x12\n"
+            "mov z12.b, #0x0\n"
+            "mov z0.b, #0x0\n"
+            "add x25, x26, x12\n"
+            "mov z13.b, #0x0\n"
+            "mov z1.b, #0x0\n"
+            "add x24, x25, x12\n"
+            "mov z20.b, #0x0\n"
+            "mov z25.b, #0x0\n"
+            "mov z11.b, #0x0\n"
+            "mov z16.b, #0x0\n"
+            "mov z19.b, #0x0\n"
+            "mov z26.b, #0x0\n"
+            "mov z8.b, #0x0\n"
+            "mov z29.b, #0x0\n"
+            "mov z27.b, #0x0\n"
+            "mov z10.b, #0x0\n"
+            "3:"  // Block loop
+            "ld1b { z30.b }, p1/Z, [x11]\n"
+            "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n"
+            "mov z18.s, #0x0\n"
+            "mov z7.s, #0x0\n"
+            "ld1rqb { z3.b }, p1/Z, [x28]\n"
+            "ld1rqb { z5.b }, p1/Z, [x28, #16]\n"
+            "mov z9.s, #0x0\n"
+            "mov z22.s, #0x0\n"
+            "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n"
+            "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n"
+            "sub x20, x11, #0x10\n"
+            "sub x23, x28, #0x8\n"
+            "lsl z31.b, z30.b, #0x4\n"
+            "lsl z6.b, z21.b, #0x4\n"
+            "ld1h { z23.s }, p1/Z, [x20]\n"
+            "sub x22, x26, #0x8\n"
+            "and z30.b, z30.b, #0xf0\n"
+            "and z21.b, z21.b, #0xf0\n"
+            "sub x21, x25, #0x8\n"
+            "sub x20, x24, #0x8\n"
+            "lsl z14.b, z4.b, #0x4\n"
+            "lsl z2.b, z17.b, #0x4\n"
+            "subs x27, x27, #0x1\n"
+            "add x11, x11, #0x90\n"
+            ".inst 0x451f9872  // smmla z18.s, z3.b, z31.b\n"
+            ".inst 0x45069867  // smmla z7.s, z3.b, z6.b\n"
+            "ld1rqb { z3.b }, p1/Z, [x28, #32]\n"
+            "and z4.b, z4.b, #0xf0\n"
+            ".inst 0x451f98a9  // smmla z9.s, z5.b, z31.b\n"
+            ".inst 0x450698b6  // smmla z22.s, z5.b, z6.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x28, #48]\n"
+            "and z17.b, z17.b, #0xf0\n"
+            "fcvt z23.s, p1/m, z23.h\n"
+            ".inst 0x450e9872  // smmla z18.s, z3.b, z14.b\n"
+            ".inst 0x45029867  // smmla z7.s, z3.b, z2.b\n"
+            "ld1rqb { z3.b }, p1/Z, [x28, #64]\n"
+            ".inst 0x450e98a9  // smmla z9.s, z5.b, z14.b\n"
+            ".inst 0x450298b6  // smmla z22.s, z5.b, z2.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x28, #80]\n"
+            "fscale z23.s, p1/m, z23.s, z28.s\n"
+            ".inst 0x451e9872  // smmla z18.s, z3.b, z30.b\n"
+            ".inst 0x45159867  // smmla z7.s, z3.b, z21.b\n"
+            "ld1rqb { z3.b }, p1/Z, [x28, #96]\n"
+            ".inst 0x451e98a9  // smmla z9.s, z5.b, z30.b\n"
+            ".inst 0x451598b6  // smmla z22.s, z5.b, z21.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x28, #112]\n"
+            "add x28, x28, #0x88\n"
+            ".inst 0x45049872  // smmla z18.s, z3.b, z4.b\n"
+            ".inst 0x45119867  // smmla z7.s, z3.b, z17.b\n"
+            "ld1h { z3.s }, p0/Z, [x23]\n"
+            ".inst 0x450498a9  // smmla z9.s, z5.b, z4.b\n"
+            ".inst 0x451198b6  // smmla z22.s, z5.b, z17.b\n"
+            "fcvt z3.s, p1/m, z3.h\n"
+            "uzp1 z5.d, z18.d, z7.d\n"
+            "uzp2 z18.d, z18.d, z7.d\n"
+            "mov z3.q, z3.q[0]\n"
+            "uzp1 z7.d, z9.d, z22.d\n"
+            "uzp2 z22.d, z9.d, z22.d\n"
+            "fmul z9.s, z23.s, z3.s[0]\n"
+            "scvtf z5.s, p1/m, z5.s\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "scvtf z7.s, p1/m, z7.s\n"
+            "scvtf z22.s, p1/m, z22.s\n"
+            "fmla z24.s, p1/M, z5.s, z9.s\n"
+            "ld1rqb { z5.b }, p1/Z, [x26]\n"
+            "fmul z9.s, z23.s, z3.s[1]\n"
+            "fmla z15.s, p1/M, z18.s, z9.s\n"
+            "ld1rqb { z18.b }, p1/Z, [x26, #16]\n"
+            "fmul z9.s, z23.s, z3.s[2]\n"
+            "fmul z3.s, z23.s, z3.s[3]\n"
+            "fmla z12.s, p1/M, z7.s, z9.s\n"
+            "mov z9.s, #0x0\n"
+            "ld1h { z7.s }, p0/Z, [x22]\n"
+            ".inst 0x451f98a9  // smmla z9.s, z5.b, z31.b\n"
+            "fmla z0.s, p1/M, z22.s, z3.s\n"
+            "mov z22.s, #0x0\n"
+            "ld1h { z3.s }, p0/Z, [x21]\n"
+            ".inst 0x450698b6  // smmla z22.s, z5.b, z6.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x26, #32]\n"
+            "fcvt z7.s, p1/m, z7.h\n"
+            "fcvt z3.s, p1/m, z3.h\n"
+            ".inst 0x450e98a9  // smmla z9.s, z5.b, z14.b\n"
+            ".inst 0x450298b6  // smmla z22.s, z5.b, z2.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x26, #64]\n"
+            "mov z7.q, z7.q[0]\n"
+            "mov z3.q, z3.q[0]\n"
+            ".inst 0x451e98a9  // smmla z9.s, z5.b, z30.b\n"
+            ".inst 0x451598b6  // smmla z22.s, z5.b, z21.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x26, #96]\n"
+            ".inst 0x450498a9  // smmla z9.s, z5.b, z4.b\n"
+            ".inst 0x451198b6  // smmla z22.s, z5.b, z17.b\n"
+            "uzp1 z5.d, z9.d, z22.d\n"
+            "scvtf z5.s, p1/m, z5.s\n"
+            "uzp2 z22.d, z9.d, z22.d\n"
+            "fmul z9.s, z23.s, z7.s[0]\n"
+            "scvtf z22.s, p1/m, z22.s\n"
+            "fmla z13.s, p1/M, z5.s, z9.s\n"
+            "ld1rqb { z9.b }, p1/Z, [x25]\n"
+            "fmul z5.s, z23.s, z7.s[1]\n"
+            "fmla z1.s, p1/M, z22.s, z5.s\n"
+            "mov z5.s, #0x0\n"
+            "mov z22.s, #0x0\n"
+            ".inst 0x451f9a45  // smmla z5.s, z18.b, z31.b\n"
+            ".inst 0x45069a56  // smmla z22.s, z18.b, z6.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x26, #48]\n"
+            ".inst 0x450e9a45  // smmla z5.s, z18.b, z14.b\n"
+            ".inst 0x45029a56  // smmla z22.s, z18.b, z2.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x26, #80]\n"
+            ".inst 0x451e9a45  // smmla z5.s, z18.b, z30.b\n"
+            ".inst 0x45159a56  // smmla z22.s, z18.b, z21.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x26, #112]\n"
+            "add x26, x26, #0x88\n"
+            ".inst 0x45049a45  // smmla z5.s, z18.b, z4.b\n"
+            ".inst 0x45119a56  // smmla z22.s, z18.b, z17.b\n"
+            "uzp1 z18.d, z5.d, z22.d\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "uzp2 z22.d, z5.d, z22.d\n"
+            "fmul z5.s, z23.s, z7.s[2]\n"
+            "fmul z7.s, z23.s, z7.s[3]\n"
+            "scvtf z22.s, p1/m, z22.s\n"
+            "fmla z20.s, p1/M, z18.s, z5.s\n"
+            "ld1rqb { z18.b }, p1/Z, [x25, #16]\n"
+            "ld1h { z5.s }, p0/Z, [x20]\n"
+            "fcvt z5.s, p1/m, z5.h\n"
+            "fmla z25.s, p1/M, z22.s, z7.s\n"
+            "mov z22.s, #0x0\n"
+            "mov z7.s, #0x0\n"
+            ".inst 0x451f9936  // smmla z22.s, z9.b, z31.b\n"
+            ".inst 0x45069927  // smmla z7.s, z9.b, z6.b\n"
+            "ld1rqb { z9.b }, p1/Z, [x25, #32]\n"
+            "mov z5.q, z5.q[0]\n"
+            ".inst 0x450e9936  // smmla z22.s, z9.b, z14.b\n"
+            ".inst 0x45029927  // smmla z7.s, z9.b, z2.b\n"
+            "ld1rqb { z9.b }, p1/Z, [x25, #64]\n"
+            ".inst 0x451e9936  // smmla z22.s, z9.b, z30.b\n"
+            ".inst 0x45159927  // smmla z7.s, z9.b, z21.b\n"
+            "ld1rqb { z9.b }, p1/Z, [x25, #96]\n"
+            ".inst 0x45049936  // smmla z22.s, z9.b, z4.b\n"
+            ".inst 0x45119927  // smmla z7.s, z9.b, z17.b\n"
+            "uzp1 z9.d, z22.d, z7.d\n"
+            "scvtf z9.s, p1/m, z9.s\n"
+            "uzp2 z22.d, z22.d, z7.d\n"
+            "fmul z7.s, z23.s, z3.s[0]\n"
+            "scvtf z22.s, p1/m, z22.s\n"
+            "fmla z11.s, p1/M, z9.s, z7.s\n"
+            "ld1rqb { z9.b }, p1/Z, [x24]\n"
+            "fmul z7.s, z23.s, z3.s[1]\n"
+            "fmla z16.s, p1/M, z22.s, z7.s\n"
+            "mov z22.s, #0x0\n"
+            "mov z7.s, #0x0\n"
+            ".inst 0x451f9a56  // smmla z22.s, z18.b, z31.b\n"
+            ".inst 0x45069a47  // smmla z7.s, z18.b, z6.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x25, #48]\n"
+            ".inst 0x450e9a56  // smmla z22.s, z18.b, z14.b\n"
+            ".inst 0x45029a47  // smmla z7.s, z18.b, z2.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x25, #80]\n"
+            ".inst 0x451e9a56  // smmla z22.s, z18.b, z30.b\n"
+            ".inst 0x45159a47  // smmla z7.s, z18.b, z21.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x25, #112]\n"
+            "add x25, x25, #0x88\n"
+            ".inst 0x45049a56  // smmla z22.s, z18.b, z4.b\n"
+            ".inst 0x45119a47  // smmla z7.s, z18.b, z17.b\n"
+            "uzp1 z18.d, z22.d, z7.d\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "uzp2 z7.d, z22.d, z7.d\n"
+            "fmul z22.s, z23.s, z3.s[2]\n"
+            "fmul z3.s, z23.s, z3.s[3]\n"
+            "scvtf z7.s, p1/m, z7.s\n"
+            "fmla z19.s, p1/M, z18.s, z22.s\n"
+            "ld1rqb { z18.b }, p1/Z, [x24, #16]\n"
+            "fmul z22.s, z23.s, z5.s[0]\n"
+            "fmla z26.s, p1/M, z7.s, z3.s\n"
+            "mov z3.s, #0x0\n"
+            "mov z7.s, #0x0\n"
+            ".inst 0x451f9923  // smmla z3.s, z9.b, z31.b\n"
+            ".inst 0x45069927  // smmla z7.s, z9.b, z6.b\n"
+            "ld1rqb { z9.b }, p1/Z, [x24, #32]\n"
+            ".inst 0x450e9923  // smmla z3.s, z9.b, z14.b\n"
+            ".inst 0x45029927  // smmla z7.s, z9.b, z2.b\n"
+            "mov z9.s, #0x0\n"
+            ".inst 0x451f9a49  // smmla z9.s, z18.b, z31.b\n"
+            "mov z31.s, #0x0\n"
+            ".inst 0x45069a5f  // smmla z31.s, z18.b, z6.b\n"
+            "ld1rqb { z6.b }, p1/Z, [x24, #48]\n"
+            "ld1rqb { z18.b }, p1/Z, [x24, #64]\n"
+            ".inst 0x450e98c9  // smmla z9.s, z6.b, z14.b\n"
+            "fmul z14.s, z23.s, z5.s[1]\n"
+            ".inst 0x450298df  // smmla z31.s, z6.b, z2.b\n"
+            "ld1rqb { z6.b }, p1/Z, [x24, #80]\n"
+            "fmul z2.s, z23.s, z5.s[2]\n"
+            "fmul z23.s, z23.s, z5.s[3]\n"
+            ".inst 0x451e9a43  // smmla z3.s, z18.b, z30.b\n"
+            ".inst 0x45159a47  // smmla z7.s, z18.b, z21.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x24, #96]\n"
+            ".inst 0x451e98c9  // smmla z9.s, z6.b, z30.b\n"
+            ".inst 0x451598df  // smmla z31.s, z6.b, z21.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x24, #112]\n"
+            "add x24, x24, #0x88\n"
+            ".inst 0x450498a3  // smmla z3.s, z5.b, z4.b\n"
+            ".inst 0x451198a7  // smmla z7.s, z5.b, z17.b\n"
+            ".inst 0x45049a49  // smmla z9.s, z18.b, z4.b\n"
+            ".inst 0x45119a5f  // smmla z31.s, z18.b, z17.b\n"
+            "uzp1 z18.d, z3.d, z7.d\n"
+            "uzp2 z5.d, z3.d, z7.d\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "uzp1 z6.d, z9.d, z31.d\n"
+            "uzp2 z9.d, z9.d, z31.d\n"
+            "scvtf z5.s, p1/m, z5.s\n"
+            "fmla z8.s, p1/M, z18.s, z22.s\n"
+            "scvtf z6.s, p1/m, z6.s\n"
+            "scvtf z9.s, p1/m, z9.s\n"
+            "fmla z29.s, p1/M, z5.s, z14.s\n"
+            "fmla z27.s, p1/M, z6.s, z2.s\n"
+            "fmla z10.s, p1/M, z9.s, z23.s\n"
+            "bgt 3b\n"
+            "mov x20, %x[res_ptr]\n"
+            "subs x10, x10, #0x8\n"
+            "add %x[res_ptr], %x[res_ptr], #0x20\n"
+            "st1w { z24.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z15.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z12.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z0.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z13.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z1.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z20.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z25.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z11.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z16.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z19.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z26.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z8.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z29.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z27.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z10.s }, p1, [x20]\n"
+            "bne 2b\n"
+            "mov x20, #0x4\n"
+            "sub x13, x13, #0x10\n"
+            "cmp x13, #0x10\n"
+            "mov %x[res_ptr], x9\n"
+            "madd %x[a_ptr], x20, x12, %x[a_ptr]\n"
+            "bge 1b\n"
+            "4:"  // Row loop skip
+            "cbz x13, 9f\n"
+            "5:"  // Row tail: Row loop
+            "add x25, %x[b_ptr], #0x10\n"
+            "mov x24, %x[nc]\n"
+            "add x23, %x[res_ptr], %x[res_stride], LSL #2\n"
+            "6:"  // Row tail: Column loop
+            "mov z24.b, #0x0\n"
+            "mov z15.b, #0x0\n"
+            "add x28, %x[a_ptr], #0x8\n"
+            "mov x22, %x[nb]\n"
+            "mov z12.b, #0x0\n"
+            "mov z0.b, #0x0\n"
+            "7:"  // Row tail: Block loop
+            "ld1b { z3.b }, p1/Z, [x25]\n"
+            "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n"
+            "mov z2.s, #0x0\n"
+            "mov z25.s, #0x0\n"
+            "ld1rqb { z26.b }, p1/Z, [x28]\n"
+            "ld1rqb { z21.b }, p1/Z, [x28, #16]\n"
+            "mov z27.s, #0x0\n"
+            "mov z19.s, #0x0\n"
+            "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n"
+            "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n"
+            "sub x21, x25, #0x10\n"
+            "sub x20, x28, #0x8\n"
+            "lsl z20.b, z3.b, #0x4\n"
+            "lsl z4.b, z6.b, #0x4\n"
+            "ld1rqb { z10.b }, p1/Z, [x28, #32]\n"
+            "ld1rqb { z23.b }, p1/Z, [x28, #48]\n"
+            "and z3.b, z3.b, #0xf0\n"
+            "and z6.b, z6.b, #0xf0\n"
+            "ld1rqb { z11.b }, p1/Z, [x28, #64]\n"
+            "ld1rqb { z7.b }, p1/Z, [x28, #80]\n"
+            "lsl z8.b, z29.b, #0x4\n"
+            "lsl z14.b, z16.b, #0x4\n"
+            "ld1rqb { z18.b }, p1/Z, [x28, #96]\n"
+            "ld1rqb { z30.b }, p1/Z, [x28, #112]\n"
+            ".inst 0x45149b42  // smmla z2.s, z26.b, z20.b\n"
+            ".inst 0x45049b59  // smmla z25.s, z26.b, z4.b\n"
+            "and z29.b, z29.b, #0xf0\n"
+            "ld1h { z17.s }, p1/Z, [x21]\n"
+            ".inst 0x45149abb  // smmla z27.s, z21.b, z20.b\n"
+            ".inst 0x45049ab3  // smmla z19.s, z21.b, z4.b\n"
+            "and z16.b, z16.b, #0xf0\n"
+            "ld1h { z4.s }, p0/Z, [x20]\n"
+            "subs x22, x22, #0x1\n"
+            "add x28, x28, #0x88\n"
+            "fcvt z17.s, p1/m, z17.h\n"
+            "add x25, x25, #0x90\n"
+            ".inst 0x45089942  // smmla z2.s, z10.b, z8.b\n"
+            ".inst 0x450e9959  // smmla z25.s, z10.b, z14.b\n"
+            "fcvt z4.s, p1/m, z4.h\n"
+            ".inst 0x45089afb  // smmla z27.s, z23.b, z8.b\n"
+            ".inst 0x450e9af3  // smmla z19.s, z23.b, z14.b\n"
+            "fscale z17.s, p1/m, z17.s, z28.s\n"
+            "mov z4.q, z4.q[0]\n"
+            ".inst 0x45039962  // smmla z2.s, z11.b, z3.b\n"
+            ".inst 0x45069979  // smmla z25.s, z11.b, z6.b\n"
+            "fmul z23.s, z17.s, z4.s[0]\n"
+            "fmul z9.s, z17.s, z4.s[1]\n"
+            "fmul z21.s, z17.s, z4.s[2]\n"
+            "fmul z4.s, z17.s, z4.s[3]\n"
+            ".inst 0x450398fb  // smmla z27.s, z7.b, z3.b\n"
+            ".inst 0x450698f3  // smmla z19.s, z7.b, z6.b\n"
+            ".inst 0x451d9a42  // smmla z2.s, z18.b, z29.b\n"
+            ".inst 0x45109a59  // smmla z25.s, z18.b, z16.b\n"
+            ".inst 0x451d9bdb  // smmla z27.s, z30.b, z29.b\n"
+            ".inst 0x45109bd3  // smmla z19.s, z30.b, z16.b\n"
+            "uzp1 z31.d, z2.d, z25.d\n"
+            "uzp2 z13.d, z2.d, z25.d\n"
+            "scvtf z31.s, p1/m, z31.s\n"
+            "uzp1 z17.d, z27.d, z19.d\n"
+            "uzp2 z18.d, z27.d, z19.d\n"
+            "scvtf z13.s, p1/m, z13.s\n"
+            "fmla z24.s, p1/M, z31.s, z23.s\n"
+            "scvtf z17.s, p1/m, z17.s\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "fmla z15.s, p1/M, z13.s, z9.s\n"
+            "fmla z12.s, p1/M, z17.s, z21.s\n"
+            "fmla z0.s, p1/M, z18.s, z4.s\n"
+            "bgt 7b\n"
+            "mov x20, %x[res_ptr]\n"
+            "cmp x13, #0x1\n"
+            "st1w { z24.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "cmp x13, #0x2\n"
+            "st1w { z15.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "cmp x13, #0x3\n"
+            "st1w { z12.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "st1w { z0.s }, p1, [x20]\n"
+            "8:"  // Row tail: Accumulator store skip
+            "subs x24, x24, #0x8\n"
+            "add %x[res_ptr], %x[res_ptr], #0x20\n"
+            "bne 6b\n"
+            "subs x13, x13, #0x4\n"
+            "add %x[a_ptr], %x[a_ptr], x12\n"
+            "mov %x[res_ptr], x23\n"
+            "bgt 5b\n"
+            "9:"  // Row tail: Row loop skip
+            : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+            : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+            : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+        );
+        return;
+    }
+    else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+        GGML_ASSERT((ggml_cpu_has_sve() && (svcntw() == 8)) &&
+                    "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal "
+                    "performance");
+    }
+    else if (ggml_cpu_has_neon()) {
+        GGML_ASSERT(((ggml_cpu_has_sve() && (svcntw() == 8)) || ggml_cpu_has_matmul_int8()) &&
+                    "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 "
+                    "quantization format for optimal performance");
+    }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    GGML_ASSERT(ggml_cpu_has_sve() &&
+                "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance");
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+    GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
+                "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
+                "performance");
+#else
+    float sumf[4][8];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+                            }
+                            sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+#endif
+}

+ 65 - 0
llama/ggml-aarch64.h

@@ -0,0 +1,65 @@
+/**
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
+#pragma once
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+
+#include "ggml.h"
+
+// GGML internal header
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Quantization
+void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave);
+
+// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
+size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+
+// GEMV
+void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+
+// GEMM
+void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+
+#ifdef __cplusplus
+}
+#endif
+

+ 97 - 46
llama/ggml-alloc.c

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -117,8 +117,7 @@ void ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tenso
     if (talloc->offset + size > ggml_backend_buffer_get_size(talloc->buffer)) {
         fprintf(stderr, "%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n",
                 __func__, tensor->name, size, ggml_backend_buffer_get_size(talloc->buffer) - talloc->offset);
-        GGML_ASSERT(!"not enough space in the buffer");
-        return;
+        GGML_ABORT("not enough space in the buffer");
     }
 
     void * addr = (char *)ggml_backend_buffer_get_base(talloc->buffer) + talloc->offset;
@@ -159,7 +158,7 @@ static void add_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offset,
             return;
         }
     }
-    GGML_ASSERT(!"out of allocated_tensors");
+    GGML_ABORT("out of allocated_tensors");
 }
 static void remove_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, const struct ggml_tensor * tensor) {
     for (int i = 0; i < 1024; i++) {
@@ -168,8 +167,7 @@ static void remove_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offs
             return;
         }
     }
-    fprintf(stderr, "tried to free tensor %s not found\n", tensor->name);
-    GGML_ASSERT(!"tensor not found");
+    GGML_ABORT("tried to free tensor %s not found\n", tensor->name);
 }
 #endif
 
@@ -202,8 +200,7 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz
             // this should never happen
             fprintf(stderr, "%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n",
                     __func__, size, max_avail);
-            GGML_ASSERT(!"not enough space in the buffer");
-            GGML_UNREACHABLE();
+            GGML_ABORT("not enough space in the buffer");
         }
     }
 
@@ -365,6 +362,7 @@ struct hash_node {
 };
 
 struct tensor_alloc {
+    int buffer_id;
     size_t offset;
     size_t size_max; // 0 = pre-allocated, unused, or view
 };
@@ -375,7 +373,6 @@ struct leaf_alloc {
 };
 
 struct node_alloc {
-    int buffer_id;
     struct tensor_alloc dst;
     struct tensor_alloc src[GGML_MAX_SRC];
 };
@@ -412,8 +409,19 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs
     for (int i = 0; i < n_bufs; i++) {
         galloc->bufts[i] = bufts[i];
         galloc->buffers[i] = NULL;
-        size_t alignment = ggml_backend_buft_get_alignment(bufts[i]);
-        galloc->buf_tallocs[i] = ggml_dyn_tallocr_new(alignment);
+
+        // check if the same buffer type is used multiple times and reuse the same allocator
+        for (int j = 0; j < i; j++) {
+            if (bufts[i] == bufts[j]) {
+                galloc->buf_tallocs[i] = galloc->buf_tallocs[j];
+                break;
+            }
+        }
+
+        if (galloc->buf_tallocs[i] == NULL) {
+            size_t alignment = ggml_backend_buft_get_alignment(bufts[i]);
+            galloc->buf_tallocs[i] = ggml_dyn_tallocr_new(alignment);
+        }
     }
     galloc->n_buffers = n_bufs;
 
@@ -431,14 +439,34 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) {
 
     for (int i = 0; i < galloc->n_buffers; i++) {
         if (galloc->buffers != NULL) {
-            ggml_backend_buffer_free(galloc->buffers[i]);
+            // skip if already freed
+            bool freed = false;
+            for (int j = 0; j < i; j++) {
+                if (galloc->buffers[j] == galloc->buffers[i]) {
+                    freed = true;
+                    break;
+                }
+            }
+            if (!freed) {
+                ggml_backend_buffer_free(galloc->buffers[i]);
+            }
         }
         if (galloc->buf_tallocs != NULL) {
-            ggml_dyn_tallocr_free(galloc->buf_tallocs[i]);
+            // skip if already freed
+            bool freed = false;
+            for (int j = 0; j < i; j++) {
+                if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) {
+                    freed = true;
+                    break;
+                }
+            }
+            if (!freed) {
+                ggml_dyn_tallocr_free(galloc->buf_tallocs[i]);
+            }
         }
     }
 
-    free(galloc->hash_set.keys);
+    ggml_hash_set_free(&galloc->hash_set);
     free(galloc->hash_values);
     free(galloc->bufts);
     free(galloc->buffers);
@@ -451,7 +479,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) {
 typedef struct ggml_gallocr * ggml_gallocr_t;
 
 static struct hash_node * ggml_gallocr_hash_get(ggml_gallocr_t galloc, struct ggml_tensor * t) {
-    size_t i = ggml_hash_find_or_insert(galloc->hash_set, t);
+    size_t i = ggml_hash_find_or_insert(&galloc->hash_set, t);
     return &galloc->hash_values[i];
 }
 
@@ -537,17 +565,18 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
     }
 }
 
-static void ggml_gallocr_free_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
+static void ggml_gallocr_free_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
     // graph outputs are never freed
     if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
         AT_PRINTF("not freeing output %s\n", node->name);
         return;
     }
 
-    struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id];
-    ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id];
     struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
     size_t offset = hn->offset;
+    int buffer_id = hn->buffer_id;
+    struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id];
+    ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id];
     size_t size = ggml_backend_buft_get_alloc_size(buft, node);
     ggml_dyn_tallocr_free_tensor(alloc, offset, size, node);
     hn->allocated = false;
@@ -559,8 +588,8 @@ static int get_node_buffer_id(const int * node_buffer_ids, int i) {
 
 static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) {
     // clear hash tables
-    memset(galloc->hash_set.keys, 0, galloc->hash_set.size * sizeof(struct ggml_tensor *));
-    memset(galloc->hash_values,   0, galloc->hash_set.size * sizeof(struct hash_node));
+    ggml_hash_set_reset(&galloc->hash_set);
+    memset(galloc->hash_values, 0, sizeof(struct hash_node) * galloc->hash_set.size);
 
     // allocate leafs
     // these may be tensors that the application is not using in the graph, but may still want to allocate for other purposes
@@ -652,11 +681,11 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
                     AT_PRINTF("view_src %s: %d children, %d views\n",
                         view_src->name, view_src_hn->n_children, view_src_hn->n_views);
                     if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src_hn->allocated) {
-                        ggml_gallocr_free_node(galloc, view_src, buffer_id);
+                        ggml_gallocr_free_node(galloc, view_src);
                     }
                 }
                 else if (p_hn->allocated) {
-                    ggml_gallocr_free_node(galloc, parent, buffer_id);
+                    ggml_gallocr_free_node(galloc, parent);
                 }
             }
             AT_PRINTF("\n");
@@ -665,21 +694,19 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
 }
 
 bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) {
-    size_t hash_size = graph->visited_hash_table.size;
+    size_t min_hash_size = graph->n_nodes + graph->n_leafs;
+    // add 25% margin to avoid hash collisions
+    min_hash_size += min_hash_size / 4;
 
     // initialize hash table
-    if (galloc->hash_set.size < hash_size) {
-        free(galloc->hash_set.keys);
-        free(galloc->hash_values);
-        galloc->hash_set.size = hash_size;
-        galloc->hash_set.keys = calloc(hash_size, sizeof(struct ggml_tensor *));
-        galloc->hash_values   = calloc(hash_size, sizeof(struct hash_node));
+    if (galloc->hash_set.size < min_hash_size) {
+        ggml_hash_set_free(&galloc->hash_set);
+        galloc->hash_set = ggml_hash_set_new(min_hash_size);
         GGML_ASSERT(galloc->hash_set.keys != NULL);
+
+        free(galloc->hash_values);
+        galloc->hash_values = malloc(sizeof(struct hash_node) * galloc->hash_set.size);
         GGML_ASSERT(galloc->hash_values != NULL);
-    } else {
-        // reset hash table
-        memset(galloc->hash_set.keys, 0, sizeof(struct ggml_tensor *) * galloc->hash_set.size);
-        memset(galloc->hash_values,   0, sizeof(struct hash_node) * galloc->hash_set.size);
     }
 
     // reset allocators
@@ -700,22 +727,25 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
     for (int i = 0; i < graph->n_nodes; i++) {
         struct ggml_tensor * node = graph->nodes[i];
         struct node_alloc * node_alloc = &galloc->node_allocs[i];
-        node_alloc->buffer_id = get_node_buffer_id(node_buffer_ids, i);
         if (node->view_src || node->data) {
+            node_alloc->dst.buffer_id = -1;
             node_alloc->dst.offset = SIZE_MAX;
             node_alloc->dst.size_max = 0;
         } else {
             struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
-            node_alloc->dst.offset   = hn->offset;
-            node_alloc->dst.size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);
+            node_alloc->dst.buffer_id = hn->buffer_id;
+            node_alloc->dst.offset    = hn->offset;
+            node_alloc->dst.size_max  = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);
         }
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             struct ggml_tensor * src = node->src[j];
             if (!src || src->view_src || src->data) {
+                node_alloc->src[j].buffer_id = -1;
                 node_alloc->src[j].offset = SIZE_MAX;
                 node_alloc->src[j].size_max = 0;
             } else {
                 struct hash_node * hn = ggml_gallocr_hash_get(galloc, src);
+                node_alloc->src[j].buffer_id = hn->buffer_id;
                 node_alloc->src[j].offset   = hn->offset;
                 node_alloc->src[j].size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], src);
             }
@@ -732,9 +762,11 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
         struct hash_node * hn = ggml_gallocr_hash_get(galloc, leaf);
         galloc->leaf_allocs[i].buffer_id = hn->buffer_id;
         if (leaf->view_src || leaf->data) {
+            galloc->leaf_allocs[i].leaf.buffer_id = -1;
             galloc->leaf_allocs[i].leaf.offset = SIZE_MAX;
             galloc->leaf_allocs[i].leaf.size_max = 0;
         } else {
+            galloc->leaf_allocs[i].leaf.buffer_id = hn->buffer_id;
             galloc->leaf_allocs[i].leaf.offset = hn->offset;
             galloc->leaf_allocs[i].leaf.size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], leaf);
         }
@@ -742,6 +774,14 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
 
     // reallocate buffers if needed
     for (int i = 0; i < galloc->n_buffers; i++) {
+        // if the buffer type is used multiple times, we reuse the same buffer
+        for (int j = 0; j < i; j++) {
+            if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) {
+                galloc->buffers[i] = galloc->buffers[j];
+                break;
+            }
+        }
+
         size_t cur_size = galloc->buffers[i] ? ggml_backend_buffer_get_size(galloc->buffers[i]) : 0;
         size_t new_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]);
 
@@ -750,12 +790,14 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
 #ifndef NDEBUG
             fprintf(stderr, "%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
 #endif
+
             ggml_backend_buffer_free(galloc->buffers[i]);
             galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size);
             if (galloc->buffers[i] == NULL) {
                 fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
                 return false;
             }
+            ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
         }
     }
 
@@ -766,7 +808,8 @@ bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
     return ggml_gallocr_reserve_n(galloc, graph, NULL, NULL);
 }
 
-static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * tensor, int buffer_id, struct tensor_alloc * tensor_alloc) {
+static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * tensor, struct tensor_alloc * tensor_alloc) {
+    int buffer_id = tensor_alloc->buffer_id;
     assert(tensor->data || tensor->view_src || ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max);
 
     if (tensor->view_src != NULL) {
@@ -794,9 +837,8 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor *
     }
 }
 
-static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct node_alloc * nalloc, struct tensor_alloc * talloc) {
-    ggml_backend_buffer_type_t buft = galloc->bufts[nalloc->buffer_id];
-    size_t node_size = (node->data || node->view_src) ? 0 : ggml_backend_buft_get_alloc_size(buft, node);
+static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) {
+    size_t node_size = (node->data || node->view_src) ? 0 : ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);
     return talloc->size_max >= node_size;
 }
 
@@ -819,7 +861,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph
         struct ggml_tensor * node = graph->nodes[i];
         struct node_alloc * node_alloc = &galloc->node_allocs[i];
 
-        if (!ggml_gallocr_node_needs_realloc(galloc, node, node_alloc, &node_alloc->dst)) {
+        if (!ggml_gallocr_node_needs_realloc(galloc, node, &node_alloc->dst)) {
 #ifndef NDEBUG
             fprintf(stderr, "%s: node %s is not valid\n", __func__, node->name);
 #endif
@@ -831,7 +873,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph
             if (src == NULL) {
                 continue;
             }
-            if (!ggml_gallocr_node_needs_realloc(galloc, src, node_alloc, &node_alloc->src[j])) {
+            if (!ggml_gallocr_node_needs_realloc(galloc, src, &node_alloc->src[j])) {
 #ifndef NDEBUG
                 fprintf(stderr, "%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name);
 #endif
@@ -872,7 +914,7 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph)
     for (int i = 0; i < graph->n_leafs; i++) {
         struct ggml_tensor * leaf = graph->leafs[i];
         struct leaf_alloc * leaf_alloc = &galloc->leaf_allocs[i];
-        ggml_gallocr_init_tensor(galloc, leaf, leaf_alloc->buffer_id, &leaf_alloc->leaf);
+        ggml_gallocr_init_tensor(galloc, leaf, &leaf_alloc->leaf);
     }
     // nodes
     for (int i = 0; i < graph->n_nodes; i++) {
@@ -883,9 +925,9 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph)
             if (src == NULL) {
                 continue;
             }
-            ggml_gallocr_init_tensor(galloc, src, node_alloc->buffer_id, &node_alloc->src[j]);
+            ggml_gallocr_init_tensor(galloc, src, &node_alloc->src[j]);
         }
-        ggml_gallocr_init_tensor(galloc, node, node_alloc->buffer_id, &node_alloc->dst);
+        ggml_gallocr_init_tensor(galloc, node, &node_alloc->dst);
     }
 
     return true;
@@ -897,6 +939,15 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
     if (galloc->buffers[buffer_id] == NULL) {
         return 0;
     }
+
+    for (int i = 0; i < buffer_id; i++) {
+        if (galloc->buffers[i] == galloc->buffers[buffer_id]) {
+            // this buffer is the same as a previous one due to the same buffer type being used multiple times
+            // only return the buffer size the first time it appears to avoid double counting
+            return 0;
+        }
+    }
+
     return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]);
 }
 
@@ -912,7 +963,7 @@ static bool alloc_tensor_range(struct ggml_context * ctx,
         fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size);
 #endif
         for (size_t i = 0; i < *n_buffers; i++) {
-            ggml_backend_buffer_free(*buffers[i]);
+            ggml_backend_buffer_free((*buffers)[i]);
         }
         free(*buffers);
         return false;

+ 1 - 1
llama/ggml-alloc.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 21 - 9
llama/ggml-backend-impl.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -43,13 +43,15 @@ extern "C" {
 
     struct ggml_backend_buffer_type_i {
         const char *          (*GGML_CALL get_name)        (ggml_backend_buffer_type_t buft);
+        // allocate a buffer of this type
         ggml_backend_buffer_t (*GGML_CALL alloc_buffer)    (ggml_backend_buffer_type_t buft, size_t size);
-        size_t                (*GGML_CALL get_alignment)   (ggml_backend_buffer_type_t buft); // tensor alignment
-        size_t                (*GGML_CALL get_max_size)    (ggml_backend_buffer_type_t buft); // allocation max size
-        size_t                (*GGML_CALL get_alloc_size)  (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
-        bool                  (*GGML_CALL supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
+        // tensor alignment
+        size_t                (*GGML_CALL get_alignment)   (ggml_backend_buffer_type_t buft);
+        // max buffer size that can be allocated
+        size_t                (*GGML_CALL get_max_size)    (ggml_backend_buffer_type_t buft);
+        // data size needed to allocate the tensor, including padding
+        size_t                (*GGML_CALL get_alloc_size)  (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
         // check if tensor data is in host memory
-        // should be equivalent to supports_backend(buft, ggml_backend_cpu_init())
         bool                  (*GGML_CALL is_host)         (ggml_backend_buffer_type_t buft);
     };
 
@@ -118,27 +120,37 @@ extern "C" {
         void (*GGML_CALL synchronize)(ggml_backend_t backend);
 
         // compute graph with a plan (not used currently)
+        // create a new plan for a graph
         ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
         void                      (*GGML_CALL graph_plan_free)   (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+        // update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology
+        void                      (*GGML_CALL graph_plan_update) (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph);
+        // compute the graph with the plan
+        enum ggml_status          (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
 
-        // compute graph with a plan
-        enum ggml_status (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
         // compute graph without a plan (async)
         enum ggml_status (*GGML_CALL graph_compute)     (ggml_backend_t backend, struct ggml_cgraph * cgraph);
 
-        // check if the backend supports an operation
+        // check if the backend can compute an operation
         bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
 
+        // check if the backend can use tensors allocated in a buffer type
+        bool (*GGML_CALL supports_buft)(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
+
         // check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer
         // these should be expensive operations with large batch sizes that may benefit from running on this backend
         // even if the weight has to be copied from the CPU temporarily
         bool (*GGML_CALL offload_op)(ggml_backend_t backend, const struct ggml_tensor * op);
 
         // (optional) event synchronization
+        // create a new event that can record events on this backend instance
         ggml_backend_event_t (*GGML_CALL event_new)         (ggml_backend_t backend);
         void                 (*GGML_CALL event_free)        (ggml_backend_event_t event);
+        // record an event on the backend instance that created it
         void                 (*GGML_CALL event_record)      (ggml_backend_event_t event);
+        // wait for an event on on a different backend instance
         void                 (*GGML_CALL event_wait)        (ggml_backend_t backend, ggml_backend_event_t event);
+        // block until an event is recorded
         void                 (*GGML_CALL event_synchronize) (ggml_backend_event_t event);
     };
 

+ 303 - 163
llama/ggml-backend.c

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -70,10 +70,6 @@ GGML_CALL size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buf
     return ggml_nbytes(tensor);
 }
 
-bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
-    return buft->iface.supports_backend(buft, backend);
-}
-
 bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) {
     if (buft->iface.is_host) {
         return buft->iface.is_host(buft);
@@ -169,6 +165,10 @@ void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backe
     }
 }
 
+enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage(ggml_backend_buffer_t buffer) {
+    return buffer->usage;
+}
+
 ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) {
     return buffer->buft;
 }
@@ -317,6 +317,10 @@ bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor *
     return backend->iface.supports_op(backend, op);
 }
 
+bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+    return backend->iface.supports_buft(backend, buft);
+}
+
 bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) {
     if (backend->iface.offload_op != NULL) {
         return backend->iface.offload_op(backend, op);
@@ -425,7 +429,7 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
 
 // backend registry
 
-#define GGML_REG_MAX_BACKENDS 16
+#define GGML_REG_MAX_BACKENDS 64
 
 struct ggml_backend_reg {
     char name[128];
@@ -476,6 +480,11 @@ GGML_CALL static void ggml_backend_registry_init(void) {
     extern GGML_CALL void ggml_backend_kompute_reg_devices(void);
     ggml_backend_kompute_reg_devices();
 #endif
+
+#ifdef GGML_USE_CANN
+    extern GGML_CALL int ggml_backend_cann_reg_devices(void);
+    ggml_backend_cann_reg_devices();
+#endif
 }
 
 GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
@@ -670,12 +679,6 @@ GGML_CALL static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_
     GGML_UNUSED(buft);
 }
 
-GGML_CALL static bool ggml_backend_cpu_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
-    return ggml_backend_is_cpu(backend);
-
-    GGML_UNUSED(buft);
-}
-
 GGML_CALL static bool ggml_backend_cpu_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
     return true;
 
@@ -690,7 +693,6 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) {
             /* .get_alignment    = */ ggml_backend_cpu_buffer_type_get_alignment,
             /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
             /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
-            /* .supports_backend = */ ggml_backend_cpu_buffer_type_supports_backend,
             /* .is_host          = */ ggml_backend_cpu_buffer_type_is_host,
         },
         /* .context = */ NULL,
@@ -746,7 +748,6 @@ ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
             /* .get_alignment    = */ ggml_backend_cpu_buffer_type_get_alignment,
             /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
             /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
-            /* .supports_backend = */ ggml_backend_cpu_buffer_type_supports_backend,
             /* .is_host          = */ ggml_backend_cpu_buffer_type_is_host,
         },
         /* .context  = */ NULL,
@@ -867,6 +868,12 @@ GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const
     GGML_UNUSED(backend);
 }
 
+GGML_CALL static bool ggml_backend_cpu_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+    return ggml_backend_buft_is_host(buft);
+
+    GGML_UNUSED(backend);
+}
+
 static struct ggml_backend_i cpu_backend_i = {
     /* .get_name                = */ ggml_backend_cpu_name,
     /* .free                    = */ ggml_backend_cpu_free,
@@ -877,9 +884,11 @@ static struct ggml_backend_i cpu_backend_i = {
     /* .synchronize             = */ NULL,
     /* .graph_plan_create       = */ ggml_backend_cpu_graph_plan_create,
     /* .graph_plan_free         = */ ggml_backend_cpu_graph_plan_free,
+    /* .graph_plan_update       = */ NULL,
     /* .graph_plan_compute      = */ ggml_backend_cpu_graph_plan_compute,
     /* .graph_compute           = */ ggml_backend_cpu_graph_compute,
     /* .supports_op             = */ ggml_backend_cpu_supports_op,
+    /* .supports_buft           = */ ggml_backend_cpu_supports_buft,
     /* .offload_op              = */ NULL,
     /* .event_new               = */ NULL,
     /* .event_free              = */ NULL,
@@ -1077,17 +1086,19 @@ struct ggml_backend_sched {
     ggml_backend_buffer_type_t bufts[GGML_SCHED_MAX_BACKENDS];
     ggml_gallocr_t galloc;
 
-    // hash keys of the nodes in the graph
-    struct ggml_hash_set    hash_set;
-    // hash values
-    int * tensor_backend_id;
-    struct ggml_tensor * (* tensor_copies)[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES];
+    // hash map of the nodes in the graph
+    struct ggml_hash_set  hash_set;
+    int                 * hv_tensor_backend_ids; // [hash_set.size]
+    struct ggml_tensor ** hv_tensor_copies;      // [hash_set.size][n_backends][n_copies]
 
     int * node_backend_ids; // [graph_size]
     int * leaf_backend_ids; // [graph_size]
 
+    int * prev_node_backend_ids; // [graph_size]
+    int * prev_leaf_backend_ids; // [graph_size]
+
     // copy of the graph with modified inputs
-    struct ggml_cgraph * graph;
+    struct ggml_cgraph graph;
 
     // graph splits
     struct ggml_backend_sched_split * splits;
@@ -1106,17 +1117,16 @@ struct ggml_backend_sched {
     ggml_backend_sched_eval_callback callback_eval;
     void * callback_eval_user_data;
 
-    // align context_buffer to GGML_MEM_ALIGN
-#ifdef _MSC_VER
-    __declspec(align(GGML_MEM_ALIGN))
-#else
-    __attribute__((aligned(GGML_MEM_ALIGN)))
-#endif
-    char context_buffer[GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
+    char * context_buffer;
+    size_t context_buffer_size;
+
+    bool debug;
 };
 
-#define hash_id(tensor) ggml_hash_find_or_insert(sched->hash_set, tensor)
-#define tensor_backend_id(tensor) sched->tensor_backend_id[hash_id(tensor)]
+#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)
+#define tensor_backend_id(tensor) sched->hv_tensor_backend_ids[hash_id(tensor)]
+#define tensor_id_copy(id, backend_id, copy_id) sched->hv_tensor_copies[(id) * sched->n_backends * sched->n_copies + (backend_id) * sched->n_copies + (copy_id)]
+#define tensor_copy(tensor, backend_id, copy_id) tensor_id_copy(hash_id(tensor), backend_id, copy_id)
 
 // returns the priority of the backend, lower id is higher priority
 static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backend_t backend) {
@@ -1128,22 +1138,24 @@ static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backen
     return -1;
 }
 
-static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor) {
+static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor, const struct ggml_tensor * op) {
     ggml_backend_buffer_t buffer = tensor->buffer;
     if (buffer == NULL) {
         return -1;
     }
 
-    // find highest prio backend that supports the buffer type
+    // find highest prio backend that supports the buffer type and the op
     for (int i = 0; i < sched->n_backends; i++) {
-        if (ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) {
+        if (ggml_backend_supports_buft(sched->backends[i], buffer->buft) &&
+            ggml_backend_supports_op(sched->backends[i], op)) {
             return i;
         }
     }
 
-    fprintf(stderr, "%s: error: no backend supports buffer type %s used in tensor %s\n",
-        __func__, ggml_backend_buffer_name(buffer), tensor->name);
-    GGML_ASSERT(false);
+#ifndef NDEBUG
+    fprintf(stderr, "%s: warning: no backend supports op %s with a weight with buffer type %s used in tensor %s, the weight will need to be copied\n",
+        __func__, ggml_op_desc(tensor), ggml_backend_buffer_name(buffer), tensor->name);
+#endif
 
     return -1;
 }
@@ -1162,7 +1174,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
     // TODO: use supports_op to check if the backend supports the op
 
     // assign pre-allocated nodes to their backend
-    int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor);
+    int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor, tensor);
     if (cur_backend_id != -1) {
         SET_CAUSE(tensor, "1.dst");
         return cur_backend_id;
@@ -1170,7 +1182,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
 
     // view_src
     if (tensor->view_src != NULL) {
-        cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src);
+        cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src, tensor);
         if (cur_backend_id != -1) {
             SET_CAUSE(tensor, "1.vsrc");
             return cur_backend_id;
@@ -1184,7 +1196,6 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
         return cur_backend_id;
     }
 
-    // assign nodes that use weights to the backend of the weights
     // operations with weights are preferably run on the same backend as the weights
     for (int i = 0; i < GGML_MAX_SRC; i++) {
         const struct ggml_tensor * src = tensor->src[i];
@@ -1192,11 +1203,11 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
             continue;
         }
         if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
-            int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src);
+            int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
             // check if a backend with higher prio wants to offload the op
             if (src_backend_id == sched->n_backends - 1) {
                 for (int b = 0; b < src_backend_id; b++) {
-                    if (ggml_backend_offload_op(sched->backends[b], tensor)) {
+                    if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
                         SET_CAUSE(tensor, "1.off");
                         return b;
                     }
@@ -1254,10 +1265,33 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
     }
 }
 
-//#define DEBUG_PASS1
-//#define DEBUG_PASS2
-//#define DEBUG_PASS3
-//#define DEBUG_PASS4
+static bool ggml_backend_sched_buffer_supported(ggml_backend_sched_t sched, struct ggml_tensor * t, int backend_id) {
+    ggml_backend_buffer_t buf = t->view_src ? t->view_src->buffer : t->buffer;
+    ggml_backend_buffer_type_t buft = NULL;
+
+    if (buf) {
+        // the tensor is already allocated
+        buft = buf->buft;
+    } else {
+        // see if the tensor already has a backend assigned, and use the buffer type of that backend
+        int tensor_backend_id = tensor_backend_id(t);
+        if (tensor_backend_id == -1 && t->view_src) {
+            tensor_backend_id = tensor_backend_id(t->view_src);
+        }
+        if (tensor_backend_id != -1) {
+            buft = sched->bufts[tensor_backend_id];
+        }
+    }
+
+    return buft != NULL && ggml_backend_supports_buft(sched->backends[backend_id], buft);
+}
+
+static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, struct ggml_tensor * node, int cur_backend_id, int * node_backend_id) {
+    if (ggml_backend_supports_op(sched->backends[cur_backend_id], node)) {
+        *node_backend_id = cur_backend_id;
+        SET_CAUSE(node, "2.sup");
+    }
+}
 
 // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
 static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
@@ -1267,7 +1301,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
     sched->is_reset = false;
 
     struct ggml_init_params params = {
-        /* .mem_size =   */ sizeof(sched->context_buffer),
+        /* .mem_size =   */ sched->context_buffer_size,
         /* .mem_buffer = */ sched->context_buffer,
         /* .no_alloc =   */ true
     };
@@ -1276,52 +1310,52 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
 
     sched->ctx = ggml_init(params);
     if (sched->ctx == NULL) {
-        fprintf(stderr, "%s: failed to initialize context\n", __func__);
-        GGML_ASSERT(false);
+        GGML_ABORT("%s: failed to initialize context\n", __func__);
     }
 
     // pass 1: assign backends to ops with pre-allocated inputs
     for (int i = 0; i < graph->n_leafs; i++) {
         struct ggml_tensor * leaf = graph->leafs[i];
         int * leaf_backend_id = &tensor_backend_id(leaf);
-        if (*leaf_backend_id != -1) {
-            // do not overwrite user assignments
-            continue;
+        // do not overwrite user assignments
+        if (*leaf_backend_id == -1) {
+            *leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf);
         }
-        *leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf);
     }
 
     for (int i = 0; i < graph->n_nodes; i++) {
         struct ggml_tensor * node = graph->nodes[i];
         int * node_backend_id = &tensor_backend_id(node);
-        if (*node_backend_id != -1) {
-            // do not overwrite user assignments
-            continue;
-        }
-        *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node);
-        // src
-        for (int j = 0; j < GGML_MAX_SRC; j++) {
-            struct ggml_tensor * src = node->src[j];
-            if (src == NULL) {
+        // do not overwrite user assignments
+        if (*node_backend_id == -1) {
+            *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node);
+
+#if 0
+            // src
+            if (node->op == GGML_OP_NONE) {
                 continue;
             }
-            int * src_backend_id = &tensor_backend_id(src);
-            if (*src_backend_id == -1) {
-                *src_backend_id = ggml_backend_sched_backend_id_from_cur(sched, src);
+
+            for (int j = 0; j < GGML_MAX_SRC; j++) {
+                struct ggml_tensor * src = node->src[j];
+                if (src == NULL) {
+                    continue;
+                }
+                int * src_backend_id = &tensor_backend_id(src);
+                if (*src_backend_id == -1) {
+                    *src_backend_id = ggml_backend_sched_backend_id_from_cur(sched, src);
+                }
             }
+#endif
         }
     }
-#ifdef DEBUG_PASS1
-    fprintf(stderr, "PASS 1 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph);
-#endif
 
     // pass 2: expand current backend assignments
     // assign the same backend to adjacent nodes
     // expand gpu backends (i.e. non last prio) up and down, ignoring cpu (the lowest priority backend)
     // thus, cpu will never be used unless weights are on cpu, or there are no gpu ops between cpu ops
-
-
-    // pass 2.2 expand gpu down
+    // ops unsupported by the backend being expanded will be left unassigned so that they can be assigned later when the locations of its inputs are known
+    // expand gpu down
     {
         int cur_backend_id = -1;
         for (int i = 0; i < graph->n_nodes; i++) {
@@ -1337,13 +1371,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                 } else {
                     cur_backend_id = *node_backend_id;
                 }
-            } else {
-                *node_backend_id = cur_backend_id;
-                SET_CAUSE(node, "2.2");
+            } else if (cur_backend_id != -1) {
+                ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
             }
         }
     }
-    // pass 2.1 expand gpu up
+    // expand gpu up
     {
         int cur_backend_id = -1;
         for (int i = graph->n_nodes - 1; i >= 0; i--) {
@@ -1359,13 +1392,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                 } else {
                     cur_backend_id = *node_backend_id;
                 }
-            } else {
-                *node_backend_id = cur_backend_id;
-                SET_CAUSE(node, "2.1");
+            } else if (cur_backend_id != -1) {
+                ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
             }
         }
     }
-    // pass 2.4 expand rest down
+    // expand rest down
     {
         int cur_backend_id = -1;
         for (int i = 0; i < graph->n_nodes; i++) {
@@ -1376,13 +1408,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
             int * node_backend_id = &tensor_backend_id(node);
             if (*node_backend_id != -1) {
                 cur_backend_id = *node_backend_id;
-            } else {
-                *node_backend_id = cur_backend_id;
-                SET_CAUSE(node, "2.4");
+            } else if (cur_backend_id != -1) {
+                ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
             }
         }
     }
-    // pass 2.3 expand rest up
+    // expand rest up
     {
         int cur_backend_id = -1;
         for (int i = graph->n_nodes - 1; i >= 0; i--) {
@@ -1393,24 +1424,80 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
             int * node_backend_id = &tensor_backend_id(node);
             if (*node_backend_id != -1) {
                 cur_backend_id = *node_backend_id;
-            } else {
-                *node_backend_id = cur_backend_id;
-                SET_CAUSE(node, "2.3");
+            } else if (cur_backend_id != -1) {
+                ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
             }
         }
     }
 
-#ifdef DEBUG_PASS2
-    fprintf(stderr, "PASS 2 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph);
-#endif
+    // pass 3: upgrade nodes to higher prio backends with compatible buffer types
+    // if the tensor is already in the same buffer type (*) as another higher priority backend, we should move it there
+    // however, we also need to verify that the sources are in compatible buffer types
+    // (*) the actual requirement is more relaxed, the buffer type of the backend should be supported by all the users of this tensor further down the graph
+    // however, this is slow to verify, so we have a more strict requirement that the buffer type is the same
+    // this is not uncommon since multiple backends can use host memory, with the same buffer type (eg. BLAS and CPU)
+    // additionally, set remaining unassigned nodes to the backend with the most supported inputs
+    // only nodes that could not be assigned during expansion due to the backend not supporting the op should be unassigned at this point
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        if (ggml_is_view_op(node->op)) {
+            continue;
+        }
+        int * node_backend_id = &tensor_backend_id(node);
+        if (*node_backend_id == -1) {
+            // unassigned node: find the backend with the most supported inputs
+            int n_supported_best = -1;
+            for (int b = 0; b < sched->n_backends; b++) {
+                if (ggml_backend_supports_op(sched->backends[b], node)) {
+                    int n_supported = 0;
+                    for (int j = 0; j < GGML_MAX_SRC; j++) {
+                        struct ggml_tensor * src = node->src[j];
+                        if (src == NULL) {
+                            continue;
+                        }
+                        if ((tensor_backend_id(src) != -1 || tensor_backend_id(src->view_src) != -1) && ggml_backend_sched_buffer_supported(sched, src, b)) {
+                            n_supported++;
+                        }
+                    }
+                    if (n_supported > n_supported_best) {
+                        n_supported_best = n_supported;
+                        *node_backend_id = b;
+                        SET_CAUSE(node, "3.best");
+                    }
+                }
+            }
+        } else {
+            // assigned node: upgrade to higher prio backend if possible
+            for (int b = 0; b < *node_backend_id; b++) {
+                if (sched->bufts[b] == sched->bufts[*node_backend_id] && ggml_backend_supports_op(sched->backends[b], node)) {
+                    bool supported = true;
+                    for (int j = 0; j < GGML_MAX_SRC; j++) {
+                        struct ggml_tensor * src = node->src[j];
+                        if (src == NULL) {
+                            continue;
+                        }
+                        if (!ggml_backend_sched_buffer_supported(sched, src, b)) {
+                            supported = false;
+                            break;
+                        }
+                    }
+                    if (supported) {
+                        *node_backend_id = b;
+                        SET_CAUSE(node, "3.upg");
+                        break;
+                    }
+                }
+            }
+        }
+    }
 
-    // pass 3: assign backends to remaining src from dst and view_src
+    // pass 4: assign backends to remaining src from dst and view_src
     for (int i = 0; i < graph->n_nodes; i++) {
         struct ggml_tensor * node = graph->nodes[i];
         int * cur_backend_id = &tensor_backend_id(node);
         if (node->view_src != NULL && *cur_backend_id == -1) {
             *cur_backend_id = tensor_backend_id(node->view_src);
-            SET_CAUSE(node, "3.vsrc");
+            SET_CAUSE(node, "4.vsrc");
         }
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             struct ggml_tensor * src = node->src[j];
@@ -1422,24 +1509,22 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                 if (src->view_src != NULL) {
                     // views are always on the same backend as the source
                     *src_backend_id = tensor_backend_id(src->view_src);
-                    SET_CAUSE(src, "3.vsrc");
+                    SET_CAUSE(src, "4.vsrc");
                 } else {
                     *src_backend_id = *cur_backend_id;
-                    SET_CAUSE(src, "3.cur");
+                    SET_CAUSE(src, "4.cur");
                 }
             }
         }
     }
-#ifdef DEBUG_PASS3
-    fprintf(stderr, "PASS 3 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph);
-#endif
 
-    // pass 4: split graph, find tensors that need to be copied
+    // pass 5: split graph, find tensors that need to be copied
     {
         int i_split = 0;
         struct ggml_backend_sched_split * split = &sched->splits[0];
         // find the backend of the first split, skipping view ops
-        for (int i = 0; i < graph->n_nodes; i++) {
+        int i = 0;
+        for (; i < graph->n_nodes; i++) {
             struct ggml_tensor * node = graph->nodes[i];
             if (!ggml_is_view_op(node->op)) {
                 split->backend_id = tensor_backend_id(node);
@@ -1448,9 +1533,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
         }
         split->i_start = 0;
         split->n_inputs = 0;
-        memset(split->inputs, 0, sizeof(split->inputs)); //HACK
         int cur_backend_id = split->backend_id;
-        for (int i = 0; i < graph->n_nodes; i++) {
+        for (; i < graph->n_nodes; i++) {
             struct ggml_tensor * node = graph->nodes[i];
 
             if (ggml_is_view_op(node->op)) {
@@ -1459,7 +1543,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
 
             const int node_backend_id = tensor_backend_id(node);
 
-            GGML_ASSERT(node_backend_id != -1); // all nodes should be assigned by now
+            assert(node_backend_id != -1); // all nodes should be assigned by now
 
             // check if we should start a new split based on the sources of the current node
             bool need_new_split = false;
@@ -1473,16 +1557,18 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                     // by starting a new split, the memory of the previously offloaded weights can be reused
                     if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
                         int src_backend_id = tensor_backend_id(src);
-                        if (src_backend_id != -1 && src_backend_id != cur_backend_id) {
+                        if (src_backend_id != cur_backend_id) {
                             need_new_split = true;
                             break;
                         }
                     }
                     // check if the split has too many inputs
+                    // FIXME: count the number of inputs instead of only checking when full
                     if (split->n_inputs == GGML_SCHED_MAX_SPLIT_INPUTS) {
                         const size_t id = hash_id(src);
-                        int src_backend_id = sched->tensor_backend_id[id];
-                        if (src_backend_id != cur_backend_id && sched->tensor_copies[hash_id(src)][cur_backend_id][0] == NULL) {
+                        int src_backend_id = sched->hv_tensor_backend_ids[id];
+                        bool supported = ggml_backend_sched_buffer_supported(sched, src, cur_backend_id);
+                        if (src_backend_id != cur_backend_id && tensor_id_copy(id, cur_backend_id, 0) == NULL && !supported) {
                             //printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name);
                             need_new_split = true;
                             break;
@@ -1514,12 +1600,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                     continue;
                 }
 
-                const int src_backend_id = tensor_backend_id(src);
+                size_t src_id = hash_id(src);
+                const int src_backend_id = sched->hv_tensor_backend_ids[src_id];
                 assert(src_backend_id != -1); // all inputs should be assigned by now
 
-                if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1)  {
-                    size_t id = hash_id(src);
-                    if (sched->tensor_copies[id][src_backend_id][0] == NULL) {
+                if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) {
+                    if (tensor_id_copy(src_id, src_backend_id, 0) == NULL) {
                         ggml_backend_t backend = sched->backends[src_backend_id];
                         for (int c = 0; c < sched->n_copies; c++) {
                             struct ggml_tensor * tensor_copy;
@@ -1533,7 +1619,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                                 ggml_set_input(tensor_copy);
                                 ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor
                             }
-                            sched->tensor_copies[id][src_backend_id][c] = tensor_copy;
+                            tensor_id_copy(src_id, src_backend_id, c) = tensor_copy;
                             SET_CAUSE(tensor_copy, "4.cpy");
                         }
                         int n_graph_inputs = sched->n_graph_inputs++;
@@ -1542,10 +1628,9 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                     }
                 }
 
-                if (src_backend_id != node_backend_id) {
+                if (src_backend_id != cur_backend_id && !ggml_backend_sched_buffer_supported(sched, src, cur_backend_id)) {
                     // create a copy of the input in the split's backend
-                    const size_t id = hash_id(src);
-                    if (sched->tensor_copies[id][cur_backend_id][0] == NULL) {
+                    if (tensor_id_copy(src_id, cur_backend_id, 0) == NULL) {
                         ggml_backend_t backend = sched->backends[cur_backend_id];
                         for (int c = 0; c < sched->n_copies; c++) {
                             struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
@@ -1554,27 +1639,49 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                                 ggml_set_input(tensor_copy);
                                 ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor
                             }
-                            sched->tensor_copies[id][cur_backend_id][c] = tensor_copy;
+                            tensor_id_copy(src_id, cur_backend_id, c) = tensor_copy;
                             SET_CAUSE(tensor_copy, "4.cpy");
                         }
                         int n_inputs = split->n_inputs++;
                         GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);
                         split->inputs[n_inputs] = src;
                     }
-                    node->src[j] = sched->tensor_copies[id][cur_backend_id][sched->cur_copy];
+                    node->src[j] = tensor_id_copy(src_id, cur_backend_id, sched->cur_copy);
                 }
             }
         }
         split->i_end = graph->n_nodes;
         sched->n_splits = i_split + 1;
     }
-#ifdef DEBUG_PASS4
-    fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph);
-#endif
 
-    // create copies of the graph for each split
-    // TODO: avoid this copy
-    struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2, false);
+    if (sched->debug) {
+        ggml_backend_sched_print_assignments(sched, graph);
+    }
+
+    // swap node_backend_ids and leaf _backend_ids with prevs
+    {
+        int * tmp = sched->node_backend_ids;
+        sched->node_backend_ids = sched->prev_node_backend_ids;
+        sched->prev_node_backend_ids = tmp;
+
+        tmp = sched->leaf_backend_ids;
+        sched->leaf_backend_ids = sched->prev_leaf_backend_ids;
+        sched->prev_leaf_backend_ids = tmp;
+    }
+
+    int graph_size = graph->n_nodes + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2;
+    if (sched->graph.size < graph_size) {
+        sched->graph.size = graph_size;
+        sched->graph.nodes = realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *));
+        sched->graph.leafs = realloc(sched->graph.leafs, graph_size * sizeof(struct ggml_tensor *));
+        GGML_ASSERT(sched->graph.nodes != NULL);
+        GGML_ASSERT(sched->graph.leafs != NULL);
+    }
+    sched->graph.n_nodes = 0;
+    sched->graph.n_leafs = 0;
+
+    struct ggml_cgraph * graph_copy = &sched->graph;
+
     for (int i = 0; i < sched->n_splits; i++) {
         struct ggml_backend_sched_split * split = &sched->splits[i];
         split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
@@ -1585,12 +1692,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
 
             struct ggml_tensor * input = split->inputs[j];
             const size_t input_id = hash_id(input);
-            struct ggml_tensor * input_cpy = sched->tensor_copies[input_id][split->backend_id][sched->cur_copy];
+            struct ggml_tensor * input_cpy = tensor_id_copy(input_id, split->backend_id, sched->cur_copy);
 
             // add a dependency to the input source so that it is not freed before the copy is done
             struct ggml_tensor * input_dep = ggml_view_tensor(sched->ctx, input);
             input_dep->src[0] = input;
-            sched->node_backend_ids[graph_copy->n_nodes] = sched->tensor_backend_id[input_id];
+            sched->node_backend_ids[graph_copy->n_nodes] = sched->hv_tensor_backend_ids[input_id];
             graph_copy->nodes[graph_copy->n_nodes++] = input_dep;
 
             // add a dependency to the input copy so that it is allocated at the start of the split
@@ -1612,7 +1719,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
             size_t id = hash_id(input);
             int backend_id = tensor_backend_id(input);
             for (int c = 0; c < sched->n_copies; c++) {
-                struct ggml_tensor * input_cpy = sched->tensor_copies[id][backend_id][c];
+                struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c);
                 sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;
                 graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;
             }
@@ -1625,7 +1732,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                 struct ggml_tensor * input = split->inputs[j];
                 size_t id = hash_id(input);
                 for (int c = 0; c < sched->n_copies; c++) {
-                    struct ggml_tensor * input_cpy = sched->tensor_copies[id][backend_id][c];
+                    struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c);
                     sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;
                     graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;
                 }
@@ -1639,20 +1746,36 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
         sched->leaf_backend_ids[graph_copy->n_leafs] = tensor_backend_id(leaf);
         graph_copy->leafs[graph_copy->n_leafs++] = leaf;
     }
-
-    sched->graph = graph_copy;
 }
 
 static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
+    bool backend_ids_changed = false;
+    for (int i = 0; i < sched->graph.n_nodes; i++) {
+        if (sched->node_backend_ids[i] != sched->prev_node_backend_ids[i] &&
+            sched->bufts[sched->node_backend_ids[i]] != sched->bufts[sched->prev_node_backend_ids[i]]) {
+            backend_ids_changed = true;
+            break;
+        }
+    }
+    if (!backend_ids_changed) {
+        for (int i = 0; i < sched->graph.n_leafs; i++) {
+            if (sched->leaf_backend_ids[i] != sched->prev_leaf_backend_ids[i] &&
+                sched->bufts[sched->leaf_backend_ids[i]] != sched->bufts[sched->prev_leaf_backend_ids[i]]) {
+                backend_ids_changed = true;
+                break;
+            }
+        }
+    }
+
     // allocate graph
-    if (!ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) {
+    if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
         // the re-allocation may cause the split inputs to be moved to a different address
         ggml_backend_sched_synchronize(sched);
 #ifndef NDEBUG
-        fprintf(stderr, "%s: failed to allocate graph, reserving\n", __func__);
+        fprintf(stderr, "%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
 #endif
-        ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids, sched->leaf_backend_ids);
-        if (!ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) {
+        ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids);
+        if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
             fprintf(stderr, "%s: failed to allocate graph\n", __func__);
             return false;
         }
@@ -1673,7 +1796,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
         for (int j = 0; j < split->n_inputs; j++) {
             ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]);
             struct ggml_tensor * input = split->inputs[j];
-            struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split_backend_id][sched->cur_copy];
+            struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy);
 
             if (input->flags & GGML_TENSOR_FLAG_INPUT) {
                 // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
@@ -1758,18 +1881,24 @@ ggml_backend_sched_t ggml_backend_sched_new(
 
     struct ggml_backend_sched * sched = calloc(1, sizeof(struct ggml_backend_sched));
 
+    sched->debug = getenv("GGML_SCHED_DEBUG") != NULL;
+    sched->n_backends = n_backends;
+    sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1;
+
     // initialize hash table
-    sched->hash_set          = ggml_hash_set_new(graph_size);
-    sched->tensor_backend_id = calloc(sched->hash_set.size, sizeof(sched->tensor_backend_id[0]));
-    sched->tensor_copies     = calloc(sched->hash_set.size, sizeof(sched->tensor_copies[0]));
+    // FIXME: needs to be size*2 to account for leafs (do it in graph_split instead)
+    sched->hash_set    = ggml_hash_set_new(graph_size);
+    sched->hv_tensor_backend_ids = malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));
+    sched->hv_tensor_copies      = malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));
 
     const size_t nodes_size = graph_size + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2;
-    sched->node_backend_ids  = calloc(nodes_size, sizeof(sched->node_backend_ids[0]));
-    sched->leaf_backend_ids  = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0]));
+    sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0]));
+    sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0]));
+    sched->prev_node_backend_ids = calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0]));
+    sched->prev_leaf_backend_ids = calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0]));
 
-    sched->n_backends = n_backends;
-
-    sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1;
+    sched->context_buffer_size = GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false);
+    sched->context_buffer = malloc(sched->context_buffer_size);
 
     const int initial_splits_capacity = 16;
     sched->splits = calloc(initial_splits_capacity, sizeof(sched->splits[0]));
@@ -1778,7 +1907,7 @@ ggml_backend_sched_t ggml_backend_sched_new(
     for (int b = 0; b < n_backends; b++) {
         sched->backends[b] = backends[b];
         sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]);
-        GGML_ASSERT(ggml_backend_buft_supports_backend(sched->bufts[b], backends[b]));
+        GGML_ASSERT(ggml_backend_supports_buft(backends[b], sched->bufts[b]));
         if (sched->n_copies > 1) {
             for (int c = 0; c < sched->n_copies; c++) {
                 sched->events[b][c] = ggml_backend_event_new(backends[b]);
@@ -1804,35 +1933,37 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
     }
     ggml_gallocr_free(sched->galloc);
     ggml_free(sched->ctx);
+    ggml_hash_set_free(&sched->hash_set);
     free(sched->splits);
-    free(sched->hash_set.keys);
-    free(sched->tensor_backend_id);
-    free(sched->tensor_copies);
+    free(sched->hv_tensor_backend_ids);
+    free(sched->hv_tensor_copies);
     free(sched->node_backend_ids);
     free(sched->leaf_backend_ids);
+    free(sched->prev_node_backend_ids);
+    free(sched->prev_leaf_backend_ids);
+    free(sched->context_buffer);
+    free(sched->graph.nodes);
+    free(sched->graph.leafs);
     free(sched);
 }
 
 void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
     // reset state for the next run
     if (!sched->is_reset) {
-        size_t hash_size = sched->hash_set.size;
-        memset(sched->hash_set.keys,      0, sizeof(sched->hash_set.keys[0])     * hash_size); // NOLINT
-        memset(sched->tensor_backend_id, -1, sizeof(sched->tensor_backend_id[0]) * hash_size);
-        memset(sched->tensor_copies,      0, sizeof(sched->tensor_copies[0])     * hash_size);
-
+        ggml_hash_set_reset(&sched->hash_set);
+        memset(sched->hv_tensor_backend_ids, -1, sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));
+        memset(sched->hv_tensor_copies,       0, sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));
         sched->is_reset = true;
     }
     sched->is_alloc = false;
 }
 
 bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
-    GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes);
+    GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
 
     ggml_backend_sched_split_graph(sched, measure_graph);
 
-    // TODO: extract this to a separate function
-    if (!ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) {
+    if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) {
         return false;
     }
 
@@ -1843,10 +1974,11 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
 }
 
 bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
-    GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes);
+    GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);
 
     ggml_backend_sched_split_graph(sched, graph);
 
+
     if (!ggml_backend_sched_alloc_splits(sched)) {
         return false;
     }
@@ -1895,6 +2027,15 @@ int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) {
     return sched->n_copies;
 }
 
+int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) {
+    return sched->n_backends;
+}
+
+ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) {
+    GGML_ASSERT(i >= 0 && i < sched->n_backends);
+    return sched->backends[i];
+}
+
 size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
     int backend_index = ggml_backend_sched_backend_id(sched, backend);
     GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
@@ -1906,6 +2047,8 @@ void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct gg
     int backend_index = ggml_backend_sched_backend_id(sched, backend);
     GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
     tensor_backend_id(node) = backend_index;
+    SET_CAUSE(node, "usr");
+    sched->is_reset = false;
 }
 
 ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) {
@@ -1948,9 +2091,9 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set,
     GGML_ASSERT(src != NULL);
     GGML_ASSERT(src->data && "graph must be allocated");
 
-    size_t id = ggml_hash_insert(hash_set, src);
-    if (id == GGML_HASHTABLE_ALREADY_EXISTS) {
-        return node_copies[ggml_hash_find(hash_set, src)];
+    size_t id = ggml_hash_insert(&hash_set, src);
+    if (id == GGML_HASHSET_ALREADY_EXISTS) {
+        return node_copies[ggml_hash_find(&hash_set, src)];
     }
 
     struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src);
@@ -1975,7 +2118,7 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set,
     return dst;
 }
 
-static void graph_copy_init_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) {
+static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) {
     size_t id = ggml_hash_find(hash_set, src);
     if (node_init[id]) {
         return;
@@ -2002,10 +2145,7 @@ static void graph_copy_init_tensor(struct ggml_hash_set hash_set, struct ggml_te
 }
 
 struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
-    struct ggml_hash_set hash_set = {
-        /* .size = */ graph->visited_hash_table.size,
-        /* .keys = */ calloc(graph->visited_hash_table.size, sizeof(hash_set.keys[0])) // NOLINT
-    };
+    struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size);
     struct ggml_tensor ** node_copies = calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT
     bool * node_init = calloc(hash_set.size, sizeof(node_init[0]));
 
@@ -2020,7 +2160,7 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s
 
     if (ctx_allocated == NULL || ctx_unallocated == NULL) {
         fprintf(stderr, "failed to allocate context for graph copy\n");
-        free(hash_set.keys);
+        ggml_hash_set_free(&hash_set);
         free(node_copies);
         free(node_init);
         ggml_free(ctx_allocated);
@@ -2043,7 +2183,7 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s
     ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);
     if (buffer == NULL) {
         fprintf(stderr, "failed to allocate buffer for graph copy\n");
-        free(hash_set.keys);
+        ggml_hash_set_free(&hash_set);
         free(node_copies);
         free(node_init);
         ggml_free(ctx_allocated);
@@ -2061,19 +2201,19 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s
     // copy data and init views
     for (int i = 0; i < graph->n_nodes; i++) {
         struct ggml_tensor * node = graph->nodes[i];
-        graph_copy_init_tensor(hash_set, node_copies, node_init, node);
+        graph_copy_init_tensor(&hash_set, node_copies, node_init, node);
     }
 
     // build graph copy
     struct ggml_cgraph * graph_copy = ggml_new_graph_custom(ctx_allocated, graph->size, false);
     for (int i = 0; i < graph->n_nodes; i++) {
         struct ggml_tensor * node = graph->nodes[i];
-        struct ggml_tensor * node_copy = node_copies[ggml_hash_find(hash_set, node)];
+        struct ggml_tensor * node_copy = node_copies[ggml_hash_find(&hash_set, node)];
         graph_copy->nodes[i] = node_copy;
     }
     graph_copy->n_nodes = graph->n_nodes;
 
-    free(hash_set.keys);
+    ggml_hash_set_free(&hash_set);
     free(node_copies);
     free(node_init);
 

+ 22 - 17
llama/ggml-backend.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -49,28 +49,29 @@ extern "C" {
     GGML_API           size_t                ggml_backend_buft_get_alignment   (ggml_backend_buffer_type_t buft);
     GGML_API           size_t                ggml_backend_buft_get_max_size    (ggml_backend_buffer_type_t buft);
     GGML_API GGML_CALL size_t                ggml_backend_buft_get_alloc_size  (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
-    GGML_API           bool                  ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend);
     GGML_API           bool                  ggml_backend_buft_is_host         (ggml_backend_buffer_type_t buft);
 
     // buffer
     enum ggml_backend_buffer_usage {
         GGML_BACKEND_BUFFER_USAGE_ANY = 0,
         GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1,
+        GGML_BACKEND_BUFFER_USAGE_COMPUTE = 2,
     };
 
-    GGML_API           const char *               ggml_backend_buffer_name          (ggml_backend_buffer_t buffer);
-    GGML_API           void                       ggml_backend_buffer_free          (ggml_backend_buffer_t buffer);
-    GGML_API           void *                     ggml_backend_buffer_get_base      (ggml_backend_buffer_t buffer);
-    GGML_API           size_t                     ggml_backend_buffer_get_size      (ggml_backend_buffer_t buffer);
-    GGML_API GGML_CALL void                       ggml_backend_buffer_init_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
-    GGML_API           size_t                     ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
-    GGML_API           size_t                     ggml_backend_buffer_get_max_size  (ggml_backend_buffer_t buffer);
-    GGML_API           size_t                     ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
-    GGML_API           void                       ggml_backend_buffer_clear         (ggml_backend_buffer_t buffer, uint8_t value);
-    GGML_API           bool                       ggml_backend_buffer_is_host       (ggml_backend_buffer_t buffer);
-    GGML_API           void                       ggml_backend_buffer_set_usage     (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
-    GGML_API           ggml_backend_buffer_type_t ggml_backend_buffer_get_type      (ggml_backend_buffer_t buffer);
-    GGML_API           void                       ggml_backend_buffer_reset         (ggml_backend_buffer_t buffer);
+    GGML_API           const char *                   ggml_backend_buffer_name          (ggml_backend_buffer_t buffer);
+    GGML_API           void                           ggml_backend_buffer_free          (ggml_backend_buffer_t buffer);
+    GGML_API           void *                         ggml_backend_buffer_get_base      (ggml_backend_buffer_t buffer);
+    GGML_API           size_t                         ggml_backend_buffer_get_size      (ggml_backend_buffer_t buffer);
+    GGML_API GGML_CALL void                           ggml_backend_buffer_init_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API           size_t                         ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
+    GGML_API           size_t                         ggml_backend_buffer_get_max_size  (ggml_backend_buffer_t buffer);
+    GGML_API           size_t                         ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API           void                           ggml_backend_buffer_clear         (ggml_backend_buffer_t buffer, uint8_t value);
+    GGML_API           bool                           ggml_backend_buffer_is_host       (ggml_backend_buffer_t buffer);
+    GGML_API           void                           ggml_backend_buffer_set_usage     (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
+    GGML_API           enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage     (ggml_backend_buffer_t buffer);
+    GGML_API           ggml_backend_buffer_type_t     ggml_backend_buffer_get_type      (ggml_backend_buffer_t buffer);
+    GGML_API           void                           ggml_backend_buffer_reset         (ggml_backend_buffer_t buffer);
 
     //
     // Backend
@@ -100,6 +101,7 @@ extern "C" {
     GGML_API enum ggml_status ggml_backend_graph_compute      (ggml_backend_t backend, struct ggml_cgraph * cgraph);
     GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
     GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
+    GGML_API bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
     GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op);
 
     // tensor copy between different backends
@@ -116,7 +118,7 @@ extern "C" {
     GGML_API void                   ggml_backend_event_free       (ggml_backend_event_t event);
     GGML_API void                   ggml_backend_event_record     (ggml_backend_event_t event);
     GGML_API void                   ggml_backend_event_synchronize(ggml_backend_event_t event);
-    GGML_API void                   ggml_backend_event_wait       (ggml_backend_t backend, ggml_backend_event_t event); // wait async on event
+    GGML_API void                   ggml_backend_event_wait       (ggml_backend_t backend, ggml_backend_event_t event);
 
     //
     // CPU backend
@@ -145,7 +147,7 @@ extern "C" {
 
     GGML_API size_t                     ggml_backend_reg_get_count(void);
     GGML_API size_t                     ggml_backend_reg_find_by_name(const char * name);
-    GGML_API ggml_backend_t             ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params]
+    GGML_API ggml_backend_t             ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional)
     GGML_API const char *               ggml_backend_reg_get_name(size_t i);
     GGML_API ggml_backend_t             ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
     GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
@@ -208,6 +210,9 @@ extern "C" {
     // Initialize backend buffers from a measure graph
     GGML_API bool                 ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
 
+    GGML_API int                  ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
+    GGML_API ggml_backend_t       ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);
+
     // Get the number of splits of the last graph
     GGML_API int                  ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
     GGML_API int                  ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);

+ 37 - 9
llama/ggml-common.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -45,7 +45,11 @@ typedef half2 ggml_half2;
 
 #define GGML_COMMON_DECL
 #elif defined(GGML_COMMON_DECL_CUDA)
+#if defined(GGML_COMMON_DECL_MUSA)
+#include <musa_fp16.h>
+#else
 #include <cuda_fp16.h>
+#endif
 #include <cstdint>
 
 typedef half  ggml_half;
@@ -132,19 +136,19 @@ typedef sycl::half2 ggml_half2;
 #define QR6_K 2
 
 #define QI2_XXS (QK_K / (4*QR2_XXS))
-#define QR2_XXS 8
+#define QR2_XXS 4
 
 #define QI2_XS (QK_K / (4*QR2_XS))
-#define QR2_XS 8
+#define QR2_XS 4
 
 #define QI2_S (QK_K / (4*QR2_S))
-#define QR2_S 8
+#define QR2_S 4
 
 #define QI3_XXS (QK_K / (4*QR3_XXS))
-#define QR3_XXS 8
+#define QR3_XXS 4
 
 #define QI3_XS (QK_K / (4*QR3_XS))
-#define QR3_XS 8
+#define QR3_XS 4
 
 #define QI1_S (QK_K / (4*QR1_S))
 #define QR1_S 8
@@ -156,10 +160,10 @@ typedef sycl::half2 ggml_half2;
 #define QR4_NL 2
 
 #define QI4_XS (QK_K / (4*QR4_XS))
-#define QR4_XS 8
+#define QR4_XS 2
 
 #define QI3_S (QK_K / (4*QR3_S))
-#define QR3_S 8
+#define QR3_S 4
 
 #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
 
@@ -225,6 +229,30 @@ typedef struct {
 } block_q8_1;
 static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding");
 
+typedef struct {
+    ggml_half d[4];        // deltas for 4 q4_0 blocks
+    uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks
+} block_q4_0x4;
+static_assert(sizeof(block_q4_0x4) == 4 * sizeof(ggml_half) + QK4_0 * 2, "wrong q4_0x4 block size/padding");
+
+typedef struct {
+    ggml_half d[8];        // deltas for 8 q4_0 blocks
+    uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks
+} block_q4_0x8;
+static_assert(sizeof(block_q4_0x8) == 8 * sizeof(ggml_half) + QK4_0 * 4, "wrong q4_0x8 block size/padding");
+
+typedef struct {
+    ggml_half d[4];        // deltas for 4 q8_0 blocks
+    int8_t qs[QK8_0 * 4];  // quants for 4 q8_0 blocks
+} block_q8_0x4;
+static_assert(sizeof(block_q8_0x4) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong q8_0x4 block size/padding");
+
+typedef struct {
+    ggml_half d[8];        // deltas for 8 q8_0 blocks
+    int8_t qs[QK8_0 * 8];  // quants for 8 q8_0 blocks
+} block_q8_0x8;
+static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding");
+
 //
 // Super-block quantization structures
 //
@@ -417,7 +445,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
 #define GGML_TABLE_END() };
 
 #define GGML_COMMON_IMPL
-#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP)
+#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)
 #include <cstdint>
 
 #define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {

+ 195 - 151
llama/ggml-cuda.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -55,6 +55,7 @@
 #include "ggml-cuda/tsembd.cuh"
 #include "ggml-cuda/unary.cuh"
 #include "ggml-cuda/upscale.cuh"
+#include "ggml-cuda/conv-transpose-1d.cuh"
 
 #include <algorithm>
 #include <array>
@@ -123,7 +124,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
     GGML_CUDA_LOG_ERROR("  current device: %d, in function %s at %s:%d\n", id, func, file, line);
     GGML_CUDA_LOG_ERROR("  %s\n", stmt);
     // abort with GGML_ASSERT to get a stack trace
-    GGML_ASSERT(!"CUDA error");
+    GGML_ABORT("CUDA error");
 }
 
 // this is faster on Windows
@@ -178,21 +179,21 @@ static ggml_cuda_device_info ggml_cuda_init() {
     GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
 
     int64_t total_vram = 0;
-#if defined(GGML_CUDA_FORCE_MMQ)
-    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:   yes\n", __func__);
+#ifdef GGML_CUDA_FORCE_MMQ
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:    yes\n", __func__);
 #else
-    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:   no\n", __func__);
-#endif
-#if defined(CUDA_USE_TENSOR_CORES)
-    GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:    no\n", __func__);
+#endif // GGML_CUDA_FORCE_MMQ
+#ifdef GGML_CUDA_FORCE_CUBLAS
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
 #else
-    GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
-#endif
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
+#endif // GGML_CUDA_FORCE_CUBLAS
     GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
     for (int id = 0; id < info.device_count; ++id) {
         int device_vmm = 0;
 
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
         CUdevice device;
         CU_CHECK(cuDeviceGet(&device, id));
         CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
@@ -204,7 +205,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
             alloc_prop.location.id = id;
             CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
         }
-#endif // !defined(GGML_USE_HIPBLAS)
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
         info.devices[id].vmm = !!device_vmm;
 
         cudaDeviceProp prop;
@@ -214,13 +215,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
         info.default_tensor_split[id] = total_vram;
         total_vram += prop.totalGlobalMem;
 
+        info.devices[id].nsm   = prop.multiProcessorCount;
+        info.devices[id].smpb  = prop.sharedMemPerBlock;
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+        info.devices[id].smpbo = prop.sharedMemPerBlock;
         info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
 #else
+        info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
         info.devices[id].cc = 100*prop.major + 10*prop.minor;
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-        info.devices[id].smpb = prop.sharedMemPerBlock;
-        info.devices[id].nsm  = prop.multiProcessorCount;
     }
 
     for (int id = 0; id < info.device_count; ++id) {
@@ -338,7 +341,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
 };
 
 // pool with virtual memory
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
 struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
     static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
 
@@ -432,14 +435,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
         GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
     }
 };
-#endif // !defined(GGML_USE_HIPBLAS)
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
 
 std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
     if (ggml_cuda_info().devices[device].vmm) {
         return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
     }
-#endif
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
     return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
 }
 
@@ -491,12 +494,12 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t
         return;
     }
 
-    if (ggml_is_quantized(tensor->type)) {
+    if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
         // initialize padding to 0 to avoid possible NaN values
         size_t original_size = ggml_nbytes(tensor);
         size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
 
-        if (padded_size > original_size && tensor->view_src == nullptr) {
+        if (padded_size > original_size) {
             ggml_cuda_set_device(ctx->device);
             CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size));
         }
@@ -573,6 +576,10 @@ GGML_CALL static const char * ggml_backend_cuda_buffer_type_name(ggml_backend_bu
     return ctx->name.c_str();
 }
 
+static bool ggml_backend_buft_is_cuda(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_cuda_buffer_type_name;
+}
+
 GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
     ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
 
@@ -615,24 +622,12 @@ GGML_CALL static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backen
     GGML_UNUSED(buft);
 }
 
-GGML_CALL static bool ggml_backend_cuda_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
-    if (!ggml_backend_is_cuda(backend)) {
-        return false;
-    }
-
-    ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
-    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-
-    return buft_ctx->device == cuda_ctx->device;
-}
-
 static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
     /* .get_name         = */ ggml_backend_cuda_buffer_type_name,
     /* .alloc_buffer     = */ ggml_backend_cuda_buffer_type_alloc_buffer,
     /* .get_alignment    = */ ggml_backend_cuda_buffer_type_get_alignment,
     /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
     /* .get_alloc_size   = */ ggml_backend_cuda_buffer_type_get_alloc_size,
-    /* .supports_backend = */ ggml_backend_cuda_buffer_type_supports_backend,
     /* .is_host          = */ NULL,
 };
 
@@ -671,7 +666,7 @@ static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> &
         }
 
         const int cc = ggml_cuda_info().devices[id].cc;
-        row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc)));
+        row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));
     }
     return row_rounding;
 }
@@ -893,6 +888,10 @@ GGML_CALL static const char * ggml_backend_cuda_split_buffer_type_name(ggml_back
     GGML_UNUSED(buft);
 }
 
+static bool ggml_backend_buft_is_cuda_split(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_cuda_split_buffer_type_name;
+}
+
 GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
     // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
     // instead, we allocate them for each tensor separately in init_tensor
@@ -936,12 +935,6 @@ GGML_CALL static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_
     return total_size;
 }
 
-GGML_CALL static bool ggml_backend_cuda_split_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
-    return ggml_backend_is_cuda(backend);
-
-    GGML_UNUSED(buft);
-}
-
 GGML_CALL static bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
     return false;
 
@@ -954,7 +947,6 @@ static ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface
     /* .get_alignment    = */ ggml_backend_cuda_split_buffer_type_get_alignment,
     /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
     /* .get_alloc_size   = */ ggml_backend_cuda_split_buffer_type_get_alloc_size,
-    /* .supports_backend = */ ggml_backend_cuda_split_buffer_type_supports_backend,
     /* .is_host          = */ ggml_backend_cuda_split_buffer_type_is_host,
 };
 
@@ -1054,7 +1046,6 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() {
             /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
             /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
             /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
-            /* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend,
             /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,
         },
         /* .context  = */ nullptr,
@@ -1377,10 +1368,30 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
     GGML_UNUSED(main_device);
 }
 
+static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
+    void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
+
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+    // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
+    cudaMemcpy3DPeerParms p = {};
+    p.dstDevice = dstDevice;
+    p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height);
+    p.srcDevice = srcDevice;
+    p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height);
+    p.extent = make_cudaExtent(width, height, 1);
+    return cudaMemcpy3DPeerAsync(&p, stream);
+#else
+    // HIP does not support cudaMemcpy3DPeerAsync or vmm pools
+    GGML_UNUSED(dstDevice);
+    GGML_UNUSED(srcDevice);
+    return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+}
+
 static void ggml_cuda_op_mul_mat(
     ggml_backend_cuda_context & ctx,
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
-    const bool convert_src1_to_q8_1) {
+    quantize_cuda_t quantize_src1) {
 
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
@@ -1437,7 +1448,9 @@ static void ggml_cuda_op_mul_mat(
     }
 
     struct dev_data {
-        ggml_cuda_pool_alloc<char>  src0_dd_alloc;
+        int cc;
+
+        ggml_cuda_pool_alloc<char>   src0_dd_alloc;
         ggml_cuda_pool_alloc<float> src1_ddf_alloc;
         ggml_cuda_pool_alloc<char>  src1_ddq_alloc;
         ggml_cuda_pool_alloc<float>   dst_dd_alloc;
@@ -1456,6 +1469,8 @@ static void ggml_cuda_op_mul_mat(
     int used_devices = 0;
 
     for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        dev[id].cc = ggml_cuda_info().devices[id].cc;
+
         // by default, use all rows
         dev[id].row_low  = 0;
         dev[id].row_high = ne01;
@@ -1500,17 +1515,28 @@ static void ggml_cuda_op_mul_mat(
             dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0));
         }
 
+        // If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared:
+        if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
+            const int64_t nbytes_data    = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
+            const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
+            CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
+        }
+
         if (src1_on_device && src1_is_contiguous) {
             dev[id].src1_ddf = (float *) src1->data;
         } else {
             dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1));
         }
 
-        if (convert_src1_to_q8_1) {
-            dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
+        if (quantize_src1) {
+            size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
+            if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+                src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq);
+            }
+            dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
 
             if (src1_on_device && src1_is_contiguous) {
-                quantize_row_q8_1_cuda(dev[id].src1_ddf, dev[id].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
+                quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
                 CUDA_CHECK(cudaGetLastError());
             }
         }
@@ -1556,7 +1582,12 @@ static void ggml_cuda_op_mul_mat(
                 const int64_t i03 = i0 / ne12;
                 const int64_t i02 = i0 % ne12;
 
-                const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
+                size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs;
+                if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+                    src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq);
+                } else {
+                    src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs;
+                }
 
                 // for split tensors the data begins at i0 == i0_offset_low
                 char  *  src0_dd_i =  dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
@@ -1573,10 +1604,17 @@ static void ggml_cuda_op_mul_mat(
                 // copy src0, src1 to device if necessary
                 if (src1_is_contiguous) {
                     if (id != ctx.device) {
-                        if (convert_src1_to_q8_1) {
+                        if (quantize_src1) {
                             char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
-                            CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddq_i, id, src1_ddq_i_source, ctx.device,
-                                                            src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
+                            if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+                                const size_t pitch = ne11*sizeof(block_q8_1_mmq);
+                                const size_t width = src1_ncols*sizeof(block_q8_1_mmq);
+                                const size_t height = src1_padded_col_size/(4*QK8_1);
+                                CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream));
+                            } else {
+                                CUDA_CHECK(cudaMemcpyPeerAsync(
+                                    src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
+                            }
                         } else {
                             float * src1_ddf_i_source = (float *) src1->data;
                             src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
@@ -1588,11 +1626,11 @@ static void ggml_cuda_op_mul_mat(
                     CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
                                 src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
                 } else {
-                    GGML_ASSERT(false);
+                    GGML_ABORT("fatal error");
                 }
 
-                if (convert_src1_to_q8_1 && !src1_is_contiguous) {
-                    quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
+                if (quantize_src1 && !src1_is_contiguous) {
+                    quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
                     CUDA_CHECK(cudaGetLastError());
                 }
 
@@ -1617,22 +1655,8 @@ static void ggml_cuda_op_mul_mat(
                         float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
                         GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
                         dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
-#if !defined(GGML_USE_HIPBLAS)
-                        // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
-                        cudaMemcpy3DPeerParms p = {};
-                        p.dstDevice = ctx.device;
-                        p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), row_diff, src1_ncols);
-                        p.srcDevice = id;
-                        p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols);
-                        p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1);
-                        CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream));
-#else
-                        // HIP does not support cudaMemcpy3DPeerAsync or vmm pools
-                        CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float),
-                                                        dst_dd_i, row_diff*sizeof(float),
-                                                        row_diff*sizeof(float), src1_ncols,
-                                                        cudaMemcpyDeviceToDevice, stream));
-#endif
+                        CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(
+                            dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream));
                     } else {
                         float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
                         GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
@@ -1834,6 +1858,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
         }
     }
 #else
+#ifdef GGML_USE_MUSA
+    GGML_ASSERT(false);
+#else // !GGML_USE_MUSA
     if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
         // there is no broadcast and src0, src1 are contiguous across dims 2, 3
         // use cublasGemmStridedBatchedEx
@@ -1876,6 +1903,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
                 cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
     }
+#endif // GGML_USE_MUSA
 #endif
 
     if (dst->op_params[0] == GGML_PREC_DEFAULT) {
@@ -1887,9 +1915,23 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
 static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
 
-    int64_t min_compute_capability = INT_MAX;
+    bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+        && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[0] >= GGML_CUDA_DMMV_X*2
+        && src1->ne[1] == 1;
+    bool          use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
+    bool              use_mul_mat_q =  ggml_is_quantized(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+
+    // if mmvq is available it's a better choice than dmmv:
+#ifndef GGML_CUDA_FORCE_DMMV
+    use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
+#endif // GGML_CUDA_FORCE_DMMV
+
+    bool any_gpus_with_slow_fp16 = false;
 
-    bool any_pascal_with_slow_fp16 = false;
     if (split) {
         ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
         auto & tensor_split = buft_ctx->tensor_split;
@@ -1899,60 +1941,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
                 continue;
             }
 
-            if (min_compute_capability > ggml_cuda_info().devices[id].cc) {
-                min_compute_capability = ggml_cuda_info().devices[id].cc;
-            }
-            if (ggml_cuda_info().devices[id].cc == 610) {
-                any_pascal_with_slow_fp16 = true;
-            }
+            const int cc            = ggml_cuda_info().devices[id].cc;
+            use_mul_mat_q           = use_mul_mat_q           && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+            any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
         }
     } else {
-        min_compute_capability    = ggml_cuda_info().devices[ctx.device].cc;
-        any_pascal_with_slow_fp16 = ggml_cuda_info().devices[ctx.device].cc == 610;
+        const int cc            = ggml_cuda_info().devices[ctx.device].cc;
+        use_mul_mat_q           = use_mul_mat_q           && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+        any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
     }
 
-    // check data types and tensor shapes for custom matrix multiplication kernels:
-    bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
-
-    bool          use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
-
-    bool              use_mul_mat_q =  ggml_cuda_supports_mmq(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
-
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-
-    const bool fp16_performance_good = min_compute_capability >= CC_RDNA1;
-
-#ifdef CUDA_USE_TENSOR_CORES
-    use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3;
-#endif // CUDA_USE_TENSOR_CORES
-
-#else
-
-    // fp16 performance is good on Volta or newer and on P100 (compute capability 6.0)
-    const bool fp16_performance_good = min_compute_capability >= CC_PASCAL && !any_pascal_with_slow_fp16;
-
-    // mmvq and mmq need the __dp4a instruction which on NVIDIA is only available for CC >= 6.1
-    use_mul_mat_vec_q = use_mul_mat_vec_q && min_compute_capability >= MIN_CC_DP4A;
-    use_mul_mat_q     = use_mul_mat_q     && min_compute_capability >= MIN_CC_DP4A;
-
-#ifdef CUDA_USE_TENSOR_CORES
-    // when tensor cores are available, use them for large batch size
-    // ref: https://github.com/ggerganov/llama.cpp/pull/3776
-    use_mul_mat_q     = use_mul_mat_q     && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
-#endif // CUDA_USE_TENSOR_CORES
-
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-
-    // if mmvq is available it's a better choice than dmmv:
-#ifndef GGML_CUDA_FORCE_DMMV
-    use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
-#endif // GGML_CUDA_FORCE_DMMV
-
     // debug helpers
     //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
     //printf("      %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
@@ -1961,23 +1959,24 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
-    if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
-        // KQ single-batch
+    if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+        // FP32 precision KQ single-batch for batch size 1 without FlashAttention
         ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
-    } else if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
-        // KQV single-batch
+    } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
+        // FP32 precision KQV single-batch for batch size 1 without FlashAttention
         ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
-    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || fp16_performance_good) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
-        // KQ + KQV multi-batch
+    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
+               && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+        // KQ + KQV multi-batch without FlashAttention
         ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
     } else if (use_dequantize_mul_mat_vec) {
-        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
     } else if (use_mul_mat_vec_q) {
-        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
     } else if (use_mul_mat_q) {
-        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
     } else {
-        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
     }
 }
 
@@ -2281,6 +2280,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SQR:
             ggml_cuda_op_sqr(ctx, dst);
             break;
+        case GGML_OP_SQRT:
+            ggml_cuda_op_sqrt(ctx, dst);
+            break;
         case GGML_OP_CLAMP:
             ggml_cuda_op_clamp(ctx, dst);
             break;
@@ -2302,6 +2304,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_IM2COL:
             ggml_cuda_op_im2col(ctx, dst);
             break;
+        case GGML_OP_CONV_TRANSPOSE_1D:
+            ggml_cuda_op_conv_transpose_1d(ctx,dst);
+            break;
         case GGML_OP_POOL_2D:
             ggml_cuda_op_pool2d(ctx, dst);
             break;
@@ -2744,7 +2749,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                 case GGML_UNARY_OP_HARDSWISH:
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_TANH:
-                    return true;
+                    return ggml_is_contiguous(op->src[0]);
                 default:
                     return false;
             }
@@ -2752,27 +2757,40 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             {
-                struct ggml_tensor * a;
-                struct ggml_tensor * b;
+                struct ggml_tensor * a = op->src[0];
                 if (op->op == GGML_OP_MUL_MAT) {
-                    a = op->src[0];
-                    b = op->src[1];
-                } else {
-                    a = op->src[2];
-                    b = op->src[1];
-                }
-                if (a->ne[3] != b->ne[3]) {
-                    return false;
-                }
-                ggml_type a_type = a->type;
-                if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
-                    a_type == GGML_TYPE_IQ1_S   || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S   ||
-                    a_type == GGML_TYPE_IQ1_M   || a_type == GGML_TYPE_IQ2_S  || a_type == GGML_TYPE_IQ4_XS) {
-                    if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
+                    struct ggml_tensor * b = op->src[1];
+                    if (a->ne[3] != b->ne[3]) {
                         return false;
                     }
                 }
-                return true;
+                switch (a->type) {
+                    case GGML_TYPE_F32:
+                    case GGML_TYPE_F16:
+                    case GGML_TYPE_Q4_0:
+                    case GGML_TYPE_Q4_1:
+                    case GGML_TYPE_Q5_0:
+                    case GGML_TYPE_Q5_1:
+                    case GGML_TYPE_Q8_0:
+                    case GGML_TYPE_Q2_K:
+                    case GGML_TYPE_Q3_K:
+                    case GGML_TYPE_Q4_K:
+                    case GGML_TYPE_Q5_K:
+                    case GGML_TYPE_Q6_K:
+                    case GGML_TYPE_Q8_K:
+                    case GGML_TYPE_IQ1_M:
+                    case GGML_TYPE_IQ1_S:
+                    case GGML_TYPE_IQ2_S:
+                    case GGML_TYPE_IQ2_XS:
+                    case GGML_TYPE_IQ2_XXS:
+                    case GGML_TYPE_IQ3_S:
+                    case GGML_TYPE_IQ3_XXS:
+                    case GGML_TYPE_IQ4_NL:
+                    case GGML_TYPE_IQ4_XS:
+                        return true;
+                    default:
+                        return false;
+                }
             } break;
         case GGML_OP_GET_ROWS:
             {
@@ -2832,6 +2850,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                 ggml_type src0_type = op->src[0]->type;
                 return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
             } break;
+        case GGML_OP_CONV_TRANSPOSE_1D:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                ggml_type src1_type = op->src[1]->type;
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
+                return false;
+            } break;
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
@@ -2844,6 +2871,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_RMS_NORM:
         case GGML_OP_SCALE:
         case GGML_OP_SQR:
+        case GGML_OP_SQRT:
         case GGML_OP_CLAMP:
         case GGML_OP_CONT:
         case GGML_OP_DIAG_MASK_INF:
@@ -2883,6 +2911,20 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
     GGML_UNUSED(backend);
 }
 
+GGML_CALL static bool ggml_backend_cuda_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+    if (ggml_backend_buft_is_cuda_split(buft)) {
+        return true;
+    }
+
+    if (ggml_backend_buft_is_cuda(buft)) {
+        ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+        ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
+        return buft_ctx->device == cuda_ctx->device;
+    }
+
+    return false;
+}
+
 GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
     const int min_batch_size = 32;
 
@@ -2937,7 +2979,7 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
 
         CUDA_CHECK(cudaLaunchHostFunc(cuda_ctx->stream(), wait_fn, event));
 #endif
-        GGML_ASSERT(false);
+        GGML_ABORT("fatal error");
     }
 }
 
@@ -2955,9 +2997,11 @@ static ggml_backend_i ggml_backend_cuda_interface = {
     /* .synchronize             = */ ggml_backend_cuda_synchronize,
     /* .graph_plan_create       = */ NULL,
     /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_update       = */ NULL,
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_cuda_graph_compute,
     /* .supports_op             = */ ggml_backend_cuda_supports_op,
+    /* .supports_buft           = */ ggml_backend_cuda_supports_buft,
     /* .offload_op              = */ ggml_backend_cuda_offload_op,
     /* .event_new               = */ ggml_backend_cuda_event_new,
     /* .event_free              = */ ggml_backend_cuda_event_free,
@@ -3017,7 +3061,7 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size
         return false;
     }
 
-#if CUDART_VERSION >= 11100
+#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
     cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
     if (err != cudaSuccess) {
         // clear the error

+ 4 - 1
llama/ggml-cuda.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -32,6 +32,9 @@
 #ifdef GGML_USE_HIPBLAS
 #define GGML_CUDA_NAME "ROCm"
 #define GGML_CUBLAS_NAME "hipBLAS"
+#elif defined(GGML_USE_MUSA)
+#define GGML_CUDA_NAME "MUSA"
+#define GGML_CUBLAS_NAME "muBLAS"
 #else
 #define GGML_CUDA_NAME "CUDA"
 #define GGML_CUBLAS_NAME "cuBLAS"

+ 1 - 1
llama/ggml-cuda/acc.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/acc.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 79
llama/ggml-cuda/alibi.cu

@@ -1,83 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
- *
- * MIT License
- *
- * Copyright (c) 2023-2024 The ggml authors
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-/**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
- *
- * MIT License
- *
- * Copyright (c) 2023-2024 The ggml authors
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-/**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
- *
- * MIT License
- *
- * Copyright (c) 2023-2024 The ggml authors
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-/**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 53
llama/ggml-cuda/alibi.cuh

@@ -1,57 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
- *
- * MIT License
- *
- * Copyright (c) 2023-2024 The ggml authors
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-/**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
- *
- * MIT License
- *
- * Copyright (c) 2023-2024 The ggml authors
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-/**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/arange.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/arange.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 3 - 2
llama/ggml-cuda/argsort.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -99,6 +99,7 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
     const dim3 block_nums(1, nrows, 1);
     const size_t shared_mem = ncols_pad * sizeof(int);
 
+    // FIXME: this limit could be raised by ~2-4x on Ampere or newer
     GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
 
     if (order == GGML_SORT_ORDER_ASC) {
@@ -106,7 +107,7 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
     } else if (order == GGML_SORT_ORDER_DESC) {
         k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
     } else {
-        GGML_ASSERT(false);
+        GGML_ABORT("fatal error");
     }
 }
 

+ 1 - 1
llama/ggml-cuda/argsort.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 2 - 2
llama/ggml-cuda/binbcast.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -285,7 +285,7 @@ static void ggml_cuda_op_bin_bcast(
     } else {
         fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
             ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
-        GGML_ASSERT(false);
+        GGML_ABORT("fatal error");
     }
 }
 

+ 1 - 1
llama/ggml-cuda/binbcast.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/clamp.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/clamp.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 273 - 67
llama/ggml-cuda/common.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -29,6 +29,7 @@
 #include "ggml.h"
 #include "ggml-cuda.h"
 
+#include <cstdint>
 #include <memory>
 
 #if defined(GGML_USE_HIPBLAS)
@@ -37,6 +38,10 @@
 #else
 #define GGML_COMMON_DECL_CUDA
 #define GGML_COMMON_IMPL_CUDA
+#if defined(GGML_USE_MUSA)
+#define GGML_COMMON_DECL_MUSA
+#define GGML_COMMON_IMPL_MUSA
+#endif
 #endif
 #include "ggml-common.h"
 
@@ -129,7 +134,7 @@
 #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
 #define cudaStream_t hipStream_t
 #define cudaSuccess hipSuccess
-#define __trap abort
+#define __trap() do { abort(); __builtin_unreachable(); } while(0)
 #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
 #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
 #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
@@ -139,6 +144,150 @@
 #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
 #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
 #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
+#elif defined(GGML_USE_MUSA)
+#include <musa_runtime.h>
+#include <musa.h>
+#include <mublas.h>
+#include <musa_fp16.h>
+// XXX: Keep the following order the same as hipBLAS
+// #define CUBLAS_COMPUTE_16F MUBLAS_COMPUTE_16F
+// #define CUBLAS_COMPUTE_32F MUBLAS_COMPUTE_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
+#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
+#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
+#define CUBLAS_OP_N MUBLAS_OP_N
+#define CUBLAS_OP_T MUBLAS_OP_T
+#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
+// #define CUBLAS_TF32_TENSOR_OP_MATH 0
+#define CUDA_R_16F  MUSA_R_16F
+#define CUDA_R_32F  MUSA_R_32F
+// #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+// #define cublasComputeType_t mublasComputeType_t
+#define cublasCreate mublasCreate
+#define cublasDestroy mublasDestroy
+#define cublasGemmEx mublasGemmEx
+#define cublasGemmBatchedEx mublasGemmBatchedEx
+#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
+#define cublasHandle_t mublasHandle_t
+// #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
+#define cublasSetMathMode mublasSetMathMode
+#define cublasSetStream mublasSetStream
+#define cublasSgemm mublasSgemm
+#define cublasStatus_t mublasStatus_t
+#define cudaDataType_t musaDataType_t //deprecated, new hipblasDatatype not in 5.6
+#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
+#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
+#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
+#define cudaDeviceProp musaDeviceProp
+#define cudaDeviceSynchronize musaDeviceSynchronize
+#define cudaError_t musaError_t
+#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
+#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
+#define cudaEventCreateWithFlags musaEventCreateWithFlags
+#define cudaEventDisableTiming musaEventDisableTiming
+#define cudaEventRecord musaEventRecord
+#define cudaEventSynchronize musaEventSynchronize
+#define cudaEvent_t musaEvent_t
+#define cudaEventDestroy musaEventDestroy
+#define cudaFree musaFree
+#define cudaFreeHost musaFreeHost
+#define cudaGetDevice musaGetDevice
+#define cudaGetDeviceCount musaGetDeviceCount
+#define cudaGetDeviceProperties musaGetDeviceProperties
+#define cudaGetErrorString musaGetErrorString
+#define cudaGetLastError musaGetLastError
+#define cudaHostRegister musaHostRegister
+#define cudaHostRegisterPortable musaHostRegisterPortable
+#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
+#define cudaHostUnregister musaHostUnregister
+#define cudaLaunchHostFunc musaLaunchHostFunc
+#define cudaMalloc musaMalloc
+#define cudaMallocHost musaMallocHost
+#define cudaMemcpy musaMemcpy
+#define cudaMemcpyAsync musaMemcpyAsync
+#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
+#define cudaMemcpy2DAsync musaMemcpy2DAsync
+#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
+#define cudaMemcpyKind musaMemcpyKind
+#define cudaMemset musaMemset
+#define cudaMemsetAsync musaMemsetAsync
+#define cudaMemGetInfo musaMemGetInfo
+#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
+#define cudaSetDevice musaSetDevice
+#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
+#define cudaStreamDestroy musaStreamDestroy
+#define cudaStreamFireAndForget musaStreamFireAndForget
+#define cudaStreamNonBlocking musaStreamNonBlocking
+#define cudaStreamPerThread musaStreamPerThread
+#define cudaStreamSynchronize musaStreamSynchronize
+#define cudaStreamWaitEvent musaStreamWaitEvent
+#define cudaStream_t musaStream_t
+#define cudaSuccess musaSuccess
+
+// XXX: Other CUDA => MUSA mapping
+#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
+#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
+#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
+#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
+#define CUdevice MUdevice
+#define CUdeviceptr MUdeviceptr
+#define CUmemAccessDesc MUmemAccessDesc
+#define CUmemAllocationProp MUmemAllocationProp
+#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
+#define cuDeviceGet muDeviceGet
+#define cuDeviceGetAttribute muDeviceGetAttribute
+#define cuMemAddressFree muMemAddressFree
+#define cuMemAddressReserve muMemAddressReserve
+#define cuMemCreate muMemCreate
+#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
+#define cuMemMap muMemMap
+#define cuMemRelease muMemRelease
+#define cuMemSetAccess muMemSetAccess
+#define cuMemUnmap muMemUnmap
+#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
+#define cudaFuncSetAttribute musaFuncSetAttribute
+#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
+#define make_cudaExtent make_musaExtent
+#define make_cudaPitchedPtr make_musaPitchedPtr
+
+// XXX: USE_CUDA_GRAPH
+#define CUDA_SUCCESS MUSA_SUCCESS
+#define CUresult MUresult
+#define cuGetErrorString muGetErrorString
+#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
+#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
+#define cudaGraphDestroy musaGraphDestroy
+#define cudaGraphExecDestroy musaGraphExecDestroy
+#define cudaGraphExec_t musaGraphExec_t
+#define cudaGraphExecUpdate musaGraphExecUpdate
+#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
+#define cudaGraphGetNodes musaGraphGetNodes
+#define cudaGraphInstantiate musaGraphInstantiate
+#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
+#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
+#define cudaGraphLaunch musaGraphLaunch
+#define cudaGraphNodeGetType musaGraphNodeGetType
+#define cudaGraphNode_t musaGraphNode_t
+#define cudaGraphNodeType musaGraphNodeType
+#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
+#define cudaGraph_t musaGraph_t
+#define cudaKernelNodeParams musaKernelNodeParams
+#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
+#define cudaStreamEndCapture musaStreamEndCapture
+
+// XXX: cuBLAS => muBLAS mapping
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
+#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
+#define CUBLAS_COMPUTE_16F CUDA_R_16F
+#define CUBLAS_COMPUTE_32F CUDA_R_32F
+#define cublasComputeType_t cudaDataType_t
+
+// XXX: Clang builtins mapping
+#define __vsub4   __vsub4_musa
+#define __vcmpeq4 __vcmpeq4_musa
+#define __vcmpne4 __vcmpne4_musa
 #else
 #include <cuda_runtime.h>
 #include <cuda.h>
@@ -165,29 +314,13 @@
 #define CC_PASCAL     600
 #define MIN_CC_DP4A   610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
 #define CC_VOLTA      700
+#define CC_TURING     750
 #define CC_AMPERE     800
 #define CC_OFFSET_AMD 1000000
 #define CC_RDNA1      (CC_OFFSET_AMD + 1010)
 #define CC_RDNA2      (CC_OFFSET_AMD + 1030)
 #define CC_RDNA3      (CC_OFFSET_AMD + 1100)
 
-// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
-// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
-// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
-// -  7B quantum model: +100-200 MB
-// - 13B quantum model: +200-400 MB
-//
-//#define GGML_CUDA_FORCE_MMQ
-
-// TODO: improve this to be correct for more hardware
-//       for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
-#if !defined(GGML_CUDA_FORCE_MMQ)
-#define CUDA_USE_TENSOR_CORES
-#endif
-
-#define MMVQ_MAX_BATCH_SIZE  8 // max batch size to use MMVQ kernels
-#define  MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available
-
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
 #if defined(_MSC_VER)
@@ -209,9 +342,13 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
 
 #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
 
-#if CUDART_VERSION >= 12000
+#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
     static const char * cublas_get_error_str(const cublasStatus_t err) {
+#ifndef GGML_USE_MUSA
         return cublasGetStatusString(err);
+#else
+        return mublasStatus_to_string(err);
+#endif // GGML_USE_MUSA
     }
 #else
     static const char * cublas_get_error_str(const cublasStatus_t err) {
@@ -241,7 +378,7 @@ static const char * cu_get_error_str(CUresult err) {
 #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
 #endif
 
-#if CUDART_VERSION >= 11100
+#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
 #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
 #else
 #define GGML_CUDA_ASSUME(x)
@@ -255,6 +392,42 @@ typedef float dfloat; // dequantize float
 typedef float2 dfloat2;
 #endif //GGML_CUDA_F16
 
+#if defined(GGML_USE_MUSA)
+#ifndef __has_builtin
+    #define __has_builtin(x) 0
+#endif
+
+typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
+
+static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) {
+    return __vsubss4(a, b);
+}
+
+static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) {
+    const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
+    const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
+    unsigned int c;
+    uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
+#pragma unroll
+    for (int i = 0; i < 4; ++i) {
+        vc[i] = va[i] == vb[i] ? 0xff : 0x00;
+    }
+    return c;
+}
+
+static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) {
+    const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
+    const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
+    unsigned int c;
+    uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
+#pragma unroll
+    for (int i = 0; i < 4; ++i) {
+        vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
+    }
+    return c;
+}
+#endif // defined(GGML_USE_MUSA)
+
 #if defined(GGML_USE_HIPBLAS)
 #define __CUDA_ARCH__ 1300
 
@@ -268,6 +441,10 @@ typedef float2 dfloat2;
 #define RDNA2
 #endif
 
+#if defined(__gfx1010__) || defined(__gfx1012__)
+#define RDNA1
+#endif
+
 #ifndef __has_builtin
     #define __has_builtin(x) 0
 #endif
@@ -310,30 +487,15 @@ static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigne
     return c;
 }
 
-static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
-#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
-    c = __builtin_amdgcn_sdot4(a, b, c, false);
-#elif defined(RDNA3)
-    c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
-#elif defined(__gfx1010__) || defined(__gfx900__)
-    int tmp1;
-    int tmp2;
-    asm("\n \
-        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
-        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
-        v_add3_u32 %0, %1, %2, %0 \n \
-        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
-        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
-        v_add3_u32 %0, %1, %2, %0 \n \
-        "
-        : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
-        : "v"(a), "v"(b)
-    );
-#else
-    const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
-    const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
-    c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
-#endif
+static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
+    const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
+    const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
+    unsigned int c;
+    uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
+#pragma unroll
+    for (int i = 0; i < 4; ++i) {
+        vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
+    }
     return c;
 }
 
@@ -352,18 +514,34 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
 #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
 #endif // defined(GGML_USE_HIPBLAS)
 
-#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
+#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
+#define FP16_AVAILABLE
+#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
+
+#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
+#define FAST_FP16_AVAILABLE
+#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
 
-#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+#define FP16_MMA_AVAILABLE
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
 
-static bool fast_fp16_available(const int cc) {
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
+#define INT8_MMA_AVAILABLE
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
+
+static constexpr bool fast_fp16_available(const int cc) {
     return cc >= CC_PASCAL && cc != 610;
 }
 
-static bool fp16_mma_available(const int cc) {
+static constexpr bool fp16_mma_available(const int cc) {
     return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
 }
 
+static constexpr bool int8_mma_available(const int cc) {
+    return cc < CC_OFFSET_AMD && cc >= CC_TURING;
+}
+
 [[noreturn]]
 static __device__ void no_device_code(
     const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
@@ -384,7 +562,7 @@ static __device__ void no_device_code(
 #ifdef __CUDA_ARCH__
 #define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
 #else
-#define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
+#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
 #endif // __CUDA_ARCH__
 
 static __device__ __forceinline__ float warp_reduce_sum(float x) {
@@ -405,7 +583,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
 }
 
 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
 
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 #pragma unroll
@@ -438,7 +616,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
 }
 
 static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
 
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
     return __float2half(fmaxf(__half2float(a), __half2float(b)));
@@ -491,10 +669,50 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
     const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
     return mask_low | mask_high;
 }
-#endif // CUDART_VERSION < 12000
+#endif // CUDART_VERSION < CUDART_HMASK
+
+static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
+    c = __builtin_amdgcn_sdot4(a, b, c, false);
+#elif defined(RDNA3)
+    c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
+#elif defined(__gfx1010__) || defined(__gfx900__)
+    int tmp1;
+    int tmp2;
+    asm("\n \
+        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
+        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
+        v_add3_u32 %0, %1, %2, %0 \n \
+        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
+        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
+        v_add3_u32 %0, %1, %2, %0 \n \
+        "
+        : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
+        : "v"(a), "v"(b)
+    );
+#else
+    const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
+    const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+    c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
+#endif
+    return c;
+
+#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+    return __dp4a(a, b, c);
+#else // __CUDA_ARCH__ >= MIN_CC_DP4A
+    const int8_t * a8 = (const int8_t *) &a;
+    const int8_t * b8 = (const int8_t *) &b;
+    return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+}
 
 // TODO: move to ggml-common.h
-static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
 
 typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
 
@@ -652,19 +870,6 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
     static constexpr int qi = QI3_S;
 };
 
-static int get_mmq_x_max_host(const int cc) {
-#ifdef CUDA_USE_TENSOR_CORES
-    return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
-#else
-    return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
-#endif // CUDA_USE_TENSOR_CORES
-}
-
-// Round rows to this value for --split-mode row:
-static int get_mmq_y_host(const int cc, const int mmq_x) {
-    return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64;
-}
-
 //////////////////////
 
 struct ggml_cuda_device_info {
@@ -674,6 +879,7 @@ struct ggml_cuda_device_info {
         int     cc;                 // compute capability
         int     nsm;                // number of streaming multiprocessors
         size_t  smpb;               // max. shared memory per block
+        size_t  smpbo;              // max. shared memory per block (with opt-in)
         bool    vmm;                // virtual memory support
         size_t  vmm_granularity;    // granularity of virtual memory
         size_t  total_vram;

+ 1 - 1
llama/ggml-cuda/concat.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/concat.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 113 - 0
llama/ggml-cuda/conv-transpose-1d.cu

@@ -0,0 +1,113 @@
+/**
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "conv-transpose-1d.cuh"
+
+static  __global__ void conv_transpose_1d_kernel(
+        const int s0, const int p0, const int d0, const int output_size,
+        const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
+        const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
+        const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
+        const float * src0, const float * src1,  float * dst) {
+    int global_index = threadIdx.x + blockIdx.x * blockDim.x;
+    if (global_index >= output_size) {
+        return;
+    }
+
+    int out_index = global_index / dst_ne0;
+
+    float accumulator = 0;
+
+    for (int c = 0; c < src0_ne2; c++) {
+        int idx = global_index % dst_ne0;
+
+        int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
+        int input_offset = src1_ne0 * c;
+
+        for (int i = 0; i < src1_ne0; i++) {
+            if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
+                continue;
+            }
+            int weight_idx = idx - i*s0;
+
+            float kernel_weight = src0[kernel_offset + weight_idx];
+            float input_value =  src1[input_offset+i];
+
+            accumulator += kernel_weight * input_value;
+        }
+    }
+    dst[global_index] = accumulator;
+}
+
+static void conv_transpose_1d_f32_f32_cuda(
+        const int s0, const int p0, const int d0, const int output_size,
+        const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
+        const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
+        const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
+        const float * src0, const float * src1,  float * dst,
+        cudaStream_t stream) {
+
+    const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE;
+    conv_transpose_1d_kernel<<<num_blocks,CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE, 0, stream>>>(
+        s0,p0,d0,output_size,
+        src0_ne0, src0_ne1,  src0_ne2, src0_ne3,
+        src1_ne0, src1_ne1,  src1_ne2, src1_ne3,
+        dst_ne0,  dst_ne1,   dst_ne2,  dst_ne3,
+        src0,src1, dst);
+}
+
+void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+
+    const ggml_tensor * src1 = dst->src[1];
+    const float * src1_d = (const float *)src1->data;
+
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+
+    const int s0 = opts[0];
+    const int p0 = 0;//opts[3];
+    const int d0 = 1;//opts[4];
+
+    const int64_t kernel_size = ggml_nelements(src0);
+    const int64_t input_size = ggml_nelements(src1);
+    const int64_t output_size = ggml_nelements(dst);
+
+    conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
+        src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+        src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+        dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+        src0_d, src1_d, dst_d, stream);
+}

+ 31 - 0
llama/ggml-cuda/conv-transpose-1d.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256
+
+void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 1 - 1
llama/ggml-cuda/convert.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/convert.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 3 - 4
llama/ggml-cuda/cpy.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -477,7 +477,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
     } else {
         fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
-        GGML_ASSERT(false);
+        GGML_ABORT("fatal error");
     }
 }
 
@@ -510,7 +510,6 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
     } else {
         fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
-        GGML_ASSERT(false);
+        GGML_ABORT("fatal error");
     }
 }
-

+ 1 - 1
llama/ggml-cuda/cpy.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/dequantize.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/diagmask.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/diagmask.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 2 - 2
llama/ggml-cuda/dmmv.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -688,7 +688,7 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
             convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
             break;
         default:
-            GGML_ASSERT(false);
+            GGML_ABORT("fatal error");
             break;
     }
 

+ 1 - 1
llama/ggml-cuda/dmmv.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 30 - 70
llama/ggml-cuda/fattn-common.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -80,12 +80,11 @@ typedef float (*vec_dot_KQ_f32_t)(
 template<typename T, int D>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A
 
     const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
     GGML_UNUSED(Q_v);
 
-    half sum = 0.0f;
+    T sum = 0.0f;
 
 #pragma unroll
     for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
@@ -95,12 +94,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
         const int iqs4  = k_KQ %  QI4_0;
         const int shift = k_KQ & (QI8_1/2);
 
-        const int v = (get_int_from_uint8(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
         const int u = Q_q8[k_KQ_0/WARP_SIZE];
 
-        const int sumi = __dp4a(v, u, 0);
+        const int sumi = ggml_cuda_dp4a(v, u, 0);
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
         if (std::is_same<T, half>::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
 
@@ -116,19 +115,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
     }
 
     return sum;
-#else
-    GGML_UNUSED(K_c);
-    GGML_UNUSED(Q_v);
-    GGML_UNUSED(Q_q8);
-    GGML_UNUSED(Q_ds_v);
-    NO_DEVICE_CODE;
-#endif  // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
 template<typename T, int D>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A
 
     const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
     GGML_UNUSED(Q_v);
@@ -143,12 +134,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
         const int iqs4  = k_KQ %  QI4_1;
         const int shift = k_KQ & (QI8_1/2);
 
-        const int v = (get_int_from_uint8_aligned(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
         const int u = Q_q8[k_KQ_0/WARP_SIZE];
 
-        const int sumi = __dp4a(v, u, 0);
+        const int sumi = ggml_cuda_dp4a(v, u, 0);
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
         if (std::is_same<T, half>::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
 
@@ -168,19 +159,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
     }
 
     return sum;
-#else
-    GGML_UNUSED(K_c);
-    GGML_UNUSED(Q_v);
-    GGML_UNUSED(Q_q8);
-    GGML_UNUSED(Q_ds_v);
-    NO_DEVICE_CODE;
-#endif  // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
 template<typename T, int D>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A
 
     const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
     GGML_UNUSED(Q_v);
@@ -196,8 +179,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
         const int iqs8  = k_KQ %  QI8_1;
         const int shift = k_KQ & (QI8_1/2);
 
-        int v = (get_int_from_uint8(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
-        const int vh = get_int_from_uint8(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
+        int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
         v |= (vh <<  4) & 0x00000010; // 0 ->  4
         v |= (vh << 11) & 0x00001000; // 1 -> 12
         v |= (vh << 18) & 0x00100000; // 2 -> 20
@@ -205,9 +188,9 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
 
         const int u = Q_q8[k_KQ_0/WARP_SIZE];
 
-        const int sumi = __dp4a(v, u, 0);
+        const int sumi = ggml_cuda_dp4a(v, u, 0);
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
         if (std::is_same<T, half>::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
 
@@ -223,19 +206,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
     }
 
     return sum;
-#else
-    GGML_UNUSED(K_c);
-    GGML_UNUSED(Q_v);
-    GGML_UNUSED(Q_q8);
-    GGML_UNUSED(Q_ds_v);
-    NO_DEVICE_CODE;
-#endif  // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
 template<typename T, int D>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A
 
     const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
     GGML_UNUSED(Q_v);
@@ -251,8 +226,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
         const int iqs8  = k_KQ %  QI8_1;
         const int shift = k_KQ & (QI8_1/2);
 
-        int v = (get_int_from_uint8(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
-        const int vh = get_int_from_uint8(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
+        int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
         v |= (vh <<  4) & 0x00000010; // 0 ->  4
         v |= (vh << 11) & 0x00001000; // 1 -> 12
         v |= (vh << 18) & 0x00100000; // 2 -> 20
@@ -260,9 +235,9 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
 
         const int u = Q_q8[k_KQ_0/WARP_SIZE];
 
-        const int sumi = __dp4a(v, u, 0);
+        const int sumi = ggml_cuda_dp4a(v, u, 0);
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
         if (std::is_same<T, half>::value) {
             const half2  * Q_ds = (const half2  *) Q_ds_v;
 
@@ -282,19 +257,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
     }
 
     return sum;
-#else
-    GGML_UNUSED(K_c);
-    GGML_UNUSED(Q_v);
-    GGML_UNUSED(Q_q8);
-    GGML_UNUSED(Q_ds_v);
-    NO_DEVICE_CODE;
-#endif  // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
 template <typename T, int D>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A
 
     const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
     GGML_UNUSED(Q_v);
@@ -308,7 +275,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
         const int ib  = k_KQ / QI8_0;
         const int iqs = k_KQ % QI8_0;
 
-        const int v = get_int_from_int8(K_q8_0[ib].qs, iqs);
+        const int v = get_int_b2(K_q8_0[ib].qs, iqs);
 
         T Q_d;
         if (std::is_same<T, half>::value) {
@@ -323,13 +290,6 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
     }
 
     return sum;
-#else
-    GGML_UNUSED(K_c);
-    GGML_UNUSED(Q_v);
-    GGML_UNUSED(Q_q8);
-    GGML_UNUSED(Q_ds_v);
-    NO_DEVICE_CODE;
-#endif  // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
 template <typename T, int D>
@@ -340,7 +300,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
     GGML_UNUSED(Q_q8);
     GGML_UNUSED(Q_ds_v);
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
     if (std::is_same<T, half>::value) {
         const half2 * Q_h2 = (const half2 *) Q_v;
 
@@ -433,7 +393,7 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__
     const int q0 = x[ib].qs[iqs];
     const int q  = ((q0 >> (4*shift)) & 0x0F) - 8;
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
     if (std::is_same<T, half>::value) {
         return ((half) d)*((half) q);
     }
@@ -454,7 +414,7 @@ static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__
     const int   q0 = x[ib].qs[iqs];
     const int   q  = ((q0 >> (4*shift)) & 0x0F);
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
     if (std::is_same<T, half>::value) {
         return __low2half(dm)*((half) q) + __high2half(dm);
     }
@@ -474,12 +434,12 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__
 
     const T   d   = x[ib].d;
     const int ql0 = x[ib].qs[iqs];
-    const int qh0 = get_int_from_uint8(x[ib].qh, 0);
+    const int qh0 = get_int_b2(x[ib].qh, 0);
     const int ql  = ((ql0 >> (4*shift)) & 0x0F);
     const int qh  = ((qh0 >> idq) << 4) & 0x10;
     const int q   = (ql | qh) - 16;
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
     if (std::is_same<T, half>::value) {
         return ((half) d)*((half) q);
     }
@@ -499,12 +459,12 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
 
     const half2 dm  = x[ib].dm;
     const int   ql0 = x[ib].qs[iqs];
-    const int   qh0 = get_int_from_uint8_aligned(x[ib].qh, 0);
+    const int   qh0 = get_int_b4(x[ib].qh, 0);
     const int   ql  = ((ql0 >> (4*shift)) & 0x0F);
     const int   qh  = ((qh0 >> idq) << 4) & 0x10;
     const int   q   = (ql | qh);
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
     if (std::is_same<T, half>::value) {
         return __low2half(dm)*((half) q) + __high2half(dm);
     }
@@ -523,7 +483,7 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__
     const T   d = x[ib].d;
     const int q = x[ib].qs[iqs];
 
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
     if (std::is_same<T, half>::value) {
         return ((half) d)*((half) q);
     }
@@ -629,20 +589,20 @@ static void on_no_fattn_vec_case(const int D) {
     if (D == 64) {
         fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
         fprintf(stderr, "By default only f16 KV cache is supported.\n");
-        fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
-        GGML_ASSERT(false);
+        fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
+        GGML_ABORT("fatal error");
     } else if (D == 128) {
         fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
         fprintf(stderr, "Supported combinations:\n");
         fprintf(stderr, "  - K == q4_0, V == q4_0,  4.50 BPV\n");
         fprintf(stderr, "  - K == q8_0, V == q8_0,  8.50 BPV\n");
         fprintf(stderr, "  - K == f16,  V == f16,  16.00 BPV\n");
-        fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
-        GGML_ASSERT(false);
+        fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
+        GGML_ABORT("fatal error");
     } else {
         fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
         fprintf(stderr, "Only f16 is supported.\n");
-        GGML_ASSERT(false);
+        GGML_ABORT("fatal error");
     }
 }
 

+ 3 - 3
llama/ggml-cuda/fattn-tile-f16.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -69,7 +69,7 @@ static __global__ void flash_attn_tile_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
@@ -313,7 +313,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         } break;
         default: {
-            GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
+            GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
         } break;
     }
 }

+ 1 - 1
llama/ggml-cuda/fattn-tile-f16.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 2 - 2
llama/ggml-cuda/fattn-tile-f32.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -310,7 +310,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         } break;
         default: {
-            GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
+            GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
         } break;
     }
 }

+ 1 - 1
llama/ggml-cuda/fattn-tile-f32.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 2 - 2
llama/ggml-cuda/fattn-vec-f16.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -66,7 +66,7 @@ static __global__ void flash_attn_vec_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#if FP16_AVAILABLE
+#ifdef FP16_AVAILABLE
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);

+ 2 - 2
llama/ggml-cuda/fattn-vec-f32.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -175,7 +175,7 @@ static __global__ void flash_attn_vec_ext_f32(
             for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
                 const int i = i0 + threadIdx.x;
 
-                Q_f2[j][i0/WARP_SIZE]    = ncols <= 2 || ic0 + j ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
+                Q_f2[j][i0/WARP_SIZE]    = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
                 Q_f2[j][i0/WARP_SIZE].x *= scale;
                 Q_f2[j][i0/WARP_SIZE].y *= scale;
             }

+ 4 - 4
llama/ggml-cuda/fattn-wmma-f16.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -27,9 +27,9 @@
 #include "common.cuh"
 #include "fattn-common.cuh"
 
-#if FP16_MMA_AVAILABLE
+#ifdef FP16_MMA_AVAILABLE
 #include <mma.h>
-#endif
+#endif // FP16_MMA_AVAILABLE
 
 // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
 template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
@@ -71,7 +71,7 @@ static __global__ void flash_attn_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#if FP16_MMA_AVAILABLE
+#ifdef FP16_MMA_AVAILABLE
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.

+ 6 - 6
llama/ggml-cuda/fattn.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -64,7 +64,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
                     ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
                     break;
                 default:
-                    GGML_ASSERT(false);
+                    GGML_ABORT("fatal error");
                     break;
             }
         } else {
@@ -89,7 +89,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
                 //     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
                 //     break;
                 default:
-                    GGML_ASSERT(false);
+                    GGML_ABORT("fatal error");
                     break;
             }
         }
@@ -112,7 +112,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
                 ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
                 break;
             default:
-                GGML_ASSERT(false);
+                GGML_ABORT("fatal error");
                 break;
         }
         return;
@@ -140,7 +140,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
                 ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
                 break;
             default:
-                GGML_ASSERT(false);
+                GGML_ABORT("fatal error");
                 break;
         }
         return;
@@ -167,7 +167,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
             ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
             break;
         default:
-            GGML_ASSERT(false);
+            GGML_ABORT("fatal error");
             break;
     }
 }

+ 1 - 1
llama/ggml-cuda/fattn.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 2 - 3
llama/ggml-cuda/getrows.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -197,8 +197,7 @@ void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
             break;
         default:
             // TODO: k-quants
-            fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
-            GGML_ASSERT(false);
+            GGML_ABORT("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
             break;
     }
 }

+ 1 - 1
llama/ggml-cuda/getrows.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/im2col.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/im2col.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 247 - 0
llama/ggml-cuda/mma.cuh

@@ -0,0 +1,247 @@
+/**
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+struct mma_int_A_I16K4 {
+    static constexpr int I  = 16;
+    static constexpr int K  = 4;
+    static constexpr int ne = 2;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_i(const int l) {
+        const int ret = (l%2) * (I/2) + threadIdx.x / K;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  I);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_k(const int /* l */) {
+        const int ret = threadIdx.x % K;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  K);
+        return ret;
+    }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE)
+        const int * xs = xs0 + (threadIdx.x%I)*stride;
+        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "+r"(x[0]), "+r"(x[1])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_i(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
+};
+
+struct mma_int_A_I16K8 {
+    static constexpr int I  = 16;
+    static constexpr int K  = 8;
+    static constexpr int ne = 4;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_i(const int l) {
+        const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  I);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_k(const int l) {
+        const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  K);
+        return ret;
+    }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE)
+        const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
+        asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
+            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_i(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
+
+    __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
+        ((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
+    }
+};
+
+struct mma_int_B_J8K4 {
+    static constexpr int J  = 8;
+    static constexpr int K  = 4;
+    static constexpr int ne = 1;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_j(const int /* l */) {
+        const int ret = threadIdx.x / K;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  J);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_k(const int /* l */) {
+        const int ret = threadIdx.x % K;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  K);
+        return ret;
+    }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
+        const int * xs = xs0 + (threadIdx.x%J)*stride;
+        asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
+            : "+r"(x[0])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_j(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
+};
+
+struct mma_int_B_J8K8 {
+    static constexpr int J  = 8;
+    static constexpr int K  = 8;
+    static constexpr int ne = 2;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_j(const int /* l */) {
+        const int ret = threadIdx.x / (K/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  J);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_k(const int l) {
+        const int ret = l * (K/2) + threadIdx.x % (K/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  K);
+        return ret;
+    }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
+        const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
+        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "+r"(x[0]), "+r"(x[1])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_j(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
+};
+
+struct mma_int_C_I16J8 {
+    static constexpr int I  = 16;
+    static constexpr int J  = 8;
+    static constexpr int ne = 4;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_i(const int l) {
+        const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  I);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_j(const int l) {
+        const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  J);
+        return ret;
+    }
+
+    __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
+#ifdef INT8_MMA_AVAILABLE
+#if __CUDA_ARCH__ >= CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+#else
+        // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[0]), "+r"(x[1])
+            : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+#endif // __CUDA_ARCH__ >= CC_AMPERE
+#else
+        GGML_UNUSED(mma_A);
+        GGML_UNUSED(mma_B);
+        NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
+#ifdef INT8_MMA_AVAILABLE
+#if __CUDA_ARCH__ >= CC_AMPERE
+        asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
+#else
+        // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[0]), "+r"(x[1])
+            : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[0]), "+r"(x[1])
+            : "r"(mma_A.x[2]), "r"(mma_B.x[1]));
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[3]), "r"(mma_B.x[1]));
+#endif // __CUDA_ARCH__ >= CC_AMPERE
+#else
+        GGML_UNUSED(mma_A);
+        GGML_UNUSED(mma_B);
+        NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+    }
+};

+ 79 - 16
llama/ggml-cuda/mmq.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -37,6 +37,7 @@ void ggml_cuda_op_mul_mat_q(
     const int64_t nb01 = src0->nb[1];
 
     const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
     GGML_ASSERT(ne10 % QK8_1 == 0);
 
     const int64_t ne0 = dst->ne[0];
@@ -51,41 +52,65 @@ void ggml_cuda_op_mul_mat_q(
     // nrows_dst == nrows of the matrix that the kernel writes into
     const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
 
-    const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, nrows_dst};
+    const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst};
 
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
-            mul_mat_q_case<GGML_TYPE_Q4_0>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
             break;
         case GGML_TYPE_Q4_1:
-            mul_mat_q_case<GGML_TYPE_Q4_1>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);
             break;
         case GGML_TYPE_Q5_0:
-            mul_mat_q_case<GGML_TYPE_Q5_0>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream);
             break;
         case GGML_TYPE_Q5_1:
-            mul_mat_q_case<GGML_TYPE_Q5_1>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
             break;
         case GGML_TYPE_Q8_0:
-            mul_mat_q_case<GGML_TYPE_Q8_0>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
             break;
         case GGML_TYPE_Q2_K:
-            mul_mat_q_case<GGML_TYPE_Q2_K>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
             break;
         case GGML_TYPE_Q3_K:
-            mul_mat_q_case<GGML_TYPE_Q3_K>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream);
             break;
         case GGML_TYPE_Q4_K:
-            mul_mat_q_case<GGML_TYPE_Q4_K>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream);
             break;
         case GGML_TYPE_Q5_K:
-            mul_mat_q_case<GGML_TYPE_Q5_K>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream);
             break;
         case GGML_TYPE_Q6_K:
-            mul_mat_q_case<GGML_TYPE_Q6_K>(args, stream);
+            mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ2_XXS:
+            mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ2_XS:
+            mul_mat_q_case<GGML_TYPE_IQ2_XS>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ2_S:
+            mul_mat_q_case<GGML_TYPE_IQ2_S>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ3_XXS:
+            mul_mat_q_case<GGML_TYPE_IQ3_XXS>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ3_S:
+            mul_mat_q_case<GGML_TYPE_IQ3_S>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ1_S:
+            mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ4_XS:
+            mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ4_NL:
+            mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
             break;
         default:
-            GGML_ASSERT(false);
+            GGML_ABORT("fatal error");
             break;
     }
 
@@ -94,7 +119,13 @@ void ggml_cuda_op_mul_mat_q(
     GGML_UNUSED(src1_ddf_i);
 }
 
-bool ggml_cuda_supports_mmq(enum ggml_type type) {
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
+#ifdef GGML_CUDA_FORCE_CUBLAS
+    return false;
+#endif // GGML_CUDA_FORCE_CUBLAS
+
+    bool mmq_supported;
+
     switch (type) {
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
@@ -106,8 +137,40 @@ bool ggml_cuda_supports_mmq(enum ggml_type type) {
         case GGML_TYPE_Q4_K:
         case GGML_TYPE_Q5_K:
         case GGML_TYPE_Q6_K:
-            return true;
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ4_NL:
+            mmq_supported = true;
+            break;
         default:
-            return false;
+            mmq_supported = false;
+            break;
+    }
+
+    if (!mmq_supported) {
+        return false;
     }
+
+    if (int8_mma_available(cc)) {
+        return true;
+    }
+
+    if (cc < MIN_CC_DP4A) {
+        return false;
+    }
+
+#ifdef GGML_CUDA_FORCE_MMQ
+    return true;
+#endif //GGML_CUDA_FORCE_MMQ
+
+    if (cc < CC_OFFSET_AMD) {
+        return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+    }
+
+    return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
 }

File diff suppressed because it is too large
+ 1033 - 364
llama/ggml-cuda/mmq.cuh


+ 21 - 15
llama/ggml-cuda/mmvq.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -54,16 +54,22 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
 
 static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
     return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
-        type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
-        type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
-        type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
-        type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
-        type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
-        type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
-        type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
-        type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
-        type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
-        type == GGML_TYPE_IQ4_NL ? VDR_Q4_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q4_1    ? VDR_Q4_1_Q8_1_MMVQ :
+        type == GGML_TYPE_Q5_0    ? VDR_Q5_0_Q8_1_MMVQ :
+        type == GGML_TYPE_Q5_1    ? VDR_Q5_1_Q8_1_MMVQ :
+        type == GGML_TYPE_Q8_0    ? VDR_Q8_0_Q8_1_MMVQ :
+        type == GGML_TYPE_Q2_K    ? VDR_Q2_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q3_K    ? VDR_Q3_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q4_K    ? VDR_Q4_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q5_K    ? VDR_Q5_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q6_K    ? VDR_Q6_K_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ2_XS  ? VDR_IQ2_XS_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ2_S   ? VDR_IQ2_S_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ3_S   ? VDR_IQ3_S_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ4_NL  ? VDR_IQ4_NL_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ4_XS  ? VDR_IQ4_XS_Q8_1_MMVQ :
         1;
 }
 
@@ -143,7 +149,7 @@ static __global__ void mul_mat_vec_q(
             tmp[j][i] = warp_reduce_sum(tmp[j][i]);
         }
 
-        if (threadIdx.x < rows_per_cuda_block) {
+        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
             dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
         }
     }
@@ -182,7 +188,7 @@ static void mul_mat_vec_q_cuda(
                 rows_per_cuda_block = 2;
                 break;
             default:
-                GGML_ASSERT(false);
+                GGML_ABORT("fatal error");
                 break;
         }
     }
@@ -216,7 +222,7 @@ static void mul_mat_vec_q_cuda(
             mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         default:
-            GGML_ASSERT(false);
+            GGML_ABORT("fatal error");
             break;
     }
 }
@@ -433,7 +439,7 @@ void ggml_cuda_op_mul_mat_vec_q(
             mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
             break;
         default:
-            GGML_ASSERT(false);
+            GGML_ABORT("fatal error");
             break;
     }
 

+ 3 - 1
llama/ggml-cuda/mmvq.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -26,6 +26,8 @@
 
 #include "common.cuh"
 
+#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
+
 void ggml_cuda_op_mul_mat_vec_q(
     ggml_backend_cuda_context & ctx,
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,

+ 1 - 1
llama/ggml-cuda/norm.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/norm.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/pad.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/pad.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/pool2d.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/pool2d.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 135 - 11
llama/ggml-cuda/quantize.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -25,24 +25,25 @@
  */
 
 #include "quantize.cuh"
+#include <cstdint>
 
-static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx_padded) {
-    const int64_t ix = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
+    const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (ix >= kx_padded) {
+    if (ix0 >= kx0_padded) {
         return;
     }
 
-    const int64_t iy = (int64_t)blockDim.y*blockIdx.y + threadIdx.y;
+    const int64_t ix1 = blockIdx.y;
 
-    const int64_t i_padded = (int64_t)iy*kx_padded + ix;
+    const int64_t i_padded = ix1*kx0_padded + ix0;
 
     block_q8_1 * y = (block_q8_1 *) vy;
 
     const int64_t ib = i_padded / QK8_1; // block index
     const int64_t iqs = i_padded % QK8_1; // quant index
 
-    const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
+    const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
     float amax = fabsf(xi);
     float sum = xi;
 
@@ -62,10 +63,133 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
     reinterpret_cast<half&>(y[ib].ds.y) = sum;
 }
 
-void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream) {
-    const int64_t block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
-    const dim3 num_blocks(block_num_x, ky, 1);
+template <mmq_q8_1_ds_layout ds_layout>
+static __global__ void quantize_mmq_q8_1(
+    const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
+
+    constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
+    constexpr int vals_per_sum   = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
+
+    const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
+
+    if (ix0 >= kx0_padded) {
+        return;
+    }
+
+    const float4 * x4 = (const float4 *) x;
+
+    const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
+
+    block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
+
+    const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
+    const int64_t ib  = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y;                   // block index in channel
+    const int64_t iqs = ix0 % (4*QK8_1);                                            // quant index in block
+
+    // Load 4 floats per thread and calculate max. abs. value between them:
+    const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
+    float amax = fabsf(xi.x);
+    amax = fmaxf(amax, fabsf(xi.y));
+    amax = fmaxf(amax, fabsf(xi.z));
+    amax = fmaxf(amax, fabsf(xi.w));
+
+    // Exchange max. abs. value between vals_per_scale/4 threads.
+#pragma unroll
+    for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
+        amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
+    }
+
+    float sum;
+    if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
+        sum = xi.x + xi.y + xi.z + xi.w;
+
+        // Exchange calculate sum across vals_per_sum/4 threads.
+#pragma unroll
+        for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
+            sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
+        }
+    }
+
+    const float d_inv = 127.0f / amax;
+    char4 q;
+    q.x = roundf(xi.x*d_inv);
+    q.y = roundf(xi.y*d_inv);
+    q.z = roundf(xi.z*d_inv);
+    q.w = roundf(xi.w*d_inv);
+
+    // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
+    char4 * yqs4 = (char4 *) y[ib].qs;
+    yqs4[iqs/4] = q;
+
+    if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
+        if (iqs % 16 != 0 || iqs >= 96) {
+            return;
+        }
+
+        y[ib].d2s6[2 + iqs/16] = sum;
+
+        if (iqs % 64 != 0) {
+            return;
+        }
+
+        const float d = 1.0f / d_inv;
+
+        y[ib].d2s6[iqs/64] = d;
+
+        return;
+    }
+
+    if (iqs % 32 != 0) {
+        return;
+    }
+
+    const float d = 1.0f / d_inv;
+
+    if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
+        y[ib].ds4[iqs/32] = make_half2(d, sum);
+    } else {
+        y[ib].d4[iqs/32]  = d;
+    }
+}
+
+void quantize_row_q8_1_cuda(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
+    const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
+
+    GGML_ASSERT(kx0_padded % QK8_1 == 0);
+
+    const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+    const dim3 num_blocks(block_num_x, kx1*channels, 1);
     const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
-    quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
+    quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded);
+
+    GGML_UNUSED(type_x);
 }
 
+void quantize_mmq_q8_1_cuda(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
+    const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
+
+    GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
+
+    const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
+    const dim3 num_blocks(block_num_x, kx1, channels);
+    const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
+    switch (mmq_get_q8_1_ds_layout(type_x)) {
+        case MMQ_Q8_1_DS_LAYOUT_D4:
+            quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
+                <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+            break;
+        case MMQ_Q8_1_DS_LAYOUT_DS4:
+            quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
+                <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+            break;
+        case MMQ_Q8_1_DS_LAYOUT_D2S6:
+            quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
+                <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}

+ 22 - 3
llama/ggml-cuda/quantize.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -24,8 +24,27 @@
  * SOFTWARE.
  */
 
+#pragma once
+
 #include "common.cuh"
+#include "mmq.cuh"
+
+#include <cstdint>
+
+#define CUDA_QUANTIZE_BLOCK_SIZE     256
+#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
+
+static_assert(MATRIX_ROW_PADDING %    CUDA_QUANTIZE_BLOCK_SIZE      == 0, "Risk of out-of-bounds access.");
+static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
+
+typedef void (*quantize_cuda_t)(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
+    const ggml_type type_x, cudaStream_t stream);
 
-#define CUDA_QUANTIZE_BLOCK_SIZE 256
+void quantize_row_q8_1_cuda(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
+    const ggml_type type_x, cudaStream_t stream);
 
-void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream);
+void quantize_mmq_q8_1_cuda(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
+    const ggml_type type_x, cudaStream_t stream);

+ 3 - 3
llama/ggml-cuda/rope.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -277,7 +277,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
                 attn_factor, corr_dims, freq_factors, stream
             );
         } else {
-            GGML_ASSERT(false);
+            GGML_ABORT("fatal error");
         }
     } else {
         if (src0->type == GGML_TYPE_F32) {
@@ -291,7 +291,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
                 attn_factor, corr_dims, freq_factors, stream
             );
         } else {
-            GGML_ASSERT(false);
+            GGML_ABORT("fatal error");
         }
     }
 }

+ 1 - 1
llama/ggml-cuda/rope.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/scale.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/scale.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 2 - 1
llama/ggml-cuda/softmax.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -156,6 +156,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
+    // FIXME: this limit could be raised by ~2-4x on Ampere or newer
     if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
         switch (ncols_x) {
             case 32:

+ 1 - 1
llama/ggml-cuda/softmax.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/sumrows.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/sumrows.cuh

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *

Some files were not shown because too many files changed in this diff