Browse Source

update to `d5c938cd`

jmorganca 11 months ago
parent
commit
783134a3bb
62 changed files with 7596 additions and 5554 deletions
  1. BIN
      ggml-metal.o
  2. 27 1
      llama/build-info.cpp
  3. 1 1
      llama/clip.cpp
  4. 2 2
      llama/clip.h
  5. 732 397
      llama/common.cpp
  6. 167 91
      llama/common.h
  7. 5 5
      llama/ggml-alloc.c
  8. 1 1
      llama/ggml-alloc.h
  9. 1 1
      llama/ggml-backend-impl.h
  10. 6 6
      llama/ggml-backend.c
  11. 2 2
      llama/ggml-backend.h
  12. 7 55
      llama/ggml-common.h
  13. 117 122
      llama/ggml-cuda.cu
  14. 2 1
      llama/ggml-cuda.h
  15. 156 6
      llama/ggml-cuda/common.cuh
  16. 157 10
      llama/ggml-cuda/concat.cu
  17. 0 138
      llama/ggml-cuda/convert.cu
  18. 21 160
      llama/ggml-cuda/dmmv.cu
  19. 585 6
      llama/ggml-cuda/fattn-common.cuh
  20. 10 3
      llama/ggml-cuda/fattn-tile-f16.cu
  21. 11 8
      llama/ggml-cuda/fattn-tile-f32.cu
  22. 394 2
      llama/ggml-cuda/fattn-vec-f16.cuh
  23. 372 1
      llama/ggml-cuda/fattn-vec-f32.cuh
  24. 490 0
      llama/ggml-cuda/fattn-wmma-f16.cuh
  25. 228 521
      llama/ggml-cuda/fattn.cu
  26. 0 2174
      llama/ggml-cuda/mmq.cu
  27. 1300 0
      llama/ggml-cuda/mmq.cuh
  28. 76 61
      llama/ggml-cuda/mmvq.cu
  29. 6 0
      llama/ggml-cuda/norm.cu
  30. 120 157
      llama/ggml-cuda/rope.cu
  31. 43 169
      llama/ggml-cuda/vecdotq.cuh
  32. 45 1
      llama/ggml-impl.h
  33. 148 90
      llama/ggml-metal-darwin_arm64.m
  34. 2 2
      llama/ggml-metal.h
  35. 168 473
      llama/ggml-metal.metal
  36. BIN
      llama/ggml-metal.o
  37. 541 258
      llama/ggml-quants.c
  38. 1 1
      llama/ggml-quants.h
  39. 371 229
      llama/ggml.c
  40. 54 46
      llama/ggml.h
  41. 119 32
      llama/grammar-parser.cpp
  42. 1 1
      llama/grammar-parser.h
  43. 21 59
      llama/json-schema-to-grammar.cpp
  44. 1 1
      llama/json-schema-to-grammar.h
  45. 393 163
      llama/llama.cpp
  46. 42 52
      llama/llama.h
  47. 1 1
      llama/llava.cpp
  48. 1 1
      llama/llava.h
  49. 1 1
      llama/log.h
  50. 18 9
      llama/patches/02-llamacpp.diff
  51. 3 3
      llama/patches/03-metal.diff
  52. 0 13
      llama/patches/04-qwen2.diff
  53. 90 8
      llama/sampling.cpp
  54. 6 1
      llama/sampling.h
  55. 260 0
      llama/sampling_ext.cpp
  56. 260 0
      llama/sampling_ext.h
  57. 1 1
      llama/stb_image.h
  58. 1 1
      llama/unicode-data.cpp
  59. 1 1
      llama/unicode-data.h
  60. 1 1
      llama/unicode.cpp
  61. 1 1
      llama/unicode.h
  62. 5 3
      scripts/sync_llama.sh

BIN
ggml-metal.o


+ 27 - 1
llama/build-info.cpp

@@ -1,4 +1,30 @@
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
 int LLAMA_BUILD_NUMBER = 0;
-char const *LLAMA_COMMIT = "059031b8c40e1f4ba60586842c5b1ed3ddf61842";
+char const *LLAMA_COMMIT = "";
 char const *LLAMA_COMPILER = "";
 char const *LLAMA_BUILD_TARGET = "";

+ 1 - 1
llama/clip.cpp

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *

+ 2 - 2
llama/clip.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -94,7 +94,7 @@ CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8
 /** interpret bytes as an image file with length bytes_length, and use the result to populate img */
 CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
 
-/** preprocess img and store the result in res_imgs, pad_to_square may be overriden to false depending on model configuration */
+/** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */
 CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs );
 
 CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx);

File diff suppressed because it is too large
+ 732 - 397
llama/common.cpp


+ 167 - 91
llama/common.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -53,7 +53,7 @@
 #define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
 
 #define print_build_info() do {                                                                     \
-    fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);           \
+    fprintf(stderr, "%s: build = %d (%s)\n",      __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);      \
     fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);    \
 } while(0)
 
@@ -61,14 +61,18 @@
 
 // build info
 extern int LLAMA_BUILD_NUMBER;
-extern char const *LLAMA_COMMIT;
-extern char const *LLAMA_COMPILER;
-extern char const *LLAMA_BUILD_TARGET;
+extern char const * LLAMA_COMMIT;
+extern char const * LLAMA_COMPILER;
+extern char const * LLAMA_BUILD_TARGET;
 
 struct llama_control_vector_load_info;
 
-int get_math_cpu_count();
-int32_t get_num_physical_cores();
+//
+// CPU utils
+//
+
+int32_t cpu_get_num_physical_cores();
+int32_t cpu_get_num_math();
 
 //
 // CLI argument parsing
@@ -77,67 +81,68 @@ int32_t get_num_physical_cores();
 struct gpt_params {
     uint32_t seed                 = LLAMA_DEFAULT_SEED; // RNG seed
 
-    int32_t n_threads             = get_math_cpu_count();
-    int32_t n_threads_draft       = -1;
-    int32_t n_threads_batch       = -1;    // number of threads to use for batch processing (-1 = use n_threads)
-    int32_t n_threads_batch_draft = -1;
-    int32_t n_predict             = -1;    // new tokens to predict
-    int32_t n_ctx                 = 512;   // context size
-    int32_t n_batch               = 2048;  // logical batch size for prompt processing (must be >=32 to use BLAS)
-    int32_t n_ubatch              = 512;   // physical batch size for prompt processing (must be >=32 to use BLAS)
-    int32_t n_keep                = 0;     // number of tokens to keep from initial prompt
-    int32_t n_draft               = 5;     // number of tokens to draft during speculative decoding
-    int32_t n_chunks              = -1;    // max number of chunks to process (-1 = unlimited)
-    int32_t n_parallel            = 1;     // number of parallel sequences to decode
-    int32_t n_sequences           = 1;     // number of sequences to decode
-    float   p_split               = 0.1f;  // speculative decoding split probability
-    int32_t n_gpu_layers          = -1;    // number of layers to store in VRAM (-1 - use default)
-    int32_t n_gpu_layers_draft    = -1;    // number of layers to store in VRAM for the draft model (-1 - use default)
-    llama_split_mode split_mode   = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
-    int32_t main_gpu              = 0;     // the GPU that is used for scratch and small tensors
-    float   tensor_split[128]     = {0};   // how split tensors should be distributed across GPUs
-    int32_t n_beams               = 0;     // if non-zero then use beam search of given width.
-    int32_t grp_attn_n            = 1;     // group-attention factor
-    int32_t grp_attn_w            = 512;   // group-attention width
-    int32_t n_print               = -1;    // print token count every n tokens (-1 = disabled)
-    float   rope_freq_base        = 0.0f;  // RoPE base frequency
-    float   rope_freq_scale       = 0.0f;  // RoPE frequency scaling factor
+    int32_t n_threads             = cpu_get_num_math();
+    int32_t n_threads_draft       =    -1;
+    int32_t n_threads_batch       =    -1; // number of threads to use for batch processing (-1 = use n_threads)
+    int32_t n_threads_batch_draft =    -1;
+    int32_t n_predict             =    -1; // new tokens to predict
+    int32_t n_ctx                 =     0; // context size
+    int32_t n_batch               =  2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
+    int32_t n_ubatch              =   512; // physical batch size for prompt processing (must be >=32 to use BLAS)
+    int32_t n_keep                =     0; // number of tokens to keep from initial prompt
+    int32_t n_draft               =     5; // number of tokens to draft during speculative decoding
+    int32_t n_chunks              =    -1; // max number of chunks to process (-1 = unlimited)
+    int32_t n_parallel            =     1; // number of parallel sequences to decode
+    int32_t n_sequences           =     1; // number of sequences to decode
+    float   p_split               =  0.1f; // speculative decoding split probability
+    int32_t n_gpu_layers          =    -1; // number of layers to store in VRAM (-1 - use default)
+    int32_t n_gpu_layers_draft    =    -1; // number of layers to store in VRAM for the draft model (-1 - use default)
+    int32_t main_gpu              =     0; // the GPU that is used for scratch and small tensors
+    float   tensor_split[128]     =   {0}; // how split tensors should be distributed across GPUs
+    int32_t n_beams               =     0; // if non-zero then use beam search of given width.
+    int32_t grp_attn_n            =     1; // group-attention factor
+    int32_t grp_attn_w            =   512; // group-attention width
+    int32_t n_print               =    -1; // print token count every n tokens (-1 = disabled)
+    float   rope_freq_base        =  0.0f; // RoPE base frequency
+    float   rope_freq_scale       =  0.0f; // RoPE frequency scaling factor
     float   yarn_ext_factor       = -1.0f; // YaRN extrapolation mix factor
-    float   yarn_attn_factor      = 1.0f;  // YaRN magnitude scaling factor
+    float   yarn_attn_factor      =  1.0f; // YaRN magnitude scaling factor
     float   yarn_beta_fast        = 32.0f; // YaRN low correction dim
-    float   yarn_beta_slow        = 1.0f;  // YaRN high correction dim
-    int32_t yarn_orig_ctx         = 0;     // YaRN original context length
+    float   yarn_beta_slow        =  1.0f; // YaRN high correction dim
+    int32_t yarn_orig_ctx         =     0; // YaRN original context length
     float   defrag_thold          = -1.0f; // KV cache defragmentation threshold
-    std::string rpc_servers       = "";    // comma separated list of RPC servers
 
     ggml_backend_sched_eval_callback cb_eval = nullptr;
     void * cb_eval_user_data                 = nullptr;
 
     ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
 
+    enum llama_split_mode        split_mode        = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
     enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
     enum llama_pooling_type      pooling_type      = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
 
     // // sampling parameters
     struct llama_sampling_params sparams;
 
-    std::string model                = "";  // model path
-    std::string model_draft          = "";  // draft model for speculative decoding
+    std::string model                = ""; // model path
+    std::string model_draft          = ""; // draft model for speculative decoding
     std::string model_alias          = "unknown"; // model alias
-    std::string model_url            = "";  // model url to download
-    std::string hf_repo              = "";  // HF repo
-    std::string hf_file              = "";  // HF file
+    std::string model_url            = ""; // model url to download
+    std::string hf_repo              = ""; // HF repo
+    std::string hf_file              = ""; // HF file
     std::string prompt               = "";
-    std::string prompt_file          = "";  // store the external prompt file name
-    std::string path_prompt_cache    = "";  // path to file for saving/loading prompt eval state
-    std::string input_prefix         = "";  // string to prefix user inputs with
-    std::string input_suffix         = "";  // string to suffix user inputs with
-    std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
-    std::string logdir               = "";  // directory in which to save YAML log files
+    std::string prompt_file          = ""; // store the external prompt file name
+    std::string path_prompt_cache    = ""; // path to file for saving/loading prompt eval state
+    std::string input_prefix         = ""; // string to prefix user inputs with
+    std::string input_suffix         = ""; // string to suffix user inputs with
+    std::string logdir               = ""; // directory in which to save YAML log files
     std::string lookup_cache_static  = ""; // path of static ngram cache file for lookup decoding
     std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding
-    std::string logits_file          = "";  // file for saving *all* logits
+    std::string logits_file          = ""; // file for saving *all* logits
+    std::string rpc_servers          = ""; // comma separated list of RPC servers
 
+    std::vector<std::string> in_files;   // all input files
+    std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
     std::vector<llama_model_kv_override> kv_overrides;
 
     // TODO: avoid tuple, use struct
@@ -146,36 +151,36 @@ struct gpt_params {
 
     std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale
 
+    int32_t verbosity                  = 0;
     int32_t control_vector_layer_start = -1; // layer range for control vector
     int32_t control_vector_layer_end   = -1; // layer range for control vector
 
-    int  ppl_stride        = 0;     // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
-    int  ppl_output_type   = 0;     // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
-                                    //                                       (which is more convenient to use for plotting)
-                                    //
-    bool   hellaswag       = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
-    size_t hellaswag_tasks = 400;   // number of tasks to use when computing the HellaSwag score
+    int32_t ppl_stride      = 0;     // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
+    int32_t ppl_output_type = 0;     // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
+                                     //                                       (which is more convenient to use for plotting)
+                                     //
+    bool   hellaswag        = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
+    size_t hellaswag_tasks  = 400;   // number of tasks to use when computing the HellaSwag score
 
-    bool   winogrande      = false; // compute Winogrande score over random tasks from datafile supplied in prompt
-    size_t winogrande_tasks= 0;     // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed
+    bool   winogrande       = false; // compute Winogrande score over random tasks from datafile supplied in prompt
+    size_t winogrande_tasks = 0;     // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed
 
-    bool   multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt
-    size_t multiple_choice_tasks = 0;     // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed
+    bool   multiple_choice  = false;  // compute TruthfulQA score over random tasks from datafile supplied in prompt
+    size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed
 
-    bool   kl_divergence   = false; // compute KL divergence
+    bool   kl_divergence    = false; // compute KL divergence
 
-    bool random_prompt     = false; // do not randomize prompt if none provided
+    bool usage             = false; // print usage
     bool use_color         = false; // use color to distinguish generations and inputs
+    bool special           = false; // enable special token output
     bool interactive       = false; // interactive mode
-    bool interactive_specials = false; // whether to allow special tokens from user, during interactive mode
+    bool interactive_first = false; // wait for user input immediately
     bool conversation      = false; // conversation mode (does not print special tokens and suffix/prefix)
-    bool chatml            = false; // chatml mode (used for models trained on chatml syntax)
     bool prompt_cache_all  = false; // save user input and generations to prompt cache
     bool prompt_cache_ro   = false; // open the prompt cache read-only and do not update it
 
     bool embedding         = false; // get only sentence embedding
-    bool escape            = false; // escape "\n", "\r", "\t", "\'", "\"", and "\\"
-    bool interactive_first = false; // wait for user input immediately
+    bool escape            = true;  // escape "\n", "\r", "\t", "\'", "\"", and "\\"
     bool multiline_input   = false; // reverse the usage of `\`
     bool simple_io         = false; // improves compatibility with subprocesses and limited consoles
     bool cont_batching     = true;  // insert new sequences for decoding on-the-fly
@@ -183,7 +188,6 @@ struct gpt_params {
 
     bool input_prefix_bos  = false; // prefix BOS to user inputs, preceding input_prefix
     bool ignore_eos        = false; // ignore generated EOS tokens
-    bool instruct          = false; // instruction mode (used for Alpaca models)
     bool logits_all        = false; // return logits for all tokens in the batch
     bool use_mmap          = true;  // use mmap for faster loads
     bool use_mlock         = false; // use mlock to keep model in memory
@@ -201,37 +205,102 @@ struct gpt_params {
     // multimodal models (see examples/llava)
     std::string mmproj = "";        // path to multimodal projector
     std::vector<std::string> image; // path to image file(s)
-};
 
-void gpt_params_handle_model_default(gpt_params & params);
+    // server params
+    int32_t port           = 8080;         // server listens on this network port
+    int32_t timeout_read   = 600;          // http read timeout in seconds
+    int32_t timeout_write  = timeout_read; // http write timeout in seconds
+    int32_t n_threads_http = -1;           // number of threads to process HTTP requests
+
+    std::string hostname      = "127.0.0.1";
+    std::string public_path   = "";
+    std::string chat_template = "";
+    std::string system_prompt = "";
+
+    std::vector<std::string> api_keys;
+
+    std::string ssl_file_key  = "";
+    std::string ssl_file_cert = "";
 
-bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
+    bool endpoint_slots   = true;
+    bool endpoint_metrics = false;
 
-bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);
+    bool log_json = false;
 
-bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
+    std::string slot_save_path;
 
-void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
+    // batched-bench params
+    bool is_pp_shared = false;
 
-bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param);
+    std::vector<int32_t> n_pp;
+    std::vector<int32_t> n_tg;
+    std::vector<int32_t> n_pl;
 
-std::string get_system_info(const gpt_params & params);
+    // retrieval params
+    std::vector<std::string> context_files; // context files to embed
 
-std::string gpt_random_prompt(std::mt19937 & rng);
+    int32_t chunk_size = 64; // chunk size for context embedding
 
-void process_escapes(std::string& input);
+    std::string chunk_separator = "\n"; // chunk separator for context embedding
 
-bool validate_file_name(const std::string & filename);
+    // passkey params
+    int32_t n_junk = 250; // number of times to repeat the junk text
+    int32_t i_pos  = -1;  // position of the passkey in the junk text
+
+    // imatrix params
+    std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file
+
+    int32_t n_out_freq  = 10; // output the imatrix every n_out_freq iterations
+    int32_t n_save_freq =  0; // save the imatrix every n_save_freq iterations
+    int32_t i_chunk     =  0; // start processing from this chunk
+
+    bool process_output = false; // collect data for the output tensor
+    bool compute_ppl    = true;  // whether to compute perplexity
+};
+
+void gpt_params_handle_model_default(gpt_params & params);
+
+bool gpt_params_parse_ex   (int argc, char ** argv, gpt_params & params);
+bool gpt_params_parse      (int argc, char ** argv, gpt_params & params);
+bool gpt_params_find_arg   (int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param);
+void gpt_params_print_usage(int argc, char ** argv, const gpt_params & params);
+
+std::string gpt_params_get_system_info(const gpt_params & params);
 
 //
 // String utils
 //
 
-std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
-std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & names_string);
 std::vector<std::string> string_split(std::string input, char separator);
+
 std::string string_strip(const std::string & str);
-std::string sampler_type_to_name_string(llama_sampler_type sampler_type);
+std::string string_get_sortable_timestamp();
+
+template<class T>
+static std::vector<T> string_split(const std::string & str, char delim) {
+    std::vector<T> values;
+    std::istringstream str_stream(str);
+    std::string token;
+    while (std::getline(str_stream, token, delim)) {
+        T value;
+        std::istringstream token_stream(token);
+        token_stream >> value;
+        values.push_back(value);
+    }
+    return values;
+}
+
+bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
+void string_process_escapes(std::string & input);
+
+//
+// Filesystem utils
+//
+
+bool fs_validate_filename(const std::string & filename);
+bool fs_create_directory_with_parents(const std::string & path);
+
+std::string fs_get_cache_directory();
 
 //
 // Model utils
@@ -303,28 +372,21 @@ std::string llama_detokenize_bpe(
 bool llama_should_add_bos_token(const llama_model * model);
 
 //
-// YAML utils
+// Chat template utils
 //
 
-bool create_directory_with_parents(const std::string & path);
-void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data);
-void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data);
-void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data);
-std::string get_sortable_timestamp();
-
-void dump_non_result_info_yaml(
-    FILE * stream, const gpt_params & params, const llama_context * lctx,
-    const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
+// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
+bool llama_chat_verify_template(const std::string & tmpl);
 
 //
 // KV cache utils
 //
 
 // Dump the KV cache view with the number of sequences per cell.
-void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
+void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80);
 
 // Dump the KV cache view showing individual sequences in each cell (long output).
-void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
+void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
 
 //
 // Embedding utils
@@ -358,6 +420,20 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
 //
 // Split utils
 //
+
 static const char * const LLM_KV_SPLIT_NO            = "split.no";
 static const char * const LLM_KV_SPLIT_COUNT         = "split.count";
 static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
+
+//
+// YAML utils
+//
+
+void yaml_dump_vector_float    (FILE * stream, const char * prop_name, const std::vector<float> & data);
+void yaml_dump_vector_int      (FILE * stream, const char * prop_name, const std::vector<int> & data);
+void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data);
+
+void yaml_dump_non_result_info(
+    FILE * stream, const gpt_params & params, const llama_context * lctx,
+    const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
+

+ 5 - 5
llama/ggml-alloc.c

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -403,7 +403,7 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs
     galloc->bufts = calloc(n_bufs, sizeof(ggml_backend_buffer_type_t));
     GGML_ASSERT(galloc->bufts != NULL);
 
-    galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t) * n_bufs);
+    galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t));
     GGML_ASSERT(galloc->buffers != NULL);
 
     galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *));
@@ -776,7 +776,7 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor *
                 // this tensor was allocated without ggml-backend
                 return;
             }
-            ggml_backend_view_init(galloc->buffers[buffer_id], tensor);
+            ggml_backend_view_init(tensor);
         }
     } else {
         if (tensor->data == NULL) {
@@ -925,12 +925,12 @@ static bool alloc_tensor_range(struct ggml_context * ctx,
             if (t->view_src == NULL) {
                 ggml_tallocr_alloc(&tallocr, t);
             } else if (t->buffer == NULL) {
-                ggml_backend_view_init(buffer, t);
+                ggml_backend_view_init(t);
             }
         } else {
             if (t->view_src != NULL && t->buffer == NULL) {
                 // view of a pre-allocated tensor
-                ggml_backend_view_init(buffer, t);
+                ggml_backend_view_init(t);
             }
         }
     }

+ 1 - 1
llama/ggml-alloc.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *

+ 1 - 1
llama/ggml-backend-impl.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *

+ 6 - 6
llama/ggml-backend.c

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -182,7 +182,7 @@ void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) {
 bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst) {
     ggml_backend_buffer_t dst_buf = dst->view_src ? dst->view_src->buffer : dst->buffer;
     if (dst_buf->iface.cpy_tensor) {
-        return src->buffer->iface.cpy_tensor(dst_buf, src, dst);
+        return dst_buf->iface.cpy_tensor(dst_buf, src, dst);
     }
     return false;
 }
@@ -1918,15 +1918,15 @@ ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched,
 
 // utils
 
-void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+void ggml_backend_view_init(struct ggml_tensor * tensor) {
     GGML_ASSERT(tensor->buffer == NULL);
     GGML_ASSERT(tensor->view_src != NULL);
     GGML_ASSERT(tensor->view_src->buffer != NULL);
     GGML_ASSERT(tensor->view_src->data != NULL);
 
-    tensor->buffer = buffer;
+    tensor->buffer = tensor->view_src->buffer;
     tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
-    ggml_backend_buffer_init_tensor(buffer, tensor);
+    ggml_backend_buffer_init_tensor(tensor->buffer, tensor);
 }
 
 void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
@@ -1985,7 +1985,7 @@ static void graph_copy_init_tensor(struct ggml_hash_set hash_set, struct ggml_te
     struct ggml_tensor * dst = node_copies[id];
     if (dst->view_src != NULL) {
         graph_copy_init_tensor(hash_set, node_copies, node_init, src->view_src);
-        ggml_backend_view_init(dst->view_src->buffer, dst);
+        ggml_backend_view_init(dst);
     }
     else {
         ggml_backend_tensor_copy(src, dst);

+ 2 - 2
llama/ggml-backend.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -251,7 +251,7 @@ extern "C" {
 
     // Tensor initialization
     GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
-    GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
 
 
 #ifdef  __cplusplus

+ 7 - 55
llama/ggml-common.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -91,13 +91,8 @@ typedef sycl::half2 ggml_half2;
 // QK = number of values after dequantization
 // QK_K = super-block size
 
-#ifdef GGML_QKK_64
-#define QK_K 64
-#define K_SCALE_SIZE 4
-#else
 #define QK_K 256
 #define K_SCALE_SIZE 12
-#endif // GGML_QKK_64
 
 #if defined(GGML_COMMON_DECL_CUDA) || defined(GGML_COMMON_DECL_HIP) || defined(GGML_COMMON_DECL_SYCL)
 // QR = QK / number of values before dequantization
@@ -154,16 +149,17 @@ typedef sycl::half2 ggml_half2;
 #define QI1_S (QK_K / (4*QR1_S))
 #define QR1_S 8
 
+#define QI1_M (QK_K / (4*QR1_M))
+#define QR1_M 8
+
 #define QI4_NL (QK4_NL / (4*QR4_NL))
 #define QR4_NL 2
 
-#if QK_K == 64
-#define QI4_XS QI4_NL
-#define QR4_XS QR4_NL
-#else
 #define QI4_XS (QK_K / (4*QR4_XS))
 #define QR4_XS 8
-#endif
+
+#define QI3_S (QK_K / (4*QR3_S))
+#define QR3_S 8
 
 #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
 
@@ -254,15 +250,6 @@ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wro
 // weight is represented as x = a * q
 // 16 blocks of 16 elements each
 // Effectively 3.4375 bits per weight
-#ifdef GGML_QKK_64
-typedef struct {
-    uint8_t hmask[QK_K/8]; // quants - high bit
-    uint8_t qs[QK_K/4];    // quants - low 2 bits
-    uint8_t scales[2];
-    ggml_half d;           // super-block scale
-} block_q3_K;
-static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
-#else
 typedef struct {
     uint8_t hmask[QK_K/8]; // quants - high bit
     uint8_t qs[QK_K/4];    // quants - low 2 bits
@@ -270,20 +257,11 @@ typedef struct {
     ggml_half d;           // super-block scale
 } block_q3_K;
 static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
-#endif
 
 // 4-bit quantization
 // 8 blocks of 32 elements each
 // weight is represented as x = a * q + b
 // Effectively 4.5 bits per weight
-#ifdef GGML_QKK_64
-typedef struct {
-    ggml_half d[2];     // super-block scales/mins
-    uint8_t scales[2];  // 4-bit block scales/mins
-    uint8_t qs[QK_K/2]; // 4--bit quants
-} block_q4_K;
-static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + QK_K/2 + 2, "wrong q4_K block size/padding");
-#else
 typedef struct {
     union {
         struct {
@@ -296,21 +274,11 @@ typedef struct {
     uint8_t qs[QK_K/2];           // 4--bit quants
 } block_q4_K;
 static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
-#endif
 
 // 5-bit quantization
 // 8 blocks of 32 elements each
 // weight is represented as x = a * q + b
 // Effectively 5.5 bits per weight
-#ifdef GGML_QKK_64
-typedef struct {
-    ggml_half d;             // super-block scale
-    int8_t  scales[QK_K/16]; // 8-bit block scales
-    uint8_t qh[QK_K/8];      // quants, high bit
-    uint8_t qs[QK_K/2];      // quants, low 4 bits
-} block_q5_K;
-static_assert(sizeof(block_q5_K) == sizeof(ggml_half) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
-#else
 typedef struct {
     union {
         struct {
@@ -324,7 +292,6 @@ typedef struct {
     uint8_t qs[QK_K/2];           // quants, low 4 bits
 } block_q5_K;
 static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
-#endif
 
 // 6-bit quantization
 // weight is represented as x = a * q
@@ -382,11 +349,7 @@ typedef struct {
 static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_half) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
 
 // 3.4375 bpw
-#if QK_K == 64
-#define IQ3S_N_SCALE 2
-#else
 #define IQ3S_N_SCALE QK_K/64
-#endif
 typedef struct {
     ggml_half d;
     uint8_t qs[QK_K/4];
@@ -407,16 +370,9 @@ static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wron
 typedef struct {
     uint8_t  qs[QK_K/8];      // grid index, low 8 bits
     uint8_t  qh[QK_K/16];     // grid index, high 3 bits + grid shift bit (for two groups of 8)
-#if QK_K == 64
-    ggml_half d;
-#endif
     uint8_t  scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64)
 } block_iq1_m;
-#if QK_K == 64
-static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32 + sizeof(ggml_half), "wrong iq1_m block size/padding");
-#else
 static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");
-#endif
 
 // Used by IQ1_M quants
 typedef union {
@@ -432,9 +388,6 @@ typedef struct {
 } block_iq4_nl;
 static_assert(sizeof(block_iq4_nl) == sizeof(ggml_half) + QK4_NL/2, "wrong iq4_nl block size/padding");
 
-#if QK_K == 64
-#define block_iq4_xs block_iq4_nl
-#else
 typedef struct {
     ggml_half d;
     uint16_t scales_h;
@@ -442,7 +395,6 @@ typedef struct {
     uint8_t  qs[QK_K/2];
 } block_iq4_xs;
 static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
-#endif
 
 #endif // GGML_COMMON_DECL
 #endif // GGML_COMMON_DECL

+ 117 - 122
llama/ggml-cuda.cu

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -69,19 +69,59 @@
 #include <mutex>
 #include <stdint.h>
 #include <stdio.h>
+#include <stdarg.h>
+#include <stdlib.h>
 #include <string>
 #include <vector>
 
 static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
 
+static void ggml_cuda_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
+    GGML_UNUSED(level);
+    GGML_UNUSED(user_data);
+    fprintf(stderr, "%s", msg);
+}
+
+ggml_log_callback ggml_cuda_log_callback = ggml_cuda_default_log_callback;
+void * ggml_cuda_log_user_data = NULL;
+
+GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data) {
+    ggml_cuda_log_callback = log_callback;
+    ggml_cuda_log_user_data = user_data;
+}
+
+#define GGML_CUDA_LOG_INFO(...) ggml_cuda_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
+#define GGML_CUDA_LOG_WARN(...) ggml_cuda_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
+#define GGML_CUDA_LOG_ERROR(...) ggml_cuda_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
+
+GGML_ATTRIBUTE_FORMAT(2, 3)
+static void ggml_cuda_log(enum ggml_log_level level, const char * format, ...) {
+    if (ggml_cuda_log_callback != NULL) {
+        va_list args;
+        va_start(args, format);
+        char buffer[128];
+        int len = vsnprintf(buffer, 128, format, args);
+        if (len < 128) {
+            ggml_cuda_log_callback(level, buffer, ggml_cuda_log_user_data);
+        } else {
+            std::vector<char> buffer2(len + 1);  // vsnprintf adds a null terminator
+            va_end(args);
+            va_start(args, format);
+            vsnprintf(&buffer2[0], buffer2.size(), format, args);
+            ggml_cuda_log_callback(level, buffer2.data(), ggml_cuda_log_user_data);
+        }
+        va_end(args);
+    }
+}
+
 [[noreturn]]
 void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
     int id = -1; // in case cudaGetDevice fails
     cudaGetDevice(&id);
 
-    fprintf(stderr, "CUDA error: %s\n", msg);
-    fprintf(stderr, "  current device: %d, in function %s at %s:%d\n", id, func, file, line);
-    fprintf(stderr, "  %s\n", stmt);
+    GGML_CUDA_LOG_ERROR("CUDA error: %s\n", msg);
+    GGML_CUDA_LOG_ERROR("  current device: %d, in function %s at %s:%d\n", id, func, file, line);
+    GGML_CUDA_LOG_ERROR("  %s\n", stmt);
     // abort with GGML_ASSERT to get a stack trace
     GGML_ASSERT(!"CUDA error");
 }
@@ -105,6 +145,20 @@ int ggml_cuda_get_device() {
     return id;
 }
 
+static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
+    ggml_cuda_set_device(device);
+#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
+    auto res = hipMallocManaged(ptr, size);
+    if (res == hipSuccess) {
+        // if error we "need" to know why...
+        CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
+    }
+    return res;
+#else
+    return cudaMalloc(ptr, size);
+#endif
+}
+
 static ggml_cuda_device_info ggml_cuda_init() {
 #ifdef __HIP_PLATFORM_AMD__
     // Workaround for a rocBLAS bug when using multiple graphics cards:
@@ -117,7 +171,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
 
     cudaError_t err = cudaGetDeviceCount(&info.device_count);
     if (err != cudaSuccess) {
-        fprintf(stderr, "%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err));
+        GGML_CUDA_LOG_ERROR("%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err));
         return info;
     }
 
@@ -125,16 +179,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
 
     int64_t total_vram = 0;
 #if defined(GGML_CUDA_FORCE_MMQ)
-    fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ:   yes\n", __func__);
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:   yes\n", __func__);
 #else
-    fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ:   no\n", __func__);
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:   no\n", __func__);
 #endif
 #if defined(CUDA_USE_TENSOR_CORES)
-    fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
+    GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
 #else
-    fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
+    GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
 #endif
-    fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
+    GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
     for (int id = 0; id < info.device_count; ++id) {
         int device_vmm = 0;
 
@@ -155,7 +209,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
 
         cudaDeviceProp prop;
         CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
-        fprintf(stderr, "  Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
+        GGML_CUDA_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
 
         info.default_tensor_split[id] = total_vram;
         total_vram += prop.totalGlobalMem;
@@ -257,12 +311,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
         size_t look_ahead_size = (size_t) (1.05 * size);
         look_ahead_size = 256 * ((look_ahead_size + 255)/256);
         ggml_cuda_set_device(device);
-        CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
+        CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
         *actual_size = look_ahead_size;
         pool_size += look_ahead_size;
 #ifdef DEBUG_CUDA_MALLOC
-        fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz,
-                (uint32_t)(max_size/1024/1024), (uint32_t)(pool_size/1024/1024), (uint32_t)(size/1024/1024));
+        GGML_CUDA_LOG_INFO("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz,
+                           (uint32_t)(max_size / 1024 / 1024), (uint32_t)(pool_size / 1024 / 1024), (uint32_t)(size / 1024 / 1024));
 #endif
         return ptr;
     }
@@ -276,7 +330,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
                 return;
             }
         }
-        fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
+        GGML_CUDA_LOG_WARN("Cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
         ggml_cuda_set_device(device);
         CUDA_CHECK(cudaFree(ptr));
         pool_size -= size;
@@ -527,9 +581,11 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe
     size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
 
     void * dev_ptr;
-    cudaError_t err = cudaMalloc(&dev_ptr, size);
+    cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
     if (err != cudaSuccess) {
-        fprintf(stderr, "%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size/1024.0/1024.0, buft_ctx->device, cudaGetErrorString(err));
+        // clear the error
+        cudaGetLastError();
+        GGML_CUDA_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
         return nullptr;
     }
 
@@ -607,88 +663,22 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
 
 // cuda split buffer
 
-static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
-    int64_t min_compute_capability = INT_MAX;
-    int64_t max_compute_capability = INT_MIN;
+static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
+    int64_t row_rounding = 0;
     for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
-        if (tensor_split[id] < (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
-            if (min_compute_capability > ggml_cuda_info().devices[id].cc) {
-                min_compute_capability = ggml_cuda_info().devices[id].cc;
-            }
-            if (max_compute_capability < ggml_cuda_info().devices[id].cc) {
-                max_compute_capability = ggml_cuda_info().devices[id].cc;
-            }
+        if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
+            continue;
         }
-    }
 
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-    switch(type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-            return max_compute_capability >= CC_RDNA2 ? 128 : 64;
-        case GGML_TYPE_F16:
-        case GGML_TYPE_F32:
-            return 1;
-        case GGML_TYPE_Q2_K:
-            return max_compute_capability >= CC_RDNA2 ? 128 : 32;
-        case GGML_TYPE_Q3_K:
-            return min_compute_capability < CC_RDNA2 ? 128 : 64;
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ2_S:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-            return max_compute_capability >= CC_RDNA2 ? 128 : 64;
-        default:
-            GGML_ASSERT(false);
+        const int cc = ggml_cuda_info().devices[id].cc;
+        row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc)));
     }
-#else
-    switch(type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-            return max_compute_capability >= CC_VOLTA ? 128 : 64;
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-            return 64;
-        case GGML_TYPE_F16:
-        case GGML_TYPE_F32:
-            return 1;
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ2_S:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-            return max_compute_capability >= CC_VOLTA ? 128 : 64;
-        case GGML_TYPE_Q6_K:
-            return 64;
-        default:
-            GGML_ASSERT(false);
-    }
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    return row_rounding;
 }
 
 static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) {
     const int64_t nrows = ggml_nrows(tensor);
-    const int64_t rounding = get_row_rounding(tensor->type, tensor_split);
+    const int64_t rounding = get_row_rounding(tensor_split);
 
     *row_low = id == 0 ? 0 : nrows*tensor_split[id];
     *row_low -= *row_low % rounding;
@@ -786,7 +776,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu
         // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
         ggml_cuda_set_device(id);
         char * buf;
-        CUDA_CHECK(cudaMalloc(&buf, size));
+        CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
 
         // set padding to 0 to avoid possible NaN values
         if (size > original_size) {
@@ -1032,8 +1022,8 @@ static void * ggml_cuda_host_malloc(size_t size) {
     if (err != cudaSuccess) {
         // clear the error
         cudaGetLastError();
-        fprintf(stderr, "%s: warning: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
-            size/1024.0/1024.0, cudaGetErrorString(err));
+        GGML_CUDA_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
+                           size / 1024.0 / 1024.0, cudaGetErrorString(err));
         return nullptr;
     }
 
@@ -1473,7 +1463,7 @@ static void ggml_cuda_op_mul_mat(
         // for multi GPU, get the row boundaries from tensor split
         // and round to mul_mat_q tile sizes
         if (split) {
-            const int64_t rounding = get_row_rounding(src0->type, tensor_split);
+            const int64_t rounding = get_row_rounding(tensor_split);
 
             if (id != 0) {
                 dev[id].row_low  = ne01*tensor_split[id];
@@ -1844,7 +1834,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
         }
     }
 #else
-    if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
+    if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
         // there is no broadcast and src0, src1 are contiguous across dims 2, 3
         // use cublasGemmStridedBatchedEx
         CUBLAS_CHECK(
@@ -2276,7 +2266,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
             break;
         case GGML_OP_MUL_MAT:
             if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
-                fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
+                GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
                 return false;
             } else {
                 ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
@@ -2330,7 +2320,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
 
     cudaError_t err = cudaGetLastError();
     if (err != cudaSuccess) {
-        fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
+        GGML_CUDA_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst));
         CUDA_CHECK(err);
     }
 
@@ -2498,15 +2488,15 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
 
     bool use_cuda_graph = true;
     bool cuda_graph_update_required = false;
-    // pointer to CUDA cpy kernel, which is required to identify
+    // vector of pointers to CUDA cpy kernels, which are required to identify
     // kernel parameters which need updated in the graph for each token
-    void * ggml_cuda_cpy_fn_ptr = nullptr;
+    std::vector<void *> ggml_cuda_cpy_fn_ptrs;
 
     if (cuda_ctx->cuda_graph->graph == nullptr) {
         if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
             cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
 #ifndef NDEBUG
-            fprintf(stderr, "%s: disabling CUDA graphs due to GPU architecture\n", __func__);
+            GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
 #endif
         }
     }
@@ -2553,14 +2543,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
             if (node->src[0] && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) {
                 use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
 #ifndef NDEBUG
-                fprintf(stderr, "%s: disabling CUDA graphs due to split buffer\n", __func__);
+                GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to split buffer\n", __func__);
 #endif
             }
 
             if (node->op == GGML_OP_MUL_MAT_ID) {
                 use_cuda_graph = false; // This node type is not supported by CUDA graph capture
 #ifndef NDEBUG
-                fprintf(stderr, "%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
+                GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
 #endif
             }
 
@@ -2569,16 +2559,17 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
                 // Changes in batch size or context size can cause changes to the grid size of some kernels.
                 use_cuda_graph = false;
 #ifndef NDEBUG
-                fprintf(stderr, "%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
+                GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
 #endif
             }
 
             if (node->op == GGML_OP_CPY) {
                 // store the copy op parameter which changes with each token.
                 cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
-                if (ggml_cuda_cpy_fn_ptr == nullptr) {
-                    // store a pointer to the copy op CUDA kernel to identify it later
-                    ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
+                // store a pointer to each copy op CUDA kernel to identify it later
+                void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
+                if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
+                    ggml_cuda_cpy_fn_ptrs.push_back(ptr);
                 }
             }
 
@@ -2597,7 +2588,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
         if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
             cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
 #ifndef NDEBUG
-            fprintf(stderr, "%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
+            GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
 #endif
         }
     }
@@ -2635,7 +2626,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
 
                 bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
                 if (!ok) {
-                    fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+                    GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
                 }
                 GGML_ASSERT(ok);
             }
@@ -2654,7 +2645,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
                 use_cuda_graph = false;
                 cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
 #ifndef NDEBUG
-                fprintf(stderr, "%s: disabling CUDA graphs due to failed graph capture\n", __func__);
+                GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
 #endif
             } else {
                 graph_evaluated_or_captured = true; // CUDA graph has been captured
@@ -2675,10 +2666,8 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
 
         if (cuda_graph_update_required) {
             // Extract nodes from graph
-            if (cuda_ctx->cuda_graph->num_nodes == 0) {
-                // First call with null argument gets number of nodes in graph
-                CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
-            }
+            // First call with null argument gets number of nodes in graph
+            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
             // Subsequent call with non-null argument gets nodes
             cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
             cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
@@ -2708,7 +2697,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
         if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
             int k = 0;
             for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-                if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) {
+                if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
                     char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
                     cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
                     CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
@@ -2721,7 +2710,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
         cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
         if (stat == cudaErrorGraphExecUpdateFailure) {
 #ifndef NDEBUG
-            fprintf(stderr, "%s: CUDA graph update failed\n", __func__);
+            GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__);
 #endif
             // The pre-existing graph exec cannot be updated due to violated constraints
             // so instead clear error and re-instantiate
@@ -2859,7 +2848,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_CONT:
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
+            return true;
         case GGML_OP_ROPE:
+            return ggml_is_contiguous(op->src[0]);
         case GGML_OP_IM2COL:
         case GGML_OP_POOL_2D:
         case GGML_OP_SUM_ROWS:
@@ -2876,10 +2867,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
             return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
 #else
-            if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {
+            if (op->src[0]->ne[0] == 128) {
+                return true;
+            }
+            if (op->src[0]->ne[0] ==  64 && op->src[1]->type == GGML_TYPE_F16) {
                 return true;
             }
-            return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA;
+            return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
+                op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
         default:
             return false;
@@ -2978,13 +2973,13 @@ static ggml_guid_t ggml_backend_cuda_guid() {
 
 GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) {
     if (device < 0 || device >= ggml_backend_cuda_get_device_count()) {
-        fprintf(stderr, "%s: error: invalid device %d\n", __func__, device);
+        GGML_CUDA_LOG_ERROR("%s: invalid device %d\n", __func__, device);
         return nullptr;
     }
 
     ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context(device);
     if (ctx == nullptr) {
-        fprintf(stderr, "%s: error: failed to allocate context\n", __func__);
+        GGML_CUDA_LOG_ERROR("%s: failed to allocate context\n", __func__);
         return nullptr;
     }
 
@@ -3028,8 +3023,8 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size
         // clear the error
         cudaGetLastError();
 
-        fprintf(stderr, "%s: warning: failed to register %.2f MiB of pinned memory: %s\n", __func__,
-                size/1024.0/1024.0, cudaGetErrorString(err));
+        GGML_CUDA_LOG_WARN("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
+                           size / 1024.0 / 1024.0, cudaGetErrorString(err));
         return false;
     }
     return true;

+ 2 - 1
llama/ggml-cuda.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -66,6 +66,7 @@ GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t *
 GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
 GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);
 
+GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data);
 #ifdef  __cplusplus
 }
 #endif

+ 156 - 6
llama/ggml-cuda/common.cuh

@@ -79,13 +79,8 @@
 #define cudaHostRegisterReadOnly hipHostRegisterReadOnly
 #define cudaHostUnregister hipHostUnregister
 #define cudaLaunchHostFunc hipLaunchHostFunc
-#ifdef GGML_HIP_UMA
-#define cudaMalloc hipMallocManaged
-#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
-#else
 #define cudaMalloc hipMalloc
 #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
-#endif
 #define cudaMemcpy hipMemcpy
 #define cudaMemcpyAsync hipMemcpyAsync
 #define cudaMemcpyPeerAsync hipMemcpyPeerAsync
@@ -165,7 +160,7 @@
 #endif
 
 #define MMVQ_MAX_BATCH_SIZE  8 // max batch size to use MMVQ kernels
-#define  MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available
+#define  MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available
 
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
@@ -489,6 +484,161 @@ static __device__ __forceinline__ float get_alibi_slope(
     return powf(base, exph);
 }
 
+template <ggml_type type>
+struct ggml_cuda_type_traits;
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_F16> {
+    static constexpr int qk = 1;
+    static constexpr int qr = 1;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
+    static constexpr int qk = QK4_0;
+    static constexpr int qr = QR4_0;
+    static constexpr int qi = QI4_0;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
+    static constexpr int qk = QK4_1;
+    static constexpr int qr = QR4_1;
+    static constexpr int qi = QI4_1;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
+    static constexpr int qk = QK5_0;
+    static constexpr int qr = QR5_0;
+    static constexpr int qi = QI5_0;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
+    static constexpr int qk = QK5_1;
+    static constexpr int qr = QR5_1;
+    static constexpr int qi = QI5_1;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
+    static constexpr int qk = QK8_0;
+    static constexpr int qr = QR8_0;
+    static constexpr int qi = QI8_0;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR2_K;
+    static constexpr int qi = QI2_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR3_K;
+    static constexpr int qi = QI3_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR4_K;
+    static constexpr int qi = QI4_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR5_K;
+    static constexpr int qi = QI5_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR6_K;
+    static constexpr int qi = QI6_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR2_XXS;
+    static constexpr int qi = QI2_XXS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR2_XS;
+    static constexpr int qi = QI2_XS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR2_S;
+    static constexpr int qi = QI2_S;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR3_XXS;
+    static constexpr int qi = QI3_XXS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR1_S;
+    static constexpr int qi = QI1_S;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR1_M;
+    static constexpr int qi = QI1_M;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
+    static constexpr int qk = QK4_NL;
+    static constexpr int qr = QR4_NL;
+    static constexpr int qi = QI4_NL;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR4_XS;
+    static constexpr int qi = QI4_XS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR3_S;
+    static constexpr int qi = QI3_S;
+};
+
+static int get_mmq_x_max_host(const int cc) {
+#ifdef CUDA_USE_TENSOR_CORES
+    return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
+#else
+    return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
+#endif // CUDA_USE_TENSOR_CORES
+}
+
+// Round rows to this value for --split-mode row:
+static int get_mmq_y_host(const int cc, const int mmq_x) {
+    return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64;
+}
+
 //////////////////////
 
 struct ggml_cuda_device_info {

+ 157 - 10
llama/ggml-cuda/concat.cu

@@ -1,15 +1,69 @@
 #include "concat.cuh"
 
-static __global__ void concat_f32(const float * x,const float * y, float * dst, const int ne0, const int ne02) {
+// contiguous kernels
+static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
     int nidx = threadIdx.x + blockIdx.x * blockDim.x;
     if (nidx >= ne0) {
         return;
     }
-    // operation
+
     int offset_dst =
         nidx +
         blockIdx.y * ne0 +
         blockIdx.z * ne0 * gridDim.y;
+
+    if (nidx < ne00) { // src0
+        int offset_src =
+            nidx +
+            blockIdx.y * ne00 +
+            blockIdx.z * ne00 * gridDim.y;
+        dst[offset_dst] = x[offset_src];
+    } else {
+        int offset_src =
+            (nidx - ne00) +
+            blockIdx.y * (ne0 - ne00) +
+            blockIdx.z * (ne0 - ne00) * gridDim.y;
+        dst[offset_dst] = y[offset_src];
+    }
+}
+
+static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+
+    if (blockIdx.y < ne01) { // src0
+        int offset_src =
+            nidx +
+            blockIdx.y * ne0 +
+            blockIdx.z * ne0 * ne01;
+        dst[offset_dst] = x[offset_src];
+    } else {
+        int offset_src =
+            nidx +
+            (blockIdx.y - ne01) * ne0 +
+            blockIdx.z * ne0 * (gridDim.y - ne01);
+        dst[offset_dst] = y[offset_src];
+    }
+}
+
+static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+
     if (blockIdx.z < ne02) { // src0
         int offset_src =
             nidx +
@@ -25,25 +79,118 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst,
     }
 }
 
-static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) {
+static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
     int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
     dim3 gridDim(num_blocks, ne1, ne2);
-    concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
+    if (dim == 0) {
+        concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
+        return;
+    }
+    if (dim == 1) {
+        concat_f32_dim1<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne01);
+        return;
+    }
+    concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
 }
 
+// non-contiguous kernel (slow)
+static __global__ void concat_f32_non_cont(
+        const char * src0,
+        const char * src1,
+              char * dst,
+           int64_t   ne00,
+           int64_t   ne01,
+           int64_t   ne02,
+           int64_t   ne03,
+          uint64_t   nb00,
+          uint64_t   nb01,
+          uint64_t   nb02,
+          uint64_t   nb03,
+           int64_t /*ne10*/,
+           int64_t /*ne11*/,
+           int64_t /*ne12*/,
+           int64_t /*ne13*/,
+          uint64_t   nb10,
+          uint64_t   nb11,
+          uint64_t   nb12,
+          uint64_t   nb13,
+           int64_t   ne0,
+           int64_t /*ne1*/,
+           int64_t /*ne2*/,
+           int64_t /*ne3*/,
+          uint64_t   nb0,
+          uint64_t   nb1,
+          uint64_t   nb2,
+          uint64_t   nb3,
+          int32_t   dim) {
+    const int64_t i3 = blockIdx.z;
+    const int64_t i2 = blockIdx.y;
+    const int64_t i1 = blockIdx.x;
+
+    int64_t o[4] = {0, 0, 0, 0};
+    o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
+
+    const float * x;
+
+    for (int i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
+        if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+            x = (const float *)(src0 + (i3       )*nb03 + (i2       )*nb02 + (i1       )*nb01 + (i0       )*nb00);
+        } else {
+            x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
+        }
+
+        float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+        *y = *x;
+    }
+}
+
+
 void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
-    const float * src0_d = (const float *)src0->data;
-    const float * src1_d = (const float *)src1->data;
-    float * dst_d = (float *)dst->data;
+
     cudaStream_t stream = ctx.stream();
 
+    const int32_t dim = ((int32_t *) dst->op_params)[0];
+
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
+        const float * src0_d = (const float *)src0->data;
+        const float * src1_d = (const float *)src1->data;
 
-    for (int i3 = 0; i3 < dst->ne[3]; i3++) {
-        concat_f32_cuda(src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), dst_d + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], stream);
+        float * dst_d = (float *)dst->data;
+
+        if (dim != 3) {
+            for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+                concat_f32_cuda(
+                        src0_d + i3 * (src0->nb[3] / 4),
+                        src1_d + i3 * (src1->nb[3] / 4),
+                        dst_d + i3 * ( dst->nb[3] / 4),
+                        src0->ne[0], src0->ne[1], src0->ne[2],
+                        dst->ne[0],  dst->ne[1],  dst->ne[2], dim, stream);
+            }
+        } else {
+            const size_t size0 = ggml_nbytes(src0);
+            const size_t size1 = ggml_nbytes(src1);
+
+            CUDA_CHECK(cudaMemcpyAsync(dst_d,           src0_d, size0, cudaMemcpyDeviceToDevice, stream));
+            CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
+        }
+    } else {
+        dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
+        concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
+                (const char *)src0->data,
+                (const char *)src1->data,
+                (      char *)dst->data,
+                src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+                src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+                src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+                src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
+                dst->ne[0],  dst->ne[1],  dst->ne[2],  dst->ne[3],
+                dst->nb[0],  dst->nb[1],  dst->nb[2],  dst->nb[3], dim);
     }
 }

+ 0 - 138
llama/ggml-cuda/convert.cu

@@ -131,7 +131,6 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
     const block_q2_K * x = (const block_q2_K *) vx;
 
     const int64_t tid = threadIdx.x;
-#if QK_K == 256
     const int64_t n   = tid/32;
     const int64_t l   = tid - 32*n;
     const int64_t is  = 8*n + l/16;
@@ -145,17 +144,6 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
     y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
     y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
     y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
-#else
-    const int64_t is = tid/16;  // 0 or 1
-    const int64_t il = tid%16;  // 0...15
-    const uint8_t q = x[i].qs[il] >> (2*is);
-    dst_t * y = yy + i*QK_K + 16*is + il;
-    float dall = __low2half(x[i].dm);
-    float dmin = __high2half(x[i].dm);
-    y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
-    y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
-#endif
-
 }
 
 template<typename dst_t>
@@ -164,7 +152,6 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t
     const int64_t i = blockIdx.x;
     const block_q3_K * x = (const block_q3_K *) vx;
 
-#if QK_K == 256
     const int64_t r = threadIdx.x/4;
     const int64_t tid = r/2;
     const int64_t is0 = r%2;
@@ -188,31 +175,8 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t
     const uint8_t * hm = x[i].hmask;
 
     for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
-#else
-    const int64_t tid = threadIdx.x;
-    const int64_t is  = tid/16;  // 0 or 1
-    const int64_t il  = tid%16;  // 0...15
-    const int64_t im  = il/8;    // 0...1
-    const int64_t in  = il%8;    // 0...7
-
-    dst_t * y = yy + i*QK_K + 16*is + il;
-
-    const uint8_t q = x[i].qs[il] >> (2*is);
-    const uint8_t h = x[i].hmask[in] >> (2*is + im);
-    const float   d = (float)x[i].d;
-
-    if (is == 0) {
-        y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
-        y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
-    } else {
-        y[ 0] = d * ((x[i].scales[0] >>  4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
-        y[32] = d * ((x[i].scales[1] >>  4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
-    }
-#endif
-
 }
 
-#if QK_K == 256
 static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
     if (j < 4) {
         d = q[j] & 63; m = q[j + 4] & 63;
@@ -221,7 +185,6 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
         m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
     }
 }
-#endif
 
 template<typename dst_t>
 static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
@@ -229,7 +192,6 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
 
     const int64_t i = blockIdx.x;
 
-#if QK_K == 256
     // assume 32 threads
     const int64_t tid = threadIdx.x;
     const int64_t il  = tid/8;
@@ -253,15 +215,6 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
         y[l + 0] = d1 * (q[l] & 0xF) - m1;
         y[l +32] = d2 * (q[l] >>  4) - m2;
     }
-#else
-    const int64_t tid = threadIdx.x;
-    const uint8_t * q = x[i].qs;
-    dst_t * y = yy + i*QK_K;
-    const float d = (float)x[i].dm[0];
-    const float m = (float)x[i].dm[1];
-    y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
-    y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >>  4) - m * (x[i].scales[1] >> 4);
-#endif
 }
 
 template<typename dst_t>
@@ -270,7 +223,6 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
 
     const int64_t i = blockIdx.x;
 
-#if QK_K == 256
     // assume 64 threads - this is very slightly better than the one below
     const int64_t tid = threadIdx.x;
     const int64_t il  = tid/16;   // il is in 0...3
@@ -297,18 +249,6 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
     hm <<= 1;
     y[32] = d2 * ((ql[ 0] >>  4) + (qh[ 0] & hm ? 16 : 0)) - m2;
     y[33] = d2 * ((ql[ 1] >>  4) + (qh[ 1] & hm ? 16 : 0)) - m2;
-#else
-    const int64_t tid = threadIdx.x;
-    const uint8_t q = x[i].qs[tid];
-    const int64_t im = tid/8;  // 0...3
-    const int64_t in = tid%8;  // 0...7
-    const int64_t is = tid/16; // 0 or 1
-    const uint8_t h = x[i].qh[in] >> im;
-    const float d = x[i].d;
-    dst_t * y = yy + i*QK_K + tid;
-    y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
-    y[32] = d * x[i].scales[is+2] * ((q >>  4) - ((h >> 4) & 1 ? 0 : 16));
-#endif
 }
 
 template<typename dst_t>
@@ -316,7 +256,6 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
     const block_q6_K * x = (const block_q6_K *) vx;
 
     const int64_t i = blockIdx.x;
-#if QK_K == 256
 
     // assume 64 threads - this is very slightly better than the one below
     const int64_t tid = threadIdx.x;
@@ -336,24 +275,6 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
     y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
     y[64] = d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh >> 4) & 3) << 4)) - 32);
     y[96] = d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh >> 6) & 3) << 4)) - 32);
-#else
-
-    // assume 32 threads
-    const int64_t tid = threadIdx.x;
-    const int64_t ip  = tid/16;         // 0 or 1
-    const int64_t il  = tid - 16*ip;    // 0...15
-
-    dst_t * y = yy + i*QK_K + 16*ip + il;
-
-    const float d = x[i].d;
-
-    const uint8_t   ql = x[i].ql[16*ip + il];
-    const uint8_t   qh = x[i].qh[il] >> (2*ip);
-    const int8_t  * sc = x[i].scales;
-
-    y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
-    y[32] = d * sc[ip+2] * ((int8_t)((ql  >> 4) | (((qh >> 4) & 3) << 4)) - 32);
-#endif
 }
 
 template<typename dst_t>
@@ -363,7 +284,6 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
     const block_iq2_xxs * x = (const block_iq2_xxs  *) vx;
 
     const int64_t tid = threadIdx.x;
-#if QK_K == 256
     const int64_t il = tid/8; // 0...3
     const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -374,10 +294,6 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
     const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
     const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
     for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
-#else
-    NO_DEVICE_CODE;
-#endif
-
 }
 
 template<typename dst_t>
@@ -387,7 +303,6 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
     const block_iq2_xs * x = (const block_iq2_xs *) vx;
 
     const int64_t tid = threadIdx.x;
-#if QK_K == 256
     const int64_t il = tid/8; // 0...3
     const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -396,10 +311,6 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
     const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
     const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
     for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
-#else
-    NO_DEVICE_CODE;
-#endif
-
 }
 
 template<typename dst_t>
@@ -409,7 +320,6 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
     const block_iq2_s * x = (const block_iq2_s *) vx;
 
     const int64_t tid = threadIdx.x;
-#if QK_K == 256
     const int64_t il = tid/8; // 0...3
     const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -417,10 +327,6 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
     const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
     const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
     for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
-#else
-    NO_DEVICE_CODE;
-#endif
-
 }
 
 template<typename dst_t>
@@ -430,7 +336,6 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
     const block_iq3_xxs * x = (const block_iq3_xxs  *) vx;
 
     const int64_t tid = threadIdx.x;
-#if QK_K == 256
     const int64_t il = tid/8; // 0...3
     const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -445,10 +350,6 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
         y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
         y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
     }
-#else
-    NO_DEVICE_CODE;
-#endif
-
 }
 
 template<typename dst_t>
@@ -458,7 +359,6 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
     const block_iq3_s * x = (const block_iq3_s *) vx;
 
     const int64_t tid = threadIdx.x;
-#if QK_K == 256
     const int64_t il = tid/8; // 0...3
     const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -471,10 +371,6 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
         y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
         y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
     }
-#else
-    NO_DEVICE_CODE;
-#endif
-
 }
 
 template<typename dst_t>
@@ -484,7 +380,6 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
     const block_iq1_s * x = (const block_iq1_s  *) vx;
 
     const int64_t tid = threadIdx.x;
-#if QK_K == 256
     const int64_t il = tid/8; // 0...3
     const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -497,10 +392,6 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
     for (int j = 0; j < 8; ++j) {
         y[j] = d * (q[j] + delta);
     }
-#else
-    NO_DEVICE_CODE;
-#endif
-
 }
 
 template<typename dst_t>
@@ -510,7 +401,6 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
     const block_iq1_m * x = (const block_iq1_m  *) vx;
 
     const int64_t tid = threadIdx.x;
-#if QK_K == 256
     const int64_t il = tid/8; // 0...3
     const int64_t ib = tid%8; // 0...7
     dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -527,13 +417,8 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
     for (int j = 0; j < 8; ++j) {
         y[j] = d * (q[j] + delta);
     }
-#else
-    NO_DEVICE_CODE;
-#endif
-
 }
 
-
 template<typename dst_t>
 static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
 
@@ -550,10 +435,8 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
         y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
         y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];
     }
-
 }
 
-#if QK_K != 64
 template<typename dst_t>
 static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
     const int64_t i   = blockIdx.x;
@@ -570,7 +453,6 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
         y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];
     }
 }
-#endif
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
 static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
@@ -592,21 +474,13 @@ static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half *
 template<typename dst_t>
 static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = k / QK_K;
-#if QK_K == 256
     dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
-#else
-    dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
-#endif
 }
 
 template<typename dst_t>
 static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = k / QK_K;
-#if QK_K == 256
     dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
-#else
-    dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
-#endif
 }
 
 template<typename dst_t>
@@ -632,21 +506,13 @@ static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k
 template<typename dst_t>
 static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = k / QK_K;
-#if QK_K == 256
     dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
-#else
-    dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
-#endif
 }
 
 template<typename dst_t>
 static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = k / QK_K;
-#if QK_K == 256
     dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
-#else
-    dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
-#endif
 }
 
 template<typename dst_t>
@@ -700,11 +566,7 @@ static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t
 template<typename dst_t>
 static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
     const int nb = (k + QK_K - 1) / QK_K;
-#if QK_K == 64
-    dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
-#else
     dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
-#endif
 }
 
 template <typename src_t, typename dst_t>

+ 21 - 160
llama/ggml-cuda/dmmv.cu

@@ -22,7 +22,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
 
     float tmp = 0; // partial sum for thread in warp
 
-#if QK_K == 256
     const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...15
     const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
@@ -71,37 +70,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
         tmp += dall * sum1 - dmin * sum2;
 
     }
-#else
-    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...15 or 0...7
-    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);  // 0....1 or 0...3
-    const int offset = tid * K_QUANTS_PER_ITERATION;
-
-    uint32_t uaux[2];
-    const uint8_t * d = (const uint8_t *)uaux;
-
-    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
-
-        const float   * y = yy + i * QK_K + offset;
-        const uint8_t * q = x[i].qs + offset;
-        const uint32_t * s = (const uint32_t *)x[i].scales;
-
-        uaux[0] = s[0] & 0x0f0f0f0f;
-        uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
-
-        const float2 dall = __half22float2(x[i].dm);
-
-        float sum1 = 0, sum2 = 0;
-        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
-            const uint8_t ql = q[l];
-            sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
-                  + y[l+16] * d[1] * ((ql >> 2) & 3)
-                  + y[l+32] * d[2] * ((ql >> 4) & 3)
-                  + y[l+48] * d[3] * ((ql >> 6) & 3);
-            sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
-        }
-        tmp += dall.x * sum1 - dall.y * sum2;
-    }
-#endif
 
     // sum up partial sums and write back result
     tmp = warp_reduce_sum(tmp);
@@ -123,8 +91,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
 
     float tmp = 0; // partial sum for thread in warp
 
-#if QK_K == 256
-
     const uint16_t kmask1 = 0x0303;
     const uint16_t kmask2 = 0x0f0f;
 
@@ -175,34 +141,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
         tmp += d * sum;
 
     }
-#else
-
-    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...15 or 0...7
-    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);  // 0....1 or 0...3
-    const int offset = tid * K_QUANTS_PER_ITERATION;         // 0...15 or 0...14
-    const int in = offset/8;                                 // 0 or 1
-    const int im = offset%8;                                 // 0...7
-
-    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
-
-        const float   * y = yy + i * QK_K + offset;
-        const uint8_t * q = x[i].qs + offset;
-        const uint8_t * s = x[i].scales;
-
-        const float dall = (float)x[i].d;
-
-        float sum = 0;
-        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
-            const uint8_t hl = x[i].hmask[im+l] >> in;
-            const uint8_t ql = q[l];
-            sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
-                 + y[l+16] * dall * ((s[0] >>  4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
-                 + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
-                 + y[l+48] * dall * ((s[1] >>  4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
-        }
-        tmp += sum;
-    }
-#endif
 
     // sum up partial sums and write back result
     tmp = warp_reduce_sum(tmp);
@@ -221,7 +159,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
 
     const block_q4_K * x = (const block_q4_K *)vx + ib0;
 
-#if QK_K == 256
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask3 = 0xc0c0;
@@ -306,36 +243,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
 #endif
 
     }
-#else
-    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...15
-    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
-
-    const int step = tid * K_QUANTS_PER_ITERATION;
-
-    uint16_t aux16[2];
-    const uint8_t * s = (const uint8_t *)aux16;
-
-    float tmp = 0;
-
-    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
-        const uint8_t * q = x[i].qs + step;
-        const float   * y = yy + i*QK_K + step;
-        const uint16_t * a = (const uint16_t *)x[i].scales;
-        aux16[0] = a[0] & 0x0f0f;
-        aux16[1] = (a[0] >> 4) & 0x0f0f;
-        const float d = (float)x[i].dm[0];
-        const float m = (float)x[i].dm[1];
-        float sum = 0.f;
-        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
-            sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
-                 + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
-                 + y[j+32] * (d * s[1] * (q[j+ 0] >>  4) - m * s[3])
-                 + y[j+48] * (d * s[1] * (q[j+16] >>  4) - m * s[3]);
-        }
-        tmp += sum;
-    }
-
-#endif
 
     // sum up partial sums and write back result
     tmp = warp_reduce_sum(tmp);
@@ -355,7 +262,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
 
     float tmp = 0; // partial sum for thread in warp
 
-#if QK_K == 256
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask3 = 0xc0c0;
@@ -426,30 +332,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
         tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
     }
 
-#else
-    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...15
-    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
-    const int step = tid * K_QUANTS_PER_ITERATION;
-    const int im = step/8;
-    const int in = step%8;
-
-    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
-        const uint8_t * q = x[i].qs + step;
-        const int8_t  * s = x[i].scales;
-        const float   * y = yy + i*QK_K + step;
-        const float     d = x[i].d;
-        float sum = 0.f;
-        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
-            const uint8_t h = x[i].qh[in+j] >> im;
-            sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
-                 + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
-                 + y[j+32] * d * s[2] * ((q[j+ 0] >>  4) - ((h >> 4) & 1 ? 0 : 16))
-                 + y[j+48] * d * s[3] * ((q[j+16] >>  4) - ((h >> 6) & 1 ? 0 : 16));
-        }
-        tmp += sum;
-    }
-#endif
-
     // sum up partial sums and write back result
     tmp = warp_reduce_sum(tmp);
 
@@ -470,8 +352,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
 
     const block_q6_K * x = (const block_q6_K *)vx + ib0;
 
-#if QK_K == 256
-
     const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
     const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
 
@@ -526,37 +406,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
 
     }
 
-#else
-
-    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...7
-    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);  // 0...3
-
-    const int step = tid * K_QUANTS_PER_ITERATION;
-
-    float tmp = 0; // partial sum for thread in warp
-
-    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
-
-        const float   * y  = yy + i * QK_K + step;
-        const uint8_t * ql = x[i].ql + step;
-        const uint8_t * qh = x[i].qh + step;
-        const int8_t  * s  = x[i].scales;
-
-        const float d = x[i+0].d;
-
-        float sum = 0;
-        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
-            sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
-                 + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
-                 + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >>  4) | ((qh[j] & 0x30) >> 0)) - 32)
-                 + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >>  4) | ((qh[j] & 0xc0) >> 2)) - 32);
-        }
-        tmp += sum;
-
-    }
-
-#endif
-
     // sum up partial sums and write back result
     tmp = warp_reduce_sum(tmp);
 
@@ -573,10 +422,22 @@ static __device__ void convert_f16(const void * vx, const int64_t ib, const int
     v.y = x[ib + iqs + 1];
 }
 
-template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
+static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
+    return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 :
+        type == GGML_TYPE_Q4_1 ? dequantize_q4_1 :
+        type == GGML_TYPE_Q5_0 ? dequantize_q5_0 :
+        type == GGML_TYPE_Q5_1 ? dequantize_q5_1 :
+        type == GGML_TYPE_Q8_0 ? dequantize_q8_0 :
+        type == GGML_TYPE_F16 ? convert_f16 :
+        nullptr;
+}
+
+template <ggml_type type>
 static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
-    // qk = quantized weights per x block
-    // qr = number of quantized weights per data value in x block
+    constexpr int qk = ggml_cuda_type_traits<type>::qk; // quantized weights per x block
+    constexpr int qr = ggml_cuda_type_traits<type>::qr; // number of quantized weights per data value in x block
+    constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type);
+
     const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
 
     if (row >= nrows) {
@@ -644,7 +505,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
     // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
+    dequantize_mul_mat_vec<GGML_TYPE_Q4_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
@@ -653,7 +514,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
+    dequantize_mul_mat_vec<GGML_TYPE_Q4_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
@@ -662,7 +523,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
+    dequantize_mul_mat_vec<GGML_TYPE_Q5_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
@@ -671,7 +532,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
+    dequantize_mul_mat_vec<GGML_TYPE_Q5_1>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
@@ -680,7 +541,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y,
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
+    dequantize_mul_mat_vec<GGML_TYPE_Q8_0>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
@@ -731,7 +592,7 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<1, 1, convert_f16>
+    dequantize_mul_mat_vec<GGML_TYPE_F16>
         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 

+ 585 - 6
llama/ggml-cuda/fattn-common.cuh

@@ -1,4 +1,8 @@
+#pragma once
+
 #include "common.cuh"
+#include "convert.cuh"
+#include "vecdotq.cuh"
 
 #include <cstdint>
 
@@ -34,11 +38,523 @@ typedef void (* fattn_kernel_t)(
         const int nb11,
         const int nb12,
         const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
         const int ne0,
         const int ne1,
         const int ne2,
         const int ne3);
 
+typedef half (*vec_dot_KQ_f16_t)(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
+typedef float (*vec_dot_KQ_f32_t)(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
+
+template<typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+
+    const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    half sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI4_0;
+        const int shift = k_KQ & (QI8_1/2);
+
+        const int v = (get_int_from_uint8(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+        const int sumi = __dp4a(v, u, 0);
+
+#if FP16_AVAILABLE
+        if (std::is_same<T, half>::value) {
+            const half2  * Q_ds = (const half2  *) Q_ds_v;
+
+            const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
+            sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
+        } else
+#endif // FP16_AVAILABLE
+        {
+            const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+            sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
+        }
+    }
+
+    return sum;
+#else
+    GGML_UNUSED(K_c);
+    GGML_UNUSED(Q_v);
+    GGML_UNUSED(Q_q8);
+    GGML_UNUSED(Q_ds_v);
+    NO_DEVICE_CODE;
+#endif  // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+template<typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+
+    const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    T sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI4_1;
+        const int shift = k_KQ & (QI8_1/2);
+
+        const int v = (get_int_from_uint8_aligned(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+        const int sumi = __dp4a(v, u, 0);
+
+#if FP16_AVAILABLE
+        if (std::is_same<T, half>::value) {
+            const half2  * Q_ds = (const half2  *) Q_ds_v;
+
+            const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
+            const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
+            sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
+        } else
+#endif // FP16_AVAILABLE
+        {
+            const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+            const float sumid4d8   =  __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
+            const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
+
+            sum += (T) (sumid4d8 + m4s8scaled);
+        }
+    }
+
+    return sum;
+#else
+    GGML_UNUSED(K_c);
+    GGML_UNUSED(Q_v);
+    GGML_UNUSED(Q_q8);
+    GGML_UNUSED(Q_ds_v);
+    NO_DEVICE_CODE;
+#endif  // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+template<typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+
+    const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    T sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI5_0;
+        const int iqs8  = k_KQ %  QI8_1;
+        const int shift = k_KQ & (QI8_1/2);
+
+        int v = (get_int_from_uint8(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int vh = get_int_from_uint8(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
+        v |= (vh <<  4) & 0x00000010; // 0 ->  4
+        v |= (vh << 11) & 0x00001000; // 1 -> 12
+        v |= (vh << 18) & 0x00100000; // 2 -> 20
+        v |= (vh << 25) & 0x10000000; // 3 -> 28
+
+        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+        const int sumi = __dp4a(v, u, 0);
+
+#if FP16_AVAILABLE
+        if (std::is_same<T, half>::value) {
+            const half2  * Q_ds = (const half2  *) Q_ds_v;
+
+            const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
+            sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
+        } else
+#endif // FP16_AVAILABLE
+        {
+            const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+            sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
+        }
+    }
+
+    return sum;
+#else
+    GGML_UNUSED(K_c);
+    GGML_UNUSED(Q_v);
+    GGML_UNUSED(Q_q8);
+    GGML_UNUSED(Q_ds_v);
+    NO_DEVICE_CODE;
+#endif  // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+template<typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+
+    const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    T sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI5_1;
+        const int iqs8  = k_KQ %  QI8_1;
+        const int shift = k_KQ & (QI8_1/2);
+
+        int v = (get_int_from_uint8(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int vh = get_int_from_uint8(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
+        v |= (vh <<  4) & 0x00000010; // 0 ->  4
+        v |= (vh << 11) & 0x00001000; // 1 -> 12
+        v |= (vh << 18) & 0x00100000; // 2 -> 20
+        v |= (vh << 25) & 0x10000000; // 3 -> 28
+
+        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+        const int sumi = __dp4a(v, u, 0);
+
+#if FP16_AVAILABLE
+        if (std::is_same<T, half>::value) {
+            const half2  * Q_ds = (const half2  *) Q_ds_v;
+
+            const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
+            const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
+            sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
+        } else
+#endif // FP16_AVAILABLE
+        {
+            const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+            const float sumid5d8   =  __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
+            const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
+
+            sum += (T) (sumid5d8 + m5s8scaled);
+        }
+    }
+
+    return sum;
+#else
+    GGML_UNUSED(K_c);
+    GGML_UNUSED(Q_v);
+    GGML_UNUSED(Q_q8);
+    GGML_UNUSED(Q_ds_v);
+    NO_DEVICE_CODE;
+#endif  // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+template <typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+
+    const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    T sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const int ib  = k_KQ / QI8_0;
+        const int iqs = k_KQ % QI8_0;
+
+        const int v = get_int_from_int8(K_q8_0[ib].qs, iqs);
+
+        T Q_d;
+        if (std::is_same<T, half>::value) {
+            const half2  * Q_ds = (const half2  *) Q_ds_v;
+            Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]);
+        } else {
+            const float2 * Q_ds = (const float2 *) Q_ds_v;
+            Q_d = Q_ds[k_KQ_0/WARP_SIZE].x;
+        }
+
+        sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d);
+    }
+
+    return sum;
+#else
+    GGML_UNUSED(K_c);
+    GGML_UNUSED(Q_v);
+    GGML_UNUSED(Q_q8);
+    GGML_UNUSED(Q_ds_v);
+    NO_DEVICE_CODE;
+#endif  // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+template <typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
+
+    const half2 * K_h2 = (const half2 *) K_c;
+    GGML_UNUSED(Q_q8);
+    GGML_UNUSED(Q_ds_v);
+
+#if FP16_AVAILABLE
+    if (std::is_same<T, half>::value) {
+        const half2 * Q_h2 = (const half2 *) Q_v;
+
+        half2 sum2 = make_half2(0.0f, 0.0f);
+
+#pragma unroll
+        for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+            const int k_KQ = k_KQ_0 + threadIdx.x;
+
+            const half2 K_ik = K_h2[k_KQ];
+            sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
+        }
+
+        return __low2half(sum2) + __high2half(sum2);
+    }
+#endif // FP16_AVAILABLE
+
+    const float2 * Q_f2 = (const float2 *) Q_v;
+
+    float sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const half2 K_ik = K_h2[k_KQ];
+        sum +=  __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x;
+        sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y;
+    }
+
+    return sum;
+}
+
+template <typename Tds>
+static __device__ __forceinline__ void quantize_q8_1_to_shared(
+    const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
+
+    float vals[sizeof(int)] = {0.0f};
+#pragma unroll
+    for (int l = 0; l < sizeof(int); ++l) {
+        vals[l] = scale * x[4*threadIdx.x + l];
+    }
+
+    float amax = fabsf(vals[0]);
+    float sum  = vals[0];
+#pragma unroll
+    for (int l = 1; l < sizeof(int); ++l) {
+        amax = fmaxf(amax, fabsf(vals[l]));
+        sum += vals[l];
+    }
+#pragma unroll
+    for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
+        amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32));
+        sum +=             __shfl_xor_sync(0xFFFFFFFF, sum,  mask, 32);
+    }
+
+    const float d = amax / 127;
+    int q32 = 0;
+    int8_t * q8 = (int8_t *) &q32;
+
+    if (d != 0.0f) {
+#pragma unroll
+        for (int l = 0; l < sizeof(int); ++l) {
+            q8[l] = roundf(vals[l] / d);
+        }
+    }
+
+    yq32[threadIdx.x] = q32;
+    if (threadIdx.x % QI8_1 == 0) {
+        if (std::is_same<Tds, half2>::value) {
+            ((half2  *) yds)[threadIdx.x/QI8_1] =  make_half2(d, sum);
+        } else {
+            ((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum);
+        }
+    }
+}
+
+typedef half  (*dequantize_1_f16_t)(const void *, const int64_t);
+typedef float (*dequantize_1_f32_t)(const void *, const int64_t);
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) {
+    const block_q4_0 * x = (const block_q4_0 *) vx;
+
+    const int64_t ib    =  i          /  QK4_0;
+    const int     iqs   =  i          % (QK4_0/2);
+    const int     shift = (i % QK4_0) / (QK4_0/2);
+
+    const T   d  = x[ib].d;
+    const int q0 = x[ib].qs[iqs];
+    const int q  = ((q0 >> (4*shift)) & 0x0F) - 8;
+
+#if FP16_AVAILABLE
+    if (std::is_same<T, half>::value) {
+        return ((half) d)*((half) q);
+    }
+#endif // FP16_AVAILABLE
+
+    return ((float) d)*((float) q);
+}
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) {
+    const block_q4_1 * x = (const block_q4_1 *) vx;
+
+    const int64_t ib    =  i          /  QK4_1;
+    const int     iqs   =  i          % (QK4_1/2);
+    const int     shift = (i % QK4_1) / (QK4_1/2);
+
+    const half2 dm = x[ib].dm;
+    const int   q0 = x[ib].qs[iqs];
+    const int   q  = ((q0 >> (4*shift)) & 0x0F);
+
+#if FP16_AVAILABLE
+    if (std::is_same<T, half>::value) {
+        return __low2half(dm)*((half) q) + __high2half(dm);
+    }
+#endif // FP16_AVAILABLE
+
+    return __low2float(dm)*((float) q) + __high2float(dm);
+}
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) {
+    const block_q5_0 * x = (const block_q5_0 *) vx;
+
+    const int64_t ib    =  i          /  QK5_0;
+    const int     idq   =  i          %  QK5_0;
+    const int     iqs   =  i          % (QK5_0/2);
+    const int     shift = (i % QK5_0) / (QK5_0/2);
+
+    const T   d   = x[ib].d;
+    const int ql0 = x[ib].qs[iqs];
+    const int qh0 = get_int_from_uint8(x[ib].qh, 0);
+    const int ql  = ((ql0 >> (4*shift)) & 0x0F);
+    const int qh  = ((qh0 >> idq) << 4) & 0x10;
+    const int q   = (ql | qh) - 16;
+
+#if FP16_AVAILABLE
+    if (std::is_same<T, half>::value) {
+        return ((half) d)*((half) q);
+    }
+#endif // FP16_AVAILABLE
+
+    return ((float) d)*((float) q);
+}
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) {
+    const block_q5_1 * x = (const block_q5_1 *) vx;
+
+    const int64_t ib    =  i          /  QK5_1;
+    const int     idq   =  i          %  QK5_1;
+    const int     iqs   =  i          % (QK5_1/2);
+    const int     shift = (i % QK5_1) / (QK5_1/2);
+
+    const half2 dm  = x[ib].dm;
+    const int   ql0 = x[ib].qs[iqs];
+    const int   qh0 = get_int_from_uint8_aligned(x[ib].qh, 0);
+    const int   ql  = ((ql0 >> (4*shift)) & 0x0F);
+    const int   qh  = ((qh0 >> idq) << 4) & 0x10;
+    const int   q   = (ql | qh);
+
+#if FP16_AVAILABLE
+    if (std::is_same<T, half>::value) {
+        return __low2half(dm)*((half) q) + __high2half(dm);
+    }
+#endif // FP16_AVAILABLE
+
+    return __low2float(dm)*((float) q) + __high2float(dm);
+}
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) {
+    const block_q8_0 * x = (const block_q8_0 *) vx;
+
+    const int64_t ib  = i / QK8_0;
+    const int     iqs = i % QK8_0;
+
+    const T   d = x[ib].d;
+    const int q = x[ib].qs[iqs];
+
+#if FP16_AVAILABLE
+    if (std::is_same<T, half>::value) {
+        return ((half) d)*((half) q);
+    }
+#endif // FP16_AVAILABLE
+
+    return ((float) d)*((float) q);
+}
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) {
+    const half * x = (const half *) vx;
+
+    return x[i];
+}
+
+template <int D>
+constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
+    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
+        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
+        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
+        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
+        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
+        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
+        nullptr;
+}
+
+template <int D>
+constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
+    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
+        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
+        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
+        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
+        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
+        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
+        nullptr;
+}
+
+constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) {
+    return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
+        type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
+        type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
+        type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
+        type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
+        type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
+        nullptr;
+}
+
+constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
+    return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float> :
+        type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float> :
+        type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float> :
+        type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float> :
+        type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float> :
+        type_V == GGML_TYPE_F16 ? dequantize_1_f16<float> :
+        nullptr;
+}
+
 template<int D, int parallel_blocks> // D == head size
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
@@ -83,8 +599,32 @@ static __global__ void flash_attn_combine_results(
     dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
 }
 
+static void on_no_fattn_vec_case(const int D) {
+    if (D == 64) {
+        fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
+        fprintf(stderr, "By default only f16 KV cache is supported.\n");
+        fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
+        GGML_ASSERT(false);
+    } else if (D == 128) {
+        fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
+        fprintf(stderr, "Supported combinations:\n");
+        fprintf(stderr, "  - K == q4_0, V == q4_0,  4.50 BPV\n");
+        fprintf(stderr, "  - K == q8_0, V == q8_0,  8.50 BPV\n");
+        fprintf(stderr, "  - K == f16,  V == f16,  16.00 BPV\n");
+        fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
+        GGML_ASSERT(false);
+    } else {
+        fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
+        fprintf(stderr, "Only f16 is supported.\n");
+        GGML_ASSERT(false);
+    }
+}
+
 template <int D, int parallel_blocks>
-void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
+void launch_fattn(
+    ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
+    const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
+) {
     const ggml_tensor * Q = dst->src[0];
     const ggml_tensor * K = dst->src[1];
     const ggml_tensor * V = dst->src[2];
@@ -94,8 +634,6 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
     ggml_tensor * KQV = dst;
 
     GGML_ASSERT(Q->type == GGML_TYPE_F32);
-    GGML_ASSERT(K->type == GGML_TYPE_F16);
-    GGML_ASSERT(V->type == GGML_TYPE_F16);
     GGML_ASSERT(KQV->type == GGML_TYPE_F32);
 
     GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
@@ -107,9 +645,49 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t main_stream = ctx.stream();
 
+    ggml_cuda_pool_alloc<half>   K_f16(pool);
+    ggml_cuda_pool_alloc<half>   V_f16(pool);
     ggml_cuda_pool_alloc<float>  dst_tmp(pool);
     ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
 
+    char * K_data = (char *) K->data;
+    size_t nb11 = K->nb[1];
+    size_t nb12 = K->nb[2];
+    size_t nb13 = K->nb[3];
+
+    char * V_data = (char *) V->data;
+    size_t nb21 = V->nb[1];
+    size_t nb22 = V->nb[2];
+    size_t nb23 = V->nb[3];
+
+    if (need_f16_K && K->type != GGML_TYPE_F16) {
+        K_f16.alloc(ggml_nelements(K));
+        to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
+        to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
+        K_data = (char *) K_f16.ptr;
+
+        const size_t bs = ggml_blck_size(K->type);
+        const size_t ts = ggml_type_size(K->type);
+
+        nb11 = nb11*bs*sizeof(half)/ts;
+        nb12 = nb12*bs*sizeof(half)/ts;
+        nb13 = nb13*bs*sizeof(half)/ts;
+    }
+
+    if (need_f16_V && V->type != GGML_TYPE_F16) {
+        V_f16.alloc(ggml_nelements(V));
+        to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
+        to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
+        V_data = (char *) V_f16.ptr;
+
+        const size_t bs = ggml_blck_size(V->type);
+        const size_t ts = ggml_type_size(V->type);
+
+        nb21 = nb21*bs*sizeof(half)/ts;
+        nb22 = nb22*bs*sizeof(half)/ts;
+        nb23 = nb23*bs*sizeof(half)/ts;
+    }
+
     if (parallel_blocks > 1) {
         dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
         dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
@@ -133,8 +711,8 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
 
     fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
         (const char *) Q->data,
-        (const char *) K->data,
-        (const char *) V->data,
+        K_data,
+        V_data,
         mask ? ((const char *) mask->data) : nullptr,
         (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
         scale, max_bias, m0, m1, n_head_log2,
@@ -142,7 +720,8 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
         K->ne[0], K->ne[1], K->ne[2], K->ne[3],
         mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
         Q->nb[1], Q->nb[2], Q->nb[3],
-        K->nb[1], K->nb[2], K->nb[3],
+        nb11, nb12, nb13,
+        nb21, nb22, nb23,
         KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
     );
     CUDA_CHECK(cudaGetLastError());

+ 10 - 3
llama/ggml-cuda/fattn-tile-f16.cu

@@ -36,6 +36,9 @@ static __global__ void flash_attn_tile_ext_f16(
         const int nb11,
         const int nb12,
         const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
         const int ne0,
         const int ne1,
         const int ne2,
@@ -83,7 +86,7 @@ static __global__ void flash_attn_tile_ext_f16(
         for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
             const int i = i0 + threadIdx.x;
 
-            const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
+            const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
             Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
         }
     }
@@ -238,6 +241,10 @@ static __global__ void flash_attn_tile_ext_f16(
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
 
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+
         half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
         kqsum_j = warp_reduce_sum(kqsum_j);
 
@@ -271,13 +278,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int      D = 64;
             constexpr int nwarps = 8;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
-            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         } break;
         case 128: {
             constexpr int      D = 128;
             constexpr int nwarps = 8;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
-            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         } break;
         default: {
             GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");

+ 11 - 8
llama/ggml-cuda/fattn-tile-f32.cu

@@ -36,6 +36,9 @@ static __global__ void flash_attn_tile_ext_f32(
         const int nb11,
         const int nb12,
         const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
         const int ne0,
         const int ne1,
         const int ne2,
@@ -79,7 +82,7 @@ static __global__ void flash_attn_tile_ext_f32(
 
 #pragma unroll
         for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
-            float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x];
+            float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f);
             Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
             Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
         }
@@ -237,6 +240,10 @@ static __global__ void flash_attn_tile_ext_f32(
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
 
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+
         float kqsum_j = kqsum[j_VKQ_0/nwarps];
         kqsum_j = warp_reduce_sum(kqsum_j);
 
@@ -268,13 +275,13 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int      D = 64;
             constexpr int nwarps = 8;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
-            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         } break;
         case 128: {
             constexpr int      D = 128;
             constexpr int nwarps = 8;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
-            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         } break;
         default: {
             GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
@@ -283,11 +290,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
 }
 
 void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * KQV = dst;
-    const ggml_tensor * Q   = dst->src[0];
-
-    const int32_t precision = KQV->op_params[2];
-    GGML_ASSERT(precision == GGML_PREC_DEFAULT);
+    const ggml_tensor * Q = dst->src[0];
 
     if (Q->ne[1] <= 16) {
         constexpr int cols_per_block = 16;

+ 394 - 2
llama/ggml-cuda/fattn-vec-f16.cuh

@@ -1,5 +1,397 @@
 #include "common.cuh"
+#include "fattn-common.cuh"
 
-void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(D, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_vec_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#if FP16_AVAILABLE
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
-void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+    constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
+    constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
+    constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
+
+    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
+    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    Q += nb02* blockIdx.y              + nb01*ic0;
+    K += nb12*(blockIdx.y / gqa_ratio);
+    V += nb22*(blockIdx.y / gqa_ratio);
+
+    const half * maskh = (const half   *)  mask + ne11*ic0;
+
+    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const half  slopeh = __float2half(slopef);
+
+    static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
+    constexpr int nwarps = D / WARP_SIZE;
+    const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+    __builtin_assume(tid < D);
+
+    __shared__ half KQ[ncols*D];
+    half2 * KQ2 = (half2 *) KQ;
+
+    half kqmax[ncols];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        kqmax[j] = -HALF_MAX_HALF;
+    }
+    half kqsum[ncols] = {0.0f};
+
+    __shared__ half kqmax_shared[ncols][WARP_SIZE];
+    __shared__ half kqsum_shared[ncols][WARP_SIZE];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        if (threadIdx.y == 0) {
+            kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF;
+            kqsum_shared[j][threadIdx.x] = 0.0f;
+        }
+    }
+    __syncthreads();
+
+    // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
+    half2  Q_h2[ncols][D/(2*WARP_SIZE)];
+    int   Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)];
+    half2  Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
+    if (Q_q8_1) {
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (j0 + nwarps > ncols && j >= ncols) {
+                break;
+            }
+
+            // Reuse KQ as temporary storage for converting Q to q8_1:
+            int   * tmp_q_i32 = (int   *) &KQ[j*D];
+            half2 * tmp_q_ds  = (half2 *) (tmp_q_i32 + D/sizeof(int));
+
+            // Set memory to zero if out of bounds:
+            if (ncols > 2 && ic0 + j >= ne01) {
+#pragma unroll
+                for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                    const int i = i0 + threadIdx.x;
+
+                    tmp_q_i32[i] = 0;
+                }
+                if (threadIdx.x < D/QK8_1) {
+                    tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f);
+                }
+                continue;
+            }
+
+            const float * Q_f = (const float *) (Q + j*nb01);
+#pragma unroll
+            for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                quantize_q8_1_to_shared<half2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            int   * tmp_q_i32 = (int   *) &KQ[j*D];
+            half2 * tmp_q_ds  = (half2 *) (tmp_q_i32 + D/sizeof(int));
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
+                Q_ds[j][i0/WARP_SIZE]  = tmp_q_ds[i/QI8_1];
+            }
+        }
+
+        __syncthreads();
+    } else {
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
+                Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
+            }
+        }
+    }
+
+
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        KQ[j*D + tid] = -HALF_MAX_HALF;
+    }
+
+    half2 VKQ[ncols] = {{0.0f, 0.0f}};
+
+    const int k_start = parallel_blocks == 1 ? 0 : ip*D;
+    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
+        // Calculate KQ tile and keep track of new maximum KQ values:
+
+        // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
+        // see https://github.com/ggerganov/llama.cpp/pull/7061 .
+        // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
+        half kqmax_new = kqmax[0];
+        half kqmax_new_arr[ncols];
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            kqmax_new_arr[j] = kqmax[j];
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
+            const int i_KQ = i_KQ_0 + threadIdx.y;
+
+            if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
+                break;
+            }
+
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
+                sum = warp_reduce_sum(sum);
+                sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
+
+                if (ncols == 1) {
+                    kqmax_new        = ggml_cuda_hmax(kqmax_new,        sum);
+                } else {
+                    kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
+                }
+
+                if (threadIdx.x == 0) {
+                    KQ[j*D + i_KQ] = sum;
+                }
+            }
+        }
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
+
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+            if (threadIdx.x == 0) {
+                kqmax_shared[j][threadIdx.y] = kqmax_new_j;
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            half kqmax_new_j = kqmax_shared[j][threadIdx.x];
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+            const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
+            kqmax[j] = kqmax_new_j;
+
+            const half val = hexp(KQ[j*D + tid] - kqmax[j]);
+            kqsum[j] = kqsum[j]*KQ_max_scale + val;
+            KQ[j*D + tid] = val;
+
+            VKQ[j] *= __half2half2(KQ_max_scale);
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int k0 = 0; k0 < D; k0 += 2) {
+            if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
+                break;
+            }
+
+            half2 V_k;
+            reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid);
+            reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        kqsum[j] = warp_reduce_sum(kqsum[j]);
+        if (threadIdx.x == 0) {
+            kqsum_shared[j][threadIdx.y] = kqsum[j];
+        }
+    }
+
+    __syncthreads();
+
+#pragma unroll
+    for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+        if (ncols > 2 && ic0 + j_VKQ >= ne01) {
+            break;
+        }
+
+        kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
+        kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
+
+        half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
+        if (parallel_blocks == 1) {
+            dst_val /= kqsum[j_VKQ];
+        }
+        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+        dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
+    }
+
+    if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
+        dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
+    }
+#else
+   NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+}
+
+template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
+void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    constexpr int nwarps = D/WARP_SIZE;
+    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
+    constexpr bool need_f16_K = D != 128;
+    constexpr bool need_f16_V = D != 128 && D != 64;
+    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
+}
+
+template <int D, ggml_type type_K, ggml_type type_V>
+void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * KQV = dst;
+    ggml_tensor * Q   = dst->src[0];
+    ggml_tensor * K   = dst->src[1];
+    ggml_tensor * V   = dst->src[2];
+
+    const int32_t precision = KQV->op_params[2];
+    GGML_ASSERT(precision == GGML_PREC_DEFAULT);
+
+    GGML_ASSERT(K->type == type_K);
+    GGML_ASSERT(V->type == type_V);
+
+    if (Q->ne[1] == 1) {
+        constexpr int cols_per_block  = 1;
+        constexpr int parallel_blocks = 4;
+        ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+        return;
+    }
+
+    if (Q->ne[1] == 2) {
+        constexpr int cols_per_block  = 2;
+        constexpr int parallel_blocks = 4;
+        ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+        return;
+    }
+
+    if (Q->ne[1] <= 4) {
+        constexpr int cols_per_block  = 4;
+        constexpr int parallel_blocks = 4;
+        ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+        return;
+    }
+
+    if (Q->ne[1] <= 8) {
+        constexpr int cols_per_block  = 8;
+        constexpr int parallel_blocks = 4;
+        ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+        return;
+    }
+
+    constexpr int cols_per_block  = 8;
+    constexpr int parallel_blocks = 1;
+    ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+}
+
+#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V)                          \
+    template void ggml_cuda_flash_attn_ext_vec_f16_case                     \
+    <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_0);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_1);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_0);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_1);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q8_0);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_F16);
+
+extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);

+ 372 - 1
llama/ggml-cuda/fattn-vec-f32.cuh

@@ -1,3 +1,374 @@
 #include "common.cuh"
+#include "fattn-common.cuh"
 
-void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(D, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_vec_ext_f32(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
+    constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
+    constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
+
+    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
+    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    Q += nb02* blockIdx.y              + nb01*ic0;
+    K += nb12*(blockIdx.y / gqa_ratio);
+    V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
+    const half * maskh = (const half   *)  mask + ne11*ic0;
+
+    const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+
+    static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
+    constexpr int nwarps = D / WARP_SIZE;
+    const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+    __builtin_assume(tid < D);
+
+    __shared__ float KQ[ncols*D];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        KQ[j*D + tid] = -FLT_MAX/2.0f;
+    }
+
+    float kqmax[ncols];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        kqmax[j] = -FLT_MAX/2.0f;
+    }
+    float kqsum[ncols] = {0.0f};
+
+    __shared__ float kqmax_shared[ncols][WARP_SIZE];
+    __shared__ float kqsum_shared[ncols][WARP_SIZE];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        if (threadIdx.y == 0) {
+            kqmax_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
+            kqsum_shared[j][threadIdx.x] = 0.0f;
+        }
+    }
+    __syncthreads();
+
+    // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
+    float2  Q_f2[ncols][D/(2*WARP_SIZE)];
+    int    Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)];
+    float2  Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
+    if (Q_q8_1) {
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (j0 + nwarps > ncols && j >= ncols) {
+                break;
+            }
+
+            // Reuse KQ as temporary storage for converting Q to q8_1:
+            int    * tmp_q_i32 = (int    *) &KQ[j*D];
+            float2 * tmp_q_ds  = (float2 *) (tmp_q_i32 + D/sizeof(int));
+
+            // Set memory to zero if out of bounds:
+            if (ncols > 2 && ic0 + j >= ne01) {
+#pragma unroll
+                for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                    const int i = i0 + threadIdx.x;
+
+                    tmp_q_i32[i] = 0;
+                }
+                if (threadIdx.x < D/QK8_1) {
+                    tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
+                }
+                continue;
+            }
+
+            const float * Q_f = (const float *) (Q + j*nb01);
+#pragma unroll
+            for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            int    * tmp_q_i32 = (int    *) &KQ[j*D];
+            float2 * tmp_q_ds  = (float2 *) (tmp_q_i32 + D/sizeof(int));
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
+                Q_ds[j][i0/WARP_SIZE]  = tmp_q_ds[i/QI8_1];
+            }
+        }
+
+        __syncthreads();
+    } else {
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                Q_f2[j][i0/WARP_SIZE]    = ncols <= 2 || ic0 + j ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
+                Q_f2[j][i0/WARP_SIZE].x *= scale;
+                Q_f2[j][i0/WARP_SIZE].y *= scale;
+            }
+        }
+    }
+
+    float VKQ[ncols] = {0.0f};
+
+    const int k_start = parallel_blocks == 1 ? 0 : ip*D;
+    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
+        // Calculate KQ tile and keep track of new maximum KQ values:
+
+        float kqmax_new_arr[ncols];
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            kqmax_new_arr[j] = kqmax[j];
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
+            const int i_KQ = i_KQ_0 + threadIdx.y;
+
+            if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
+                break;
+            }
+
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
+                sum = warp_reduce_sum(sum);
+                sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
+
+                kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
+
+                if (threadIdx.x == 0) {
+                    KQ[j*D + i_KQ] = sum;
+                }
+            }
+        }
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            float kqmax_new_j = kqmax_new_arr[j];
+
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+            if (threadIdx.x == 0) {
+                kqmax_shared[j][threadIdx.y] = kqmax_new_j;
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            float kqmax_new_j = kqmax_shared[j][threadIdx.x];
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+            const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
+            kqmax[j] = kqmax_new_j;
+
+            const float val = expf(KQ[j*D + tid] - kqmax[j]);
+            kqsum[j] = kqsum[j]*KQ_max_scale + val;
+            KQ[j*D + tid] = val;
+
+            VKQ[j] *= KQ_max_scale;
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int k = 0; k < D; ++k) {
+            if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) {
+                break;
+            }
+
+            const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                VKQ[j] += V_ki*KQ[j*D + k];
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        kqsum[j] = warp_reduce_sum(kqsum[j]);
+        if (threadIdx.x == 0) {
+            kqsum_shared[j][threadIdx.y] = kqsum[j];
+        }
+    }
+
+    __syncthreads();
+
+#pragma unroll
+    for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+        if (ncols > 2 && ic0 + j_VKQ >= ne01) {
+            break;
+        }
+
+        kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
+        kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
+
+        float dst_val = VKQ[j_VKQ];
+        if (parallel_blocks == 1) {
+            dst_val /= kqsum[j_VKQ];
+        }
+        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+        dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
+    }
+
+    if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
+        dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
+    }
+}
+
+template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
+void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    constexpr int nwarps = D/WARP_SIZE;
+    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
+    constexpr bool need_f16_K = D != 128;
+    constexpr bool need_f16_V = D != 128 && D != 64;
+    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
+}
+
+template <int D, ggml_type type_K, ggml_type type_V>
+void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * Q   = dst->src[0];
+    ggml_tensor * K   = dst->src[1];
+    ggml_tensor * V   = dst->src[2];
+
+    GGML_ASSERT(K->type == type_K);
+    GGML_ASSERT(V->type == type_V);
+
+    if (Q->ne[1] == 1) {
+        constexpr int cols_per_block  = 1;
+        constexpr int parallel_blocks = 4;
+        ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+        return;
+    }
+
+    if (Q->ne[1] == 2) {
+        constexpr int cols_per_block  = 2;
+        constexpr int parallel_blocks = 4;
+        ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+        return;
+    }
+
+    if (Q->ne[1] <= 4) {
+        constexpr int cols_per_block  = 4;
+        constexpr int parallel_blocks = 4;
+        ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+        return;
+    }
+
+    if (Q->ne[1] <= 8) {
+        constexpr int cols_per_block  = 8;
+        constexpr int parallel_blocks = 4;
+        ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+        return;
+    }
+
+    constexpr int cols_per_block  = 8;
+    constexpr int parallel_blocks = 1;
+    ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+}
+
+#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V)                          \
+    template void ggml_cuda_flash_attn_ext_vec_f32_case                     \
+    <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_0);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_1);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_0);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_1);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q8_0);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_F16);
+
+extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);

+ 490 - 0
llama/ggml-cuda/fattn-wmma-f16.cuh

@@ -0,0 +1,490 @@
+#include "common.cuh"
+#include "fattn-common.cuh"
+
+#if FP16_MMA_AVAILABLE
+#include <mma.h>
+#endif
+
+// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
+template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#if FP16_MMA_AVAILABLE
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
+    const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel.
+
+    static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
+    static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
+    constexpr int frag_m = ncols == 8 ? 32 : 16;
+    constexpr int frag_n = ncols == 8 ?  8 : 16;
+    static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t>                      frag_c_KQ;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half>                          frag_c_VKQ;
+
+    constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
+    constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
+    static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
+
+    // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
+    constexpr int D_padded = D + 8;
+    constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
+    constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float * Q_f   = (const float *) (Q + nb02* blockIdx.y              + nb01*ic0);
+    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.y / gqa_ratio));
+    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0;
+    const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2);
+
+    const int stride_Q  = nb01 / sizeof(float);
+    const int stride_KV = nb11 / sizeof(half);
+
+    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const half  slopeh = __float2half(slopef);
+    const half2 slope2 = make_half2(slopef, slopef);
+
+    frag_b Q_b[D/16][ncols/frag_n];
+
+    // A single buffer for temporarily holding tiles of KQ and VKQ parts:
+    constexpr int mem_KQ = ncols*kqs_padded*kqar;
+    constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
+    __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
+    float * KQ_f = (float *) KQ;
+    half2 * KQ2 = (half2 *) KQ;
+
+    float    KQ_rowsum_f[ncols/nwarps] = {0.0f};
+    float       KQ_max_f[ncols/nwarps];
+    float KQ_max_scale_f[ncols/nwarps] = {0.0f};
+
+#pragma unroll
+    for (int j = 0; j < ncols/nwarps; ++j) {
+        KQ_max_f[j] = -FLT_MAX/2.0f;
+    }
+
+    half2    KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+    half2       KQ_max_h2[ncols/nwarps];
+    half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+
+#pragma unroll
+    for (int j = 0; j < ncols/nwarps; ++j) {
+        KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
+    }
+
+    __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
+    half2 * VKQ2 = (half2 *) VKQ;
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+#pragma unroll
+        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                break;
+            }
+            VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
+        }
+    }
+
+    // Convert Q to half and apply scale, temporarily store in KQ:
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D && i >= D) {
+                break;
+            }
+            KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
+        }
+    }
+
+    __syncthreads();
+
+    // Load Q into tensor core fragments/registers since it will be used frequently:
+#pragma unroll
+    for (int i0 = 0; i0 < D; i0 += 16) {
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+            nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+        }
+    }
+
+    __syncthreads();
+
+    // Iterate over ne11 == previous tokens:
+    for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
+        // Calculate tile of KQ:
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
+            frag_c_KQ KQ_c[ncols/frag_n];
+#pragma unroll
+            for (int j = 0; j < ncols/frag_n; ++j) {
+                nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
+            }
+#pragma unroll
+            for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
+                frag_a_K K_a;
+                nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+#pragma unroll
+                for (int j = 0; j < ncols/frag_n; ++j) {
+                    nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
+                }
+            }
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+                nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
+            }
+        }
+
+        __syncthreads();
+
+        // Calculate softmax for each KQ column using the current max. value.
+        // The divisor is stored in KQ_rowsum and will be applied at the end.
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (std::is_same<KQ_acc_t, float>::value) {
+                float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+                }
+
+                float KQ_max_new = KQ_max_f[j0/nwarps];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
+                }
+                KQ_max_new = warp_reduce_max(KQ_max_new);
+
+                const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
+                KQ_max_scale_f[j0/nwarps] = expf(diff);
+                if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                    KQ_max_scale_f[j0/nwarps] = 0.0f;
+                }
+                KQ_max_f[j0/nwarps] = KQ_max_new;
+
+                float KQ_rowsum_add = 0.0f;
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
+                    KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
+                    if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                        KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
+                    }
+                    KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
+                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
+                }
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+                KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
+            } else {
+                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+                }
+
+                half2 KQ_max_new = KQ_max_h2[j0/nwarps];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
+                }
+                KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
+                const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
+                KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
+                const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+                *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
+                KQ_max_h2[j0/nwarps] = KQ_max_new;
+
+                half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
+                    KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
+                    const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+                    *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
+                    KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
+                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
+                }
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+                KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
+            }
+        }
+
+        __syncthreads();
+
+        frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+#pragma unroll
+            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+                nvcuda::wmma::load_matrix_sync(
+                    KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
+                    KQ + j0*(kqar*kqs_padded) + k,
+                    kqar*kqs_padded);
+            }
+        }
+
+        frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
+#pragma unroll
+            for (int j = 0; j < ncols/frag_n; ++j) {
+                nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
+            }
+
+#pragma unroll
+            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+
+                frag_a_V v_a;
+                nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+#pragma unroll
+                for (int j = 0; j < ncols/frag_n; ++j) {
+                    nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
+                }
+            }
+        }
+
+        __syncthreads();
+
+        const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+                nvcuda::wmma::store_matrix_sync(
+                    KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
+                    VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
+                    D_padded, nvcuda::wmma::mem_col_major);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            half2 VKQ_scale;
+            if (std::is_same<KQ_acc_t, float>::value) {
+                VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
+            } else {
+                VKQ_scale = KQ_max_scale_h2[j0/nwarps];
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+                if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                    break;
+                }
+
+                half2 VKQ_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+                for (int l = 0; l < VKQ_ratio; ++l) {
+                    VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
+                }
+                VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j_VKQ = j0 + threadIdx.y;
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+
+        float KQ_rowsum_j;
+        if (std::is_same<KQ_acc_t, float>::value) {
+            KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
+        } else {
+            KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D && i >= D) {
+                break;
+            }
+            float dst_val = VKQ[j_VKQ*D_padded + i];
+            if (parallel_blocks == 1) {
+                dst_val /= KQ_rowsum_j;
+            }
+            dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
+        }
+
+        if (parallel_blocks == 1 || threadIdx.x != 0) {
+            continue;
+        }
+
+        float2 dst_meta_val;
+        if (std::is_same<KQ_acc_t, float>::value) {
+            dst_meta_val.x = KQ_max_f[j0/nwarps];
+        } else {
+            dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
+        }
+        dst_meta_val.y = KQ_rowsum_j;
+        dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
+    }
+#else
+   NO_DEVICE_CODE;
+#endif // FP16_MMA_AVAILABLE
+}
+
+constexpr int get_max_power_of_2(int x) {
+    return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
+}
+
+static_assert(get_max_power_of_2(1) == 1, "Test failed.");
+static_assert(get_max_power_of_2(2) == 2, "Test failed.");
+static_assert(get_max_power_of_2(4) == 4, "Test failed.");
+static_assert(get_max_power_of_2(6) == 2, "Test failed.");
+
+// Number of VKQ rows calculated in parallel:
+constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
+    return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
+}
+
+static_assert(get_VKQ_stride(128, 1, 32) ==  32, "Test failed.");
+static_assert(get_VKQ_stride(128, 2, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
+static_assert(get_VKQ_stride( 64, 1, 32) ==  32, "Test failed.");
+static_assert(get_VKQ_stride( 64, 2, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride( 64, 4, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
+
+template <int D, int cols_per_block, typename KQ_acc_t>
+void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+
+    constexpr int nwarps = 4;
+
+    constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
+    const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
+    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+
+    if (4*blocks_num_pb1 < 2*nsm) {
+        constexpr int parallel_blocks = 4;
+        fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+        launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+        return;
+    }
+    if (2*blocks_num_pb1 < 2*nsm) {
+        constexpr int parallel_blocks = 2;
+        fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+        launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+        return;
+    }
+    constexpr int parallel_blocks = 1;
+    fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+}
+
+#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t)                         \
+    template void ggml_cuda_flash_attn_ext_wmma_f16_case                              \
+    <D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE(112, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE(128, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float);
+extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float);
+extern DECL_FATTN_WMMA_F16_CASE(112, 32, float);
+extern DECL_FATTN_WMMA_F16_CASE(128, 32, float);
+// extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
+
+extern DECL_FATTN_WMMA_F16_CASE( 64,  8, half);
+extern DECL_FATTN_WMMA_F16_CASE( 96,  8, half);
+extern DECL_FATTN_WMMA_F16_CASE(128,  8, half);
+extern DECL_FATTN_WMMA_F16_CASE(256,  8, half);
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE(112, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE(128, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE(112, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE(128, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);

+ 228 - 521
llama/ggml-cuda/fattn.cu

@@ -4,519 +4,38 @@
 #include "fattn-tile-f32.cuh"
 #include "fattn-vec-f16.cuh"
 #include "fattn-vec-f32.cuh"
+#include "fattn-wmma-f16.cuh"
 #include "fattn.cuh"
 
 #include <cstdint>
 
-#if FP16_MMA_AVAILABLE
-#include <mma.h>
-#endif
-
-// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
-template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
-__launch_bounds__(nwarps*WARP_SIZE, 1)
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
-static __global__ void flash_attn_ext_f16(
-        const char * __restrict__ Q,
-        const char * __restrict__ K,
-        const char * __restrict__ V,
-        const char * __restrict__ mask,
-        float      * __restrict__ dst,
-        float2     * __restrict__ dst_meta,
-        const float scale,
-        const float max_bias,
-        const float m0,
-        const float m1,
-        const uint32_t n_head_log2,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int nb31,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3) {
-#if FP16_MMA_AVAILABLE
-    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
-
-    const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
-    const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel.
-
-    static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
-    static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
-    constexpr int frag_m = ncols == 8 ? 32 : 16;
-    constexpr int frag_n = ncols == 8 ?  8 : 16;
-    static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
-    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
-    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
-    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
-    typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t>                      frag_c_KQ;
-    typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half>                          frag_c_VKQ;
-
-    constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
-    constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
-    static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
-
-    // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
-    constexpr int D_padded = D + 8;
-    constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
-    constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
-
-    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float * Q_f   = (const float *) (Q + nb02* blockIdx.y              + nb01*ic0);
-    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.y / gqa_ratio));
-    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
-    const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0;
-    const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2);
-
-    const int stride_Q  = nb01 / sizeof(float);
-    const int stride_KV = nb11 / sizeof(half);
-
-    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
-    const half  slopeh = __float2half(slopef);
-    const half2 slope2 = make_half2(slopef, slopef);
-
-    frag_b Q_b[D/16][ncols/frag_n];
-
-    // A single buffer for temporarily holding tiles of KQ and VKQ parts:
-    constexpr int mem_KQ = ncols*kqs_padded*kqar;
-    constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
-    __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
-    float * KQ_f = (float *) KQ;
-    half2 * KQ2 = (half2 *) KQ;
-
-    float    KQ_rowsum_f[ncols/nwarps] = {0.0f};
-    float       KQ_max_f[ncols/nwarps];
-    float KQ_max_scale_f[ncols/nwarps] = {0.0f};
-
-#pragma unroll
-    for (int j = 0; j < ncols/nwarps; ++j) {
-        KQ_max_f[j] = -FLT_MAX/2.0f;
-    }
-
-    half2    KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
-    half2       KQ_max_h2[ncols/nwarps];
-    half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
-
-#pragma unroll
-    for (int j = 0; j < ncols/nwarps; ++j) {
-        KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
-    }
-
-    __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
-    half2 * VKQ2 = (half2 *) VKQ;
-#pragma unroll
-    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
-#pragma unroll
-        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D/2 && i >= D/2) {
-                break;
-            }
-            VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
-        }
-    }
-
-    // Convert Q to half and apply scale, temporarily store in KQ:
-#pragma unroll
-    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
-#pragma unroll
-        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D && i >= D) {
-                break;
-            }
-            KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
-        }
-    }
-
-    __syncthreads();
-
-    // Load Q into tensor core fragments/registers since it will be used frequently:
-#pragma unroll
-    for (int i0 = 0; i0 < D; i0 += 16) {
-#pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-            nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
-        }
-    }
-
-    __syncthreads();
-
-    // Iterate over ne11 == previous tokens:
-    for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
-        // Calculate tile of KQ:
-#pragma unroll
-        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
-            frag_c_KQ KQ_c[ncols/frag_n];
-#pragma unroll
-            for (int j = 0; j < ncols/frag_n; ++j) {
-                nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
-            }
-#pragma unroll
-            for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
-                frag_a_K K_a;
-                nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
-#pragma unroll
-                for (int j = 0; j < ncols/frag_n; ++j) {
-                    nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
-                }
-            }
-#pragma unroll
-            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-                nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
-            }
-        }
-
-        __syncthreads();
-
-        // Calculate softmax for each KQ column using the current max. value.
-        // The divisor is stored in KQ_rowsum and will be applied at the end.
-#pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-            const int j = j0 + threadIdx.y;
-
-            if (std::is_same<KQ_acc_t, float>::value) {
-                float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
-                }
-
-                float KQ_max_new = KQ_max_f[j0/nwarps];
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
-                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
-                }
-                KQ_max_new = warp_reduce_max(KQ_max_new);
-
-                const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
-                KQ_max_scale_f[j0/nwarps] = expf(diff);
-                if (diff <= SOFTMAX_FTZ_THRESHOLD) {
-                    KQ_max_scale_f[j0/nwarps] = 0.0f;
-                }
-                KQ_max_f[j0/nwarps] = KQ_max_new;
-
-                float KQ_rowsum_add = 0.0f;
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
-                    KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
-                    if (diff <= SOFTMAX_FTZ_THRESHOLD) {
-                        KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
-                    }
-                    KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
-                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
-                }
-                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
-
-                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
-                KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
-            } else {
-                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
-                }
-
-                half2 KQ_max_new = KQ_max_h2[j0/nwarps];
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
-                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
-                }
-                KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
-                const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
-                KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
-                const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
-                *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
-                KQ_max_h2[j0/nwarps] = KQ_max_new;
-
-                half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
-                    KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
-                    const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
-                    *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
-                    KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
-                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
-                }
-                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
-
-                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
-                KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
-            }
-        }
-
-        __syncthreads();
-
-        frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
-#pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-#pragma unroll
-            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
-                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
-                nvcuda::wmma::load_matrix_sync(
-                    KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
-                    KQ + j0*(kqar*kqs_padded) + k,
-                    kqar*kqs_padded);
-            }
-        }
-
-        frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
-#pragma unroll
-        for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
-#pragma unroll
-            for (int j = 0; j < ncols/frag_n; ++j) {
-                nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
-            }
-
-#pragma unroll
-            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
-                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
-
-                frag_a_V v_a;
-                nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
-#pragma unroll
-                for (int j = 0; j < ncols/frag_n; ++j) {
-                    nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
-                }
-            }
-        }
-
-        __syncthreads();
-
-        const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
-#pragma unroll
-        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
-#pragma unroll
-            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-                nvcuda::wmma::store_matrix_sync(
-                    KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
-                    VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
-                    D_padded, nvcuda::wmma::mem_col_major);
-            }
-        }
-
-        __syncthreads();
-
-#pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-            const int j = j0 + threadIdx.y;
-
-            half2 VKQ_scale;
-            if (std::is_same<KQ_acc_t, float>::value) {
-                VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
-            } else {
-                VKQ_scale = KQ_max_scale_h2[j0/nwarps];
-            }
-
-#pragma unroll
-            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
-                const int i = i0 + threadIdx.x;
-                if (i0 + WARP_SIZE > D/2 && i >= D/2) {
-                    break;
-                }
-
-                half2 VKQ_add = make_half2(0.0f, 0.0f);
-#pragma unroll
-                for (int l = 0; l < VKQ_ratio; ++l) {
-                    VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
-                }
-                VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
-            }
-        }
-
-        __syncthreads();
-    }
-
-#pragma unroll
-    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-        const int j_VKQ = j0 + threadIdx.y;
-        if (ic0 + j_VKQ >= ne01) {
-            return;
-        }
-        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
-
-        float KQ_rowsum_j;
-        if (std::is_same<KQ_acc_t, float>::value) {
-            KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
-        } else {
-            KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
-        }
-
-#pragma unroll
-        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D && i >= D) {
-                break;
-            }
-            float dst_val = VKQ[j_VKQ*D_padded + i];
-            if (parallel_blocks == 1) {
-                dst_val /= KQ_rowsum_j;
-            }
-            dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
-        }
-
-        if (parallel_blocks == 1 || threadIdx.x != 0) {
-            continue;
-        }
-
-        float2 dst_meta_val;
-        if (std::is_same<KQ_acc_t, float>::value) {
-            dst_meta_val.x = KQ_max_f[j0/nwarps];
-        } else {
-            dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
-        }
-        dst_meta_val.y = KQ_rowsum_j;
-        dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
-    }
-#else
-   NO_DEVICE_CODE;
-#endif // FP16_MMA_AVAILABLE
-}
-
-constexpr int get_max_power_of_2(int x) {
-    return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
-}
-
-static_assert(get_max_power_of_2(1) == 1, "Test failed.");
-static_assert(get_max_power_of_2(2) == 2, "Test failed.");
-static_assert(get_max_power_of_2(4) == 4, "Test failed.");
-static_assert(get_max_power_of_2(6) == 2, "Test failed.");
-
-// Number of VKQ rows calculated in parallel:
-constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
-    return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
-}
-
-static_assert(get_VKQ_stride(128, 1, 32) ==  32, "Test failed.");
-static_assert(get_VKQ_stride(128, 2, 32) ==  64, "Test failed.");
-static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
-static_assert(get_VKQ_stride( 64, 1, 32) ==  32, "Test failed.");
-static_assert(get_VKQ_stride( 64, 2, 32) ==  64, "Test failed.");
-static_assert(get_VKQ_stride( 64, 4, 32) ==  64, "Test failed.");
-static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed.");
-static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed.");
-static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
-
-template <int D, int cols_per_block, int nwarps, typename KQ_acc_t>
-void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * Q = dst->src[0];
-
-    constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
-    const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
-    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
-
-    if (4*blocks_num_pb1 < 2*nsm) {
-        constexpr int parallel_blocks = 4;
-        fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
-        launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
-        return;
-    }
-    if (2*blocks_num_pb1 < 2*nsm) {
-        constexpr int parallel_blocks = 2;
-        fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
-        launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
-        return;
-    }
-    constexpr int parallel_blocks = 1;
-    fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
-    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
-}
-
-void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * KQV = dst;
     const ggml_tensor * Q   = dst->src[0];
 
-    ggml_cuda_set_device(ctx.device);
-    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const int32_t precision = KQV->op_params[2];
 
-    // On AMD the tile kernels perform poorly, use the vec kernel instead:
-    if (cc >= CC_OFFSET_AMD) {
-        if (precision == GGML_PREC_DEFAULT) {
-            ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
-        } else {
-            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
-        }
-        return;
-    }
-
-    if (!fast_fp16_available(cc)) {
-        if (Q->ne[1] <= 8) {
-            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
-        } else {
-            ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
-        }
-        return;
-    }
-
-    if (!fp16_mma_available(cc)) {
-        if (Q->ne[1] <= 8) {
-            ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
-        } else {
-            ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
-        }
-        return;
-    }
-
     if (precision != GGML_PREC_DEFAULT) {
-        if (Q->ne[1] == 1 && (Q->ne[0] == 64 || Q->ne[0] == 128)) {
-            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
-            return;
-        }
-
         if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
             constexpr int cols_per_block = 16;
-            constexpr int nwarps         =  4;
             switch (Q->ne[0]) {
                 case 64:
-                    launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst);
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
                     break;
                 case 80:
-                    launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst);
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
                     break;
                 case 96:
-                    launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst);
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
                     break;
                 case 112:
-                    launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst);
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
                     break;
                 case 128:
-                    launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst);
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
                     break;
                 case 256:
-                    launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
                     break;
                 default:
                     GGML_ASSERT(false);
@@ -524,25 +43,24 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
             }
         } else {
             constexpr int cols_per_block = 32;
-            constexpr int nwarps         =  4;
             switch (Q->ne[0]) {
                 case 64:
-                    launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst);
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
                     break;
                 case 80:
-                    launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst);
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
                     break;
                 case 96:
-                    launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst);
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
                     break;
                 case 112:
-                    launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst);
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
                     break;
                 case 128:
-                    launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst);
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
                     break;
                 // case 256:
-                //     launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
+                //     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
                 //     break;
                 default:
                     GGML_ASSERT(false);
@@ -552,26 +70,20 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
         return;
     }
 
-    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
-        ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
-        return;
-    }
-
     if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
         constexpr int cols_per_block = 8;
-        constexpr int nwarps         = 4;
         switch (Q->ne[0]) {
             case 64:
-                launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
                 break;
             case 96:
-                launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
                 break;
             case 128:
-                launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
+                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
                 break;
             case 256:
-                launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
+                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
                 break;
             default:
                 GGML_ASSERT(false);
@@ -582,25 +94,24 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 
     if (Q->ne[1] <= 32) {
         constexpr int cols_per_block = 16;
-        constexpr int nwarps         =  4;
         switch (Q->ne[0]) {
             case 64:
-                launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
                 break;
             case 80:
-                launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst);
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
                 break;
             case 96:
-                launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
                 break;
             case 112:
-                launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst);
+                ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
                 break;
             case 128:
-                launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
+                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
                 break;
             case 256:
-                launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
+                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
                 break;
             default:
                 GGML_ASSERT(false);
@@ -610,29 +121,225 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     }
 
     constexpr int cols_per_block = 32;
-    constexpr int nwarps         =  4;
     switch (Q->ne[0]) {
         case 64:
-            launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
             break;
         case 80:
-            launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
             break;
         case 96:
-            launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
             break;
         case 112:
-            launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
             break;
         case 128:
-            launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
             break;
         case 256:
-            launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
             break;
         default:
             GGML_ASSERT(false);
             break;
     }
-    return;
+}
+#define FATTN_VEC_F16_CASE(D, type_K, type_V)                               \
+    if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) {    \
+        ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
+        return;                                                             \
+    }                                                                       \
+
+static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * Q = dst->src[1];
+    ggml_tensor * K = dst->src[1];
+    ggml_tensor * V = dst->src[2];
+
+#ifdef GGML_CUDA_FA_ALL_QUANTS
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 )
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_0)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_1)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_0)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_1)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q8_0)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_F16)
+
+    FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+#else
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+#endif // GGML_CUDA_FA_ALL_QUANTS
+
+    on_no_fattn_vec_case(Q->ne[0]);
+}
+
+#define FATTN_VEC_F32_CASE(D, type_K, type_V)                               \
+    if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) {    \
+        ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \
+        return;                                                             \
+    }                                                                       \
+
+static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * Q = dst->src[1];
+    ggml_tensor * K = dst->src[1];
+    ggml_tensor * V = dst->src[2];
+
+#ifdef GGML_CUDA_FA_ALL_QUANTS
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_0)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_1)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_0)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_1)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q8_0)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_F16)
+
+    FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+#else
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+#endif // GGML_CUDA_FA_ALL_QUANTS
+
+    on_no_fattn_vec_case(Q->ne[0]);
+}
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    ggml_cuda_set_device(ctx.device);
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const int32_t precision = KQV->op_params[2];
+
+    // On AMD the tile kernels perform poorly, use the vec kernel instead:
+    if (cc >= CC_OFFSET_AMD) {
+        if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
+            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+        } else {
+            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+        }
+        return;
+    }
+
+    if (!fast_fp16_available(cc)) {
+        if (Q->ne[1] <= 8) {
+            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+        } else {
+            ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
+        }
+        return;
+    }
+
+    if (!fp16_mma_available(cc)) {
+        if (Q->ne[1] <= 8) {
+            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+        } else {
+            ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
+        if (precision == GGML_PREC_DEFAULT) {
+            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+            return;
+        } else if(Q->ne[0] <= 128) {
+            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+            return;
+        }
+    }
+
+    ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
 }

File diff suppressed because it is too large
+ 0 - 2174
llama/ggml-cuda/mmq.cu


+ 1300 - 0
llama/ggml-cuda/mmq.cuh

@@ -1,4 +1,1304 @@
 #include "common.cuh"
+#include "vecdotq.cuh"
+
+#include <climits>
+#include <cstdint>
+
+typedef void (*load_tiles_mmq_t)(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
+typedef void (*vec_dot_mmq_t)(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, float * __restrict__ sum, const int & k0);
+
+struct tile_x_sizes {
+    int ql;
+    int dm;
+    int qh;
+    int sc;
+};
+
+// get_mmq_x_max_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
+
+static constexpr __device__ int get_mmq_x_max_device() {
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    return 64;
+#else
+#if __CUDA_ARCH__ >= CC_VOLTA
+#ifdef CUDA_USE_TENSOR_CORES
+    return MMQ_MAX_BATCH_SIZE;
+#else
+    return 128;
+#endif // CUDA_USE_TENSOR_CORES
+#else
+    return 64;
+#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+}
+
+// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
+
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+static constexpr __device__ int get_mmq_y_device(int mmq_x) {
+    return mmq_x >= 32 ? 128 : 64;
+}
+#else
+#if __CUDA_ARCH__ >= CC_VOLTA
+static constexpr __device__ int get_mmq_y_device(int mmq_x) {
+    return mmq_x >= 32 ? 128 : 64;
+}
+#else
+static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
+    return 64;
+}
+#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+
+#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0,                           0}
+#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0,                           0}
+#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0,                           0}
+#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0,                           0}
+#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0,                           0}
+#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI2_K + mmq_y/QI2_K, 0,                           mmq_y*WARP_SIZE/4 + mmq_y/4}
+#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/2 + mmq_y/2, mmq_y*WARP_SIZE/4 + mmq_y/4}
+#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, 0,                           mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, 0,                           mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, 0,                           mmq_y*WARP_SIZE/8 + mmq_y/8}
+
+#define GET_TILE_X_SIZES_BODY                           \
+    return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \
+        type == GGML_TYPE_Q4_1 ? TILE_X_SIZES_Q4_1 :    \
+        type == GGML_TYPE_Q5_0 ? TILE_X_SIZES_Q5_0 :    \
+        type == GGML_TYPE_Q5_1 ? TILE_X_SIZES_Q5_1 :    \
+        type == GGML_TYPE_Q8_0 ? TILE_X_SIZES_Q8_0 :    \
+        type == GGML_TYPE_Q2_K ? TILE_X_SIZES_Q2_K :    \
+        type == GGML_TYPE_Q3_K ? TILE_X_SIZES_Q3_K :    \
+        type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K :    \
+        type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K :    \
+        type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K :    \
+        tile_x_sizes{0, 0, 0, 0}
+
+static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
+    GET_TILE_X_SIZES_BODY;
+}
+
+template <int mmq_y>
+static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type) {
+    GET_TILE_X_SIZES_BODY;
+}
+
+// ------------------------------------------------------------
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    const int kbx  = threadIdx.x / QI4_0;
+    const int kqsx = threadIdx.x % QI4_0;
+
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
+        int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
+            const float * x_dmf = (const float *) x_dm;
+
+            int u[2*VDR_Q4_0_Q8_1_MMQ];
+
+#pragma unroll
+            for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
+                u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+                u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
+            }
+
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
+                (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
+                y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    const int kbx  = threadIdx.x / QI4_1;
+    const int kqsx = threadIdx.x % QI4_1;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
+        int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
+
+            int u[2*VDR_Q4_1_Q8_1_MMQ];
+
+#pragma unroll
+            for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
+                u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+                u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
+            }
+
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
+                (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
+                y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    const int kbx  = threadIdx.x / QI5_0;
+    const int kqsx = threadIdx.x % QI5_0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
+
+        const int ql = get_int_from_uint8(bxi->qs, kqsx);
+        const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
+
+        int qs0 = (ql >>  0)   & 0x0F0F0F0F;
+        qs0    |= (qh <<  4)   & 0x00000010;  // 0 ->  4
+        qs0    |= (qh << 11)   & 0x00001000;  // 1 -> 12
+        qs0    |= (qh << 18)   & 0x00100000;  // 2 -> 20
+        qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28
+        qs0     = __vsubss4(qs0, 0x10101010); // subtract 16
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
+
+        int qs1 = (ql >>  4)   & 0x0F0F0F0F;
+        qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4
+        qs1    |= (qh >>  5)   & 0x00001000;  // 17 -> 12
+        qs1    |= (qh <<  2)   & 0x00100000;  // 18 -> 20
+        qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
+        qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
+        int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
+            const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0;
+            const float * x_dmf = (const float *) x_dm;
+            const float * y_df  = (const float *) y_ds;
+
+            int u[2*VDR_Q5_0_Q8_1_MMQ];
+
+#pragma unroll
+            for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
+                u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+                u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
+            }
+
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
+                (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+        }
+    }
+}
+
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    const int kbx  = threadIdx.x / QI5_1;
+    const int kqsx = threadIdx.x % QI5_1;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
+
+        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
+
+        int qs0 = (ql >>  0) & 0x0F0F0F0F;
+        qs0    |= (qh <<  4) & 0x00000010; // 0 ->  4
+        qs0    |= (qh << 11) & 0x00001000; // 1 -> 12
+        qs0    |= (qh << 18) & 0x00100000; // 2 -> 20
+        qs0    |= (qh << 25) & 0x10000000; // 3 -> 28
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
+
+        int qs1 = (ql >>  4) & 0x0F0F0F0F;
+        qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4
+        qs1    |= (qh >>  5) & 0x00001000; // 17 -> 12
+        qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
+        qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
+        int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
+            const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k0/QI5_1;
+
+            int u[2*VDR_Q5_1_Q8_1_MMQ];
+
+#pragma unroll
+            for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
+                u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+                u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
+            }
+
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
+                (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+    const int kbx  = threadIdx.x / QI8_0;
+    const int kqsx = threadIdx.x % QI8_0;
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
+        int i = i0 + threadIdx.y * QI8_0 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const float * x_dmf = (const float *) x_dm;
+            const float * y_df  = (const float *) y_ds;
+
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
+                (&x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[j * WARP_SIZE + k0], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
+                y_df[j * (WARP_SIZE/QI8_1) + k0/QI8_1]);
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    GGML_UNUSED(x_qh);
+
+    const int kbx  = threadIdx.x / QI2_K;
+    const int kqsx = threadIdx.x % QI2_K;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
+        int i = (i0 + threadIdx.y * QI2_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+        int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI2_K/4);
+
+        x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, threadIdx.x % (QI2_K/4));
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const int kbx = k0 / QI2_K;
+            const int ky  = (k0 % QI2_K) * QR2_K;
+            const float * y_df = (const float *) y_ds;
+
+            int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
+
+            const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
+            const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
+
+#pragma unroll
+            for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
+                v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
+            }
+
+            const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
+
+            const int index_y = j * WARP_SIZE + (QR2_K*k0) % WARP_SIZE;
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
+                v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+
+    const int kbx  = threadIdx.x / QI3_K;
+    const int kqsx = threadIdx.x % QI3_K;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
+        int i = (i0 + threadIdx.y * QI3_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
+        int i = i0 + threadIdx.y * 2 + threadIdx.x / (WARP_SIZE/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/2)) / (QI3_K/2);
+
+        // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+        x_qh[i * (WARP_SIZE/2) + i / 2 + threadIdx.x % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, threadIdx.x % (QI3_K/2));
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+        int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI3_K/4);
+
+        const int ksc = threadIdx.x % (QI3_K/4);
+
+        const int ksc_low = ksc % (QI3_K/8);
+        const int shift_low = 4 * (ksc / (QI3_K/8));
+        const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
+
+        const int ksc_high = QI3_K/8;
+        const int shift_high = 2 * ksc;
+        const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
+
+        const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
+
+        x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = sc;
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const int kbx  = k0 / QI3_K;
+            const int ky  = (k0 % QI3_K) * QR3_K;
+            const float * x_dmf = (const float *) x_dm;
+            const float * y_df  = (const float *) y_ds;
+
+            const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
+
+            int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
+
+#pragma unroll
+            for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
+                const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
+                const int shift = 2 * ((ky % 32) / 8);
+                const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
+
+                const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
+                const int vlh = (vh << 2) & 0x04040404;
+
+                v[l] = __vsubss4(vll, vlh);
+            }
+
+            const int index_y = j * WARP_SIZE + (k0*QR3_K) % WARP_SIZE;
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
+                v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    GGML_UNUSED(x_qh);
+
+    const int kbx  = 0;           // threadIdx.x / QI4_K
+    const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
+
+        x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_K;  // == 1 if QK_K == 256
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
+        int i = (i0 + threadIdx.y * QI4_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
+
+        const int * scales = (const int *) bxi->scales;
+
+        const int ksc = threadIdx.x % (WARP_SIZE/8);
+
+        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
+        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
+        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
+
+        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
+
+            const int index_y = j * WARP_SIZE + (QR4_K*k0) % WARP_SIZE;
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
+                &x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    GGML_UNUSED(x_qh);
+
+    const int kbx  = 0;           // threadIdx.x / QI5_K
+    const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
+        const int ky = QR5_K*kqsx;
+
+        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+        const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
+        const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
+        const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
+
+        const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
+        const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
+
+        x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
+        x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_K;  // == 1 if QK_K == 256
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
+        int i = (i0 + threadIdx.y * QI5_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI5_K/8);
+
+        const int * scales = (const int *) bxi->scales;
+
+        const int ksc = threadIdx.x % (WARP_SIZE/8);
+
+        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
+        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
+        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
+
+        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
+
+            const int index_x = i * (QR5_K*WARP_SIZE + 1) +  QR5_K*k0;
+            const int index_y = j * WARP_SIZE             + (QR5_K*k0) % WARP_SIZE;
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
+                &x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
+    const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
+    GGML_UNUSED(x_qh);
+
+    const int kbx  = 0;           // threadIdx.x / QI6_K
+    const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
+        const int ky = QR6_K*kqsx;
+
+        const int ql = get_int_from_uint8(bxi->ql, kqsx);
+        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+        const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
+        const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
+        const int qh1 =  (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4))))       & 0x30303030;
+
+        const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
+        const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
+
+        x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+        x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI6_K;  // == 1 if QK_K == 256
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
+        int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
+
+        x_sc[i * (WARP_SIZE/8) + i / 8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
+
+    GGML_UNUSED(x_qh);
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const float * x_dmf = (const float *) x_dm;
+            const float * y_df  = (const float *) y_ds;
+
+            const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
+
+            const int index_x = i * (QR6_K*WARP_SIZE + 1) +  QR6_K*k0;
+            const int index_y = j * WARP_SIZE             + (QR6_K*k0) % WARP_SIZE;
+            sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
+                &x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
+        }
+    }
+}
+
+// -------------------------------------------------------------------------------------------------------------------------------------
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
+struct mmq_type_traits;
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
+    static constexpr bool             need_sum   = true;
+    static constexpr int              vdr        = VDR_Q4_0_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
+    static constexpr bool             need_sum   = true;
+    static constexpr int              vdr        = VDR_Q4_1_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
+    static constexpr bool             need_sum   = false;
+    static constexpr int              vdr        = VDR_Q5_0_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
+    static constexpr bool             need_sum   = true;
+    static constexpr int              vdr        = VDR_Q5_1_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
+    static constexpr bool             need_sum   = false;
+    static constexpr int              vdr        = VDR_Q8_0_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
+    static constexpr bool             need_sum   = false;
+    static constexpr int              vdr        = VDR_Q2_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
+    static constexpr bool             need_sum   = false;
+    static constexpr int              vdr        = VDR_Q3_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
+    static constexpr bool             need_sum   = true;
+    static constexpr int              vdr        = VDR_Q4_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
+    static constexpr bool             need_sum   = true;
+    static constexpr int              vdr        = VDR_Q5_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
+    static constexpr bool             need_sum   = false;
+    static constexpr int              vdr        = VDR_Q6_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot    = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
+};
+
+template <ggml_type type, int mmq_x, int nwarps, bool need_check>
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*nwarps, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#else
+#if __CUDA_ARCH__ >= CC_VOLTA
+    __launch_bounds__(WARP_SIZE*nwarps, 1)
+#else
+    __launch_bounds__(WARP_SIZE*nwarps, type == GGML_TYPE_Q2_K ? 1 : 2)
+#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+static __global__ void mul_mat_q(
+    const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
+    const int ne00, const int ne01, const int stride00, const int ne10, const int ne11, const int ne0) {
+
+    // Skip unused template specializations for faster compilation:
+    if (mmq_x > get_mmq_x_max_device()) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    constexpr int              qk         = ggml_cuda_type_traits<type>::qk;
+    constexpr int              qr         = ggml_cuda_type_traits<type>::qr;
+    constexpr int              qi         = ggml_cuda_type_traits<type>::qi;
+    constexpr int              mmq_y      = get_mmq_y_device(mmq_x);
+    constexpr bool             need_sum   = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::need_sum;
+    constexpr int              vdr        = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
+    constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
+    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
+
+    constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
+
+    extern __shared__ char data_mul_mat_q[];
+    int   * tile_x_ql = (int   *)  data_mul_mat_q;
+    half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql);
+    int   * tile_x_qh = (int   *) (tile_x_dm + txs.dm);
+    int   * tile_x_sc = (int   *) (tile_x_qh + txs.qh);
+    int   * tile_y_qs = (int   *) (tile_x_sc + txs.sc);          // [mmq_x * WARP_SIZE]
+    half2 * tile_y_ds = (half2 *) (tile_y_qs + mmq_x*WARP_SIZE); // [mmq_x * WARP_SIZE/QI8_1];
+
+    const block_q8_1 * y = (const block_q8_1 *) yc;
+
+    const int blocks_per_row_x = ne00 / qk;
+    const int blocks_per_col_y = ne10 / QK8_1;
+    const int blocks_per_warp = WARP_SIZE / qi;
+
+    const int & ne1 = ne11;
+
+    const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
+
+    float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f};
+
+    for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
+
+        load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride00*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride00);
+
+#pragma unroll
+        for (int kr = 0; kr < qr; ++kr) {
+            const int kqs = kr*WARP_SIZE + threadIdx.x;
+            const int kbxd = kqs / QI8_1;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_x; i0 += nwarps) {
+                const int i = min(blockIdx.y*mmq_x + threadIdx.y + i0, ne11-1); // to prevent out-of-bounds memory accesses
+
+                const block_q8_1 * by0 = &y[i*blocks_per_col_y + kb0 * (qk/QK8_1) + kbxd];
+
+                const int index_y = (i0 + threadIdx.y) * WARP_SIZE + kqs % WARP_SIZE;
+                tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
+            }
+
+#pragma unroll
+            for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
+                const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
+                const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
+                const int i_y_eff = min(blockIdx.y*mmq_x + ids, ne11-1);
+
+                // if the sum is not needed it's faster to transform the scale to f32 ahead of time
+                const half2 * dsi_src = &y[i_y_eff*blocks_per_col_y + kb0 * (qk/QK8_1) + kr*(WARP_SIZE/QI8_1) + kby].ds;
+                half2       * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
+                if (need_sum) {
+                    *dsi_dst = *dsi_src;
+                } else {
+                    float * dfi_dst = (float *) dsi_dst;
+                    *dfi_dst = __low2float(*dsi_src);
+                }
+            }
+
+            __syncthreads();
+
+// #pragma unroll // unrolling this loop causes too much register pressure
+            for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
+                vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, sum, k0);
+            }
+
+            __syncthreads();
+        }
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
+
+        if (j >= ne1) {
+            return;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
+
+            if (need_check && i >= ne0) {
+                continue;
+            }
+
+            dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+        }
+    }
+}
+
+struct mmq_args {
+    const char * x; const char * y; float * dst;
+    int64_t ne00; int64_t ne01; int64_t stride00;
+    int64_t ne10; int64_t ne11;
+    int64_t ne0;
+};
+
+template <ggml_type type, int mmq_x, int nwarps>
+static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
+    const int id = ggml_cuda_get_device();
+    const int cc = ggml_cuda_info().devices[id].cc;
+    const int mmq_y = get_mmq_y_host(cc, mmq_x);
+
+    const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y;
+    const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
+    const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int);
+    const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
+    const int shmem = shmem_x + shmem_y;
+
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+    static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+    if (!shmem_limit_raised[id]) {
+        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
+        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, true>,  cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
+        shmem_limit_raised[id] = true;
+    }
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+
+    if (args.ne01 % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
+            (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0);
+    } else {
+        const bool need_check = true;
+        mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
+            (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0);
+    }
+}
+
+template <ggml_type type>
+void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
+    const int id = ggml_cuda_get_device();
+    const int nsm = ggml_cuda_info().devices[id].nsm;
+    const int cc  = ggml_cuda_info().devices[id].cc;
+
+    const int mmq_x_max = get_mmq_x_max_host(cc);
+    const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
+    const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
+
+    int mmq_x_best  = 0;
+    int nwaves_best = INT_MAX;
+
+    for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) {
+        const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
+        const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
+
+        if (nwaves < nwaves_best) {
+            mmq_x_best  = mmq_x;
+            nwaves_best = nwaves;
+        }
+    }
+
+    switch (mmq_x_best) {
+        case   8:
+            launch_mul_mat_q<type,   8, 4>(args, stream);
+            break;
+        case  16:
+            launch_mul_mat_q<type,  16, 8>(args, stream);
+            break;
+        case  24:
+            launch_mul_mat_q<type,  24, 8>(args, stream);
+            break;
+        case  32:
+            launch_mul_mat_q<type,  32, 8>(args, stream);
+            break;
+        case  40:
+            launch_mul_mat_q<type,  40, 8>(args, stream);
+            break;
+        case  48:
+            launch_mul_mat_q<type,  48, 8>(args, stream);
+            break;
+        case  56:
+            launch_mul_mat_q<type,  56, 8>(args, stream);
+            break;
+        case  64:
+            launch_mul_mat_q<type,  64, 8>(args, stream);
+            break;
+        case  72:
+            launch_mul_mat_q<type,  72, 8>(args, stream);
+            break;
+        case  80:
+            launch_mul_mat_q<type,  80, 8>(args, stream);
+            break;
+        case  88:
+            launch_mul_mat_q<type,  88, 8>(args, stream);
+            break;
+        case  96:
+            launch_mul_mat_q<type,  96, 8>(args, stream);
+            break;
+        case 104:
+            launch_mul_mat_q<type, 104, 8>(args, stream);
+            break;
+        case 112:
+            launch_mul_mat_q<type, 112, 8>(args, stream);
+            break;
+        case 120:
+            launch_mul_mat_q<type, 120, 8>(args, stream);
+            break;
+        case 128:
+            launch_mul_mat_q<type, 128, 8>(args, stream);
+            break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+}
+
+#define DECL_MMQ_CASE(type)                                                        \
+    template void mul_mat_q_case<type>(const mmq_args & args, cudaStream_t stream) \
+
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
+extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
+
+// -------------------------------------------------------------------------------------------------------------------------
 
 void ggml_cuda_op_mul_mat_q(
     ggml_backend_cuda_context & ctx,

+ 76 - 61
llama/ggml-cuda/mmvq.cu

@@ -1,9 +1,47 @@
 #include "mmvq.cuh"
 #include "vecdotq.cuh"
 
-typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
+typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
+
+static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
+    return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
+        type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
+        type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
+        type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
+        type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
+        type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
+        type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
+        type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
+        type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
+        type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
+        type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
+        type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
+        type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
+        type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
+        type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
+        type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
+        type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
+        type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
+        type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
+        nullptr;
+}
+
+static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
+    return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
+        type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
+        type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
+        type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
+        type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
+        type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ4_NL ? VDR_Q4_K_Q8_1_MMVQ :
+        1;
+}
 
-template <int ncols_y, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
+template <ggml_type type, int ncols_y>
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 // tell the compiler to use as many registers as it wants, see nwarps definition below
 __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
@@ -12,6 +50,12 @@ static __global__ void mul_mat_vec_q(
     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
 
+    constexpr int qk  = ggml_cuda_type_traits<type>::qk;
+    constexpr int qi  = ggml_cuda_type_traits<type>::qi;
+    constexpr int vdr = get_vdr_mmvq(type);
+
+    constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
+
 #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
     constexpr int nwarps              = 1;
     constexpr int rows_per_cuda_block = 1;
@@ -29,7 +73,6 @@ static __global__ void mul_mat_vec_q(
 // partial sum for each thread
     float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
 
-    const block_q_t  * x = (const block_q_t  *) vx;
     const block_q8_1 * y = (const block_q8_1 *) vy;
 
     for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
@@ -42,8 +85,7 @@ static __global__ void mul_mat_vec_q(
         for (int j = 0; j < ncols_y; ++j) {
 #pragma unroll
             for (int i = 0; i < rows_per_cuda_block; ++i) {
-                tmp[j][i] += vec_dot_q_cuda(
-                    &x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
+                tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
             }
         }
     }
@@ -81,12 +123,12 @@ static __global__ void mul_mat_vec_q(
     }
 }
 
-template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot>
+template <ggml_type type>
 static void mul_mat_vec_q_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    GGML_ASSERT(ncols_x % qk == 0);
+    GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
     GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
 
     int id = ggml_cuda_get_device();
@@ -124,36 +166,28 @@ static void mul_mat_vec_q_cuda(
 
     switch (ncols_y) {
         case 1:
-            mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         case 2:
-            mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         case 3:
-            mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         case 4:
-            mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         case 5:
-            mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         case 6:
-            mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         case 7:
-            mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         case 8:
-            mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
-                <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
             break;
         default:
             GGML_ASSERT(false);
@@ -165,152 +199,133 @@ static void mul_mat_vec_q4_0_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q4_1_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q5_0_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q5_1_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q8_0_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q2_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q3_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q4_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q5_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_q6_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq2_xxs_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq2_xs_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq2_s_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq3_xxs_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq1_s_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq1_m_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq4_nl_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq4_xs_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 static void mul_mat_vec_iq3_s_q8_1_cuda(
     const void * vx, const void * vy, float * dst,
     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
 
-    mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
-        (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
 }
 
 void ggml_cuda_op_mul_mat_vec_q(

+ 6 - 0
llama/ggml-cuda/norm.cu

@@ -170,6 +170,8 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
@@ -188,6 +190,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
@@ -202,6 +206,8 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 

+ 120 - 157
llama/ggml-cuda/rope.cu

@@ -1,7 +1,7 @@
 #include "rope.cuh"
 
 struct rope_corr_dims {
-    float v[4];
+    float v[2];
 };
 
 static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
@@ -13,8 +13,7 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
 // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
 static __device__ void rope_yarn(
     float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
-    float * cos_theta, float * sin_theta
-) {
+    float * cos_theta, float * sin_theta) {
     // Get n-d rotational scaling corrected for extrapolation
     float theta_interp = freq_scale * theta_extrap;
     float theta = theta_interp;
@@ -29,27 +28,38 @@ static __device__ void rope_yarn(
     *sin_theta = sinf(theta) * mscale;
 }
 
-// rope == RoPE == rotary positional embedding
-template<typename T, bool has_pos>
-static __global__ void rope(
-    const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims
-) {
-    const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+template<typename T, bool has_ff>
+static __global__ void rope_norm(
+    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
-    if (col >= ncols) {
+    if (i0 >= ne0) {
         return;
     }
 
     const int row = blockDim.x*blockIdx.x + threadIdx.x;
-    const int i = row*ncols + col;
+
+    if (i0 >= n_dims) {
+        const int i = row*ne0 + i0;
+
+        dst[i + 0] = x[i + 0];
+        dst[i + 1] = x[i + 1];
+
+        return;
+    }
+
+    const int i  = row*ne0 + i0;
     const int i2 = row/p_delta_rows;
 
-    const int p = has_pos ? pos[i2] : 0;
-    const float theta_base = p*powf(freq_base, -float(col)/ncols);
+    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+
+    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+    float cos_theta;
+    float sin_theta;
 
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
     const float x0 = x[i + 0];
     const float x1 = x[i + 1];
@@ -58,23 +68,20 @@ static __global__ void rope(
     dst[i + 1] = x0*sin_theta + x1*cos_theta;
 }
 
-template<typename T, bool has_pos>
+template<typename T, bool has_ff>
 static __global__ void rope_neox(
-    const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
-) {
-    const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
-    if (col >= ncols) {
+    if (i0 >= ne0) {
         return;
     }
 
     const int row = blockDim.x*blockIdx.x + threadIdx.x;
-    const int ib = col / n_dims;
-    const int ic = col % n_dims;
 
-    if (ib > 0) {
-        const int i = row*ncols + ib*n_dims + ic;
+    if (i0 >= n_dims) {
+        const int i = row*ne0 + i0;
 
         dst[i + 0] = x[i + 0];
         dst[i + 1] = x[i + 1];
@@ -82,16 +89,17 @@ static __global__ void rope_neox(
         return;
     }
 
-    const int i  = row*ncols + ib*n_dims + ic/2;
+    const int i  = row*ne0 + i0/2;
     const int i2 = row/p_delta_rows;
 
-    float cur_rot = inv_ndims * ic - ib;
+    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+
+    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
-    const int p = has_pos ? pos[i2] : 0;
-    const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
+    float cos_theta;
+    float sin_theta;
 
-    float cos_theta, sin_theta;
-    rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
     const float x0 = x[i + 0];
     const float x1 = x[i + n_dims/2];
@@ -100,158 +108,117 @@ static __global__ void rope_neox(
     dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
-static __global__ void rope_glm_f32(
-    const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
-    int n_ctx
-) {
-    const int col = blockDim.x*blockIdx.x + threadIdx.x;
-    const int half_n_dims = ncols/4;
-
-    if (col >= half_n_dims) {
-        return;
-    }
-
-    const int row = blockDim.y*blockIdx.y + threadIdx.y;
-    const int i = row*ncols + col;
-    const int i2 = row/p_delta_rows;
-
-    const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
-     // FIXME: this is likely wrong
-    const int p = pos != nullptr ? pos[i2] : 0;
-
-    const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
-    const float sin_theta = sinf(theta);
-    const float cos_theta = cosf(theta);
-
-    const float x0 = x[i + 0];
-    const float x1 = x[i + half_n_dims];
-
-    dst[i + 0]           = x0*cos_theta - x1*sin_theta;
-    dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
-
-    const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
-    const float sin_block_theta = sinf(block_theta);
-    const float cos_block_theta = cosf(block_theta);
-
-    const float x2 = x[i + half_n_dims * 2];
-    const float x3 = x[i + half_n_dims * 3];
-
-    dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
-    dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
-}
-
-
 template<typename T>
-static void rope_cuda(
-    const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
-) {
-    GGML_ASSERT(ncols % 2 == 0);
+static void rope_norm_cuda(
+    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+    GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
-    const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
-    const dim3 block_nums(nrows, num_blocks_x, 1);
-    if (pos == nullptr) {
-        rope<T, false><<<block_nums, block_dims, 0, stream>>>(
-            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
-        );
+    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(nr, n_blocks_x, 1);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    if (freq_factors == nullptr) {
+        rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors
+                );
     } else {
-        rope<T, true><<<block_nums, block_dims, 0, stream>>>(
-            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
-        );
+        rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors
+                );
     }
 }
 
 template<typename T>
 static void rope_neox_cuda(
-    const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
-) {
-    GGML_ASSERT(ncols % 2 == 0);
+    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+    GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
-    const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
-    const dim3 block_nums(nrows, num_blocks_x, 1);
+    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(nr, n_blocks_x, 1);
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
-    const float inv_ndims = -1.0f / n_dims;
 
-    if (pos == nullptr) {
+    if (freq_factors == nullptr) {
         rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
-            x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-            theta_scale, inv_ndims
-        );
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors
+                );
     } else {
         rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
-            x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-            theta_scale, inv_ndims
-        );
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors
+                );
     }
 }
 
-static void rope_glm_f32_cuda(
-    const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, int n_ctx, cudaStream_t stream
-) {
-    GGML_ASSERT(ncols % 4 == 0);
-    const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
-    const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
-    const dim3 block_nums(num_blocks_x, nrows, 1);
-    rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
-}
-
-static void rope_cuda_f16(
-    const half * x, half * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
+static void rope_norm_cuda_f16(
+    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
 
-    rope_cuda<half>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
+    rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 }
 
-static void rope_cuda_f32(
-    const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
+static void rope_norm_cuda_f32(
+    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
 
-    rope_cuda<float>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
+    rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 }
 
 static void rope_neox_cuda_f16(
-    const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
+    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
 
-    rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
+    rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 }
 
 static void rope_neox_cuda_f32(
-    const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
+    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
 ) {
 
-    rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
+    rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 }
 
 void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
+
     const float * src0_d = (const float *)src0->data;
     const float * src1_d = (const float *)src1->data;
+
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
+    GGML_ASSERT(ggml_is_contiguous(src0));
     GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
     GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
     GGML_ASSERT(src0->type == dst->type);
 
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    const int64_t nrows = ggml_nrows(src0);
+    const int64_t nr = ggml_nrows(src0);
 
-    //const int n_past      = ((int32_t *) dst->op_params)[0];
-    const int n_dims      = ((int32_t *) dst->op_params)[1];
-    const int mode        = ((int32_t *) dst->op_params)[2];
-    const int n_ctx       = ((int32_t *) dst->op_params)[3];
-    const int n_orig_ctx  = ((int32_t *) dst->op_params)[4];
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
 
     // RoPE alteration for extended context
-    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    float freq_base;
+    float freq_scale;
+    float ext_factor;
+    float attn_factor;
+    float beta_fast;
+    float beta_slow;
+
     memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
     memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
     memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
@@ -259,47 +226,43 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
-    const int32_t * pos = nullptr;
-    if ((mode & 1) == 0) {
-        GGML_ASSERT(src1->type == GGML_TYPE_I32);
-        GGML_ASSERT(src1->ne[0] == ne2);
-        pos = (const int32_t *) src1_d;
-    }
-
     const bool is_neox = mode & 2;
-    const bool is_glm  = mode & 4;
+
+    const int32_t * pos = (const int32_t *) src1_d;
+
+    const float * freq_factors = nullptr;
+    if (src2 != nullptr) {
+        freq_factors = (const float *) src2->data;
+    }
 
     rope_corr_dims corr_dims;
-    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
 
     // compute
-    if (is_glm) {
-        GGML_ASSERT(false);
-        rope_glm_f32_cuda(src0_d, dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, stream);
-    } else if (is_neox) {
+    if (is_neox) {
         if (src0->type == GGML_TYPE_F32) {
             rope_neox_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, stream
+                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, stream
             );
         } else if (src0->type == GGML_TYPE_F16) {
             rope_neox_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, stream
+                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, stream
             );
         } else {
             GGML_ASSERT(false);
         }
     } else {
         if (src0->type == GGML_TYPE_F32) {
-            rope_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, stream
+            rope_norm_cuda_f32(
+                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, stream
             );
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, stream
+            rope_norm_cuda_f16(
+                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, stream
             );
         } else {
             GGML_ASSERT(false);

+ 43 - 169
llama/ggml-cuda/vecdotq.cuh

@@ -180,8 +180,8 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
 #define VDR_Q8_0_Q8_1_MMVQ 2
 #define VDR_Q8_0_Q8_1_MMQ 8
 
-template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl(
-    const int * v, const int * u, const float & d8_0, const float & d8_1) {
+template <typename T, int vdr> static __device__ __forceinline__ T vec_dot_q8_0_q8_1_impl(
+    const int * v, const int * u, const T & d8_0, const T & d8_1) {
 
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     int sumi = 0;
@@ -192,7 +192,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp
         sumi = __dp4a(v[i], u[i], sumi);
     }
 
-    return d8_0*d8_1 * sumi;
+    return d8_0*d8_1 * ((T) sumi);
 #else
     NO_DEVICE_CODE;
 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -566,9 +566,9 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
 }
 
 static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
+    const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq + kbx;
 
     int v[VDR_Q4_0_Q8_1_MMVQ];
     int u[2*VDR_Q4_0_Q8_1_MMVQ];
@@ -585,9 +585,9 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
 
 
 static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
+    const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq + kbx;
 
     int v[VDR_Q4_1_Q8_1_MMVQ];
     int u[2*VDR_Q4_1_Q8_1_MMVQ];
@@ -603,9 +603,9 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
+    const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq + kbx;
 
     int vl[VDR_Q5_0_Q8_1_MMVQ];
     int vh[VDR_Q5_0_Q8_1_MMVQ];
@@ -623,9 +623,9 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
+    const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq + kbx;
 
     int vl[VDR_Q5_1_Q8_1_MMVQ];
     int vh[VDR_Q5_1_Q8_1_MMVQ];
@@ -643,9 +643,9 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
+    const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq + kbx;
 
     int v[VDR_Q8_0_Q8_1_MMVQ];
     int u[VDR_Q8_0_Q8_1_MMVQ];
@@ -656,13 +656,13 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
         u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
     }
 
-    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
+    return vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
 }
 
 static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q2_K * bq2_K = (const block_q2_K *) vbq;
+    const block_q2_K * bq2_K = (const block_q2_K *) vbq + kbx;
 
     const int bq8_offset = QR2_K * (iqs / QI8_1);
     const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
@@ -683,9 +683,9 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q3_K * bq3_K = (const block_q3_K *) vbq;
+    const block_q3_K * bq3_K = (const block_q3_K *) vbq + kbx;
 
     const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
     const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
@@ -710,10 +710,9 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-#ifndef GGML_QKK_64
-    const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+    const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx;
 
     int    v[2];
     int    u[2*QR4_K];
@@ -754,59 +753,12 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
     }
 
     return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
-
-#else
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_q4_K * bq4_K = (const block_q4_K *) vbq;
-
-    float sumf_d = 0.0f;
-    float sumf_m = 0.0f;
-
-    uint16_t aux16[2];
-    const uint8_t * s = (const uint8_t *)aux16;
-
-    const uint16_t * a = (const uint16_t *)bq4_K->scales;
-    aux16[0] = a[0] & 0x0f0f;
-    aux16[1] = (a[0] >> 4) & 0x0f0f;
-
-    const float dall = bq4_K->dm[0];
-    const float dmin = bq4_K->dm[1];
-
-    const float d8_1 = __low2float(bq8_1[0].ds);
-    const float d8_2 = __low2float(bq8_1[1].ds);
-
-    const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
-    const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
-    const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
-    const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
-
-    const int * q4 = (const int *)bq4_K->qs + (iqs/2);
-    const int v1 = q4[0];
-    const int v2 = q4[4];
-
-    const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0));
-    const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
-    const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
-    const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0));
-
-    sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
-    sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
-
-    return dall * sumf_d - dmin * sumf_m;
-
-#else
-    NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
-
-#endif
 }
 
 static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-#ifndef GGML_QKK_64
-    const block_q5_K * bq5_K = (const block_q5_K *) vbq;
+    const block_q5_K * bq5_K = (const block_q5_K *) vbq + kbx;
 
     int   vl[2];
     int   vh[2];
@@ -847,54 +799,12 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
     }
 
     return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
-
-#else
-
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-    const block_q5_K * bq5_K = (const block_q5_K *) vbq;
-
-    const int8_t * s = bq5_K->scales;
-
-    const float d = bq5_K->d;
-
-    const float d8_1 = __low2half(bq8_1[0].ds);
-    const float d8_2 = __low2half(bq8_1[1].ds);
-
-    const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
-    const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
-    const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
-    const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
-
-    const int * ql = (const int *)bq5_K->qs + (iqs/2);
-    const int vl1 = ql[0];
-    const int vl2 = ql[4];
-
-    const int step = 4 * (iqs/2); // 0, 4, 8, 12
-    const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6
-    const int in = step%8; // 0, 4, 0, 4
-    const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
-
-    const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
-    const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
-    const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
-    const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
-
-    const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1])
-                       + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]);
-
-    return d * sumf_d;
-
-#else
-    NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
-
-#endif
 }
 
 static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_q6_K * bq6_K = (const block_q6_K *) vbq;
+    const block_q6_K * bq6_K = (const block_q6_K *) vbq + kbx;
 
     const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
     const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
@@ -918,9 +828,8 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
-#if QK_K == 256
-    const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+    const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq + kbx;
 
 #if QR2_XXS == 8
     const int ib32 = iqs;
@@ -960,16 +869,12 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
     }
     return d * (sumi1 + sumi2);
 #endif
-#else
-    NO_DEVICE_CODE;
-#endif
 }
 
 static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-#if QK_K == 256
-    const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
+    const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq + kbx;
 
     const int ib32 = iqs;
     const uint16_t * q2 = bq2->qs + 4*ib32;
@@ -1002,18 +907,13 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
     GGML_UNUSED(ksigns64);
     NO_DEVICE_CODE;
 #endif
-#else
-    GGML_UNUSED(ksigns64);
-    NO_DEVICE_CODE;
-#endif
 }
 
 // TODO
 static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-#if QK_K == 256
-    const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
+    const block_iq2_s * bq2 = (const block_iq2_s *) vbq + kbx;
 
     const int ib32 = iqs;
     const int8_t  * q8 = bq8_1[ib32].qs;
@@ -1048,17 +948,12 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
     GGML_UNUSED(ksigns64);
     NO_DEVICE_CODE;
 #endif
-#else
-    GGML_UNUSED(ksigns64);
-    NO_DEVICE_CODE;
-#endif
 }
 
 static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-#if QK_K == 256
-    const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;
+    const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq + kbx;
 
     const int ib32 = iqs;
     const uint8_t  * q3 = bq2->qs + 8*ib32;
@@ -1082,17 +977,13 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
 #else
     NO_DEVICE_CODE;
 #endif
-#else
-    NO_DEVICE_CODE;
-#endif
 }
 
 // TODO: don't use lookup table for signs
 static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-#if QK_K == 256
-    const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
+    const block_iq3_s * bq2 = (const block_iq3_s *) vbq + kbx;
 
     const int ib32 = iqs;
     const uint8_t  * qs = bq2->qs + 8*ib32;
@@ -1114,15 +1005,11 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
 #else
     NO_DEVICE_CODE;
 #endif
-#else
-    NO_DEVICE_CODE;
-#endif
 }
 
 static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
-#if QK_K == 256
-    const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+    const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx;
 
     const int ib32 = iqs;
     int sumi = 0;
@@ -1149,15 +1036,11 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
     const float d = d1q * __low2float (bq8_1[ib32].ds);
     const float m = d1q * __high2float(bq8_1[ib32].ds);
     return d * sumi + m * delta;
-#else
-    NO_DEVICE_CODE;
-#endif
 }
 
 static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
-#if QK_K == 256
-    const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+    const block_iq1_m * bq1 = (const block_iq1_m *) vbq + kbx;
 
     const int ib32 = iqs;
     int   sumi[2] = {0, 0};
@@ -1192,9 +1075,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
     scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
     const float d = (float)scale.f16 * __low2float (bq8_1[ib32].ds);
     return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1));
-#else
-    NO_DEVICE_CODE;
-#endif
 }
 
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
@@ -1214,9 +1094,9 @@ static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4
 #endif
 
 static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-    const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
+    const block_iq4_nl * bq = (const block_iq4_nl *) vbq + kbx;
 
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
@@ -1248,12 +1128,10 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
 }
 
 static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
-    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
-#if QK_K == 256
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
-
-    const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
+    const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq + kbx;
     const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
 
     // iqs is 0...7
@@ -1270,11 +1148,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
         sumi2 = __dp4a(v2, q8[j+4], sumi2);
     }
     return d * (sumi1 + sumi2);
-
-#else
-    NO_DEVICE_CODE;
-#endif
 #else
-    return vec_dot_iq4_xs_q8_1(vbq, bq8_1, iqs);
+    return vec_dot_iq4_xs_q8_1(vbq, bq8_1, kbx, iqs);
 #endif
 }

+ 45 - 1
llama/ggml-impl.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -43,6 +43,18 @@
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 
+#if defined(_WIN32)
+
+#define m512bh(p) p
+#define m512i(p) p
+
+#else
+
+#define m512bh(p) (__m512bh)(p)
+#define m512i(p) (__m512i)(p)
+
+#endif
+
 /**
  * Converts brain16 to float32.
  *
@@ -158,6 +170,10 @@ extern "C" {
 #endif
 #endif
 
+#if defined(__ARM_FEATURE_SVE)
+#include <arm_sve.h>
+#endif
+
 // 16-bit float
 // on Arm, we use __fp16
 // on x86, we use uint16_t
@@ -469,6 +485,34 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
 #include <riscv_vector.h>
 #endif
 
+#if defined(__loongarch64)
+#if defined(__loongarch_asx)
+#include <lasxintrin.h>
+#endif
+#if defined(__loongarch_sx)
+#include <lsxintrin.h>
+#endif
+#endif
+
+#if defined(__loongarch_asx)
+
+typedef union {
+    int32_t i;
+    float f;
+} ft_union;
+
+/* float type data load instructions */
+static __m128 __lsx_vreplfr2vr_s(float val) {
+    ft_union fi_tmpval = {.f = val};
+    return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
+}
+
+static __m256 __lasx_xvreplfr2vr_s(float val) {
+    ft_union fi_tmpval = {.f = val};
+    return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
+}
+#endif
+
 #ifdef __F16C__
 
 #ifdef _MSC_VER

+ 148 - 90
llama/ggml-metal-darwin_arm64.m

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -61,6 +61,10 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_ROW,
     GGML_METAL_KERNEL_TYPE_DIV,
     GGML_METAL_KERNEL_TYPE_DIV_ROW,
+    GGML_METAL_KERNEL_TYPE_REPEAT_F32,
+    GGML_METAL_KERNEL_TYPE_REPEAT_F16,
+    GGML_METAL_KERNEL_TYPE_REPEAT_I32,
+    GGML_METAL_KERNEL_TYPE_REPEAT_I16,
     GGML_METAL_KERNEL_TYPE_SCALE,
     GGML_METAL_KERNEL_TYPE_SCALE_4,
     GGML_METAL_KERNEL_TYPE_CLAMP,
@@ -194,8 +198,10 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
-    GGML_METAL_KERNEL_TYPE_ROPE_F32,
-    GGML_METAL_KERNEL_TYPE_ROPE_F16,
+    GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
+    GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
+    GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
+    GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
     GGML_METAL_KERNEL_TYPE_IM2COL_F16,
     GGML_METAL_KERNEL_TYPE_IM2COL_F32,
     GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -210,9 +216,9 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
+  //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,     // https://github.com/ggerganov/llama.cpp/issues/7261
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
+  //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
     GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
     GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -407,10 +413,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
                 // dictionary of preprocessor macros
                 NSMutableDictionary * prep = [NSMutableDictionary dictionary];
 
-#ifdef GGML_QKK_64
-                prep[@"GGML_QKK_64"] = @(1);
-#endif
-
                 MTLCompileOptions* options = [MTLCompileOptions new];
                 options.preprocessorMacros = prep;
 
@@ -515,6 +517,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                       mul_row,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                           div,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                       div_row,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32,                    repeat_f32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16,                    repeat_f16,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32,                    repeat_i32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16,                    repeat_i16,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,                         scale,                          true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                       scale_4,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,                         clamp,                          true);
@@ -648,8 +654,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,           mul_mm_id_iq1_m_f32,            ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,          mul_mm_id_iq4_nl_f32,           ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,          mul_mm_id_iq4_xs_f32,           ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32,                      rope_f32,                       true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16,                      rope_f16,                       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,                 rope_norm_f32,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,                 rope_norm_f16,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,                 rope_neox_f32,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,                 rope_neox_f16,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                    im2col_f16,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
@@ -664,9 +672,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        ctx->support_simdgroup_mm);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    ctx->support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,                  cpy_f32_q8_0,                   true);
@@ -776,6 +784,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
         case GGML_OP_ACC:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
+        case GGML_OP_REPEAT:
         case GGML_OP_SCALE:
         case GGML_OP_CLAMP:
         case GGML_OP_SQR:
@@ -800,6 +809,15 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
         case GGML_OP_LEAKY_RELU:
             return true;
         case GGML_OP_FLASH_ATTN_EXT:
+            if (op->src[1]->type != GGML_TYPE_F16) {
+                return false;
+            }
+            if (op->src[2]->type != GGML_TYPE_F16) {
+                return false;
+            }
+            if (op->src[0]->ne[0] == 256) {
+                return false;
+            }
             return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
@@ -953,22 +971,32 @@ static enum ggml_status ggml_metal_graph_compute(
             const int64_t  ne10 = src1 ? src1->ne[0] : 0;
             const int64_t  ne11 = src1 ? src1->ne[1] : 0;
             const int64_t  ne12 = src1 ? src1->ne[2] : 0;
-            const int64_t  ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
+            const int64_t  ne13 = src1 ? src1->ne[3] : 0;
 
             const uint64_t nb10 = src1 ? src1->nb[0] : 0;
             const uint64_t nb11 = src1 ? src1->nb[1] : 0;
             const uint64_t nb12 = src1 ? src1->nb[2] : 0;
-            const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
+            const uint64_t nb13 = src1 ? src1->nb[3] : 0;
+
+            const int64_t  ne20 = src2 ? src2->ne[0] : 0;
+            const int64_t  ne21 = src2 ? src2->ne[1] : 0;
+            const int64_t  ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
+            const int64_t  ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
+
+            const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
+            const uint64_t nb21 = src2 ? src2->nb[1] : 0;
+            const uint64_t nb22 = src2 ? src2->nb[2] : 0;
+            const uint64_t nb23 = src2 ? src2->nb[3] : 0;
 
-            const int64_t  ne0  = dst ? dst->ne[0] : 0;
-            const int64_t  ne1  = dst ? dst->ne[1] : 0;
-            const int64_t  ne2  = dst ? dst->ne[2] : 0;
-            const int64_t  ne3  = dst ? dst->ne[3] : 0;
+            const int64_t  ne0  =  dst ?  dst->ne[0] : 0;
+            const int64_t  ne1  =  dst ?  dst->ne[1] : 0;
+            const int64_t  ne2  =  dst ?  dst->ne[2] : 0;
+            const int64_t  ne3  =  dst ?  dst->ne[3] : 0;
 
-            const uint64_t nb0  = dst ? dst->nb[0] : 0;
-            const uint64_t nb1  = dst ? dst->nb[1] : 0;
-            const uint64_t nb2  = dst ? dst->nb[2] : 0;
-            const uint64_t nb3  = dst ? dst->nb[3] : 0;
+            const uint64_t nb0  =  dst ?  dst->nb[0] : 0;
+            const uint64_t nb1  =  dst ?  dst->nb[1] : 0;
+            const uint64_t nb2  =  dst ?  dst->nb[2] : 0;
+            const uint64_t nb3  =  dst ?  dst->nb[3] : 0;
 
             const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
             const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
@@ -996,10 +1024,10 @@ static enum ggml_status ggml_metal_graph_compute(
             switch (dst->op) {
                 case GGML_OP_CONCAT:
                     {
-                        const int64_t nb = ne00;
-
                         id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
 
+                        const int32_t dim = ((int32_t *) dst->op_params)[0];
+
                         [encoder setComputePipelineState:pipeline];
                         [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                         [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1028,7 +1056,7 @@ static enum ggml_status ggml_metal_graph_compute(
                         [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
                         [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
                         [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
-                        [encoder setBytes:&nb   length:sizeof(nb)   atIndex:27];
+                        [encoder setBytes:&dim  length:sizeof(dim)  atIndex:27];
 
                         const int nth = MIN(1024, ne0);
 
@@ -1038,11 +1066,14 @@ static enum ggml_status ggml_metal_graph_compute(
                 case GGML_OP_MUL:
                 case GGML_OP_DIV:
                     {
+                        GGML_ASSERT(src0t == GGML_TYPE_F32);
+                        GGML_ASSERT(src1t == GGML_TYPE_F32);
+
                         const size_t offs = 0;
 
                         bool bcast_row = false;
 
-                        int64_t nb = ne00;
+                        int64_t nb = ne00; // used by the "row" kernels
 
                         id<MTLComputePipelineState> pipeline = nil;
 
@@ -1111,6 +1142,42 @@ static enum ggml_status ggml_metal_graph_compute(
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         }
                     } break;
+                case GGML_OP_REPEAT:
+                    {
+                        id<MTLComputePipelineState> pipeline;
+
+                        switch (src0t) {
+                            case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
+                            case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
+                            case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
+                            case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
+                            default: GGML_ASSERT(false);
+                        }
+
+                        [encoder setComputePipelineState:pipeline];
+                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                        [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                        [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                        [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                        [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                        [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                        [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                        [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                        [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                        [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                        [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                        [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                        [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                        [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                        [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                        [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                        [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+
+                        const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+
+                        [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                    } break;
                 case GGML_OP_ACC:
                     {
                         GGML_ASSERT(src0t == GGML_TYPE_F32);
@@ -1488,7 +1555,6 @@ static enum ggml_status ggml_metal_graph_compute(
                     {
                         GGML_ASSERT(ne00 == ne10);
 
-                        // TODO: assert that dim2 and dim3 are contiguous
                         GGML_ASSERT(ne12 % ne02 == 0);
                         GGML_ASSERT(ne13 % ne03 == 0);
 
@@ -1499,11 +1565,9 @@ static enum ggml_status ggml_metal_graph_compute(
                         // to the matrix-vector kernel
                         int ne11_mm_min = 1;
 
-
                         // the numbers below are measured on M2 Ultra for 7B and 13B models
                         // these numbers do not translate to other devices or model sizes
                         // TODO: need to find a better approach
-                        // if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
                         switch (src0t) {
                             case GGML_TYPE_F16:  ne11_mm_min = 2;  break;
                             case GGML_TYPE_Q8_0: ne11_mm_min = 7;  break;
@@ -1518,8 +1582,6 @@ static enum ggml_status ggml_metal_graph_compute(
                             case GGML_TYPE_Q6_K: ne11_mm_min = 7;  break;
                             default:             ne11_mm_min = 1;  break;
                         }
-                        // }
-
 
                         // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                         // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@@ -1789,11 +1851,7 @@ static enum ggml_status ggml_metal_graph_compute(
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             else if (src0t == GGML_TYPE_Q3_K) {
-#ifdef GGML_QKK_64
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-#else
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-#endif
                             }
                             else if (src0t == GGML_TYPE_Q5_K) {
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -1811,16 +1869,6 @@ static enum ggml_status ggml_metal_graph_compute(
                         const int n_as = src0->ne[2];
 
                         // src2 = ids
-                        const int64_t  ne20 = src2->ne[0];
-                        const int64_t  ne21 = src2->ne[1];
-                        const int64_t  ne22 = src2->ne[2]; GGML_UNUSED(ne22);
-                        const int64_t  ne23 = src2->ne[3]; GGML_UNUSED(ne23);
-
-                        const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
-                        const uint64_t nb21 = src2->nb[1];
-                        const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
-                        const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
-
                         const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
 
                         GGML_ASSERT(src2t == GGML_TYPE_I32);
@@ -2044,12 +2092,7 @@ static enum ggml_status ggml_metal_graph_compute(
                                     {
                                         nth0 = 4;
                                         nth1 = 16;
-                                    #if QK_K == 64
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
-                                    #else
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
-                                    #endif
-
                                     } break;
                                 default:
                                     {
@@ -2114,11 +2157,7 @@ static enum ggml_status ggml_metal_graph_compute(
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             else if (src0t == GGML_TYPE_Q3_K) {
-#ifdef GGML_QKK_64
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-#else
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-#endif
                             }
                             else if (src0t == GGML_TYPE_Q5_K) {
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -2179,6 +2218,7 @@ static enum ggml_status ggml_metal_graph_compute(
                 case GGML_OP_RMS_NORM:
                     {
                         GGML_ASSERT(ne00 % 4 == 0);
+                        GGML_ASSERT(ggml_is_contiguous_1(src0));
 
                         float eps;
                         memcpy(&eps, dst->op_params, sizeof(float));
@@ -2206,6 +2246,7 @@ static enum ggml_status ggml_metal_graph_compute(
                 case GGML_OP_GROUP_NORM:
                     {
                         GGML_ASSERT(ne00 % 4 == 0);
+                        GGML_ASSERT(ggml_is_contiguous(src0));
 
                         //float eps;
                         //memcpy(&eps, dst->op_params, sizeof(float));
@@ -2239,6 +2280,8 @@ static enum ggml_status ggml_metal_graph_compute(
                     } break;
                 case GGML_OP_NORM:
                     {
+                        GGML_ASSERT(ggml_is_contiguous_1(src0));
+
                         float eps;
                         memcpy(&eps, dst->op_params, sizeof(float));
 
@@ -2268,9 +2311,15 @@ static enum ggml_status ggml_metal_graph_compute(
                         const int n_dims     = ((int32_t *) dst->op_params)[1];
                         const int mode       = ((int32_t *) dst->op_params)[2];
                         // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
-                        const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+                        const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
+                        float freq_base;
+                        float freq_scale;
+                        float ext_factor;
+                        float attn_factor;
+                        float beta_fast;
+                        float beta_slow;
 
-                        float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
                         memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
                         memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
                         memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
@@ -2278,38 +2327,52 @@ static enum ggml_status ggml_metal_graph_compute(
                         memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
                         memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
+                        const bool is_neox = mode & 2;
+
                         id<MTLComputePipelineState> pipeline = nil;
 
-                        switch (src0->type) {
-                            case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
-                            case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
-                            default: GGML_ASSERT(false);
-                        };
+                        if (!is_neox) {
+                            switch (src0->type) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
+                                default: GGML_ASSERT(false);
+                            };
+                        } else {
+                            switch (src0->type) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
+                                default: GGML_ASSERT(false);
+                            };
+                        }
 
                         [encoder setComputePipelineState:pipeline];
                         [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
                         [encoder setBuffer:id_src1     offset:offs_src1        atIndex:1];
-                        [encoder setBuffer:id_dst      offset:offs_dst         atIndex:2];
-                        [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:3];
-                        [encoder setBytes:&ne01        length:sizeof( int64_t) atIndex:4];
-                        [encoder setBytes:&ne02        length:sizeof( int64_t) atIndex:5];
-                        [encoder setBytes:&ne03        length:sizeof( int64_t) atIndex:6];
-                        [encoder setBytes:&nb00        length:sizeof(uint64_t) atIndex:7];
-                        [encoder setBytes:&nb01        length:sizeof(uint64_t) atIndex:8];
-                        [encoder setBytes:&nb02        length:sizeof(uint64_t) atIndex:9];
-                        [encoder setBytes:&nb03        length:sizeof(uint64_t) atIndex:10];
-                        [encoder setBytes:&ne0         length:sizeof( int64_t) atIndex:11];
-                        [encoder setBytes:&ne1         length:sizeof( int64_t) atIndex:12];
-                        [encoder setBytes:&ne2         length:sizeof( int64_t) atIndex:13];
-                        [encoder setBytes:&ne3         length:sizeof( int64_t) atIndex:14];
-                        [encoder setBytes:&nb0         length:sizeof(uint64_t) atIndex:15];
-                        [encoder setBytes:&nb1         length:sizeof(uint64_t) atIndex:16];
-                        [encoder setBytes:&nb2         length:sizeof(uint64_t) atIndex:17];
-                        [encoder setBytes:&nb3         length:sizeof(uint64_t) atIndex:18];
-                        [encoder setBytes:&n_past      length:sizeof(     int) atIndex:19];
-                        [encoder setBytes:&n_dims      length:sizeof(     int) atIndex:20];
-                        [encoder setBytes:&mode        length:sizeof(     int) atIndex:21];
-                        [encoder setBytes:&n_orig_ctx  length:sizeof(     int) atIndex:22];
+                        if (id_src2 != nil) {
+                            [encoder setBuffer:id_src2 offset:offs_src2        atIndex:2];
+                        } else {
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:2];
+                        }
+                        [encoder setBuffer:id_dst      offset:offs_dst         atIndex:3];
+                        [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:4];
+                        [encoder setBytes:&ne01        length:sizeof( int64_t) atIndex:5];
+                        [encoder setBytes:&ne02        length:sizeof( int64_t) atIndex:6];
+                        [encoder setBytes:&ne03        length:sizeof( int64_t) atIndex:7];
+                        [encoder setBytes:&nb00        length:sizeof(uint64_t) atIndex:8];
+                        [encoder setBytes:&nb01        length:sizeof(uint64_t) atIndex:9];
+                        [encoder setBytes:&nb02        length:sizeof(uint64_t) atIndex:10];
+                        [encoder setBytes:&nb03        length:sizeof(uint64_t) atIndex:11];
+                        [encoder setBytes:&ne0         length:sizeof( int64_t) atIndex:12];
+                        [encoder setBytes:&ne1         length:sizeof( int64_t) atIndex:13];
+                        [encoder setBytes:&ne2         length:sizeof( int64_t) atIndex:14];
+                        [encoder setBytes:&ne3         length:sizeof( int64_t) atIndex:15];
+                        [encoder setBytes:&nb0         length:sizeof(uint64_t) atIndex:16];
+                        [encoder setBytes:&nb1         length:sizeof(uint64_t) atIndex:17];
+                        [encoder setBytes:&nb2         length:sizeof(uint64_t) atIndex:18];
+                        [encoder setBytes:&nb3         length:sizeof(uint64_t) atIndex:19];
+                        [encoder setBytes:&n_past      length:sizeof(     int) atIndex:20];
+                        [encoder setBytes:&n_dims      length:sizeof(     int) atIndex:21];
+                        [encoder setBytes:&n_ctx_orig  length:sizeof(     int) atIndex:22];
                         [encoder setBytes:&freq_base   length:sizeof(   float) atIndex:23];
                         [encoder setBytes:&freq_scale  length:sizeof(   float) atIndex:24];
                         [encoder setBytes:&ext_factor  length:sizeof(   float) atIndex:25];
@@ -2561,11 +2624,6 @@ static enum ggml_status ggml_metal_graph_compute(
                         GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
                                 "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
 
-                        const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
-                        const uint64_t nb21 = src2 ? src2->nb[1] : 0;
-                        const uint64_t nb22 = src2 ? src2->nb[2] : 0;
-                        const uint64_t nb23 = src2 ? src2->nb[3] : 0;
-
                         const int64_t  ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
                       //const int64_t  ne31 = src3 ? src3->ne[1] : 0;
                         const int64_t  ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
@@ -2601,7 +2659,7 @@ static enum ggml_status ggml_metal_graph_compute(
                                 case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
                                 case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
                                 case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
-                                case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
+                              //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
                                 default:
                                           {
                                               GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -2614,7 +2672,7 @@ static enum ggml_status ggml_metal_graph_compute(
 
                             switch (ne00) {
                                 case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
-                                case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+                              //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
                                 default:
                                           {
                                               GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);

+ 2 - 2
llama/ggml-metal.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -27,7 +27,7 @@
 // An interface allowing to compute ggml_cgraph with Metal
 //
 // This is a fully functional interface that extends ggml with GPU support for Apple devices.
-// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.)
+// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.)
 //
 // How it works?
 //

File diff suppressed because it is too large
+ 168 - 473
llama/ggml-metal.metal


BIN
llama/ggml-metal.o


File diff suppressed because it is too large
+ 541 - 258
llama/ggml-quants.c


+ 1 - 1
llama/ggml-quants.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *

File diff suppressed because it is too large
+ 371 - 229
llama/ggml.c


+ 54 - 46
llama/ggml.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -507,9 +507,7 @@ extern "C" {
         GGML_OP_ARGSORT,
         GGML_OP_LEAKY_RELU,
 
-        GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_ATTN_EXT,
-        GGML_OP_FLASH_FF,
         GGML_OP_FLASH_ATTN_BACK,
         GGML_OP_SSM_CONV,
         GGML_OP_SSM_SCAN,
@@ -784,7 +782,6 @@ extern "C" {
     GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
 
     GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
-    GGML_API GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor);
     GGML_API GGML_CALL bool ggml_is_permuted  (const struct ggml_tensor * tensor);
     GGML_API GGML_CALL bool ggml_is_empty     (const struct ggml_tensor * tensor);
     GGML_API           bool ggml_is_scalar    (const struct ggml_tensor * tensor);
@@ -793,6 +790,11 @@ extern "C" {
     GGML_API           bool ggml_is_3d        (const struct ggml_tensor * tensor);
     GGML_API           int  ggml_n_dims       (const struct ggml_tensor * tensor); // returns 1 for scalars
 
+    GGML_API GGML_CALL bool ggml_is_contiguous  (const struct ggml_tensor * tensor);
+    GGML_API GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()
+    GGML_API GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
+    GGML_API GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
+
     GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
     GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
 
@@ -1035,12 +1037,13 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
-    // concat a and b on dim 2
+    // concat a and b along dim
     // used in stable-diffusion
     GGML_API struct ggml_tensor * ggml_concat(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
+            struct ggml_tensor  * b,
+            int                   dim);
 
     GGML_API struct ggml_tensor * ggml_abs(
             struct ggml_context * ctx,
@@ -1486,18 +1489,17 @@ extern "C" {
             struct ggml_tensor  * b);
 
     // rotary position embedding
-    // if mode & 1 == 1, skip n_past elements (DEPRECATED)
+    // if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
     // if mode & 2 == 1, GPT-NeoX style
-    // if mode & 4 == 1, ChatGLM style
     //
     // b is an int32 vector with size a->ne[2], it contains the positions
+    // c is freq factors (e.g. phi3-128k), (optional)
     GGML_API struct ggml_tensor * ggml_rope(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
             int                   n_dims,
-            int                   mode,
-            int                   n_ctx);
+            int                   mode);
 
     // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_rope_inplace(
@@ -1505,18 +1507,17 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
             int                   n_dims,
-            int                   mode,
-            int                   n_ctx);
+            int                   mode);
 
     // custom RoPE
-    GGML_API struct ggml_tensor * ggml_rope_custom(
+    GGML_API struct ggml_tensor * ggml_rope_ext(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
             int                   n_dims,
             int                   mode,
-            int                   n_ctx,
-            int                   n_orig_ctx,
+            int                   n_ctx_orig,
             float                 freq_base,
             float                 freq_scale,
             float                 ext_factor,
@@ -1525,14 +1526,14 @@ extern "C" {
             float                 beta_slow);
 
     // in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
+    GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
             int                   n_dims,
             int                   mode,
-            int                   n_ctx,
-            int                   n_orig_ctx,
+            int                   n_ctx_orig,
             float                 freq_base,
             float                 freq_scale,
             float                 ext_factor,
@@ -1540,18 +1541,39 @@ extern "C" {
             float                 beta_fast,
             float                 beta_slow);
 
-    // compute correction dims for YaRN RoPE scaling
-    GGML_CALL void ggml_rope_yarn_corr_dims(
-        int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   n_dims,
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow),
+        "use ggml_rope_ext instead");
 
-    // xPos RoPE, in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
             int                   n_dims,
-            float                 base,
-            bool                  down);
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow),
+        "use ggml_rope_ext_inplace instead");
+
+    // compute correction dims for YaRN RoPE scaling
+    GGML_CALL void ggml_rope_yarn_corr_dims(
+        int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
 
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
@@ -1559,18 +1581,16 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
             int                   n_dims,
             int                   mode,
-            int                   n_ctx,
-            int                   n_orig_ctx,
+            int                   n_ctx_orig,
             float                 freq_base,
             float                 freq_scale,
             float                 ext_factor,
             float                 attn_factor,
             float                 beta_fast,
-            float                 beta_slow,
-            float                 xpos_base,
-            bool                  xpos_down);
+            float                 beta_slow);
 
     // clamp
     // in-place, returns view(a)
@@ -1760,13 +1780,6 @@ extern "C" {
             struct ggml_tensor  * a,
             int                   k);
 
-    GGML_API struct ggml_tensor * ggml_flash_attn(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * q,
-            struct ggml_tensor  * k,
-            struct ggml_tensor  * v,
-            bool                  masked);
-
 #define GGML_KQ_MASK_PAD 32
 
     // q:    [n_embd, n_batch,     n_head,    1]
@@ -1787,6 +1800,7 @@ extern "C" {
             struct ggml_tensor * a,
             enum ggml_prec       prec);
 
+    // TODO: needs to be adapted to ggml_flash_attn_ext
     GGML_API struct ggml_tensor * ggml_flash_attn_back(
            struct ggml_context * ctx,
            struct ggml_tensor  * q,
@@ -1795,14 +1809,6 @@ extern "C" {
            struct ggml_tensor  * d,
            bool                  masked);
 
-    GGML_API struct ggml_tensor * ggml_flash_ff(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b0,
-            struct ggml_tensor  * b1,
-            struct ggml_tensor  * c0,
-            struct ggml_tensor  * c1);
-
     GGML_API struct ggml_tensor * ggml_ssm_conv(
             struct ggml_context * ctx,
             struct ggml_tensor  * s,
@@ -2416,8 +2422,10 @@ extern "C" {
     GGML_API int ggml_cpu_has_avx512     (void);
     GGML_API int ggml_cpu_has_avx512_vbmi(void);
     GGML_API int ggml_cpu_has_avx512_vnni(void);
+    GGML_API int ggml_cpu_has_avx512_bf16(void);
     GGML_API int ggml_cpu_has_fma        (void);
     GGML_API int ggml_cpu_has_neon       (void);
+    GGML_API int ggml_cpu_has_sve        (void);
     GGML_API int ggml_cpu_has_arm_fma    (void);
     GGML_API int ggml_cpu_has_metal      (void);
     GGML_API int ggml_cpu_has_f16c       (void);
@@ -2425,13 +2433,13 @@ extern "C" {
     GGML_API int ggml_cpu_has_wasm_simd  (void);
     GGML_API int ggml_cpu_has_blas       (void);
     GGML_API int ggml_cpu_has_cuda       (void);
-    GGML_API int ggml_cpu_has_clblast    (void);
     GGML_API int ggml_cpu_has_vulkan     (void);
     GGML_API int ggml_cpu_has_kompute    (void);
     GGML_API int ggml_cpu_has_gpublas    (void);
     GGML_API int ggml_cpu_has_sse3       (void);
     GGML_API int ggml_cpu_has_ssse3      (void);
     GGML_API int ggml_cpu_has_sycl       (void);
+    GGML_API int ggml_cpu_has_rpc        (void);
     GGML_API int ggml_cpu_has_vsx        (void);
     GGML_API int ggml_cpu_has_matmul_int8(void);
 

+ 119 - 32
llama/grammar-parser.cpp

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -72,8 +72,12 @@ namespace grammar_parser {
         state.rules[rule_id] = rule;
     }
 
+    static bool is_digit_char(char c) {
+        return '0' <= c && c <= '9';
+    }
+
     static bool is_word_char(char c) {
-        return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
+        return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
     }
 
     static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
@@ -125,6 +129,17 @@ namespace grammar_parser {
         return pos;
     }
 
+    static const char * parse_int(const char * src) {
+        const char * pos = src;
+        while (is_digit_char(*pos)) {
+            pos++;
+        }
+        if (pos == src) {
+            throw std::runtime_error(std::string("expecting integer at ") + src);
+        }
+        return pos;
+    }
+
     static std::pair<uint32_t, const char *> parse_char(const char * src) {
         if (*src == '\\') {
             switch (src[1]) {
@@ -163,6 +178,60 @@ namespace grammar_parser {
             bool                                 is_nested) {
         size_t last_sym_start = out_elements.size();
         const char * pos = src;
+
+        auto handle_repetitions = [&](int min_times, int max_times) {
+
+            if (last_sym_start == out_elements.size()) {
+                throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
+            }
+
+            // apply transformation to previous symbol (last_sym_start to end) according to
+            // the following rewrite rules:
+            // S{m,n} --> S S S (m times) S'(n-m)
+            //            S'(x)   ::= S S'(x-1) |
+            //            (... n-m definitions of these S' rules ...)
+            //            S'(1)   ::= S |
+            // S{m,} -->  S S S (m times) S'
+            //            S'     ::= S S' |
+            // S*     --> S{0,}
+            //        --> S'     ::= S S' |
+            // S+     --> S{1,}
+            //        --> S S'
+            //            S'     ::= S S' |
+            // S?     --> S{0,1}
+            //        --> S'
+            //            S'     ::= S |
+
+            std::vector<llama_grammar_element> previous_elements(out_elements.begin() + last_sym_start, out_elements.end());
+            if (min_times == 0) {
+                out_elements.resize(last_sym_start);
+            } else {
+                // Repeat the previous elements (min_times - 1) times
+                for (int i = 1; i < min_times; i++) {
+                    out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end());
+                }
+            }
+
+            uint32_t last_rec_rule_id = 0;
+            auto n_opt = max_times < 0 ? 1 : max_times - min_times;
+
+            std::vector<llama_grammar_element> rec_rule(previous_elements);
+            for (int i = 0; i < n_opt; i++) {
+                rec_rule.resize(previous_elements.size());
+                uint32_t rec_rule_id = generate_symbol_id(state, rule_name);
+                if (i > 0 || max_times < 0) {
+                    rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
+                }
+                rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
+                rec_rule.push_back({LLAMA_GRETYPE_END, 0});
+                add_rule(state, rec_rule_id, rec_rule);
+                last_rec_rule_id = rec_rule_id;
+            }
+            if (n_opt > 0) {
+                out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
+            }
+        };
+
         while (*pos) {
             if (*pos == '"') { // literal string
                 pos++;
@@ -223,40 +292,51 @@ namespace grammar_parser {
                     throw std::runtime_error(std::string("expecting ')' at ") + pos);
                 }
                 pos = parse_space(pos + 1, is_nested);
-            } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
-                if (last_sym_start == out_elements.size()) {
-                    throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos);
-                }
+            } else if (*pos == '.') { // any char
+                last_sym_start = out_elements.size();
+                out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
+                pos = parse_space(pos + 1, is_nested);
+            } else if (*pos == '*') {
+                pos = parse_space(pos + 1, is_nested);
+                handle_repetitions(0, -1);
+            } else if (*pos == '+') {
+                pos = parse_space(pos + 1, is_nested);
+                handle_repetitions(1, -1);
+            } else if (*pos == '?') {
+                pos = parse_space(pos + 1, is_nested);
+                handle_repetitions(0, 1);
+            } else if (*pos == '{') {
+                pos = parse_space(pos + 1, is_nested);
 
-                // apply transformation to previous symbol (last_sym_start to end) according to
-                // rewrite rules:
-                // S* --> S' ::= S S' |
-                // S+ --> S' ::= S S' | S
-                // S? --> S' ::= S |
-                uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
-                std::vector<llama_grammar_element> sub_rule;
-                // add preceding symbol to generated rule
-                sub_rule.insert(
-                    sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
-                if (*pos == '*' || *pos == '+') {
-                    // cause generated rule to recurse
-                    sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
+                if (!is_digit_char(*pos)) {
+                    throw std::runtime_error(std::string("expecting an int at ") + pos);
                 }
-                // mark start of alternate def
-                sub_rule.push_back({LLAMA_GRETYPE_ALT, 0});
-                if (*pos == '+') {
-                    // add preceding symbol as alternate only for '+' (otherwise empty)
-                    sub_rule.insert(
-                        sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
-                }
-                sub_rule.push_back({LLAMA_GRETYPE_END, 0});
-                add_rule(state, sub_rule_id, sub_rule);
+                const char * int_end = parse_int(pos);
+                int min_times = std::stoul(std::string(pos, int_end - pos));
+                pos = parse_space(int_end, is_nested);
 
-                // in original rule, replace previous symbol with reference to generated rule
-                out_elements.resize(last_sym_start);
-                out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
+                int max_times = -1;
 
-                pos = parse_space(pos + 1, is_nested);
+                if (*pos == '}') {
+                    max_times = min_times;
+                    pos = parse_space(pos + 1, is_nested);
+                } else if (*pos == ',') {
+                    pos = parse_space(pos + 1, is_nested);
+
+                    if (is_digit_char(*pos)) {
+                        const char * int_end = parse_int(pos);
+                        max_times = std::stoul(std::string(pos, int_end - pos));
+                        pos = parse_space(int_end, is_nested);
+                    }
+
+                    if (*pos != '}') {
+                        throw std::runtime_error(std::string("expecting '}' at ") + pos);
+                    }
+                    pos = parse_space(pos + 1, is_nested);
+                } else {
+                    throw std::runtime_error(std::string("expecting ',' at ") + pos);
+                }
+                handle_repetitions(min_times, max_times);
             } else {
                 break;
             }
@@ -351,6 +431,7 @@ namespace grammar_parser {
             case LLAMA_GRETYPE_CHAR_NOT:       return true;
             case LLAMA_GRETYPE_CHAR_ALT:       return true;
             case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
+            case LLAMA_GRETYPE_CHAR_ANY:       return true;
             default:                           return false;
         }
     }
@@ -365,6 +446,7 @@ namespace grammar_parser {
                 case LLAMA_GRETYPE_CHAR_NOT:       fprintf(file, "CHAR_NOT");       break;
                 case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
                 case LLAMA_GRETYPE_CHAR_ALT:       fprintf(file, "CHAR_ALT");       break;
+                case LLAMA_GRETYPE_CHAR_ANY:       fprintf(file, "CHAR_ANY");       break;
             }
             switch (elem.type) {
                 case LLAMA_GRETYPE_END:
@@ -376,6 +458,7 @@ namespace grammar_parser {
                 case LLAMA_GRETYPE_CHAR_NOT:
                 case LLAMA_GRETYPE_CHAR_RNG_UPPER:
                 case LLAMA_GRETYPE_CHAR_ALT:
+                case LLAMA_GRETYPE_CHAR_ANY:
                     fprintf(file, "(\"");
                     print_grammar_char(file, elem.value);
                     fprintf(file, "\") ");
@@ -433,11 +516,15 @@ namespace grammar_parser {
                     }
                     print_grammar_char(file, elem.value);
                     break;
+                case LLAMA_GRETYPE_CHAR_ANY:
+                    fprintf(file, ".");
+                    break;
             }
             if (is_char_element(elem)) {
                 switch (rule[i + 1].type) {
                     case LLAMA_GRETYPE_CHAR_ALT:
                     case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+                    case LLAMA_GRETYPE_CHAR_ANY:
                         break;
                     default:
                         fprintf(file, "] ");

+ 1 - 1
llama/grammar-parser.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *

+ 21 - 59
llama/json-schema-to-grammar.cpp

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -42,58 +42,27 @@ static std::string join(Iterator begin, Iterator end, const std::string & separa
 
 static std::string repeat(const std::string & str, size_t n);
 
-static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "", bool item_rule_is_literal = false) {
-    if (separator_rule.empty()) {
-        if (min_items == 0 && max_items == 1) {
-            return item_rule + "?";
-        } else if (min_items == 1 && max_items == std::numeric_limits<int>::max()) {
-            return item_rule + "+";
-        }
-    }
+static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
+    auto has_max = max_items != std::numeric_limits<int>::max();
 
-    std::string result;
-    if (min_items > 0) {
-        if (item_rule_is_literal && separator_rule.empty()) {
-            result = "\"" + repeat(std::string(item_rule.begin() + 1, item_rule.end() - 1), min_items) + "\"";
-        } else {
-            std::vector<std::string> items(min_items, item_rule);
-            result = join(items.begin(), items.end(), separator_rule.empty() ? " " : " " + separator_rule + " ");
-        }
+    if (min_items == 0 && max_items == 1) {
+        return item_rule + "?";
     }
 
-    std::function<std::string(int, bool)> opt_repetitions = [&](int up_to_n, bool prefix_with_sep) -> std::string {
-        auto content = prefix_with_sep && !separator_rule.empty() ? separator_rule + " " + item_rule : item_rule;
-
-        if (up_to_n == 0) {
-            return "";
-        } else if (up_to_n == 1) {
-            return "(" + content + ")?";
-        } else if (!separator_rule.empty() && !prefix_with_sep) {
-            return "(" + content + " " + opt_repetitions(up_to_n - 1, true) + ")?";
+    if (separator_rule.empty()) {
+        if (min_items == 1 && !has_max) {
+            return item_rule + "+";
+        } else if (min_items == 0 && !has_max) {
+            return item_rule + "*";
         } else {
-            std::string res = repeat("(" + content + " ", up_to_n);
-            // strip trailing space
-            res = res.substr(0, res.length() - 1);
-            res += repeat(")?", up_to_n);
-            return res;
+            return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}";
         }
-    };
-
-    if (min_items > 0 && max_items != min_items) {
-        result += " ";
     }
 
-    if (max_items != std::numeric_limits<int>::max()) {
-        result += opt_repetitions(max_items - min_items, min_items > 0);
-    } else {
-        std::string item_operator = "(" + (separator_rule.empty() ? "" : separator_rule + " ") + item_rule + ")";
-        if (min_items == 0 && !separator_rule.empty()) {
-            result = "(" + item_rule + " " + item_operator + "*)?";
-        } else {
-            result += item_operator + "*";
-        }
+    auto result = item_rule + " " + build_repetition("(" + separator_rule + " " + item_rule + ")", min_items == 0 ? 0 : min_items - 1, has_max ? max_items - 1 : max_items);
+    if (min_items == 0) {
+        result = "(" + result + ")?";
     }
-
     return result;
 }
 
@@ -104,30 +73,24 @@ struct BuiltinRule {
     std::vector<std::string> deps;
 };
 
-const std::string _up_to_15_digits = build_repetition("[0-9]", 0, 15);
-
 std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
     {"boolean", {"(\"true\" | \"false\") space", {}}},
-    {"decimal-part", {"[0-9] " + _up_to_15_digits, {}}},
-    {"integral-part", {"[0-9] | [1-9] " + _up_to_15_digits, {}}},
+    {"decimal-part", {"[0-9]{1,16}", {}}},
+    {"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}},
     {"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}},
     {"integer", {"(\"-\"? integral-part) space", {"integral-part"}}},
     {"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}},
     {"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}},
     {"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}},
-    {"uuid", {"\"\\\"\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
-                "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
-                "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
-                "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
-                "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] \"\\\"\" space", {}}},
-    {"char",   {"[^\"\\\\] | \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])", {}}},
+    {"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}},
+    {"char",   {"[^\"\\\\] | \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
     {"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}},
     {"null", {"\"null\" space", {}}},
 };
 
 std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
-    {"date", {"[0-9] [0-9] [0-9] [0-9] \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
-    {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9] [0-9] [0-9] )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
+    {"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
+    {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
     {"date-time", {"date \"T\" time", {"date", "time"}}},
     {"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}},
     {"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}},
@@ -411,8 +374,7 @@ private:
                         sub_is_literal ? "\"" + sub + "\"" : sub,
                         min_times,
                         max_times,
-                        "",
-                        sub_is_literal
+                        ""
                     );
                     seq.back().second = false;
                 } else {

+ 1 - 1
llama/json-schema-to-grammar.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *

File diff suppressed because it is too large
+ 393 - 163
llama/llama.cpp


+ 42 - 52
llama/llama.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -107,9 +107,11 @@ extern "C" {
         LLAMA_VOCAB_PRE_TYPE_GPT2           = 7,
         LLAMA_VOCAB_PRE_TYPE_REFACT         = 8,
         LLAMA_VOCAB_PRE_TYPE_COMMAND_R      = 9,
-        LLAMA_VOCAB_PRE_TYPE_QWEN2          = 10,
-        LLAMA_VOCAB_PRE_TYPE_OLMO           = 11,
-        LLAMA_VOCAB_PRE_TYPE_DBRX           = 12,
+        LLAMA_VOCAB_PRE_TYPE_STABLELM2      = 10,
+        LLAMA_VOCAB_PRE_TYPE_QWEN2          = 11,
+        LLAMA_VOCAB_PRE_TYPE_OLMO           = 12,
+        LLAMA_VOCAB_PRE_TYPE_DBRX           = 13,
+        LLAMA_VOCAB_PRE_TYPE_SMAUG          = 14,
     };
 
     // note: these values should be synchronized with ggml_rope
@@ -121,7 +123,7 @@ extern "C" {
         LLAMA_ROPE_TYPE_GLM  =  4,
     };
 
-    enum llama_token_type {
+    enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file
         LLAMA_TOKEN_TYPE_UNDEFINED    = 0,
         LLAMA_TOKEN_TYPE_NORMAL       = 1,
         LLAMA_TOKEN_TYPE_UNKNOWN      = 2,
@@ -131,6 +133,20 @@ extern "C" {
         LLAMA_TOKEN_TYPE_BYTE         = 6,
     };
 
+    enum llama_token_attr {
+        LLAMA_TOKEN_ATTR_UNDEFINED    = 0,
+        LLAMA_TOKEN_ATTR_UNKNOWN      = 1 << 0,
+        LLAMA_TOKEN_ATTR_UNUSED       = 1 << 1,
+        LLAMA_TOKEN_ATTR_NORMAL       = 1 << 2,
+        LLAMA_TOKEN_ATTR_CONTROL      = 1 << 3,  // SPECIAL?
+        LLAMA_TOKEN_ATTR_USER_DEFINED = 1 << 4,
+        LLAMA_TOKEN_ATTR_BYTE         = 1 << 5,
+        LLAMA_TOKEN_ATTR_NORMALIZED   = 1 << 6,
+        LLAMA_TOKEN_ATTR_LSTRIP       = 1 << 7,
+        LLAMA_TOKEN_ATTR_RSTRIP       = 1 << 8,
+        LLAMA_TOKEN_ATTR_SINGLE_WORD  = 1 << 9,
+    };
+
     // model file types
     enum llama_ftype {
         LLAMA_FTYPE_ALL_F32              = 0,
@@ -289,6 +305,8 @@ extern "C" {
         bool check_tensors; // validate model tensor data
     };
 
+    // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
+    //       https://github.com/ggerganov/llama.cpp/pull/7544
     struct llama_context_params {
         uint32_t seed;              // RNG seed, -1 for random
         uint32_t n_ctx;             // text context, 0 = from model
@@ -315,14 +333,14 @@ extern "C" {
         ggml_backend_sched_eval_callback cb_eval;
         void * cb_eval_user_data;
 
-        enum ggml_type type_k; // data type for K cache
-        enum ggml_type type_v; // data type for V cache
+        enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
+        enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
 
         // Keep the booleans together to avoid misalignment during copy-by-value.
         bool logits_all;  // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
         bool embeddings;  // if true, extract embeddings (together with logits)
         bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
-        bool flash_attn;  // whether to use flash attention
+        bool flash_attn;  // whether to use flash attention [EXPERIMENTAL]
 
         // Abort callback
         // if it returns true, execution of llama_decode() will be aborted
@@ -373,6 +391,9 @@ extern "C" {
         // modifies a preceding LLAMA_GRETYPE_CHAR or
         // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
         LLAMA_GRETYPE_CHAR_ALT       = 6,
+
+        // any character (.)
+        LLAMA_GRETYPE_CHAR_ANY       = 7,
     };
 
     typedef struct llama_grammar_element {
@@ -446,8 +467,8 @@ extern "C" {
 
     LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
 
-    LLAMA_API enum llama_vocab_type   llama_vocab_type  (const struct llama_model   * model);
-    LLAMA_API enum llama_rope_type    llama_rope_type   (const struct llama_model   * model);
+    LLAMA_API enum llama_vocab_type   llama_vocab_type  (const struct llama_model * model);
+    LLAMA_API enum llama_rope_type    llama_rope_type   (const struct llama_model * model);
 
     LLAMA_API int32_t llama_n_vocab    (const struct llama_model * model);
     LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
@@ -784,6 +805,12 @@ extern "C" {
     // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
     LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
 
+    // Get the number of threads used for generation of a single token.
+    LLAMA_API uint32_t llama_n_threads(struct llama_context * ctx);
+
+    // Get the number of threads used for prompt and batch processing (multiple token).
+    LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
+
     // Set whether to use causal attention or not
     // If set to true, the model will only attend to the past tokens
     LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
@@ -837,11 +864,14 @@ extern "C" {
 
     LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
 
-    LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
+    LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token);
 
     // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
     LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
 
+    // Identify if Token Id is a control token or a render-able token
+    LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token);
+
     // Special tokens
     LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
     LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
@@ -1055,49 +1085,9 @@ extern "C" {
                      llama_token   token);
 
     //
-    // Beam search
+    // Model split
     //
 
-    struct llama_beam_view {
-        const llama_token * tokens;
-
-        size_t n_tokens;
-        float  p;        // Cumulative beam probability (renormalized relative to all beams)
-        bool   eob;      // Callback should set this to true when a beam is at end-of-beam.
-    };
-
-    // Passed to beam_search_callback function.
-    // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
-    // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
-    // These pointers are valid only during the synchronous callback, so should not be saved.
-    struct llama_beams_state {
-        struct llama_beam_view * beam_views;
-
-        size_t n_beams;               // Number of elements in beam_views[].
-        size_t common_prefix_length;  // Current max length of prefix tokens shared by all beams.
-        bool   last_call;             // True iff this is the last callback invocation.
-    };
-
-    // Type of pointer to the beam_search_callback function.
-    // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
-    // passed back to beam_search_callback. This avoids having to use global variables in the callback.
-    typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
-
-    /// @details Deterministically returns entire sentence constructed by a beam search.
-    /// @param ctx Pointer to the llama_context.
-    /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
-    /// @param callback_data A pointer that is simply passed back to callback.
-    /// @param n_beams Number of beams to use.
-    /// @param n_past Number of tokens already evaluated.
-    /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
-    LLAMA_API void llama_beam_search(
-                   struct llama_context * ctx,
-        llama_beam_search_callback_fn_t   callback,
-                                   void * callback_data,
-                                 size_t   n_beams,
-                                int32_t   n_past,
-                                int32_t   n_predict);
-
     /// @details Build a split GGUF final path for this chunk.
     ///          llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf"
     //  Returns the split_path length.

+ 1 - 1
llama/llava.cpp

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *

+ 1 - 1
llama/llava.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *

+ 1 - 1
llama/log.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *

+ 18 - 9
llama/patches/02-default-pretokenizer.diff → llama/patches/02-llamacpp.diff

@@ -1,8 +1,8 @@
-diff --git a/llama.cpp b/llama.cpp
-index 40d2ec2c..74f3ee9c 100644
---- a/llama.cpp
-+++ b/llama.cpp
-@@ -4642,16 +4642,7 @@ static void llm_load_vocab(
+diff --git a/llama/llama.cpp b/llama/llama.cpp
+index 8b675ea9..bcc6ae75 100644
+--- a/llama/llama.cpp
++++ b/llama/llama.cpp
+@@ -4645,16 +4645,7 @@ static void llm_load_vocab(
  
          // for now, only BPE models have pre-tokenizers
          if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
@@ -15,12 +15,12 @@ index 40d2ec2c..74f3ee9c 100644
 -                LLAMA_LOG_WARN("%s: ************************************        \n", __func__);
 -                LLAMA_LOG_WARN("%s:                                             \n", __func__);
 -                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
--            } else if (
-+            if (
-                     tokenizer_pre == "default") {
+-            } else if (tokenizer_pre == "default") {
++            if (tokenizer_pre == "default") {
                  vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
              } else if (
-@@ -4703,7 +4694,8 @@ static void llm_load_vocab(
+                     tokenizer_pre == "llama3"   ||
+@@ -4706,7 +4697,8 @@ static void llm_load_vocab(
                  tokenizer_pre == "smaug-bpe") {
                  vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG;
              } else {
@@ -30,3 +30,12 @@ index 40d2ec2c..74f3ee9c 100644
              }
          } else {
              vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+@@ -7009,7 +7001,7 @@ static struct ggml_tensor * llm_build_kqv(
+         struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
+         cb(kq, "kq", il);
+ 
+-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
++        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
+             // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
+             // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
+             ggml_mul_mat_set_prec(kq, GGML_PREC_F32);

+ 3 - 3
llama/patches/03-metal.diff

@@ -1,7 +1,7 @@
-diff --git a/ggml-metal.m b/ggml-metal.m
+diff --git a/llama/ggml-metal-darwin_arm64.m b/llama/ggml-metal-darwin_arm64.m
 index 0207b787..b5e9884b 100644
---- a/ggml-metal.m
-+++ b/ggml-metal.m
+--- a/llama/ggml-metal-darwin_arm64.m
++++ b/llama/ggml-metal-darwin_arm64.m
 @@ -1396,27 +1396,23 @@ static enum ggml_status ggml_metal_graph_compute(
                          // to the matrix-vector kernel
                          int ne11_mm_min = 1;

+ 0 - 13
llama/patches/04-qwen2.diff

@@ -1,13 +0,0 @@
-diff --git a/llama.cpp b/llama.cpp
-index 40d2ec2c..f34eb79a 100644
---- a/llama.cpp
-+++ b/llama.cpp
-@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv(
-         struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
-         cb(kq, "kq", il);
- 
--        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
-+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
-             // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
-             // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
-             ggml_mul_mat_set_prec(kq, GGML_PREC_F32);

+ 90 - 8
llama/sampling.cpp

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -151,7 +151,7 @@ std::string llama_sampling_order_print(const llama_sampling_params & params) {
     std::string result = "CFG -> Penalties ";
     if (params.mirostat == 0) {
         for (auto sampler_type : params.samplers_sequence) {
-            const auto sampler_type_name = sampler_type_to_name_string(sampler_type);
+            const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
             if (!sampler_type_name.empty()) {
                 result += "-> " + sampler_type_name + " ";
             }
@@ -163,6 +163,87 @@ std::string llama_sampling_order_print(const llama_sampling_params & params) {
     return result;
 }
 
+std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
+    switch (sampler_type) {
+        case llama_sampler_type::TOP_K:       return "top_k";
+        case llama_sampler_type::TFS_Z:       return "tfs_z";
+        case llama_sampler_type::TYPICAL_P:   return "typical_p";
+        case llama_sampler_type::TOP_P:       return "top_p";
+        case llama_sampler_type::MIN_P:       return "min_p";
+        case llama_sampler_type::TEMPERATURE: return "temperature";
+        default : return "";
+    }
+}
+
+std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
+    std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
+        {"top_k",       llama_sampler_type::TOP_K},
+        {"top_p",       llama_sampler_type::TOP_P},
+        {"typical_p",   llama_sampler_type::TYPICAL_P},
+        {"min_p",       llama_sampler_type::MIN_P},
+        {"tfs_z",       llama_sampler_type::TFS_Z},
+        {"temperature", llama_sampler_type::TEMPERATURE}
+    };
+
+    // since samplers names are written multiple ways
+    // make it ready for both system names and input names
+    std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
+        {"top-k",       llama_sampler_type::TOP_K},
+        {"top-p",       llama_sampler_type::TOP_P},
+        {"nucleus",     llama_sampler_type::TOP_P},
+        {"typical-p",   llama_sampler_type::TYPICAL_P},
+        {"typical",     llama_sampler_type::TYPICAL_P},
+        {"min-p",       llama_sampler_type::MIN_P},
+        {"tfs-z",       llama_sampler_type::TFS_Z},
+        {"tfs",         llama_sampler_type::TFS_Z},
+        {"temp",        llama_sampler_type::TEMPERATURE}
+    };
+
+    std::vector<llama_sampler_type> sampler_types;
+    sampler_types.reserve(names.size());
+    for (const auto & name : names)
+    {
+        auto sampler_item = sampler_canonical_name_map.find(name);
+        if (sampler_item != sampler_canonical_name_map.end())
+        {
+            sampler_types.push_back(sampler_item->second);
+        }
+        else
+        {
+            if (allow_alt_names)
+            {
+                sampler_item = sampler_alt_name_map.find(name);
+                if (sampler_item != sampler_alt_name_map.end())
+                {
+                    sampler_types.push_back(sampler_item->second);
+                }
+            }
+        }
+    }
+    return sampler_types;
+}
+
+std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) {
+    std::unordered_map<char, llama_sampler_type> sampler_name_map {
+        {'k', llama_sampler_type::TOP_K},
+        {'p', llama_sampler_type::TOP_P},
+        {'y', llama_sampler_type::TYPICAL_P},
+        {'m', llama_sampler_type::MIN_P},
+        {'f', llama_sampler_type::TFS_Z},
+        {'t', llama_sampler_type::TEMPERATURE}
+    };
+
+    std::vector<llama_sampler_type> sampler_types;
+    sampler_types.reserve(names_string.size());
+    for (const auto & c : names_string) {
+        const auto sampler_item = sampler_name_map.find(c);
+        if (sampler_item != sampler_name_map.end()) {
+            sampler_types.push_back(sampler_item->second);
+        }
+    }
+    return sampler_types;
+}
+
 // no reasons to expose this function in header
 static void sampler_queue(
                    struct llama_context * ctx_main,
@@ -205,7 +286,7 @@ static llama_token llama_sampling_sample_impl(
                   struct llama_context * ctx_main,
                   struct llama_context * ctx_cfg,
                   const int idx,
-                  bool is_resampling) {  // Add a parameter to indicate if we are resampling
+                  bool is_resampling) {
     const llama_sampling_params & params = ctx_sampling->params;
 
     const float   temp            = params.temp;
@@ -214,8 +295,8 @@ static llama_token llama_sampling_sample_impl(
     const float   mirostat_eta    = params.mirostat_eta;
 
     std::vector<float> original_logits;
-    auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
-    if (!is_resampling) {
+    auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
+    if (ctx_sampling->grammar != NULL && !is_resampling) {
         GGML_ASSERT(!original_logits.empty());
     }
     llama_token id = 0;
@@ -278,7 +359,7 @@ static llama_token llama_sampling_sample_impl(
             // Restore logits from the copy
             std::copy(original_logits.begin(), original_logits.end(), logits);
 
-            return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true);  // Pass true for is_resampling
+            return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true);
         }
     }
 
@@ -311,7 +392,8 @@ static llama_token_data_array llama_sampling_prepare_impl(
     // Get a pointer to the logits
     float * logits = llama_get_logits_ith(ctx_main, idx);
 
-    if (apply_grammar && original_logits != NULL) {
+    if (ctx_sampling->grammar != NULL && !apply_grammar) {
+        GGML_ASSERT(original_logits != NULL);
         // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
         *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
     }
@@ -368,7 +450,7 @@ llama_token llama_sampling_sample(
                   struct llama_context * ctx_cfg,
                   const int idx) {
     // Call the implementation function with is_resampling set to false by default
-    return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
+    return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false);
 }
 
 llama_token_data_array llama_sampling_prepare(

+ 6 - 1
llama/sampling.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *
@@ -142,6 +142,11 @@ std::string llama_sampling_print(const llama_sampling_params & params);
 // Print sampling order into a string
 std::string llama_sampling_order_print(const llama_sampling_params & params);
 
+std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);
+
+std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
+std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string);
+
 // this is a common sampling function used across the examples for convenience
 // it can serve as a starting point for implementing your own sampling function
 // Note: When using multiple sequences, it is the caller's responsibility to call

+ 260 - 0
llama/sampling_ext.cpp

@@ -1,3 +1,263 @@
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
 // TODO: this is a temporary wrapper to allow calling C++ code from CGo
 #include "sampling.h"
 #include "sampling_ext.h"

+ 260 - 0
llama/sampling_ext.h

@@ -1,3 +1,263 @@
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
 // TODO: this is a temporary wrapper to allow calling C++ code from CGo
 #ifndef LLAMA_SAMPLING_EXT_H
 #define LLAMA_SAMPLING_EXT_H

+ 1 - 1
llama/stb_image.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *

+ 1 - 1
llama/unicode-data.cpp

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *

+ 1 - 1
llama/unicode-data.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *

+ 1 - 1
llama/unicode.cpp

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *

+ 1 - 1
llama/unicode.h

@@ -1,5 +1,5 @@
 /**
- * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842
+ * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28
  *
  * MIT License
  *

+ 5 - 3
scripts/sync_llama.sh

@@ -1,5 +1,7 @@
 #!/bin/bash
 
+set -e
+
 # Set the source directory
 src_dir=$1
 
@@ -8,7 +10,7 @@ if [ -z "$src_dir" ]; then
   exit 1
 fi
 
-# Set the destination directory (current directory)
+# Set the destination directory
 dst_dir=./llama
 
 # llama.cpp
@@ -72,7 +74,7 @@ char const *LLAMA_BUILD_TARGET = "";
 EOF
 
 # apply patches
-for patch in $dst_dir/patches/*.patch; do
+for patch in $dst_dir/patches/*.diff; do
   git apply "$patch"
 done
 
@@ -112,6 +114,6 @@ echo "_ggml_metallib_start:"              >> $TEMP_ASSEMBLY
 echo ".incbin \"temp.metal\"" >> $TEMP_ASSEMBLY
 echo ".globl _ggml_metallib_end"          >> $TEMP_ASSEMBLY
 echo "_ggml_metallib_end:"                >> $TEMP_ASSEMBLY
-as -mmacosx-version-min=11.3 $TEMP_ASSEMBLY -o ggml-metal.o
+as -mmacosx-version-min=11.3 $TEMP_ASSEMBLY -o $dst_dir/ggml-metal.o
 rm -f $TEMP_ASSEMBLY
 rm -rf temp.metal

Some files were not shown because too many files changed in this diff