Bläddra i källkod

Re-introduce the `llama` package (#5034)

* Re-introduce the llama package

This PR brings back the llama package, making it possible to call llama.cpp and
ggml APIs from Go directly via CGo. This has a few advantages:

- C APIs can be called directly from Go without needing to use the previous
  "server" REST API
- On macOS and for CPU builds on Linux and Windows, Ollama can be built without
  a go generate ./... step, making it easy to get up and running to hack on
  parts of Ollama that don't require fast inference
- Faster build times for AVX,AVX2,CUDA and ROCM (a full build of all runners
  takes <5 min on a fast CPU)
- No git submodule making it easier to clone and build from source

This is a big PR, but much of it is vendor code except for:

- llama.go CGo bindings
- example/: a simple example of running inference
- runner/: a subprocess server designed to replace the llm/ext_server package
- Makefile an as minimal as possible Makefile to build the runner package for
  different targets (cpu, avx, avx2, cuda, rocm)

Co-authored-by: Jesse Gross <jesse@ollama.com>
Co-authored-by: Daniel Hiltgen <daniel@ollama.com>

* cache: Clear old KV cache entries when evicting a slot

When forking a cache entry, if no empty slots are available we
evict the least recently used one and copy over the KV entries
from the closest match. However, this copy does not overwrite
existing values but only adds new ones. Therefore, we need to
clear the old slot first.

This change fixes two issues:
 - The KV cache fills up and runs out of space even though we think
   we are managing it correctly
 - Performance gets worse over time as we use new cache entries that
   are not hot in the processor caches

* doc: explain golang objc linker warning (#6830)

* llama: gather transitive dependencies for rocm for dist packaging (#6848)

* Refine go server makefiles to be more DRY (#6924)

This breaks up the monolithic Makefile for the Go based runners into a
set of utility files as well as recursive Makefiles for the runners.
Files starting with the name "Makefile" are buildable, while files that
end with ".make" are utilities to include in other Makefiles.  This
reduces the amount of nearly identical targets and helps set a pattern
for future community contributions for new GPU runner architectures.

When we are ready to switch over to the Go runners, these files should
move to the top of the repo, and we should add targets for the main CLI,
as well as a helper "install" (put all the built binaries on the local
system in a runnable state) and "dist" target (generate the various
tar/zip files for distribution) for local developer use.

* llama: don't create extraneous directories (#6988)

* llama: Exercise the new build in CI (#6989)

Wire up some basic sanity testing in CI for the Go runner.  GPU runners are not covered yet.

* llama: Refine developer docs for Go server (#6842)

This enhances the documentation for development focusing on the new Go
server.  After we complete the transition further doc refinements
can remove the "transition" discussion.

* runner.go: Allocate batches for all sequences during init

We should tell the model that we could have full batches for all
sequences. We already do this when we allocate the batches but it was
missed during initialization.

* llama.go: Don't return nil from Tokenize on zero length input

Potentially receiving nil in a non-error condition is surprising to
most callers - it's better to return an empty slice.

* runner.go: Remove stop tokens from cache

If the last token is EOG then we don't return this and it isn't
present in the cache (because it was never submitted to Decode).
This works well for extending the cache entry with a new sequence.

However, for multi-token stop sequences, we won't return any of the
tokens but all but the last one will be in the cache. This means
when the conversation continues the cache will contain tokens that
don't overlap with the new prompt.

This works (we will pick up the portion where there is overlap) but
it causes unnecessary cache thrashing because we will fork the original
cache entry as it is not a perfect match.

By trimming the cache to the tokens that we actually return this
issue can be avoided.

* runner.go: Simplify flushing of pending tokens

* runner.go: Update TODOs

* runner.go: Don't panic when processing sequences

If there is an error processing a sequence, we should return a
clean HTTP error back to Ollama rather than panicing. This will
make us more resilient to transient failures.

Panics can still occur during startup as there is no way to serve
requests if that fails.

Co-authored-by: jmorganca <jmorganca@gmail.com>

* runner.go: More accurately capture timings

Currently prompt processing time doesn't capture the that it takes
to tokenize the input, only decoding time. We should capture the
full process to more accurately reflect reality. This is especially
true once we start processing images where the initial processing
can take significant time. This is also more consistent with the
existing C++ runner.

* runner.go: Support for vision models

In addition to bringing feature parity with the C++ runner, this also
incorporates several improvements:
 - Cache prompting works with images, avoiding the need to re-decode
   embeddings for every message in a conversation
 - Parallelism is supported, avoiding the need to restrict to one
   sequence at a time. (Though for now Ollama will not schedule
   them while we might need to fall back to the old runner.)

Co-authored-by: jmorganca <jmorganca@gmail.com>

* runner.go: Move Unicode checking code and add tests

* runner.go: Export external cache members

Runner and cache are in the same package so the change doesn't
affect anything but it is more internally consistent.

* runner.go: Image embedding cache

Generating embeddings from images can take significant time (on
my machine between 100ms and 8s depending on the model). Although
we already cache the result of decoding these images, the embeddings
need to be regenerated every time. This is not necessary if we get
the same image over and over again, for example, during a conversation.

This currently uses a very small cache with a very simple algorithm
but it is easy to improve as is warranted.

* llama: catch up on patches

Carry forward solar-pro and cli-unicode patches

* runner.go: Don't re-allocate memory for every batch

We can reuse memory allocated from batch to batch since batch
size is fixed. This both saves the cost of reallocation as well
keeps the cache lines hot.

This results in a roughly 1% performance improvement for token
generation with Nvidia GPUs on Linux.

* runner.go: Default to classic input cache policy

The input cache as part of the go runner implemented a cache
policy that aims to maximize hit rate in both single and multi-
user scenarios. When there is a cache hit, the response is
very fast.

However, performance is actually slower when there is an input
cache miss due to worse GPU VRAM locality. This means that
performance is generally better overall for multi-user scenarios
(better input cache hit rate, locality was relatively poor already).
But worse for single users (input cache hit rate is about the same,
locality is now worse).

This defaults the policy back to the old one to avoid a regression
but keeps the new one available through an environment variable
OLLAMA_MULTIUSER_CACHE. This is left undocumented as the goal is
to improve this in the future to get the best of both worlds
without user configuration.

For inputs that result in cache misses, on Nvidia/Linux this
change improves performance by 31% for prompt processing and
13% for token generation.

* runner.go: Increase size of response channel

Generally the CPU can easily keep up with handling reponses that
are generated but there's no reason not to let generation continue
and handle things in larger batches if needed.

* llama: Add CI to verify all vendored changes have patches (#7066)

Make sure we don't accidentally merge changes in the vendored code
that aren't also reflected in the patches.

* llama: adjust clip patch for mingw utf-16 (#7065)

* llama: adjust clip patch for mingw utf-16

* llama: ensure static linking of runtime libs

Avoid runtime dependencies on non-standard libraries

* runner.go: Enable llamafile (all platforms) and BLAS (Mac OS)

These are two features that are shown on llama.cpp's system info
that are currently different between the two runners. On my test
systems the performance difference is very small to negligible
but it is probably still good to equalize the features.

* llm: Don't add BOS/EOS for tokenize requests

This is consistent with what server.cpp currently does. It affects
things like token processing counts for embedding requests.

* runner.go: Don't cache prompts for embeddings

Our integration with server.cpp implicitly disables prompt caching
because it is not part of the JSON object being parsed, this makes
the Go runner behavior similarly.

Prompt caching has been seen to affect the results of text completions
on certain hardware. The results are not wrong either way but they
are non-deterministic. However, embeddings seem to be affected even
on hardware that does not show this behavior for completions. For
now, it is best to maintain consistency with the existing behavior.

* runner.go: Adjust debug log levels

Add system info printed at startup and quiet down noisier logging.

* llama: fix compiler flag differences (#7082)

Adjust the flags for the new Go server to more closely match the
generate flow

* llama: refine developer docs (#7121)

* llama: doc and example clean up (#7122)

* llama: doc and example clean up

* llama: Move new dockerfile into llama dir

Temporary home until we fully transition to the Go server

* llama: runner doc cleanup

* llama.go: Add description for Tokenize error case

---------

Co-authored-by: Jesse Gross <jesse@ollama.com>
Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
Co-authored-by: Daniel Hiltgen <dhiltgen@users.noreply.github.com>
Jeffrey Morgan 6 månader sedan
förälder
incheckning
96efd9052f
100 ändrade filer med 33610 tillägg och 14 borttagningar
  1. 2 0
      .gitattributes
  2. 54 0
      .github/workflows/test.yaml
  3. 0 1
      .gitignore
  4. 2 10
      Dockerfile
  5. 182 1
      docs/development.md
  6. 3 0
      envconfig/config.go
  7. 1 1
      integration/concurrency_test.go
  8. 1 1
      integration/utils_test.go
  9. 3 0
      llama/.gitignore
  10. 221 0
      llama/Dockerfile
  11. 54 0
      llama/Makefile
  12. 100 0
      llama/README.md
  13. 392 0
      llama/base64.hpp
  14. 30 0
      llama/build-info.cpp
  15. 2689 0
      llama/clip.cpp
  16. 120 0
      llama/clip.h
  17. 3688 0
      llama/common.cpp
  18. 514 0
      llama/common.h
  19. 2206 0
      llama/ggml-aarch64.c
  20. 65 0
      llama/ggml-aarch64.h
  21. 1062 0
      llama/ggml-alloc.c
  22. 102 0
      llama/ggml-alloc.h
  23. 179 0
      llama/ggml-backend-impl.h
  24. 2288 0
      llama/ggml-backend.c
  25. 266 0
      llama/ggml-backend.h
  26. 397 0
      llama/ggml-blas.cpp
  27. 49 0
      llama/ggml-blas.h
  28. 1859 0
      llama/ggml-common.h
  29. 3147 0
      llama/ggml-cuda.cu
  30. 75 0
      llama/ggml-cuda.h
  31. 73 0
      llama/ggml-cuda/acc.cu
  32. 31 0
      llama/ggml-cuda/acc.cuh
  33. 60 0
      llama/ggml-cuda/arange.cu
  34. 31 0
      llama/ggml-cuda/arange.cuh
  35. 130 0
      llama/ggml-cuda/argsort.cu
  36. 29 0
      llama/ggml-cuda/argsort.cuh
  37. 314 0
      llama/ggml-cuda/binbcast.cu
  38. 33 0
      llama/ggml-cuda/binbcast.cuh
  39. 60 0
      llama/ggml-cuda/clamp.cu
  40. 31 0
      llama/ggml-cuda/clamp.cuh
  41. 695 0
      llama/ggml-cuda/common.cuh
  42. 222 0
      llama/ggml-cuda/concat.cu
  43. 31 0
      llama/ggml-cuda/concat.cuh
  44. 113 0
      llama/ggml-cuda/conv-transpose-1d.cu
  45. 31 0
      llama/ggml-cuda/conv-transpose-1d.cuh
  46. 712 0
      llama/ggml-cuda/convert.cu
  47. 39 0
      llama/ggml-cuda/convert.cuh
  48. 515 0
      llama/ggml-cuda/cpy.cu
  49. 35 0
      llama/ggml-cuda/cpy.cuh
  50. 132 0
      llama/ggml-cuda/cross-entropy-loss.cu
  51. 31 0
      llama/ggml-cuda/cross-entropy-loss.cuh
  52. 129 0
      llama/ggml-cuda/dequantize.cuh
  53. 66 0
      llama/ggml-cuda/diagmask.cu
  54. 31 0
      llama/ggml-cuda/diagmask.cuh
  55. 709 0
      llama/ggml-cuda/dmmv.cu
  56. 46 0
      llama/ggml-cuda/dmmv.cuh
  57. 734 0
      llama/ggml-cuda/fattn-common.cuh
  58. 379 0
      llama/ggml-cuda/fattn-tile-f16.cu
  59. 29 0
      llama/ggml-cuda/fattn-tile-f16.cuh
  60. 371 0
      llama/ggml-cuda/fattn-tile-f32.cu
  61. 29 0
      llama/ggml-cuda/fattn-tile-f32.cuh
  62. 468 0
      llama/ggml-cuda/fattn-vec-f16.cuh
  63. 446 0
      llama/ggml-cuda/fattn-vec-f32.cuh
  64. 569 0
      llama/ggml-cuda/fattn-wmma-f16.cuh
  65. 371 0
      llama/ggml-cuda/fattn.cu
  66. 29 0
      llama/ggml-cuda/fattn.cuh
  67. 203 0
      llama/ggml-cuda/getrows.cu
  68. 31 0
      llama/ggml-cuda/getrows.cuh
  69. 130 0
      llama/ggml-cuda/im2col.cu
  70. 31 0
      llama/ggml-cuda/im2col.cuh
  71. 247 0
      llama/ggml-cuda/mma.cuh
  72. 176 0
      llama/ggml-cuda/mmq.cu
  73. 2962 0
      llama/ggml-cuda/mmq.cuh
  74. 451 0
      llama/ggml-cuda/mmvq.cu
  75. 35 0
      llama/ggml-cuda/mmvq.cuh
  76. 250 0
      llama/ggml-cuda/norm.cu
  77. 33 0
      llama/ggml-cuda/norm.cuh
  78. 75 0
      llama/ggml-cuda/pad.cu
  79. 31 0
      llama/ggml-cuda/pad.cuh
  80. 120 0
      llama/ggml-cuda/pool2d.cu
  81. 31 0
      llama/ggml-cuda/pool2d.cuh
  82. 195 0
      llama/ggml-cuda/quantize.cu
  83. 50 0
      llama/ggml-cuda/quantize.cuh
  84. 297 0
      llama/ggml-cuda/rope.cu
  85. 31 0
      llama/ggml-cuda/rope.cuh
  86. 57 0
      llama/ggml-cuda/scale.cu
  87. 31 0
      llama/ggml-cuda/scale.cuh
  88. 232 0
      llama/ggml-cuda/softmax.cu
  89. 31 0
      llama/ggml-cuda/softmax.cuh
  90. 65 0
      llama/ggml-cuda/sumrows.cu
  91. 31 0
      llama/ggml-cuda/sumrows.cuh
  92. 31 0
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu
  93. 31 0
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu
  94. 31 0
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu
  95. 31 0
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu
  96. 31 0
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu
  97. 31 0
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu
  98. 31 0
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu
  99. 31 0
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu
  100. 31 0
      llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu

+ 2 - 0
.gitattributes

@@ -1,3 +1,5 @@
 llm/ext_server/* linguist-vendored
+llama/**/*.{cpp,hpp,h,c,cu,cuh,m} linguist-vendored
+
 * text=auto
 *.go text eol=lf

+ 54 - 0
.github/workflows/test.yaml

@@ -24,6 +24,7 @@ jobs:
       GENERATE: ${{ steps.changes.outputs.GENERATE }}
       GENERATE_CUDA: ${{ steps.changes.outputs.GENERATE_CUDA }}
       GENERATE_ROCM: ${{ steps.changes.outputs.GENERATE_ROCM }}
+      RUNNERS: ${{ steps.changes.outputs.RUNNERS }}
     steps:
       - uses: actions/checkout@v4
         with:
@@ -41,6 +42,7 @@ jobs:
             echo GENERATE=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
             echo GENERATE_CUDA=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
             echo GENERATE_ROCM=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
+            echo RUNNERS=$(changed 'llama/**')
           } >>$GITHUB_OUTPUT
 
   generate:
@@ -213,6 +215,46 @@ jobs:
         env:
           OLLAMA_SKIP_CPU_GENERATE: '1'
 
+  runners:
+    needs: [changes]
+    if: ${{ needs.changes.outputs.RUNNERS == 'True' }}
+    strategy:
+      matrix:
+        os: [ubuntu-latest, macos-latest, windows-2019]
+        arch: [amd64, arm64]
+        exclude:
+          - os: ubuntu-latest
+            arch: arm64
+          - os: windows-2019
+            arch: arm64
+    runs-on: ${{ matrix.os }}
+    env:
+      GOARCH: ${{ matrix.arch }}
+      ARCH: ${{ matrix.arch }}
+      CGO_ENABLED: '1'
+    steps:
+      - uses: actions/checkout@v4
+      - uses: actions/setup-go@v5
+        with:
+          go-version-file: go.mod
+          cache: true
+      - run: go get ./...
+      - name: 'Build Windows Go Runners'
+        if: ${{ startsWith(matrix.os, 'windows-') }}
+        run: |
+          $gopath=(get-command go).source | split-path -parent
+          $gccpath=(get-command gcc).source | split-path -parent
+          & "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
+          cd $env:GITHUB_WORKSPACE
+          $env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
+          $env:PATH="$gopath;$gccpath;$env:PATH"
+          echo $env:PATH
+          make -C llama -j 4      
+      - name: 'Build Unix Go Runners'
+        if: ${{ ! startsWith(matrix.os, 'windows-') }}
+        run: make -C llama -j 4
+      - run: go build .
+
   lint:
     strategy:
       matrix:
@@ -280,3 +322,15 @@ jobs:
       - run: go generate ./...
       - run: go build
       - run: go test -v ./...
+
+  patches:
+    needs: [changes]
+    if: ${{ needs.changes.outputs.RUNNERS == 'True' }}
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v4
+        with:
+          submodules: recursive
+      - name: Verify patches carry all the changes
+        run: |
+          cd llama && ./sync.sh && git diff --compact-summary --exit-code .

+ 0 - 1
.gitignore

@@ -5,7 +5,6 @@
 .swp
 dist
 ollama
-ggml-metal.metal
 .cache
 *.exe
 .idea

+ 2 - 10
Dockerfile

@@ -110,9 +110,6 @@ ARG CGO_CFLAGS
 ENV GOARCH=amd64
 WORKDIR /go/src/github.com/ollama/ollama/llm/generate
 
-FROM --platform=linux/amd64 cpu-builder-amd64 AS static-build-amd64
-RUN --mount=type=cache,target=/root/.ccache \
-    OLLAMA_CPU_TARGET="static" bash gen_linux.sh
 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
 RUN --mount=type=cache,target=/root/.ccache \
     OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" bash gen_linux.sh
@@ -135,9 +132,6 @@ ARG CGO_CFLAGS
 ENV GOARCH=arm64
 WORKDIR /go/src/github.com/ollama/ollama/llm/generate
 
-FROM --platform=linux/arm64 cpu-builder-arm64 AS static-build-arm64
-RUN --mount=type=cache,target=/root/.ccache \
-    OLLAMA_CPU_TARGET="static" bash gen_linux.sh
 FROM --platform=linux/arm64 cpu-builder-arm64 AS cpu-build-arm64
 RUN --mount=type=cache,target=/root/.ccache \
     OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" bash gen_linux.sh
@@ -148,7 +142,6 @@ FROM --platform=linux/amd64 cpu-build-amd64 AS build-amd64
 ENV CGO_ENABLED=1
 WORKDIR /go/src/github.com/ollama/ollama
 COPY . .
-COPY --from=static-build-amd64 /go/src/github.com/ollama/ollama/llm/build/ llm/build/
 COPY --from=cpu_avx-build-amd64 /go/src/github.com/ollama/ollama/build/ build/
 COPY --from=cpu_avx2-build-amd64 /go/src/github.com/ollama/ollama/build/ build/
 COPY --from=cuda-11-build-amd64 /go/src/github.com/ollama/ollama/dist/ dist/
@@ -171,7 +164,6 @@ ENV CGO_ENABLED=1
 ARG GOLANG_VERSION
 WORKDIR /go/src/github.com/ollama/ollama
 COPY . .
-COPY --from=static-build-arm64 /go/src/github.com/ollama/ollama/llm/build/ llm/build/
 COPY --from=cuda-11-build-runner-arm64 /go/src/github.com/ollama/ollama/dist/ dist/
 COPY --from=cuda-11-build-runner-arm64 /go/src/github.com/ollama/ollama/build/ build/
 COPY --from=cuda-12-build-runner-arm64 /go/src/github.com/ollama/ollama/dist/ dist/
@@ -191,7 +183,7 @@ FROM dist-$TARGETARCH as dist
 
 
 # Optimized container images do not cary nested payloads
-FROM --platform=linux/amd64 static-build-amd64 AS container-build-amd64
+FROM --platform=linux/amd64 cpu-builder-amd64 AS container-build-amd64
 WORKDIR /go/src/github.com/ollama/ollama
 COPY . .
 ARG GOFLAGS
@@ -199,7 +191,7 @@ ARG CGO_CFLAGS
 RUN --mount=type=cache,target=/root/.ccache \
     go build -trimpath -o dist/linux-amd64/bin/ollama .
 
-FROM --platform=linux/arm64 static-build-arm64 AS container-build-arm64
+FROM --platform=linux/arm64 cpu-builder-arm64 AS container-build-arm64
 WORKDIR /go/src/github.com/ollama/ollama
 COPY . .
 ARG GOFLAGS

+ 182 - 1
docs/development.md

@@ -1,5 +1,8 @@
 # Development
 
+> [!IMPORTANT]
+> The `llm` package that loads and runs models is being updated to use a new [Go runner](#transition-to-go-runner): this should only impact a small set of PRs however it does change how the project is built.
+
 Install required tools:
 
 - cmake version 3.24 or higher
@@ -166,4 +169,182 @@ Follow the instructions at https://www.msys2.org/wiki/arm64/ to set up an arm64
 pacman -S mingw-w64-clang-aarch64-clang mingw-w64-clang-aarch64-gcc-compat mingw-w64-clang-aarch64-make make
 ```
 
-You will need to ensure your PATH includes go, cmake, gcc and clang mingw32-make to build ollama from source. (typically `C:\msys64\clangarm64\bin\`)
+You will need to ensure your PATH includes go, cmake, gcc and clang mingw32-make to build ollama from source. (typically `C:\msys64\clangarm64\bin\`)
+
+
+## Transition to Go runner
+
+The Ollama team is working on moving to a new Go based runner that loads and runs models in a subprocess to replace the previous code under `ext_server`. During this transition period, this new Go runner is "opt in" at build time, and requires using a different approach to build.
+
+After the transition to use the Go server exclusively, both `make` and `go generate` will build the Go runner.
+
+Install required tools:
+
+- go version 1.22 or higher
+- gcc version 11.4.0 or higher
+
+
+### MacOS
+
+[Download Go](https://go.dev/dl/)
+
+Optionally enable debugging and more verbose logging:
+
+```bash
+# At build time
+export CGO_CFLAGS="-g"
+
+# At runtime
+export OLLAMA_DEBUG=1
+```
+
+Get the required libraries and build the native LLM code:  (Adjust the job count based on your number of processors for a faster build)
+
+```bash
+make -C llama -j 5
+```
+
+Then build ollama:
+
+```bash
+go build .
+```
+
+Now you can run `ollama`:
+
+```bash
+./ollama
+```
+
+#### Xcode 15 warnings
+
+If you are using Xcode newer than version 14, you may see a warning during `go build` about `ld: warning: ignoring duplicate libraries: '-lobjc'` due to Golang issue https://github.com/golang/go/issues/67799 which can be safely ignored.  You can suppress the warning with `export CGO_LDFLAGS="-Wl,-no_warn_duplicate_libraries"`
+
+### Linux
+
+#### Linux CUDA (NVIDIA)
+
+_Your operating system distribution may already have packages for NVIDIA CUDA. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!_
+
+Install `make`, `gcc` and `golang` as well as [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads)
+development and runtime packages.
+
+Typically the build scripts will auto-detect CUDA, however, if your Linux distro
+or installation approach uses unusual paths, you can specify the location by
+specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
+libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
+a set of target CUDA architectures by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
+
+Then generate dependencies:  (Adjust the job count based on your number of processors for a faster build)
+
+```
+make -C llama -j 5
+```
+
+Then build the binary:
+
+```
+go build .
+```
+
+#### Linux ROCm (AMD)
+
+_Your operating system distribution may already have packages for AMD ROCm and CLBlast. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!_
+
+Install [CLBlast](https://github.com/CNugteren/CLBlast/blob/master/doc/installation.md) and [ROCm](https://rocm.docs.amd.com/en/latest/) development packages first, as well as `make`, `gcc`, and `golang`.
+
+Typically the build scripts will auto-detect ROCm, however, if your Linux distro
+or installation approach uses unusual paths, you can specify the location by
+specifying an environment variable `ROCM_PATH` to the location of the ROCm
+install (typically `/opt/rocm`), and `CLBlast_DIR` to the location of the
+CLBlast install (typically `/usr/lib/cmake/CLBlast`). You can also customize
+the AMD GPU targets by setting AMDGPU_TARGETS (e.g. `AMDGPU_TARGETS="gfx1101;gfx1102"`)
+
+Then generate dependencies:  (Adjust the job count based on your number of processors for a faster build)
+
+```
+make -C llama -j 5
+```
+
+Then build the binary:
+
+```
+go build .
+```
+
+ROCm requires elevated privileges to access the GPU at runtime. On most distros you can add your user account to the `render` group, or run as root.
+
+#### Advanced CPU Settings
+
+By default, running `make` will compile a few different variations
+of the LLM library based on common CPU families and vector math capabilities,
+including a lowest-common-denominator which should run on almost any 64 bit CPU
+somewhat slowly. At runtime, Ollama will auto-detect the optimal variation to
+load. 
+
+Custom CPU settings are not currently supported in the new Go server build but will be added back after we complete the transition.
+
+#### Containerized Linux Build
+
+If you have Docker available, you can build linux binaries with `OLLAMA_NEW_RUNNERS=1 ./scripts/build_linux.sh` which has the CUDA and ROCm dependencies included. The resulting binary is placed in `./dist`
+
+### Windows
+
+The following tools are required as a minimal development environment to build CPU inference support.
+
+- Go version 1.22 or higher
+  - https://go.dev/dl/
+- Git
+  - https://git-scm.com/download/win
+- GCC and Make.  There are multiple options on how to go about installing these tools on Windows.  We have verified the following, but others may work as well:  
+  - [MSYS2](https://www.msys2.org/)
+    - After installing, from an MSYS2 terminal, run `pacman -S mingw-w64-ucrt-x86_64-gcc make` to install the required tools
+  - Assuming you used the default install prefix for msys2 above, add `c:\msys64\ucrt64\bin` and `c:\msys64\usr\bin` to your environment variable `PATH` where you will perform the build steps below (e.g. system-wide, account-level, powershell, cmd, etc.)
+
+Then, build the `ollama` binary:
+
+```powershell
+$env:CGO_ENABLED="1"
+make -C llama -j 8
+go build .
+```
+
+#### GPU Support
+
+The GPU tools require the Microsoft native build tools.  To build either CUDA or ROCm, you must first install MSVC via Visual Studio:
+
+- Make sure to select `Desktop development with C++` as a Workload during the Visual Studio install
+- You must complete the Visual Studio install and run it once **BEFORE** installing CUDA or ROCm for the tools to properly register
+- Add the location of the **64 bit (x64)** compiler (`cl.exe`) to your `PATH`
+- Note: the default Developer Shell may configure the 32 bit (x86) compiler which will lead to build failures.  Ollama requires a 64 bit toolchain.
+
+#### Windows CUDA (NVIDIA)
+
+In addition to the common Windows development tools and MSVC described above:
+
+- [NVIDIA CUDA](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html)
+
+#### Windows ROCm (AMD Radeon)
+
+In addition to the common Windows development tools and MSVC described above:
+
+- [AMD HIP](https://www.amd.com/en/developer/resources/rocm-hub/hip-sdk.html)
+
+#### Windows arm64
+
+The default `Developer PowerShell for VS 2022` may default to x86 which is not what you want.  To ensure you get an arm64 development environment, start a plain PowerShell terminal and run:
+
+```powershell
+import-module 'C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\Microsoft.VisualStudio.DevShell.dll'
+Enter-VsDevShell -Arch arm64 -vsinstallpath 'C:\\Program Files\\Microsoft Visual Studio\\2022\\Community' -skipautomaticlocation
+```
+
+You can confirm with `write-host $env:VSCMD_ARG_TGT_ARCH`
+
+Follow the instructions at https://www.msys2.org/wiki/arm64/ to set up an arm64 msys2 environment.  Ollama requires gcc and mingw32-make to compile, which is not currently available on Windows arm64, but a gcc compatibility adapter is available via `mingw-w64-clang-aarch64-gcc-compat`. At a minimum you will need to install the following:
+
+```
+pacman -S mingw-w64-clang-aarch64-clang mingw-w64-clang-aarch64-gcc-compat mingw-w64-clang-aarch64-make make
+```
+
+You will need to ensure your PATH includes go, cmake, gcc and clang mingw32-make to build ollama from source. (typically `C:\msys64\clangarm64\bin\`)

+ 3 - 0
envconfig/config.go

@@ -160,6 +160,8 @@ var (
 	SchedSpread = Bool("OLLAMA_SCHED_SPREAD")
 	// IntelGPU enables experimental Intel GPU detection.
 	IntelGPU = Bool("OLLAMA_INTEL_GPU")
+	// MultiUserCache optimizes prompt caching for multi-user scenarios
+	MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
 )
 
 func String(s string) func() string {
@@ -245,6 +247,7 @@ func AsMap() map[string]EnvVar {
 		"OLLAMA_ORIGINS":           {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
 		"OLLAMA_SCHED_SPREAD":      {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
 		"OLLAMA_TMPDIR":            {"OLLAMA_TMPDIR", TmpDir(), "Location for temporary files"},
+		"OLLAMA_MULTIUSER_CACHE":   {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
 
 		// Informational
 		"HTTP_PROXY":  {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},

+ 1 - 1
integration/concurrency_test.go

@@ -42,7 +42,7 @@ func TestMultiModelConcurrency(t *testing.T) {
 		}
 		resp = [2][]string{
 			{"sunlight"},
-			{"england", "english", "massachusetts", "pilgrims", "british"},
+			{"england", "english", "massachusetts", "pilgrims", "british", "festival"},
 		}
 	)
 	var wg sync.WaitGroup

+ 1 - 1
integration/utils_test.go

@@ -275,7 +275,7 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
 				break
 			}
 		}
-		require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
+		require.True(t, atLeastOne, "%s: none of %v found in %s", genReq.Model, anyResp, response)
 		slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
 	case <-ctx.Done():
 		t.Error("outer test context done while waiting for generate")

+ 3 - 0
llama/.gitignore

@@ -0,0 +1,3 @@
+*.bin
+*.gguf
+build/

+ 221 - 0
llama/Dockerfile

@@ -0,0 +1,221 @@
+# Note: once we have fully transitioned to the Go server, this will replace the old Dockerfile at the top of the tree
+ARG GOLANG_VERSION=1.22.5
+ARG CMAKE_VERSION=3.22.1
+ARG CUDA_VERSION_11=11.3.1
+ARG CUDA_V11_ARCHITECTURES="50;52;53;60;61;62;70;72;75;80;86"
+ARG CUDA_VERSION_12=12.4.0
+ARG CUDA_V12_ARCHITECTURES="60;61;62;70;72;75;80;86;87;89;90;90a"
+ARG ROCM_VERSION=6.1.2
+
+### To create a local image for building linux binaries on mac or windows with efficient incremental builds
+#
+# docker build --platform linux/amd64 -t builder-amd64 -f Dockerfile.new --target unified-builder-amd64 .
+# docker run --platform linux/amd64 --rm -it -v $(pwd):/go/src/github.com/ollama/ollama/ builder-amd64
+#
+### Then incremental builds will be much faster in this container
+#
+# make -C llama -j 10 && go build -trimpath -o dist/linux-amd64/ollama .
+#
+FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS unified-builder-amd64
+ARG CMAKE_VERSION
+ARG GOLANG_VERSION
+ARG CUDA_VERSION_11
+ARG CUDA_VERSION_12
+COPY ./scripts/rh_linux_deps.sh /
+ENV PATH /opt/rh/devtoolset-10/root/usr/bin:/usr/local/cuda/bin:$PATH
+ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/cuda/lib64
+ENV LIBRARY_PATH=/usr/local/cuda/lib64/stubs:/opt/amdgpu/lib64
+RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
+RUN yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo && \
+    dnf clean all && \
+    dnf install -y \
+    zsh \
+    cuda-$(echo ${CUDA_VERSION_11} | cut -f1-2 -d. | sed -e "s/\./-/g") \
+    cuda-$(echo ${CUDA_VERSION_12} | cut -f1-2 -d. | sed -e "s/\./-/g")
+# TODO intel oneapi goes here...
+ENV GOARCH amd64
+ENV CGO_ENABLED 1
+WORKDIR /go/src/github.com/ollama/ollama/
+ENTRYPOINT [ "zsh" ]
+
+### To create a local image for building linux binaries on mac or linux/arm64 with efficient incremental builds
+# Note: this does not contain jetson variants
+#
+# docker build --platform linux/arm64 -t builder-arm64 -f Dockerfile.new --target unified-builder-arm64 .
+# docker run --platform linux/arm64 --rm -it -v $(pwd):/go/src/github.com/ollama/ollama/ builder-arm64
+#
+FROM --platform=linux/arm64 rockylinux:8 AS unified-builder-arm64
+ARG CMAKE_VERSION
+ARG GOLANG_VERSION
+ARG CUDA_VERSION_11
+ARG CUDA_VERSION_12
+COPY ./scripts/rh_linux_deps.sh /
+RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
+RUN yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo && \
+    dnf config-manager --set-enabled appstream && \
+    dnf clean all && \
+    dnf install -y \
+    zsh \
+    cuda-toolkit-$(echo ${CUDA_VERSION_11} | cut -f1-2 -d. | sed -e "s/\./-/g") \
+    cuda-toolkit-$(echo ${CUDA_VERSION_12} | cut -f1-2 -d. | sed -e "s/\./-/g")
+ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH:/usr/local/cuda/bin
+ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/cuda/lib64
+ENV LIBRARY_PATH=/usr/local/cuda/lib64/stubs:/opt/amdgpu/lib64
+ENV GOARCH amd64
+ENV CGO_ENABLED 1
+WORKDIR /go/src/github.com/ollama/ollama/
+ENTRYPOINT [ "zsh" ]
+
+FROM --platform=linux/amd64 unified-builder-amd64 AS runners-amd64
+COPY . .
+ARG OLLAMA_SKIP_CUDA_GENERATE
+ARG OLLAMA_SKIP_CUDA_11_GENERATE
+ARG OLLAMA_SKIP_CUDA_12_GENERATE
+ARG OLLAMA_SKIP_ROCM_GENERATE
+ARG CUDA_V11_ARCHITECTURES
+ARG CUDA_V12_ARCHITECTURES
+ARG OLLAMA_FAST_BUILD
+RUN --mount=type=cache,target=/root/.ccache \
+    if grep "^flags" /proc/cpuinfo|grep avx>/dev/null; then \
+        make -C llama -j $(expr $(nproc) / 2 ) ; \
+    else \
+        make -C llama -j 5 ; \
+    fi
+
+FROM --platform=linux/arm64 unified-builder-arm64 AS runners-arm64
+COPY . .
+ARG OLLAMA_SKIP_CUDA_GENERATE
+ARG OLLAMA_SKIP_CUDA_11_GENERATE
+ARG OLLAMA_SKIP_CUDA_12_GENERATE
+ARG CUDA_V11_ARCHITECTURES
+ARG CUDA_V12_ARCHITECTURES
+ARG OLLAMA_FAST_BUILD
+RUN --mount=type=cache,target=/root/.ccache \
+    make -C llama -j 8
+
+
+# Intermediate stages used for ./scripts/build_linux.sh
+FROM --platform=linux/amd64 centos:7 AS builder-amd64
+ARG CMAKE_VERSION
+ARG GOLANG_VERSION
+COPY ./scripts/rh_linux_deps.sh /
+RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
+ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
+ENV CGO_ENABLED 1
+ENV GOARCH amd64
+WORKDIR /go/src/github.com/ollama/ollama
+
+FROM --platform=linux/amd64 builder-amd64 AS build-amd64
+COPY . .
+COPY --from=runners-amd64 /go/src/github.com/ollama/ollama/dist/ dist/
+COPY --from=runners-amd64 /go/src/github.com/ollama/ollama/build/ build/
+ARG GOFLAGS
+ARG CGO_CFLAGS
+ARG OLLAMA_SKIP_ROCM_GENERATE
+RUN --mount=type=cache,target=/root/.ccache \
+    go build -trimpath -o dist/linux-amd64/bin/ollama .
+RUN cd dist/linux-$GOARCH && \
+    tar --exclude runners -cf - . | pigz --best > ../ollama-linux-$GOARCH.tgz
+RUN if [ -z ${OLLAMA_SKIP_ROCM_GENERATE} ] ; then \
+    cd dist/linux-$GOARCH-rocm && \
+    tar -cf - . | pigz --best > ../ollama-linux-$GOARCH-rocm.tgz ;\
+    fi
+
+FROM --platform=linux/arm64 rockylinux:8 AS builder-arm64
+ARG CMAKE_VERSION
+ARG GOLANG_VERSION
+COPY ./scripts/rh_linux_deps.sh /
+RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
+ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
+ENV CGO_ENABLED 1
+ENV GOARCH arm64
+WORKDIR /go/src/github.com/ollama/ollama
+
+FROM --platform=linux/arm64 builder-arm64 AS build-arm64
+COPY . .
+COPY --from=runners-arm64 /go/src/github.com/ollama/ollama/dist/ dist/
+COPY --from=runners-arm64 /go/src/github.com/ollama/ollama/build/ build/
+ARG GOFLAGS
+ARG CGO_CFLAGS
+RUN --mount=type=cache,target=/root/.ccache \
+    go build -trimpath -o dist/linux-arm64/bin/ollama .
+RUN cd dist/linux-$GOARCH && \
+    tar --exclude runners -cf - . | pigz --best > ../ollama-linux-$GOARCH.tgz
+
+FROM --platform=linux/amd64 scratch AS dist-amd64
+COPY --from=build-amd64 /go/src/github.com/ollama/ollama/dist/ollama-linux-*.tgz /
+FROM --platform=linux/arm64 scratch AS dist-arm64
+COPY --from=build-arm64 /go/src/github.com/ollama/ollama/dist/ollama-linux-*.tgz /
+FROM dist-$TARGETARCH AS dist
+
+
+# Optimized container images do not cary nested payloads
+FROM --platform=linux/amd64 builder-amd64 AS container-build-amd64
+WORKDIR /go/src/github.com/ollama/ollama
+COPY . .
+ARG GOFLAGS
+ARG CGO_CFLAGS
+RUN --mount=type=cache,target=/root/.ccache \
+    go build -trimpath -o dist/linux-amd64/bin/ollama .
+
+FROM --platform=linux/arm64 builder-arm64 AS container-build-arm64
+WORKDIR /go/src/github.com/ollama/ollama
+COPY . .
+ARG GOFLAGS
+ARG CGO_CFLAGS
+RUN --mount=type=cache,target=/root/.ccache \
+    go build -trimpath -o dist/linux-arm64/bin/ollama .
+
+# For amd64 container images, filter out cuda/rocm to minimize size
+FROM runners-amd64 AS runners-cuda-amd64
+RUN rm -rf \
+    ./dist/linux-amd64/lib/ollama/libggml_hipblas.so \
+    ./dist/linux-amd64/lib/ollama/runners/rocm*
+
+FROM runners-amd64 AS runners-rocm-amd64
+RUN rm -rf \
+    ./dist/linux-amd64/lib/ollama/libggml_cuda*.so \
+    ./dist/linux-amd64/lib/ollama/libcu*.so* \
+    ./dist/linux-amd64/lib/ollama/runners/cuda*
+
+FROM --platform=linux/amd64 ubuntu:22.04 AS runtime-amd64
+RUN apt-get update && \
+    apt-get install -y ca-certificates && \
+    rm -rf /var/lib/apt/lists/*
+COPY --from=container-build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/bin/ /bin/
+COPY --from=runners-cuda-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/lib/ /lib/
+
+FROM --platform=linux/arm64 ubuntu:22.04 AS runtime-arm64
+RUN apt-get update && \
+    apt-get install -y ca-certificates && \
+    rm -rf /var/lib/apt/lists/*
+COPY --from=container-build-arm64 /go/src/github.com/ollama/ollama/dist/linux-arm64/bin/ /bin/
+COPY --from=runners-arm64 /go/src/github.com/ollama/ollama/dist/linux-arm64/lib/ /lib/
+
+# ROCm libraries larger so we keep it distinct from the CPU/CUDA image
+FROM --platform=linux/amd64 ubuntu:22.04 AS runtime-rocm
+# Frontload the rocm libraries which are large, and rarely change to increase chance of a common layer
+# across releases
+COPY --from=build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64-rocm/lib/ /lib/
+RUN apt-get update && \
+    apt-get install -y ca-certificates && \
+    rm -rf /var/lib/apt/lists/*
+COPY --from=container-build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/bin/ /bin/
+COPY --from=runners-rocm-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/lib/ /lib/
+
+EXPOSE 11434
+ENV OLLAMA_HOST 0.0.0.0
+
+ENTRYPOINT ["/bin/ollama"]
+CMD ["serve"]
+
+FROM runtime-$TARGETARCH
+EXPOSE 11434
+ENV OLLAMA_HOST 0.0.0.0
+ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
+ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
+ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
+ENV NVIDIA_VISIBLE_DEVICES=all
+
+ENTRYPOINT ["/bin/ollama"]
+CMD ["serve"]

+ 54 - 0
llama/Makefile

@@ -0,0 +1,54 @@
+# top level makefile for Go server
+include make/common-defs.make
+
+RUNNER_TARGETS := default
+
+# Determine which if any GPU runners we should build
+ifeq ($(OS),windows)
+	CUDA_PATH?=$(shell cygpath -m -s "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\" 2>/dev/null)unknown
+	CUDA_BASE_DIR := $(dir $(shell cygpath -m -s "$(CUDA_PATH)\\.." 2>/dev/null))
+	CUDA_11:=$(shell ls -d $(CUDA_BASE_DIR)/v11.? 2>/dev/null)
+	CUDA_12:=$(shell ls -d $(CUDA_BASE_DIR)/v12.? 2>/dev/null)
+	HIP_PATH_83 := $(shell cygpath -m -s "$(subst \,/,$(HIP_PATH))" 2>/dev/null)
+	HIP_LIB_DIR := $(shell ls -d $(HIP_PATH_83)/lib 2>/dev/null)
+else ifeq ($(OS),linux)
+	HIP_PATH?=/opt/rocm
+	HIP_LIB_DIR := $(shell ls -d $(HIP_PATH)/lib 2>/dev/null)
+	CUDA_PATH?=/usr/local/cuda
+	CUDA_11:=$(shell ls -d $(CUDA_PATH)-11 2>/dev/null)
+	CUDA_12:=$(shell ls -d $(CUDA_PATH)-12 2>/dev/null)
+endif
+
+ifeq ($(OLLAMA_SKIP_CUDA_GENERATE),)
+ifneq ($(CUDA_11),)
+	RUNNER_TARGETS += cuda_v11
+endif
+ifneq ($(CUDA_12),)
+	RUNNER_TARGETS += cuda_v12
+endif
+endif
+ifeq ($(OLLAMA_SKIP_ROCM_GENERATE),)
+ifneq ($(HIP_LIB_DIR),)
+	RUNNER_TARGETS += rocm
+endif
+endif
+
+
+all: clean-payload .WAIT runners
+
+runners: $(RUNNER_TARGETS)
+
+$(RUNNER_TARGETS):
+	$(MAKE) -f make/Makefile.$@
+
+clean:
+	rm -rf $(BUILD_DIR) $(DIST_RUNNERS) $(PAYLOAD_RUNNERS) $(RUNNERS_PAYLOAD_DIR)
+
+clean-payload:
+	rm -rf $(addprefix $(RUNNERS_PAYLOAD_DIR)/, $(RUNNER_TARGETS) metal cpu cpu_avx cpu_avx2)
+
+.PHONY: all runners clean clean-payload $(RUNNER_TARGETS) .WAIT
+
+# Handy debugging for make variables
+print-%:
+	@echo '$*=$($*)'

+ 100 - 0
llama/README.md

@@ -0,0 +1,100 @@
+# `llama`
+
+This package integrates the [llama.cpp](https://github.com/ggerganov/llama.cpp) library as a Go package and makes it easy to build it with tags for different CPU and GPU processors.
+
+Supported:
+
+- [x] CPU
+- [x] avx, avx2
+- [x] macOS Metal
+- [x] Windows CUDA
+- [x] Windows ROCm
+- [x] Linux CUDA
+- [x] Linux ROCm
+- [x] Llava
+
+Extra build steps are required for CUDA and ROCm on Windows since `nvcc` and `hipcc` both require using msvc as the host compiler. For these shared libraries are created:
+
+- `ggml_cuda.dll` on Windows or `ggml_cuda.so` on Linux
+- `ggml_hipblas.dll` on Windows or `ggml_hipblas.so` on Linux
+
+> Note: it's important that memory is allocated and freed by the same compiler (e.g. entirely by code compiled with msvc or mingw). Issues from this should be rare, but there are some places where pointers are returned by the CUDA or HIP runtimes and freed elsewhere, causing a a crash. In a future change the same runtime should be used in both cases to avoid crashes.
+
+## Building
+
+```
+go build .
+```
+
+### AVX
+
+```shell
+go build -tags avx .
+```
+
+### AVX2
+
+```shell
+# go doesn't recognize `-mfma` as a valid compiler flag
+# see https://github.com/golang/go/issues/17895
+go env -w "CGO_CFLAGS_ALLOW=-mfma|-mf16c"
+go env -w "CGO_CXXFLAGS_ALLOW=-mfma|-mf16c"
+go build -tags=avx,avx2 .
+```
+
+## Linux
+
+### CUDA
+
+Install the [CUDA toolkit v11.3.1](https://developer.nvidia.com/cuda-11-3-1-download-archive):
+
+```shell
+make ggml_cuda.so
+go build -tags avx,cuda .
+```
+
+### ROCm
+
+Install the [CUDA toolkit v11.3.1](https://developer.nvidia.com/cuda-11-3-1-download-archive):
+
+```shell
+make ggml_hipblas.so
+go build -tags avx,rocm .
+```
+
+## Windows
+
+Download [w64devkit](https://github.com/skeeto/w64devkit/releases/latest) for a simple MinGW development environment.
+
+### CUDA
+
+Install the [CUDA toolkit v11.3.1](https://developer.nvidia.com/cuda-11-3-1-download-archive) then build the cuda code:
+
+```shell
+make ggml_cuda.dll
+go build -tags avx,cuda .
+```
+
+### ROCm
+
+Install [ROCm 5.7.1](https://rocm.docs.amd.com/en/docs-5.7.1/).
+
+```shell
+make ggml_hipblas.dll
+go build -tags avx,rocm .
+```
+
+## Building runners
+
+```shell
+# build all runners for this platform
+make -j
+```
+
+## Syncing with llama.cpp
+
+To update this package to the latest llama.cpp code, use the `sync.sh` script:
+
+```
+./sync.sh ../../llama.cpp
+```

+ 392 - 0
llama/base64.hpp

@@ -0,0 +1,392 @@
+/*
+This is free and unencumbered software released into the public domain.
+
+Anyone is free to copy, modify, publish, use, compile, sell, or
+distribute this software, either in source code form or as a compiled
+binary, for any purpose, commercial or non-commercial, and by any
+means.
+
+In jurisdictions that recognize copyright laws, the author or authors
+of this software dedicate any and all copyright interest in the
+software to the public domain. We make this dedication for the benefit
+of the public at large and to the detriment of our heirs and
+successors. We intend this dedication to be an overt act of
+relinquishment in perpetuity of all present and future rights to this
+software under copyright law.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
+OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
+ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+OTHER DEALINGS IN THE SOFTWARE.
+
+For more information, please refer to <http://unlicense.org>
+*/
+
+#ifndef PUBLIC_DOMAIN_BASE64_HPP_
+#define PUBLIC_DOMAIN_BASE64_HPP_
+
+#include <cstdint>
+#include <iterator>
+#include <stdexcept>
+#include <string>
+
+class base64_error : public std::runtime_error
+{
+public:
+    using std::runtime_error::runtime_error;
+};
+
+class base64
+{
+public:
+    enum class alphabet
+    {
+        /** the alphabet is detected automatically */
+        auto_,
+        /** the standard base64 alphabet is used */
+        standard,
+        /** like `standard` except that the characters `+` and `/` are replaced by `-` and `_` respectively*/
+        url_filename_safe
+    };
+
+    enum class decoding_behavior
+    {
+        /** if the input is not padded, the remaining bits are ignored */
+        moderate,
+        /** if a padding character is encounter decoding is finished */
+        loose
+    };
+
+    /**
+     Encodes all the elements from `in_begin` to `in_end` to `out`.
+
+     @warning The source and destination cannot overlap. The destination must be able to hold at least
+     `required_encode_size(std::distance(in_begin, in_end))`, otherwise the behavior depends on the output iterator.
+
+     @tparam Input_iterator the source; the returned elements are cast to `std::uint8_t` and should not be greater than
+     8 bits
+     @tparam Output_iterator the destination; the elements written to it are from the type `char`
+     @param in_begin the beginning of the source
+     @param in_end the ending of the source
+     @param out the destination iterator
+     @param alphabet which alphabet should be used
+     @returns the iterator to the next element past the last element copied
+     @throws see `Input_iterator` and `Output_iterator`
+    */
+    template<typename Input_iterator, typename Output_iterator>
+    static Output_iterator encode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out,
+                                  alphabet alphabet = alphabet::standard)
+    {
+        constexpr auto pad = '=';
+        const char* alpha  = alphabet == alphabet::url_filename_safe
+                                ? "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
+                                : "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+
+        while (in_begin != in_end) {
+            std::uint8_t i0 = 0, i1 = 0, i2 = 0;
+
+            // first character
+            i0 = static_cast<std::uint8_t>(*in_begin);
+            ++in_begin;
+
+            *out = alpha[i0 >> 2 & 0x3f];
+            ++out;
+
+            // part of first character and second
+            if (in_begin != in_end) {
+                i1 = static_cast<std::uint8_t>(*in_begin);
+                ++in_begin;
+
+                *out = alpha[((i0 & 0x3) << 4) | (i1 >> 4 & 0x0f)];
+                ++out;
+            } else {
+                *out = alpha[(i0 & 0x3) << 4];
+                ++out;
+
+                // last padding
+                *out = pad;
+                ++out;
+
+                // last padding
+                *out = pad;
+                ++out;
+
+                break;
+            }
+
+            // part of second character and third
+            if (in_begin != in_end) {
+                i2 = static_cast<std::uint8_t>(*in_begin);
+                ++in_begin;
+
+                *out = alpha[((i1 & 0xf) << 2) | (i2 >> 6 & 0x03)];
+                ++out;
+            } else {
+                *out = alpha[(i1 & 0xf) << 2];
+                ++out;
+
+                // last padding
+                *out = pad;
+                ++out;
+
+                break;
+            }
+
+            // rest of third
+            *out = alpha[i2 & 0x3f];
+            ++out;
+        }
+
+        return out;
+    }
+    /**
+     Encodes a string.
+
+     @param str the string that should be encoded
+     @param alphabet which alphabet should be used
+     @returns the encoded base64 string
+     @throws see base64::encode()
+    */
+    static std::string encode(const std::string& str, alphabet alphabet = alphabet::standard)
+    {
+        std::string result;
+
+        result.reserve(required_encode_size(str.length()) + 1);
+
+        encode(str.begin(), str.end(), std::back_inserter(result), alphabet);
+
+        return result;
+    }
+    /**
+     Encodes a char array.
+
+     @param buffer the char array
+     @param size the size of the array
+     @param alphabet which alphabet should be used
+     @returns the encoded string
+    */
+    static std::string encode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::standard)
+    {
+        std::string result;
+
+        result.reserve(required_encode_size(size) + 1);
+
+        encode(buffer, buffer + size, std::back_inserter(result), alphabet);
+
+        return result;
+    }
+    /**
+     Decodes all the elements from `in_begin` to `in_end` to `out`. `in_begin` may point to the same location as `out`,
+     in other words: inplace decoding is possible.
+
+     @warning The destination must be able to hold at least `required_decode_size(std::distance(in_begin, in_end))`,
+     otherwise the behavior depends on the output iterator.
+
+     @tparam Input_iterator the source; the returned elements are cast to `char`
+     @tparam Output_iterator the destination; the elements written to it are from the type `std::uint8_t`
+     @param in_begin the beginning of the source
+     @param in_end the ending of the source
+     @param out the destination iterator
+     @param alphabet which alphabet should be used
+     @param behavior the behavior when an error was detected
+     @returns the iterator to the next element past the last element copied
+     @throws base64_error depending on the set behavior
+     @throws see `Input_iterator` and `Output_iterator`
+    */
+    template<typename Input_iterator, typename Output_iterator>
+    static Output_iterator decode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out,
+                                  alphabet alphabet          = alphabet::auto_,
+                                  decoding_behavior behavior = decoding_behavior::moderate)
+    {
+        //constexpr auto pad = '=';
+        std::uint8_t last  = 0;
+        auto bits          = 0;
+
+        while (in_begin != in_end) {
+            auto c = *in_begin;
+            ++in_begin;
+
+            if (c == '=') {
+                break;
+            }
+
+            auto part = _base64_value(alphabet, c);
+
+            // enough bits for one byte
+            if (bits + 6 >= 8) {
+                *out = (last << (8 - bits)) | (part >> (bits - 2));
+                ++out;
+
+                bits -= 2;
+            } else {
+                bits += 6;
+            }
+
+            last = part;
+        }
+
+        // check padding
+        if (behavior != decoding_behavior::loose) {
+            while (in_begin != in_end) {
+                auto c = *in_begin;
+                ++in_begin;
+
+                if (c != '=') {
+                    throw base64_error("invalid base64 character.");
+                }
+            }
+        }
+
+        return out;
+    }
+    /**
+     Decodes a string.
+
+     @param str the base64 encoded string
+     @param alphabet which alphabet should be used
+     @param behavior the behavior when an error was detected
+     @returns the decoded string
+     @throws see base64::decode()
+    */
+    static std::string decode(const std::string& str, alphabet alphabet = alphabet::auto_,
+                              decoding_behavior behavior = decoding_behavior::moderate)
+    {
+        std::string result;
+
+        result.reserve(max_decode_size(str.length()));
+
+        decode(str.begin(), str.end(), std::back_inserter(result), alphabet, behavior);
+
+        return result;
+    }
+    /**
+     Decodes a string.
+
+     @param buffer the base64 encoded buffer
+     @param size the size of the buffer
+     @param alphabet which alphabet should be used
+     @param behavior the behavior when an error was detected
+     @returns the decoded string
+     @throws see base64::decode()
+    */
+    static std::string decode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::auto_,
+                              decoding_behavior behavior = decoding_behavior::moderate)
+    {
+        std::string result;
+
+        result.reserve(max_decode_size(size));
+
+        decode(buffer, buffer + size, std::back_inserter(result), alphabet, behavior);
+
+        return result;
+    }
+    /**
+     Decodes a string inplace.
+
+     @param[in,out] str the base64 encoded string
+     @param alphabet which alphabet should be used
+     @param behavior the behavior when an error was detected
+     @throws base64::decode_inplace()
+    */
+    static void decode_inplace(std::string& str, alphabet alphabet = alphabet::auto_,
+                               decoding_behavior behavior = decoding_behavior::moderate)
+    {
+        str.resize(decode(str.begin(), str.end(), str.begin(), alphabet, behavior) - str.begin());
+    }
+    /**
+     Decodes a char array inplace.
+
+     @param[in,out] str the string array
+     @param size the length of the array
+     @param alphabet which alphabet should be used
+     @param behavior the behavior when an error was detected
+     @returns the pointer to the next element past the last element decoded
+     @throws base64::decode_inplace()
+    */
+    static char* decode_inplace(char* str, std::size_t size, alphabet alphabet = alphabet::auto_,
+                                decoding_behavior behavior = decoding_behavior::moderate)
+    {
+        return decode(str, str + size, str, alphabet, behavior);
+    }
+    /**
+     Returns the required decoding size for a given size. The value is calculated with the following formula:
+
+     $$
+     \lceil \frac{size}{4} \rceil \cdot 3
+     $$
+
+     @param size the size of the encoded input
+     @returns the size of the resulting decoded buffer; this the absolute maximum
+    */
+    static std::size_t max_decode_size(std::size_t size) noexcept
+    {
+        return (size / 4 + (size % 4 ? 1 : 0)) * 3;
+    }
+    /**
+     Returns the required encoding size for a given size. The value is calculated with the following formula:
+
+     $$
+     \lceil \frac{size}{3} \rceil \cdot 4
+     $$
+
+     @param size the size of the decoded input
+     @returns the size of the resulting encoded buffer
+    */
+    static std::size_t required_encode_size(std::size_t size) noexcept
+    {
+        return (size / 3 + (size % 3 ? 1 : 0)) * 4;
+    }
+
+private:
+    static std::uint8_t _base64_value(alphabet& alphabet, char c)
+    {
+        if (c >= 'A' && c <= 'Z') {
+            return c - 'A';
+        } else if (c >= 'a' && c <= 'z') {
+            return c - 'a' + 26;
+        } else if (c >= '0' && c <= '9') {
+            return c - '0' + 52;
+        }
+
+        // comes down to alphabet
+        if (alphabet == alphabet::standard) {
+            if (c == '+') {
+                return 62;
+            } else if (c == '/') {
+                return 63;
+            }
+        } else if (alphabet == alphabet::url_filename_safe) {
+            if (c == '-') {
+                return 62;
+            } else if (c == '_') {
+                return 63;
+            }
+        } // auto detect
+        else {
+            if (c == '+') {
+                alphabet = alphabet::standard;
+
+                return 62;
+            } else if (c == '/') {
+                alphabet = alphabet::standard;
+
+                return 63;
+            } else if (c == '-') {
+                alphabet = alphabet::url_filename_safe;
+
+                return 62;
+            } else if (c == '_') {
+                alphabet = alphabet::url_filename_safe;
+
+                return 63;
+            }
+        }
+
+        throw base64_error("invalid base64 character.");
+    }
+};
+
+#endif // !PUBLIC_DOMAIN_BASE64_HPP_

+ 30 - 0
llama/build-info.cpp

@@ -0,0 +1,30 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+int LLAMA_BUILD_NUMBER = 0;
+char const *LLAMA_COMMIT = "";
+char const *LLAMA_COMPILER = "";
+char const *LLAMA_BUILD_TARGET = "";

+ 2689 - 0
llama/clip.cpp

@@ -0,0 +1,2689 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// NOTE: This is modified from clip.cpp only for LLaVA,
+// so there might be still unnecessary artifacts hanging around
+// I'll gradually clean and extend it
+// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
+#include "clip.h"
+#include "log.h"
+#include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+
+#ifdef GGML_USE_CUDA
+#include "ggml-cuda.h"
+#endif
+
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
+
+#ifdef GGML_USE_CANN
+#include "ggml-cann.h"
+#endif
+
+#ifdef GGML_USE_VULKAN
+#include "ggml-vulkan.h"
+#endif
+
+#define STB_IMAGE_IMPLEMENTATION
+#include "stb_image.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdlib>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <regex>
+#include <stdexcept>
+#include <vector>
+#include <sstream>
+#include <cinttypes>
+#include <limits>
+
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+    #define NOMINMAX
+#endif
+#include <windows.h>
+#if __GLIBCXX__
+#include <cstdio>
+#include <ext/stdio_filebuf.h>
+#include <fcntl.h>
+#endif
+#endif
+
+//#define CLIP_DEBUG_FUNCTIONS
+
+// RGB uint8 image
+struct clip_image_u8 {
+    int nx;
+    int ny;
+
+    std::vector<uint8_t> buf;
+};
+
+// RGB float32 image (NHWC)
+// Memory layout: RGBRGBRGB...
+struct clip_image_f32 {
+    int nx;
+    int ny;
+
+    std::vector<float> buf;
+};
+
+static std::string format(const char * fmt, ...) {
+    va_list ap;
+    va_list ap2;
+    va_start(ap, fmt);
+    va_copy(ap2, ap);
+    int size = vsnprintf(NULL, 0, fmt, ap);
+    GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
+    std::vector<char> buf(size + 1);
+    int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
+    GGML_ASSERT(size2 == size);
+    va_end(ap2);
+    va_end(ap);
+    return std::string(buf.data(), buf.size());
+}
+
+//
+// key constants
+//
+
+#define KEY_FTYPE               "general.file_type"
+#define KEY_NAME                "general.name"
+#define KEY_DESCRIPTION         "general.description"
+#define KEY_HAS_TEXT_ENC        "clip.has_text_encoder"
+#define KEY_HAS_VIS_ENC         "clip.has_vision_encoder"
+#define KEY_HAS_LLAVA_PROJ      "clip.has_llava_projector"
+#define KEY_HAS_MINICPMV_PROJ   "clip.has_minicpmv_projector"
+#define KEY_MINICPMV_VERSION    "clip.minicpmv_version"
+#define KEY_USE_GELU            "clip.use_gelu"
+#define KEY_N_EMBD              "clip.%s.embedding_length"
+#define KEY_N_FF                "clip.%s.feed_forward_length"
+#define KEY_N_BLOCK             "clip.%s.block_count"
+#define KEY_N_HEAD              "clip.%s.attention.head_count"
+#define KEY_LAYER_NORM_EPS      "clip.%s.attention.layer_norm_epsilon"
+#define KEY_PROJ_DIM            "clip.%s.projection_dim"
+#define KEY_TOKENS              "tokenizer.ggml.tokens"
+#define KEY_N_POSITIONS         "clip.text.context_length"
+#define KEY_IMAGE_SIZE          "clip.vision.image_size"
+#define KEY_PATCH_SIZE          "clip.vision.patch_size"
+#define KEY_IMAGE_MEAN          "clip.vision.image_mean"
+#define KEY_IMAGE_STD           "clip.vision.image_std"
+#define KEY_PROJ_TYPE           "clip.projector_type"
+
+#define KEY_MM_PATCH_MERGE_TYPE   "clip.vision.mm_patch_merge_type"
+#define KEY_IMAGE_GRID_PINPOINTS  "clip.vision.image_grid_pinpoints"
+#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
+
+
+//
+// tensor name constants
+//
+
+#define TN_TOKEN_EMBD      "%s.token_embd.weight"
+#define TN_POS_EMBD        "%s.position_embd.weight"
+#define TN_CLASS_EMBD      "v.class_embd"
+#define TN_PATCH_EMBD      "v.patch_embd.weight"
+#define TN_PATCH_BIAS      "v.patch_embd.bias"
+#define TN_ATTN_K          "%s.blk.%d.attn_k.%s"
+#define TN_ATTN_Q          "%s.blk.%d.attn_q.%s"
+#define TN_ATTN_V          "%s.blk.%d.attn_v.%s"
+#define TN_ATTN_OUTPUT     "%s.blk.%d.attn_out.%s"
+#define TN_FFN_DOWN        "%s.blk.%d.ffn_down.%s"
+#define TN_FFN_UP          "%s.blk.%d.ffn_up.%s"
+#define TN_LN_1            "%s.blk.%d.ln1.%s"
+#define TN_LN_2            "%s.blk.%d.ln2.%s"
+#define TN_LN_PRE          "%s.pre_ln.%s"
+#define TN_LN_POST         "%s.post_ln.%s"
+#define TN_TEXT_PROJ       "text_projection.weight"
+#define TN_VIS_PROJ        "visual_projection.weight"
+#define TN_LLAVA_PROJ      "mm.%d.%s"
+#define TN_MVLM_PROJ_MLP   "mm.model.mlp.%d.%s"
+#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
+#define TN_MVLM_PROJ_PEG   "mm.model.peg.%d.%s"
+#define TN_IMAGE_NEWLINE   "model.image_newline"
+
+#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
+#define TN_MINICPMV_QUERY "resampler.query"
+#define TN_MINICPMV_PROJ "resampler.proj.weight"
+#define TN_MINICPMV_KV_PROJ "resampler.kv.weight"
+#define TN_MINICPMV_ATTN "resampler.attn.%s.%s"
+#define TN_MINICPMV_LN "resampler.ln_%s.%s"
+
+
+enum projector_type {
+    PROJECTOR_TYPE_MLP,
+    PROJECTOR_TYPE_MLP_NORM,
+    PROJECTOR_TYPE_LDP,
+    PROJECTOR_TYPE_LDPV2,
+    PROJECTOR_TYPE_RESAMPLER,
+    PROJECTOR_TYPE_UNKNOWN,
+};
+
+static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
+    { PROJECTOR_TYPE_MLP, "mlp" },
+    { PROJECTOR_TYPE_LDP, "ldp" },
+    { PROJECTOR_TYPE_LDPV2, "ldpv2"},
+    { PROJECTOR_TYPE_RESAMPLER, "resampler"},
+};
+
+
+//
+// utilities to get data from a gguf file
+//
+
+static int get_key_idx(const gguf_context * ctx, const char * key) {
+    int i = gguf_find_key(ctx, key);
+    if (i == -1) {
+        LOG_TEE("key %s not found in file\n", key);
+        throw std::runtime_error(format("Missing required key: %s", key));
+    }
+
+    return i;
+}
+
+static uint32_t get_u32(const gguf_context * ctx, const std::string & key) {
+    const int i = get_key_idx(ctx, key.c_str());
+
+    return gguf_get_val_u32(ctx, i);
+}
+
+static float get_f32(const gguf_context * ctx, const std::string & key) {
+    const int i = get_key_idx(ctx, key.c_str());
+
+    return gguf_get_val_f32(ctx, i);
+}
+
+static struct ggml_tensor * get_tensor(struct ggml_context * ctx, const std::string & name) {
+    struct ggml_tensor * cur = ggml_get_tensor(ctx, name.c_str());
+    if (!cur) {
+        throw std::runtime_error(format("%s: unable to find tensor %s\n", __func__, name.c_str()));
+    }
+
+    return cur;
+}
+
+static std::string get_ftype(int ftype) {
+    return ggml_type_name(static_cast<ggml_type>(ftype));
+}
+
+static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) {
+    switch (type) {
+        case GGUF_TYPE_UINT8:   return std::to_string(((const uint8_t  *)data)[i]);
+        case GGUF_TYPE_INT8:    return std::to_string(((const int8_t   *)data)[i]);
+        case GGUF_TYPE_UINT16:  return std::to_string(((const uint16_t *)data)[i]);
+        case GGUF_TYPE_INT16:   return std::to_string(((const int16_t  *)data)[i]);
+        case GGUF_TYPE_UINT32:  return std::to_string(((const uint32_t *)data)[i]);
+        case GGUF_TYPE_INT32:   return std::to_string(((const int32_t  *)data)[i]);
+        case GGUF_TYPE_UINT64:  return std::to_string(((const uint64_t *)data)[i]);
+        case GGUF_TYPE_INT64:   return std::to_string(((const int64_t  *)data)[i]);
+        case GGUF_TYPE_FLOAT32: return std::to_string(((const float    *)data)[i]);
+        case GGUF_TYPE_FLOAT64: return std::to_string(((const double   *)data)[i]);
+        case GGUF_TYPE_BOOL:    return ((const bool *)data)[i] ? "true" : "false";
+        default:                return format("unknown type %d", type);
+    }
+}
+
+static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
+    if (search.empty()) {
+        return;
+    }
+    std::string builder;
+    builder.reserve(s.length());
+    size_t pos = 0;
+    size_t last_pos = 0;
+    while ((pos = s.find(search, last_pos)) != std::string::npos) {
+        builder.append(s, last_pos, pos - last_pos);
+        builder.append(replace);
+        last_pos = pos + search.length();
+    }
+    builder.append(s, last_pos, std::string::npos);
+    s = std::move(builder);
+}
+
+static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
+    const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
+
+    switch (type) {
+        case GGUF_TYPE_STRING:
+            return gguf_get_val_str(ctx_gguf, i);
+        case GGUF_TYPE_ARRAY:
+            {
+                const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
+                int arr_n = gguf_get_arr_n(ctx_gguf, i);
+                const void * data = gguf_get_arr_data(ctx_gguf, i);
+                std::stringstream ss;
+                ss << "[";
+                for (int j = 0; j < arr_n; j++) {
+                    if (arr_type == GGUF_TYPE_STRING) {
+                        std::string val = gguf_get_arr_str(ctx_gguf, i, j);
+                        // escape quotes
+                        replace_all(val, "\\", "\\\\");
+                        replace_all(val, "\"", "\\\"");
+                        ss << '"' << val << '"';
+                    } else if (arr_type == GGUF_TYPE_ARRAY) {
+                        ss << "???";
+                    } else {
+                        ss << gguf_data_to_str(arr_type, data, j);
+                    }
+                    if (j < arr_n - 1) {
+                        ss << ", ";
+                    }
+                }
+                ss << "]";
+                return ss.str();
+            }
+        default:
+            return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
+    }
+}
+
+static void print_tensor_info(const ggml_tensor * tensor, const char * prefix = "") {
+    size_t tensor_size = ggml_nbytes(tensor);
+    LOG_TEE("%s: n_dims = %d, name = %s, tensor_size=%zu, shape:[%" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "], type = %s\n",
+            prefix, ggml_n_dims(tensor), tensor->name, tensor_size,
+            tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], ggml_type_name(tensor->type));
+}
+
+static projector_type clip_projector_type_from_string(const std::string & name) {
+    for (const auto & kv : PROJECTOR_TYPE_NAMES) { // NOLINT
+        if (kv.second == name) {
+            return kv.first;
+        }
+    }
+    return PROJECTOR_TYPE_UNKNOWN;
+}
+
+#ifdef CLIP_DEBUG_FUNCTIONS
+static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) {
+    std::ofstream file(filename, std::ios::binary);
+    if (!file.is_open()) {
+        LOG_TEE("Failed to open file for writing: %s\n", filename.c_str());
+        return;
+    }
+
+    // PPM header: P6 format, width, height, and max color value
+    file << "P6\n" << img.nx << " " << img.ny << "\n255\n";
+
+    // Write pixel data
+    for (size_t i = 0; i < img.buf.size(); i += 3) {
+        // PPM expects binary data in RGB format, which matches our image buffer
+        file.write(reinterpret_cast<const char*>(&img.buf[i]), 3);
+    }
+
+    file.close();
+}
+
+static void clip_image_save_to_bmp(const clip_image_u8& img, const std::string& filename) {
+    std::ofstream file(filename, std::ios::binary);
+    if (!file.is_open()) {
+        LOG_TEE("Failed to open file for writing: %s\n", filename.c_str());
+        return;
+    }
+
+    int fileSize = 54 + 3 * img.nx * img.ny; // File header + info header + pixel data
+    int bytesPerPixel = 3;
+    int widthInBytes = img.nx * bytesPerPixel;
+    int paddingAmount = (4 - (widthInBytes % 4)) % 4;
+    int stride = widthInBytes + paddingAmount;
+
+    // Bitmap file header
+    unsigned char fileHeader[14] = {
+        'B','M',     // Signature
+        0,0,0,0,    // Image file size in bytes
+        0,0,0,0,    // Reserved
+        54,0,0,0    // Start of pixel array
+    };
+
+    // Total file size
+    fileSize = 54 + (stride * img.ny);
+    fileHeader[2] = (unsigned char)(fileSize);
+    fileHeader[3] = (unsigned char)(fileSize >> 8);
+    fileHeader[4] = (unsigned char)(fileSize >> 16);
+    fileHeader[5] = (unsigned char)(fileSize >> 24);
+
+    // Bitmap information header (BITMAPINFOHEADER)
+    unsigned char infoHeader[40] = {
+        40,0,0,0,   // Size of this header (40 bytes)
+        0,0,0,0,    // Image width
+        0,0,0,0,    // Image height
+        1,0,        // Number of color planes
+        24,0,       // Bits per pixel
+        0,0,0,0,    // No compression
+        0,0,0,0,    // Image size (can be 0 for no compression)
+        0,0,0,0,    // X pixels per meter (not specified)
+        0,0,0,0,    // Y pixels per meter (not specified)
+        0,0,0,0,    // Total colors (color table not used)
+        0,0,0,0     // Important colors (all are important)
+    };
+
+    // Width and height in the information header
+    infoHeader[4] = (unsigned char)(img.nx);
+    infoHeader[5] = (unsigned char)(img.nx >> 8);
+    infoHeader[6] = (unsigned char)(img.nx >> 16);
+    infoHeader[7] = (unsigned char)(img.nx >> 24);
+    infoHeader[8] = (unsigned char)(img.ny);
+    infoHeader[9] = (unsigned char)(img.ny >> 8);
+    infoHeader[10] = (unsigned char)(img.ny >> 16);
+    infoHeader[11] = (unsigned char)(img.ny >> 24);
+
+    // Write file headers
+    file.write(reinterpret_cast<char*>(fileHeader), sizeof(fileHeader));
+    file.write(reinterpret_cast<char*>(infoHeader), sizeof(infoHeader));
+
+    // Pixel data
+    std::vector<unsigned char> padding(3, 0); // Max padding size to be added to each row
+    for (int y = img.ny - 1; y >= 0; --y) { // BMP files are stored bottom-to-top
+        for (int x = 0; x < img.nx; ++x) {
+            // Each pixel
+            size_t pixelIndex = (y * img.nx + x) * 3;
+            unsigned char pixel[3] = {
+                img.buf[pixelIndex + 2], // BMP stores pixels in BGR format
+                img.buf[pixelIndex + 1],
+                img.buf[pixelIndex]
+            };
+            file.write(reinterpret_cast<char*>(pixel), 3);
+        }
+        // Write padding for the row
+        file.write(reinterpret_cast<char*>(padding.data()), paddingAmount);
+    }
+
+    file.close();
+}
+
+// debug function to convert f32 to u8
+static void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u8& dst) {
+    dst.nx = src.nx;
+    dst.ny = src.ny;
+    dst.buf.resize(3 * src.nx * src.ny);
+    for (size_t i = 0; i < src.buf.size(); ++i) {
+        dst.buf[i] = static_cast<uint8_t>(std::min(std::max(int(src.buf[i] * 255.0f), 0), 255));
+    }
+}
+#endif
+
+
+//
+// clip layers
+//
+
+struct clip_hparams {
+    int32_t image_size;
+    int32_t patch_size;
+    int32_t hidden_size;
+    int32_t n_intermediate;
+    int32_t projection_dim;
+    int32_t n_head;
+    int32_t n_layer;
+
+    float eps;
+
+    char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default)
+
+    int32_t image_grid_pinpoints[32];
+    int32_t image_crop_resolution;
+};
+
+struct clip_layer {
+    // attention
+    struct ggml_tensor * k_w;
+    struct ggml_tensor * k_b;
+    struct ggml_tensor * q_w;
+    struct ggml_tensor * q_b;
+    struct ggml_tensor * v_w;
+    struct ggml_tensor * v_b;
+
+    struct ggml_tensor * o_w;
+    struct ggml_tensor * o_b;
+
+    // layernorm 1
+    struct ggml_tensor * ln_1_w;
+    struct ggml_tensor * ln_1_b;
+
+    // ff
+    struct ggml_tensor * ff_i_w;
+    struct ggml_tensor * ff_i_b;
+
+    struct ggml_tensor * ff_o_w;
+    struct ggml_tensor * ff_o_b;
+
+    // layernorm 2
+    struct ggml_tensor * ln_2_w;
+    struct ggml_tensor * ln_2_b;
+};
+
+struct clip_vision_model {
+    struct clip_hparams hparams;
+
+    // embeddings
+    struct ggml_tensor * class_embedding;
+    struct ggml_tensor * patch_embeddings;
+    struct ggml_tensor * patch_bias;
+    struct ggml_tensor * position_embeddings;
+
+    struct ggml_tensor * pre_ln_w;
+    struct ggml_tensor * pre_ln_b;
+
+    std::vector<clip_layer> layers;
+
+    struct ggml_tensor * post_ln_w;
+    struct ggml_tensor * post_ln_b;
+
+    struct ggml_tensor * projection;
+
+    // LLaVA projection
+    struct ggml_tensor * mm_0_w = NULL;
+    struct ggml_tensor * mm_0_b = NULL;
+    struct ggml_tensor * mm_2_w = NULL;
+    struct ggml_tensor * mm_2_b = NULL;
+
+    struct ggml_tensor * image_newline = NULL;
+
+    // Yi type models with mlp+normalization projection
+    struct ggml_tensor * mm_1_w = NULL; // Yi type models have 0, 1, 3, 4
+    struct ggml_tensor * mm_1_b = NULL;
+    struct ggml_tensor * mm_3_w = NULL;
+    struct ggml_tensor * mm_3_b = NULL;
+    struct ggml_tensor * mm_4_w = NULL;
+    struct ggml_tensor * mm_4_b = NULL;
+
+    // MobileVLM projection
+    struct ggml_tensor * mm_model_mlp_1_w;
+    struct ggml_tensor * mm_model_mlp_1_b;
+    struct ggml_tensor * mm_model_mlp_3_w;
+    struct ggml_tensor * mm_model_mlp_3_b;
+    struct ggml_tensor * mm_model_block_1_block_0_0_w;
+    struct ggml_tensor * mm_model_block_1_block_0_1_w;
+    struct ggml_tensor * mm_model_block_1_block_0_1_b;
+    struct ggml_tensor * mm_model_block_1_block_1_fc1_w;
+    struct ggml_tensor * mm_model_block_1_block_1_fc1_b;
+    struct ggml_tensor * mm_model_block_1_block_1_fc2_w;
+    struct ggml_tensor * mm_model_block_1_block_1_fc2_b;
+    struct ggml_tensor * mm_model_block_1_block_2_0_w;
+    struct ggml_tensor * mm_model_block_1_block_2_1_w;
+    struct ggml_tensor * mm_model_block_1_block_2_1_b;
+    struct ggml_tensor * mm_model_block_2_block_0_0_w;
+    struct ggml_tensor * mm_model_block_2_block_0_1_w;
+    struct ggml_tensor * mm_model_block_2_block_0_1_b;
+    struct ggml_tensor * mm_model_block_2_block_1_fc1_w;
+    struct ggml_tensor * mm_model_block_2_block_1_fc1_b;
+    struct ggml_tensor * mm_model_block_2_block_1_fc2_w;
+    struct ggml_tensor * mm_model_block_2_block_1_fc2_b;
+    struct ggml_tensor * mm_model_block_2_block_2_0_w;
+    struct ggml_tensor * mm_model_block_2_block_2_1_w;
+    struct ggml_tensor * mm_model_block_2_block_2_1_b;
+
+    // MobileVLM_V2 projection
+    struct ggml_tensor * mm_model_mlp_0_w;
+    struct ggml_tensor * mm_model_mlp_0_b;
+    struct ggml_tensor * mm_model_mlp_2_w;
+    struct ggml_tensor * mm_model_mlp_2_b;
+    struct ggml_tensor * mm_model_peg_0_w;
+    struct ggml_tensor * mm_model_peg_0_b;
+
+    // MINICPMV projection
+    struct ggml_tensor * mm_model_pos_embed_k;
+    struct ggml_tensor * mm_model_query;
+    struct ggml_tensor * mm_model_proj;
+    struct ggml_tensor * mm_model_kv_proj;
+    struct ggml_tensor * mm_model_attn_q_w;
+    struct ggml_tensor * mm_model_attn_q_b;
+    struct ggml_tensor * mm_model_attn_k_w;
+    struct ggml_tensor * mm_model_attn_k_b;
+    struct ggml_tensor * mm_model_attn_v_w;
+    struct ggml_tensor * mm_model_attn_v_b;
+    struct ggml_tensor * mm_model_attn_o_w;
+    struct ggml_tensor * mm_model_attn_o_b;
+    struct ggml_tensor * mm_model_ln_q_w;
+    struct ggml_tensor * mm_model_ln_q_b;
+    struct ggml_tensor * mm_model_ln_kv_w;
+    struct ggml_tensor * mm_model_ln_kv_b;
+    struct ggml_tensor * mm_model_ln_post_w;
+    struct ggml_tensor * mm_model_ln_post_b;
+};
+
+struct clip_ctx {
+    bool has_text_encoder    = false;
+    bool has_vision_encoder  = false;
+    bool has_llava_projector = false;
+    bool has_minicpmv_projector = false;
+    int minicpmv_version = 2;
+
+    struct clip_vision_model vision_model;
+    projector_type proj_type = PROJECTOR_TYPE_MLP;
+
+    float image_mean[3];
+    float image_std[3];
+    bool use_gelu = false;
+    int32_t ftype = 1;
+
+    bool has_class_embedding = true;
+    bool has_pre_norm = true;
+    bool has_post_norm = false;
+    bool has_patch_bias = false;
+
+    struct gguf_context * ctx_gguf;
+    struct ggml_context * ctx_data;
+
+    std::vector<uint8_t> buf_compute_meta;
+
+    // memory buffers to evaluate the model
+    ggml_backend_buffer_t params_buffer  = NULL;
+
+    ggml_backend_t backend       = NULL;
+    ggml_gallocr_t compute_alloc = NULL;
+
+    struct clip_image_size * load_image_size;
+};
+
+static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
+    if (!ctx->has_vision_encoder) {
+        LOG_TEE("This gguf file seems to have no vision encoder\n");
+        return nullptr;
+    }
+
+    const auto & model = ctx->vision_model;
+    const auto & hparams = model.hparams;
+
+    const int image_size = hparams.image_size;
+    int image_size_width  = image_size;
+    int image_size_height = image_size;
+    if (ctx->has_minicpmv_projector) {
+        if (load_image_size == nullptr) {
+            load_image_size = clip_image_size_init();
+        }
+        LOG_TEE("%s: %d %d\n", __func__, load_image_size->width, load_image_size->height);
+        image_size_width  = load_image_size->width;
+        image_size_height = load_image_size->height;
+        if (is_inf) {
+            image_size_width  = imgs->data->nx;
+            image_size_height = imgs->data->ny;
+        }
+    }
+    const int patch_size           = hparams.patch_size;
+    const int num_patches          = ((image_size_width / patch_size) * (image_size_height / patch_size));
+    const int num_positions        = num_patches + (ctx->has_class_embedding ? 1 : 0);
+    const int hidden_size          = hparams.hidden_size;
+    const int n_head               = hparams.n_head;
+    const int d_head               = hidden_size / n_head;
+    int n_layer                    = hparams.n_layer;
+    const float eps                = hparams.eps;
+
+    const int batch_size = imgs->size;
+
+    if (ctx->has_llava_projector || ctx->has_minicpmv_projector) {
+        GGML_ASSERT(batch_size == 1);
+    }
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ ctx->buf_compute_meta.size(),
+        /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
+        /*.no_alloc   =*/ true,
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+    struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
+    ggml_set_name(inp_raw, "inp_raw");
+    ggml_set_input(inp_raw);
+
+    struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+
+    inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
+    inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
+
+    if (ctx->has_patch_bias) {
+        // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
+        inp = ggml_add(ctx0, inp, model.patch_bias);
+    }
+    struct ggml_tensor * embeddings = inp;
+    struct ggml_tensor * pos_embed = nullptr;
+
+    if (ctx->has_llava_projector) {
+        // concat class_embeddings and patch_embeddings
+        if (ctx->has_class_embedding) {
+            embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
+            ggml_set_name(embeddings, "embeddings");
+            ggml_set_input(embeddings);
+            embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
+                    embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
+            embeddings = ggml_acc(ctx0, embeddings, inp,
+                    embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
+        }
+    }
+
+    struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
+    ggml_set_name(positions, "positions");
+    ggml_set_input(positions);
+
+    embeddings =
+        ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
+
+    if (ctx->has_minicpmv_projector) {
+        int pos_w = image_size_width/patch_size;
+        int pos_h = image_size_height/patch_size;
+        if (ctx->minicpmv_version == 2) {
+            pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1);
+        }
+        else if (ctx->minicpmv_version == 3) {
+            pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
+        }
+        ggml_set_name(pos_embed, "pos_embed");
+        ggml_set_input(pos_embed);
+    }
+
+    // pre-layernorm
+    if (ctx->has_pre_norm) {
+        embeddings = ggml_norm(ctx0, embeddings, eps);
+        ggml_set_name(embeddings, "pre_ln");
+
+        embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
+    }
+
+    // loop over layers
+    if (ctx->has_minicpmv_projector) {
+        n_layer += 1;
+    }
+    for (int il = 0; il < n_layer - 1; il++) {
+        struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
+
+        //const size_t nb_q_w = model.layers[il].q_w->nb[0];
+
+        // layernorm1
+        {
+            cur = ggml_norm(ctx0, cur, eps);
+
+            cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w),
+                           model.layers[il].ln_1_b);
+        }
+
+        // self-attention
+        {
+
+            struct ggml_tensor * Q =
+                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
+
+            Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
+            Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
+            Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
+            Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
+
+            struct ggml_tensor * K =
+                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
+
+            K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
+            K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
+            K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
+
+            struct ggml_tensor * V =
+                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
+
+            V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
+            V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
+            V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
+
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            KQ = ggml_soft_max_inplace(ctx0, KQ);
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
+            KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
+            KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+            cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
+        }
+
+        // attention output
+        cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
+
+        // re-add the layer input, e.g., residual
+        cur = ggml_add(ctx0, cur, embeddings);
+
+        embeddings = cur; // embeddings = residual, cur = hidden_states
+
+        // layernorm2
+        {
+            cur = ggml_norm(ctx0, cur, eps);
+
+            cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
+        }
+
+        cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
+        cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
+
+        if (ctx->use_gelu) {
+            cur = ggml_gelu_inplace(ctx0, cur);
+        } else {
+            cur = ggml_gelu_quick_inplace(ctx0, cur);
+        }
+
+        cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
+        cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
+
+        // residual 2
+        cur = ggml_add(ctx0, embeddings, cur);
+
+        embeddings = cur;
+    }
+
+    // post-layernorm
+    if (ctx->has_post_norm) {
+        embeddings = ggml_norm(ctx0, embeddings, eps);
+        ggml_set_name(embeddings, "post_ln");
+
+        embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
+    }
+
+    // llava projector
+    if (ctx->has_llava_projector) {
+        embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
+
+        struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
+        ggml_set_name(patches, "patches");
+        ggml_set_input(patches);
+
+        // shape [1, 576, 1024]
+        // ne is whcn, ne = [1024, 576, 1, 1]
+        embeddings = ggml_get_rows(ctx0, embeddings, patches);
+
+        // print_tensor_info(embeddings, "embeddings");
+
+        // llava projector
+        if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
+            embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+            embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+
+            embeddings = ggml_gelu(ctx0, embeddings);
+            embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
+            embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
+        }
+        else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
+            embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+            embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+            // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
+            // First LayerNorm
+            embeddings = ggml_norm(ctx0, embeddings, eps);
+            embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_1_w),
+                                model.mm_1_b);
+
+            // GELU activation
+            embeddings = ggml_gelu(ctx0, embeddings);
+
+            // Second linear layer
+            embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings);
+            embeddings = ggml_add(ctx0, embeddings, model.mm_3_b);
+
+            // Second LayerNorm
+            embeddings = ggml_norm(ctx0, embeddings, eps);
+            embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w),
+                                model.mm_4_b);
+        }
+        else if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
+            // MobileVLM projector
+            int n_patch = 24;
+            struct ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings);
+            mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b);
+            mlp_1 = ggml_gelu(ctx0, mlp_1);
+            struct ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1);
+            mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b);
+            // mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1]
+
+            // block 1
+            struct ggml_tensor * block_1 = nullptr;
+            {
+                // transpose from [1, 576, 2048] --> [1, 2048, 576] --> [1, 2048, 24, 24]
+                mlp_3 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_3, 1, 0, 2, 3));
+                mlp_3 = ggml_reshape_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]);
+                // stride = 1, padding = 1, bias is nullptr
+                block_1 = ggml_conv_depthwise_2d(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);
+
+                // layer norm
+                // // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
+                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
+                // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
+                block_1 = ggml_norm(ctx0, block_1, eps);
+                block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_0_1_w), model.mm_model_block_1_block_0_1_b);
+                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
+
+                // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
+                // hardswish
+                struct ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
+
+                block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
+                // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
+                // pointwise conv
+                block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
+                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1);
+                block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b);
+                block_1 = ggml_relu(ctx0, block_1);
+                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1);
+                block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b);
+                block_1 = ggml_hardsigmoid(ctx0, block_1);
+                // block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1]
+                block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
+                block_1 = ggml_mul(ctx0, block_1_hw, block_1);
+
+                int w = block_1->ne[0], h = block_1->ne[1];
+                block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
+                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
+
+                // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
+                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1);
+                block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
+
+                // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
+                block_1 = ggml_norm(ctx0, block_1, eps);
+                block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_2_1_w), model.mm_model_block_1_block_2_1_b);
+                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
+                // block1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
+                // residual
+                block_1 = ggml_add(ctx0, mlp_3, block_1);
+            }
+
+            // block_2
+            {
+                // stride = 2
+                block_1 = ggml_conv_depthwise_2d(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);
+
+                // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
+                // layer norm
+                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
+                // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
+                block_1 = ggml_norm(ctx0, block_1, eps);
+                block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_0_1_w), model.mm_model_block_2_block_0_1_b);
+                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
+                // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
+                // hardswish
+                struct ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
+
+                // not sure the parameters is right for globalAvgPooling
+                block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
+                // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
+                // pointwise conv
+                block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
+                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1);
+                block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b);
+                block_1 = ggml_relu(ctx0, block_1);
+                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1);
+                block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b);
+                block_1 = ggml_hardsigmoid(ctx0, block_1);
+
+                // block_1_hw shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1], block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
+                block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
+                block_1 = ggml_mul(ctx0, block_1_hw, block_1);
+
+                int w = block_1->ne[0], h = block_1->ne[1];
+                block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
+                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
+                // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
+                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1);
+                block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
+
+
+                // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
+                block_1 = ggml_norm(ctx0, block_1, eps);
+                block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_2_1_w), model.mm_model_block_2_block_2_1_b);
+                block_1 = ggml_reshape_3d(ctx0, block_1, block_1->ne[0], block_1->ne[1] * block_1->ne[2], block_1->ne[3]);
+                // block_1 shape = [1, 144, 2048], ne = [2048, 144, 1]
+            }
+            embeddings = block_1;
+        }
+        else if (ctx->proj_type == PROJECTOR_TYPE_LDPV2)
+        {
+            int n_patch = 24;
+            struct ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
+            mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
+            mlp_0 = ggml_gelu(ctx0, mlp_0);
+            struct ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
+            mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
+            // mlp_2 ne = [2048, 576, 1, 1]
+            // // AVG Pool Layer 2*2, strides = 2
+            mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3));
+            // mlp_2 ne = [576, 2048, 1, 1]
+            mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]);
+            // mlp_2 ne [24, 24, 2048, 1]
+            mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
+            // weight ne = [3, 3, 2048, 1]
+            struct ggml_tensor * peg_0 = ggml_conv_depthwise_2d(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
+            peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
+            peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
+            mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3));
+            peg_0 = ggml_add(ctx0, peg_0, mlp_2);
+            peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]);
+            embeddings = peg_0;
+        }
+        else {
+            GGML_ABORT("fatal error");
+        }
+    }
+    // minicpmv projector
+    else if (ctx->has_minicpmv_projector)
+    {
+        if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
+            struct ggml_tensor * q = model.mm_model_query;
+            { // layernorm
+                q = ggml_norm(ctx0, q, eps);
+                q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
+            }
+            struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
+            { // layernorm
+                v = ggml_norm(ctx0, v, eps);
+                v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b);
+            }
+            struct ggml_tensor * k;
+            { // position
+                // q = ggml_add(ctx0, q, model.mm_model_pos_embed);
+                k = ggml_add(ctx0, v, pos_embed);
+            }
+
+            { // attention
+                int hidden_size = 4096;
+                const int d_head = 128;
+                int n_head = hidden_size/d_head;
+                int num_query = 96;
+                if (ctx->minicpmv_version == 2) {
+                    hidden_size = 4096;
+                    n_head = hidden_size/d_head;
+                    num_query = 96;
+                }
+                else if (ctx->minicpmv_version == 3) {
+                    hidden_size = 3584;
+                    n_head = hidden_size/d_head;
+                    num_query = 64;
+                }
+
+                struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
+                Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
+                struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
+                struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
+                // permute
+                Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
+                Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
+                Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
+                K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
+                K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
+                K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
+                V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
+                V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
+                V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
+                struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+                KQ = ggml_soft_max_inplace(ctx0, KQ);
+                struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
+                KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
+                KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+                KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
+
+                embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
+            }
+            { // layernorm
+                embeddings = ggml_norm(ctx0, embeddings, eps);
+                embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b);
+            }
+            embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
+        }
+        else {
+            GGML_ASSERT(false);
+        }
+    }
+
+    // build the graph
+    ggml_build_forward_expand(gf, embeddings);
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+
+// read and create ggml_context containing the tensors and their data
+struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
+    struct ggml_context * meta = NULL;
+
+    struct gguf_init_params params = {
+        /*.no_alloc = */ true,
+        /*.ctx      = */ &meta,
+    };
+
+    struct gguf_context * ctx = gguf_init_from_file(fname, params);
+    if (!ctx) {
+        throw std::runtime_error(format("%s: failed to load CLIP model from %s. Does this file exist?\n", __func__, fname));
+    }
+
+    if (verbosity >= 1) {
+        const int n_tensors = gguf_get_n_tensors(ctx);
+        const int n_kv = gguf_get_n_kv(ctx);
+        const int ftype = get_u32(ctx, KEY_FTYPE);
+        const std::string ftype_str = get_ftype(ftype);
+        const int idx_desc = get_key_idx(ctx, KEY_DESCRIPTION);
+        const std::string description = gguf_get_val_str(ctx, idx_desc);
+        const int idx_name = gguf_find_key(ctx, KEY_NAME);
+        if (idx_name != -1) { // make name optional temporarily as some of the uploaded models missing it due to a bug
+            const std::string name = gguf_get_val_str(ctx, idx_name);
+            LOG_TEE("%s: model name:   %s\n", __func__, name.c_str());
+        }
+        LOG_TEE("%s: description:  %s\n", __func__, description.c_str());
+        LOG_TEE("%s: GGUF version: %d\n", __func__, gguf_get_version(ctx));
+        LOG_TEE("%s: alignment:    %zu\n", __func__, gguf_get_alignment(ctx));
+        LOG_TEE("%s: n_tensors:    %d\n", __func__, n_tensors);
+        LOG_TEE("%s: n_kv:         %d\n", __func__, n_kv);
+        LOG_TEE("%s: ftype:        %s\n", __func__, ftype_str.c_str());
+        LOG_TEE("\n");
+    }
+    const int n_tensors = gguf_get_n_tensors(ctx);
+
+    // kv
+    const int n_kv = gguf_get_n_kv(ctx);
+    LOG_TEE("%s: loaded meta data with %d key-value pairs and %d tensors from %s\n",
+        __func__, n_kv, n_tensors, fname);
+    {
+        std::map<enum ggml_type, uint32_t> n_type;
+
+        for (int i = 0; i < n_tensors; i++) {
+            enum ggml_type type = gguf_get_tensor_type(ctx, i);
+
+            n_type[type]++;
+        }
+
+        LOG_TEE("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__);
+        for (int i = 0; i < n_kv; i++) {
+            const char * name           = gguf_get_key(ctx, i);
+            const enum gguf_type type   = gguf_get_kv_type(ctx, i);
+            const std::string type_name =
+                type == GGUF_TYPE_ARRAY
+                ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(ctx, i)), gguf_get_arr_n(ctx, i))
+                : gguf_type_name(type);
+
+            std::string value          = gguf_kv_to_str(ctx, i);
+            const size_t MAX_VALUE_LEN = 40;
+            if (value.size() > MAX_VALUE_LEN) {
+                value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str());
+            }
+            replace_all(value, "\n", "\\n");
+
+            LOG_TEE("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str());
+        }
+
+        // print type counts
+        for (auto & kv : n_type) {
+            if (kv.second == 0) {
+                continue;
+            }
+
+            LOG_TEE("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second);
+        }
+    }
+
+    // data
+    size_t model_size = 0;
+    {
+        for (int i = 0; i < n_tensors; ++i) {
+            const char * name = gguf_get_tensor_name(ctx, i);
+            const size_t offset = gguf_get_tensor_offset(ctx, i);
+            enum ggml_type type = gguf_get_tensor_type(ctx, i);
+            struct ggml_tensor * cur = ggml_get_tensor(meta, name);
+            size_t tensor_size = ggml_nbytes(cur);
+            model_size += tensor_size;
+            if (verbosity >= 3) {
+                LOG_TEE("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n",
+                       __func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type));
+            }
+        }
+    }
+
+    clip_ctx * new_clip = new clip_ctx{};
+
+    // update projector type
+    {
+        int idx = gguf_find_key(ctx, KEY_PROJ_TYPE);
+        if (idx != -1) {
+            const std::string proj_type = gguf_get_val_str(ctx, idx);
+            new_clip->proj_type = clip_projector_type_from_string(proj_type);
+        } else {
+            new_clip->proj_type = PROJECTOR_TYPE_MLP;
+        }
+
+        if (new_clip->proj_type == PROJECTOR_TYPE_MLP) {
+            if (gguf_find_tensor(ctx, format(TN_LLAVA_PROJ, 3, "weight").c_str()) != -1) {
+                new_clip->proj_type = PROJECTOR_TYPE_MLP_NORM;
+            }
+        }
+    }
+
+#ifdef GGML_USE_CUDA
+    new_clip->backend = ggml_backend_cuda_init(0);
+    LOG_TEE("%s: CLIP using CUDA backend\n", __func__);
+#endif
+
+#ifdef GGML_USE_METAL
+    new_clip->backend = ggml_backend_metal_init();
+    LOG_TEE("%s: CLIP using Metal backend\n", __func__);
+#endif
+
+#ifdef GGML_USE_CANN
+    new_clip->backend = ggml_backend_cann_init(0);
+    LOG_TEE("%s: CLIP using CANN backend\n", __func__);
+#endif
+
+#ifdef GGML_USE_VULKAN
+    new_clip->backend = ggml_backend_vk_init(0);
+    LOG_TEE("%s: CLIP using Vulkan backend\n", __func__);
+#endif
+
+    if (!new_clip->backend) {
+        new_clip->backend = ggml_backend_cpu_init();
+        LOG_TEE("%s: CLIP using CPU backend\n", __func__);
+    }
+
+    // model size and capabilities
+    {
+        int idx = get_key_idx(ctx, KEY_HAS_TEXT_ENC);
+        new_clip->has_text_encoder = gguf_get_val_bool(ctx, idx);
+
+        idx = get_key_idx(ctx, KEY_HAS_VIS_ENC);
+        new_clip->has_vision_encoder = gguf_get_val_bool(ctx, idx);
+
+        idx = gguf_find_key(ctx, KEY_HAS_LLAVA_PROJ);
+        if (idx != -1) {
+            new_clip->has_llava_projector = gguf_get_val_bool(ctx, idx);
+        }
+
+        idx = gguf_find_key(ctx, KEY_HAS_MINICPMV_PROJ);
+        if (idx != -1) {
+            new_clip->has_minicpmv_projector = gguf_get_val_bool(ctx, idx);
+        }
+
+        idx = gguf_find_key(ctx, KEY_MINICPMV_VERSION);
+        if (idx != -1) {
+            new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx);
+        }
+
+        // GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search
+
+        GGML_ASSERT(new_clip->has_vision_encoder);
+        GGML_ASSERT(!new_clip->has_text_encoder);
+
+        idx = get_key_idx(ctx, KEY_USE_GELU);
+        new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
+
+        if (verbosity >= 1) {
+            LOG_TEE("%s: text_encoder:   %d\n", __func__, new_clip->has_text_encoder);
+            LOG_TEE("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
+            LOG_TEE("%s: llava_projector:  %d\n", __func__, new_clip->has_llava_projector);
+            LOG_TEE("%s: minicpmv_projector:  %d\n", __func__, new_clip->has_minicpmv_projector);
+            LOG_TEE("%s: model size:     %.2f MB\n", __func__, model_size / 1024.0 / 1024.0);
+            LOG_TEE("%s: metadata size:  %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
+        }
+    }
+
+    LOG_TEE("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, model_size / (1024.0 * 1024.0), n_tensors);
+
+    // load tensors
+    {
+        std::vector<uint8_t> read_buf;
+        struct ggml_init_params params = {
+            /*.mem_size =*/ (n_tensors + 1) * ggml_tensor_overhead(),
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc =*/ true,
+        };
+
+        new_clip->ctx_data = ggml_init(params);
+        if (!new_clip->ctx_data) {
+            LOG_TEE("%s: ggml_init() failed\n", __func__);
+            clip_free(new_clip);
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+#ifdef _WIN32
+        int wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, NULL, 0);
+        if (!wlen) {
+            return NULL;
+        }
+        wchar_t * wbuf = (wchar_t *) malloc(wlen * sizeof(wchar_t));
+        wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, wbuf, wlen);
+        if (!wlen) {
+            free(wbuf);
+            return NULL;
+        }
+#if __GLIBCXX__
+        int fd = _wopen(wbuf, _O_RDONLY | _O_BINARY);
+        __gnu_cxx::stdio_filebuf<char> buffer(fd, std::ios_base::in);
+        std::istream fin(&buffer);
+#else // MSVC
+        // unused in our current build
+        auto fin = std::ifstream(wbuf, std::ios::binary);
+#endif
+        free(wbuf);
+#else
+        auto fin = std::ifstream(fname, std::ios::binary);
+#endif
+        if (!fin) {
+            LOG_TEE("cannot open model file for loading tensors\n");
+            clip_free(new_clip);
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+        // add tensors to context
+        for (int i = 0; i < n_tensors; ++i) {
+            const char * name = gguf_get_tensor_name(ctx, i);
+            struct ggml_tensor * t = ggml_get_tensor(meta, name);
+            struct ggml_tensor * cur = ggml_dup_tensor(new_clip->ctx_data, t);
+            ggml_set_name(cur, name);
+        }
+
+        // alloc memory and offload data
+        new_clip->params_buffer = ggml_backend_alloc_ctx_tensors(new_clip->ctx_data, new_clip->backend);
+        for (int i = 0; i < n_tensors; ++i) {
+            const char * name = gguf_get_tensor_name(ctx, i);
+            struct ggml_tensor * cur = ggml_get_tensor(new_clip->ctx_data, name);
+            const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i);
+            fin.seekg(offset, std::ios::beg);
+            if (!fin) {
+                LOG_TEE("%s: failed to seek for tensor %s\n", __func__, name);
+                clip_free(new_clip);
+                gguf_free(ctx);
+                return nullptr;
+            }
+            int num_bytes = ggml_nbytes(cur);
+            if (ggml_backend_buffer_is_host(new_clip->params_buffer)) {
+                // for the CPU and Metal backend, we can read directly into the tensor
+                fin.read(reinterpret_cast<char *>(cur->data), num_bytes);
+            } else {
+                // read into a temporary buffer first, then copy to device memory
+                read_buf.resize(num_bytes);
+                fin.read(reinterpret_cast<char *>(read_buf.data()), num_bytes);
+                ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
+            }
+        }
+#if defined(_WIN32) && defined(__GLIBCXX__)
+        close(fd);
+#else
+        fin.close();
+#endif
+    }
+
+    // vision model
+    if (new_clip->has_vision_encoder) {
+        // load vision model
+        auto & vision_model = new_clip->vision_model;
+        auto & hparams = vision_model.hparams;
+        hparams.hidden_size    = get_u32(ctx, format(KEY_N_EMBD, "vision"));
+        hparams.n_head         = get_u32(ctx, format(KEY_N_HEAD, "vision"));
+        hparams.n_intermediate = get_u32(ctx, format(KEY_N_FF, "vision"));
+        hparams.n_layer        = get_u32(ctx, format(KEY_N_BLOCK, "vision"));
+        hparams.image_size     = get_u32(ctx, KEY_IMAGE_SIZE);
+        hparams.patch_size     = get_u32(ctx, KEY_PATCH_SIZE);
+        hparams.projection_dim = get_u32(ctx, format(KEY_PROJ_DIM, "vision"));
+        hparams.eps            = get_f32(ctx, format(KEY_LAYER_NORM_EPS, "vision"));
+
+        try {
+            int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS);
+            int n = gguf_get_arr_n(ctx, idx);
+            const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx);
+            for (int i = 0; i < 32 && i < n && pinpoints[i] != 0; ++i) {
+                hparams.image_grid_pinpoints[i] = pinpoints[i];
+            }
+            if (n < 32)
+                hparams.image_grid_pinpoints[n] = 0;
+        } catch (std::runtime_error & /*e*/) {
+            hparams.image_grid_pinpoints[0]=0;
+        }
+
+        try {
+            int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE);
+            strcpy(hparams.mm_patch_merge_type, gguf_get_val_str(ctx, idx));
+        } catch (std::runtime_error & /*e*/) {
+            strcpy(hparams.mm_patch_merge_type, "flat");
+        }
+
+        try {
+            hparams.image_crop_resolution = get_u32(ctx, KEY_IMAGE_CROP_RESOLUTION); // llava-1.6
+        } catch(const std::exception& /*e*/) {
+            hparams.image_crop_resolution = hparams.image_size;
+        }
+
+        int idx_mean = get_key_idx(ctx, KEY_IMAGE_MEAN);
+        int idx_std  = get_key_idx(ctx, KEY_IMAGE_STD);
+
+        const float * mean_data = (const float *)gguf_get_arr_data(ctx, idx_mean);
+        const float * std_data  = (const float *)gguf_get_arr_data(ctx, idx_std);
+
+        for (int i = 0; i < 3; ++i) {
+            new_clip->image_mean[i] = mean_data[i];
+            new_clip->image_std[i]  = std_data[i];
+        }
+
+        if (verbosity >= 2) {
+            LOG_TEE("\n%s: vision model hparams\n", __func__);
+            LOG_TEE("image_size         %d\n", hparams.image_size);
+            LOG_TEE("patch_size         %d\n", hparams.patch_size);
+            LOG_TEE("v_hidden_size      %d\n", hparams.hidden_size);
+            LOG_TEE("v_n_intermediate   %d\n", hparams.n_intermediate);
+            LOG_TEE("v_projection_dim   %d\n", hparams.projection_dim);
+            LOG_TEE("v_n_head           %d\n", hparams.n_head);
+            LOG_TEE("v_n_layer          %d\n", hparams.n_layer);
+            LOG_TEE("v_eps              %f\n", hparams.eps);
+            LOG_TEE("v_image_mean       %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]);
+            LOG_TEE("v_image_std        %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]);
+            LOG_TEE("v_image_grid_pinpoints: ");
+            for (int i = 0; i < 32 && (hparams.image_grid_pinpoints[i] != 0); ++i) {
+                LOG_TEE("%d ", hparams.image_grid_pinpoints[i]);
+            }
+            LOG_TEE("\n");
+            LOG_TEE("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type);
+
+        }
+
+        try {
+            vision_model.class_embedding  = get_tensor(new_clip->ctx_data, TN_CLASS_EMBD);
+            new_clip->has_class_embedding = true;
+        } catch (const std::exception& /*e*/) {
+            new_clip->has_class_embedding = false;
+        }
+
+        try {
+            vision_model.pre_ln_w  = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight"));
+            vision_model.pre_ln_b  = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias"));
+            new_clip->has_pre_norm = true;
+        } catch (std::exception & /*e*/) {
+            new_clip->has_pre_norm = false;
+        }
+
+        try {
+            vision_model.post_ln_w  = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight"));
+            vision_model.post_ln_b  = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "bias"));
+            new_clip->has_post_norm = true;
+        } catch (std::exception & /*e*/) {
+            new_clip->has_post_norm = false;
+        }
+
+        try {
+            vision_model.patch_bias = get_tensor(new_clip->ctx_data, TN_PATCH_BIAS);
+            new_clip->has_patch_bias = true;
+        } catch (std::exception & /*e*/) {
+            new_clip->has_patch_bias = false;
+        }
+
+        try {
+            vision_model.patch_embeddings    = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
+            vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
+        } catch(const std::exception& /*e*/) {
+            LOG_TEE("%s: failed to load vision model tensors\n", __func__);
+        }
+
+        // LLaVA projection
+        if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) {
+            vision_model.mm_0_w              = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight"));
+            vision_model.mm_0_b              = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias"));
+            try {
+                // Yi-type llava
+                vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 1, "weight"));
+                vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 1, "bias"));
+            } catch (std::runtime_error & /*e*/) { }
+            try {
+                // missing in Yi-type llava
+                vision_model.mm_2_w              = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
+                vision_model.mm_2_b              = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
+            } catch (std::runtime_error & /*e*/) { }
+            try {
+                // Yi-type llava
+                vision_model.mm_3_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 3, "weight"));
+                vision_model.mm_3_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 3, "bias"));
+            } catch (std::runtime_error & /*e*/) { }
+            try {
+                // Yi-type llava
+                vision_model.mm_4_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "weight"));
+                vision_model.mm_4_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "bias"));
+            } catch (std::runtime_error & /*e*/) { }
+            try {
+                vision_model.image_newline = get_tensor(new_clip->ctx_data, TN_IMAGE_NEWLINE);
+                // LOG_TEE("%s: image_newline tensor (llava-1.6) found\n", __func__);
+            } catch (std::runtime_error & /*e*/) { }
+        } else if (new_clip->proj_type == PROJECTOR_TYPE_LDP) {
+            // MobileVLM projection
+            vision_model.mm_model_mlp_1_w               = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 1, "weight"));
+            vision_model.mm_model_mlp_1_b               = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 1, "bias"));
+            vision_model.mm_model_mlp_3_w               = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 3, "weight"));
+            vision_model.mm_model_mlp_3_b               = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 3, "bias"));
+            vision_model.mm_model_block_1_block_0_0_w   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight"));
+            vision_model.mm_model_block_1_block_0_1_w   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight"));
+            vision_model.mm_model_block_1_block_0_1_b   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias"));
+            vision_model.mm_model_block_1_block_1_fc1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight"));
+            vision_model.mm_model_block_1_block_1_fc1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias"));
+            vision_model.mm_model_block_1_block_1_fc2_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight"));
+            vision_model.mm_model_block_1_block_1_fc2_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias"));
+            vision_model.mm_model_block_1_block_2_0_w   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight"));
+            vision_model.mm_model_block_1_block_2_1_w   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight"));
+            vision_model.mm_model_block_1_block_2_1_b   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias"));
+            vision_model.mm_model_block_2_block_0_0_w   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight"));
+            vision_model.mm_model_block_2_block_0_1_w   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight"));
+            vision_model.mm_model_block_2_block_0_1_b   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias"));
+            vision_model.mm_model_block_2_block_1_fc1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight"));
+            vision_model.mm_model_block_2_block_1_fc1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias"));
+            vision_model.mm_model_block_2_block_1_fc2_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight"));
+            vision_model.mm_model_block_2_block_1_fc2_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias"));
+            vision_model.mm_model_block_2_block_2_0_w   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight"));
+            vision_model.mm_model_block_2_block_2_1_w   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight"));
+            vision_model.mm_model_block_2_block_2_1_b   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias"));
+        }
+        else if (new_clip->proj_type == PROJECTOR_TYPE_LDPV2)
+        {
+            // MobilVLM_V2 projection
+            vision_model.mm_model_mlp_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 0, "weight"));
+            vision_model.mm_model_mlp_0_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 0, "bias"));
+            vision_model.mm_model_mlp_2_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 2, "weight"));
+            vision_model.mm_model_mlp_2_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 2, "bias"));
+            vision_model.mm_model_peg_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "weight"));
+            vision_model.mm_model_peg_0_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "bias"));
+        }
+        else if (new_clip->proj_type == PROJECTOR_TYPE_RESAMPLER) {
+            // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD);
+            vision_model.mm_model_pos_embed_k = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD_K);
+            vision_model.mm_model_query = get_tensor(new_clip->ctx_data, TN_MINICPMV_QUERY);
+            vision_model.mm_model_proj = get_tensor(new_clip->ctx_data, TN_MINICPMV_PROJ);
+            vision_model.mm_model_kv_proj = get_tensor(new_clip->ctx_data, TN_MINICPMV_KV_PROJ);
+            vision_model.mm_model_attn_q_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "q", "weight"));
+            vision_model.mm_model_attn_k_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "k", "weight"));
+            vision_model.mm_model_attn_v_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "v", "weight"));
+            vision_model.mm_model_attn_q_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "q", "bias"));
+            vision_model.mm_model_attn_k_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "k", "bias"));
+            vision_model.mm_model_attn_v_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "v", "bias"));
+            vision_model.mm_model_attn_o_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "out", "weight"));
+            vision_model.mm_model_attn_o_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "out", "bias"));
+            vision_model.mm_model_ln_q_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "q", "weight"));
+            vision_model.mm_model_ln_q_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "q", "bias"));
+            vision_model.mm_model_ln_kv_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "kv", "weight"));
+            vision_model.mm_model_ln_kv_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "kv", "bias"));
+            vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight"));
+            vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias"));
+        }
+        else {
+            std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
+            throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
+        }
+
+        vision_model.layers.resize(hparams.n_layer);
+
+        for (int il = 0; il < hparams.n_layer; ++il) {
+            auto & layer = vision_model.layers[il];
+            layer.k_w    = get_tensor(new_clip->ctx_data, format(TN_ATTN_K,      "v", il, "weight"));
+            layer.q_w    = get_tensor(new_clip->ctx_data, format(TN_ATTN_Q,      "v", il, "weight"));
+            layer.v_w    = get_tensor(new_clip->ctx_data, format(TN_ATTN_V,      "v", il, "weight"));
+            layer.o_w    = get_tensor(new_clip->ctx_data, format(TN_ATTN_OUTPUT, "v", il, "weight"));
+            layer.ln_1_w = get_tensor(new_clip->ctx_data, format(TN_LN_1,        "v", il, "weight"));
+            layer.ln_2_w = get_tensor(new_clip->ctx_data, format(TN_LN_2,        "v", il, "weight"));
+            layer.ff_i_w = get_tensor(new_clip->ctx_data, format(TN_FFN_DOWN,    "v", il, "weight"));
+            layer.ff_o_w = get_tensor(new_clip->ctx_data, format(TN_FFN_UP,      "v", il, "weight"));
+            layer.k_b    = get_tensor(new_clip->ctx_data, format(TN_ATTN_K,      "v", il, "bias"));
+            layer.q_b    = get_tensor(new_clip->ctx_data, format(TN_ATTN_Q,      "v", il, "bias"));
+            layer.v_b    = get_tensor(new_clip->ctx_data, format(TN_ATTN_V,      "v", il, "bias"));
+            layer.o_b    = get_tensor(new_clip->ctx_data, format(TN_ATTN_OUTPUT, "v", il, "bias"));
+            layer.ln_1_b = get_tensor(new_clip->ctx_data, format(TN_LN_1,        "v", il, "bias"));
+            layer.ln_2_b = get_tensor(new_clip->ctx_data, format(TN_LN_2,        "v", il, "bias"));
+            layer.ff_i_b = get_tensor(new_clip->ctx_data, format(TN_FFN_DOWN,    "v", il, "bias"));
+            layer.ff_o_b = get_tensor(new_clip->ctx_data, format(TN_FFN_UP,      "v", il, "bias"));
+        }
+    }
+
+    ggml_free(meta);
+
+    new_clip->ctx_gguf = ctx;
+
+    // measure mem requirement and allocate
+    {
+        new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
+        new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend));
+        clip_image_f32_batch batch;
+        batch.size = 1;
+        ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false);
+        ggml_gallocr_reserve(new_clip->compute_alloc, gf);
+        size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0);
+        LOG_TEE("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0);
+    }
+
+    return new_clip;
+}
+
+void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) {
+    ctx_clip->load_image_size = load_image_size;
+}
+
+struct clip_image_size * clip_image_size_init() {
+    struct clip_image_size * load_image_size = new struct clip_image_size();
+    load_image_size->width = 448;
+    load_image_size->height = 448;
+    return load_image_size;
+}
+
+struct clip_image_u8 * clip_image_u8_init() {
+    return new clip_image_u8();
+}
+
+struct clip_image_f32 * clip_image_f32_init() {
+    return new clip_image_f32();
+}
+
+void clip_image_u8_free(struct clip_image_u8  * img) { delete img; }
+void clip_image_f32_free(struct clip_image_f32 * img) { delete img; }
+void clip_image_u8_batch_free(struct clip_image_u8_batch  * batch) {
+    if (batch->size > 0) {
+        delete[] batch->data;
+        batch->size = 0;
+    }
+}
+void clip_image_f32_batch_free(struct clip_image_f32_batch  * batch) {
+    if (batch->size > 0) {
+        delete[] batch->data;
+        batch->size = 0;
+    }
+}
+
+static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_image_u8 * img) {
+    img->nx = nx;
+    img->ny = ny;
+    img->buf.resize(3 * nx * ny);
+    memcpy(img->buf.data(), data, img->buf.size());
+}
+
+bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
+    int nx, ny, nc;
+    auto * data = stbi_load(fname, &nx, &ny, &nc, 3);
+    if (!data) {
+        LOG_TEE("%s: failed to load image '%s'\n", __func__, fname);
+        return false;
+    }
+    build_clip_img_from_data(data, nx, ny, img);
+    stbi_image_free(data);
+    return true;
+}
+
+bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img) {
+    int nx, ny, nc;
+    auto * data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3);
+    if (!data) {
+        LOG_TEE("%s: failed to decode image bytes\n", __func__);
+        return false;
+    }
+    build_clip_img_from_data(data, nx, ny, img);
+    stbi_image_free(data);
+    return true;
+}
+
+// Linear interpolation between two points
+inline float clip_lerp(float s, float e, float t) {
+    return s + (e - s) * t;
+}
+// Bilinear resize function
+static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
+    dst.nx = target_width;
+    dst.ny = target_height;
+    dst.buf.resize(3 * target_width * target_height);
+
+    float x_ratio = static_cast<float>(src.nx - 1) / target_width;
+    float y_ratio = static_cast<float>(src.ny - 1) / target_height;
+
+    for (int y = 0; y < target_height; y++) {
+        for (int x = 0; x < target_width; x++) {
+            float px = x_ratio * x;
+            float py = y_ratio * y;
+            int x_floor = static_cast<int>(px);
+            int y_floor = static_cast<int>(py);
+            float x_lerp = px - x_floor;
+            float y_lerp = py - y_floor;
+
+            for (int c = 0; c < 3; c++) {
+                float top = clip_lerp(
+                    static_cast<float>(src.buf[3 * (y_floor * src.nx + x_floor) + c]),
+                    static_cast<float>(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]),
+                    x_lerp
+                );
+                float bottom = clip_lerp(
+                    static_cast<float>(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]),
+                    static_cast<float>(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]),
+                    x_lerp
+                );
+                dst.buf[3 * (y * target_width + x) + c] = static_cast<uint8_t>(clip_lerp(top, bottom, y_lerp));
+            }
+        }
+    }
+}
+
+// Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not
+static void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3]) {
+    dst->nx = src->nx;
+    dst->ny = src->ny;
+    dst->buf.resize(src->buf.size());
+
+    for (size_t i = 0; i < src->buf.size(); ++i) {
+        int c = i % 3; // rgb
+        dst->buf[i] = (static_cast<float>(src->buf[i]) / 255.0f - mean[c]) / std[c];
+    }
+}
+
+inline int clip(int x, int lower, int upper) {
+    return std::max(lower, std::min(x, upper));
+}
+
+static bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_width, int target_height) {
+    const int nx = img.nx;
+    const int ny = img.ny;
+
+    dst.nx = target_width;
+    dst.ny = target_height;
+    dst.buf.resize(3 * target_width * target_height);
+
+    float Cc;
+    float C[5];
+    float d0, d2, d3, a0, a1, a2, a3;
+    int i, j, k, jj;
+    int x, y;
+    float dx, dy;
+    float tx, ty;
+
+    tx = (float)nx / (float)target_width;
+    ty = (float)ny / (float)target_height;
+
+    // Bicubic interpolation; adapted from ViT.cpp, inspired from :
+    //    -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36
+    //    -> https://en.wikipedia.org/wiki/Bicubic_interpolation
+
+    for (i = 0; i < target_height; i++) {
+        for (j = 0; j < target_width; j++) {
+            x = (int)(tx * j);
+            y = (int)(ty * i);
+
+            dx = tx * j - x;
+            dy = ty * i - y;
+
+            for (k = 0; k < 3; k++) {
+                for (jj = 0; jj <= 3; jj++) {
+                    d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+                    d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+                    d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+                    a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+
+                    a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
+                    a2 =  1.0 / 2 * d0 +      1.0 / 2 * d2;
+                    a3 = -1.0 / 6 * d0 -      1.0 / 2 * d2 + 1.0 / 6 * d3;
+
+                    C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx;
+
+                    d0 = C[0] - C[1];
+                    d2 = C[2] - C[1];
+                    d3 = C[3] - C[1];
+                    a0 = C[1];
+                    a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
+                    a2 =  1.0 / 2 * d0 +      1.0 / 2 * d2;
+                    a3 = -1.0 / 6 * d0 -      1.0 / 2 * d2 + 1.0 / 6 * d3;
+                    Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy;
+
+                    const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f);
+                    dst.buf[(i * target_width + j) * 3 + k] = float(Cc2);
+                }
+            }
+        }
+    }
+
+    return true;
+}
+
+// llava-1.6 type of resize_and_pad (black)
+static void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &image_output, const std::pair<int, int>& target_resolution) {
+    int target_width = target_resolution.first;
+    int target_height = target_resolution.second;
+
+    float scale_w = static_cast<float>(target_width) / image.nx;
+    float scale_h = static_cast<float>(target_height) / image.ny;
+
+    int new_width, new_height;
+
+    if (scale_w < scale_h) {
+        new_width = target_width;
+        new_height = std::min(static_cast<int>(std::ceil(image.ny * scale_w)), target_height);
+    } else {
+        new_height = target_height;
+        new_width = std::min(static_cast<int>(std::ceil(image.nx * scale_h)), target_width);
+    }
+
+    clip_image_u8 resized_image;
+    // bilinear_resize(image, resized_image, new_width, new_height);
+    bicubic_resize(image, resized_image, new_width, new_height);
+
+    clip_image_u8 padded_image;
+    padded_image.nx = target_width;
+    padded_image.ny = target_height;
+    padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black
+
+    // Calculate padding offsets
+    int pad_x = (target_width - new_width) / 2;
+    int pad_y = (target_height - new_height) / 2;
+
+    // Copy the resized image into the center of the padded buffer
+    for (int y = 0; y < new_height; ++y) {
+        for (int x = 0; x < new_width; ++x) {
+            for (int c = 0; c < 3; ++c) {
+                padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c];
+            }
+        }
+    }
+    image_output = std::move(padded_image);
+}
+
+/**
+ * Selects the best resolution from a list of possible resolutions based on the original size.
+ *
+ * @param original_size The original size of the image in the format (width, height).
+ * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
+ * @return The best fit resolution in the format (width, height).
+ */
+static std::pair<int, int> select_best_resolution(const std::pair<int, int> & original_size, const std::vector<std::pair<int, int>> & possible_resolutions) {
+    int original_width = original_size.first;
+    int original_height = original_size.second;
+    std::pair<int, int> best_fit;
+    int max_effective_resolution = 0;
+    int min_wasted_resolution = std::numeric_limits<int>::max();
+
+    for (const auto& resolution : possible_resolutions) {
+        int width = resolution.first;
+        int height = resolution.second;
+        float scale = std::min(static_cast<float>(width) / original_width, static_cast<float>(height) / original_height);
+        int downscaled_width = static_cast<int>(original_width * scale);
+        int downscaled_height = static_cast<int>(original_height * scale);
+        int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height);
+        int wasted_resolution = (width * height) - effective_resolution;
+        // LOG_TEE("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution);
+        if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) {
+            max_effective_resolution = effective_resolution;
+            min_wasted_resolution = wasted_resolution;
+            best_fit = resolution;
+        }
+    }
+
+    return best_fit;
+}
+
+static std::vector<clip_image_u8*> divide_to_patches_u8(const clip_image_u8 & image, int patch_size) {
+    std::vector<clip_image_u8*> patches;
+    int width = image.nx;
+    int height = image.ny;
+    for (int i = 0; i < height; i += patch_size) {
+        for (int j = 0; j < width; j += patch_size) {
+            clip_image_u8 *patch = clip_image_u8_init();
+            patch->nx = std::min(patch_size, width - j);
+            patch->ny = std::min(patch_size, height - i);
+            patch->buf.resize(3 * patch->nx * patch->ny);
+            for (int y = 0; y < patch->ny; ++y) {
+                for (int x = 0; x < patch->nx; ++x) {
+                    for (int c = 0; c < 3; ++c) {
+                        patch->buf[3 * (y * patch->nx + x) + c] = image.buf[3 * ((i + y) * width + (j + x)) + c];
+                    }
+                }
+            }
+            patches.push_back(patch);
+        }
+    }
+    return patches;
+}
+
+static int ensure_divide(int length, int patch_size) {
+    return std::max(static_cast<int>(std::round(static_cast<float>(length) / patch_size) * patch_size), patch_size);
+}
+
+static std::pair<int, int> uhd_find_best_resize(std::pair<int, int> original_size, int scale_resolution, int patch_size, bool allow_upscale = false) {
+    int width = original_size.first;
+    int height = original_size.second;
+    if ((width * height > scale_resolution * scale_resolution) || allow_upscale) {
+        float r = static_cast<float>(width) / height;
+        height = static_cast<int>(scale_resolution / std::sqrt(r));
+        width = static_cast<int>(height * r);
+    }
+    int best_width = ensure_divide(width, patch_size);
+    int best_height = ensure_divide(height, patch_size);
+    return std::make_pair(best_width, best_height);
+}
+
+static std::pair<int, int> uhd_get_refine_size(std::pair<int, int> original_size, std::pair<int, int> grid, int scale_resolution, int patch_size, bool allow_upscale = false) {
+    int width, height;
+    std::tie(width, height) = original_size;
+    int grid_x, grid_y;
+    std::tie(grid_x, grid_y) = grid;
+
+    int refine_width = ensure_divide(width, grid_x);
+    int refine_height = ensure_divide(height, grid_y);
+
+    int grid_width = refine_width / grid_x;
+    int grid_height = refine_height / grid_y;
+
+   // auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line)
+    auto best_grid_size = uhd_find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair
+    int best_grid_width, best_grid_height;
+    std::tie(best_grid_width, best_grid_height) = best_grid_size;
+
+  //  std::pair<int, int> refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line)
+    std::pair<int, int> refine_size = std::make_pair(best_grid_width * grid_x, best_grid_height * grid_y); // (new line)
+    return refine_size;
+}
+
+static std::pair<int, int> uhd_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) {
+    std::vector<int> candidate_split_grids_nums;
+    for (int i : {multiple - 1, multiple, multiple + 1}) {
+        if (i == 1 || i > max_slice_nums) {
+            continue;
+        }
+        candidate_split_grids_nums.push_back(i);
+    }
+
+    std::vector<std::pair<int, int>> candidate_grids;
+    for (int split_grids_nums : candidate_split_grids_nums) {
+        int m = 1;
+        while (m <= split_grids_nums) {
+            if (split_grids_nums % m == 0) {
+                candidate_grids.emplace_back(m, split_grids_nums / m);
+            }
+            ++m;
+        }
+    }
+
+    std::pair<int, int> best_grid{1, 1};
+    float min_error = std::numeric_limits<float>::infinity();
+    for (const auto& grid : candidate_grids) {
+        float error = std::abs(log_ratio - std::log(1.0 * grid.first / grid.second));
+        if (error < min_error) {
+            best_grid = grid;
+            min_error = error;
+        }
+    }
+    return best_grid;
+}
+
+// inspired from LLaVA-UHD:
+//    -> https://arxiv.org/pdf/2403.11703
+//    -> https://github.com/thunlp/LLaVA-UHD
+//    -> https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118
+static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) {
+    const std::pair<int, int> original_size={img->nx,img->ny};
+    const int original_width = img->nx;
+    const int original_height = img->ny;
+    const float log_ratio = log(1.0*original_width/original_height);
+    const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution);
+    const int multiple = fmin(ceil(ratio), max_slice_nums);
+
+    std::vector<std::vector<clip_image_u8 *>> images;
+    LOG_TEE("%s: multiple %d\n", __func__, multiple);
+    images.push_back(std::vector<clip_image_u8 *>());
+
+    if (multiple <= 1) {
+        auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true);
+        clip_image_u8 * source_image = clip_image_u8_init();
+        bicubic_resize(*img, *source_image, best_size.first, best_size.second);
+        // source_image = image.resize(best_size, Image.Resampling.BICUBIC)
+        images[images.size()-1].push_back(source_image);
+    }
+    else if (multiple > 1) {
+        auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size);
+        clip_image_u8 * source_image = clip_image_u8_init();
+        bicubic_resize(*img, *source_image, best_size.first, best_size.second);
+        // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
+        LOG_TEE("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second);
+        images[images.size()-1].push_back(source_image);
+
+        std::pair<int, int> best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio);
+        LOG_TEE("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second);
+
+        auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true);
+        clip_image_u8 * refine_image = clip_image_u8_init();
+        bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second);
+
+        LOG_TEE("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second);
+
+        // split_to_patches
+        int width = refine_image->nx;
+        int height = refine_image->ny;
+        int grid_x = int(width / best_grid.first);
+        int grid_y = int(height / best_grid.second);
+        for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){
+            images.push_back(std::vector<clip_image_u8 *>());
+            for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){
+                clip_image_u8 * patch = clip_image_u8_init();
+                patch->nx = grid_x;
+                patch->ny = grid_y;
+                patch->buf.resize(3 * patch->nx * patch->ny);
+                for (int y = patches_i; y < patches_i + grid_y; ++y) {
+                    for (int x = patches_j; x < patches_j + grid_x; ++x) {
+                        const int i = 3 * (y * refine_image->nx + x);
+                        const int j = 3 * ((y-patches_i) * patch->nx + (x-patches_j));
+                        patch->buf[j]   = refine_image->buf[i];
+                        patch->buf[j+1] = refine_image->buf[i+1];
+                        patch->buf[j+2] = refine_image->buf[i+2];
+                    }
+                }
+                images[images.size()-1].push_back(patch);
+            }
+        }
+    }
+    return images;
+}
+
+int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
+    const int max_slice_nums=9;
+    const int scale_resolution=448;
+    const int original_width = ctx_clip->load_image_size->width;
+    const int original_height = ctx_clip->load_image_size->height;
+    const float log_ratio = log(1.0*original_width/original_height);
+    const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution);
+    const int multiple = fmin(ceil(ratio), max_slice_nums);
+    std::pair<int, int> best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio);
+    return best_grid.first;
+}
+
+// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
+// res_imgs memory is being allocated here, previous allocations will be freed if found
+bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) {
+
+    if(clip_is_minicpmv(ctx)){
+        int max_slice_nums = 9;
+        std::vector<std::vector<clip_image_u8 *>> imgs = uhd_slice_image(img, max_slice_nums);
+        res_imgs->size = 0;
+        for (size_t i = 0; i < imgs.size(); ++i){
+            res_imgs->size += imgs[i].size();
+        }
+        res_imgs->data = new clip_image_f32[res_imgs->size];
+        int idx = 0;
+        for (size_t i = 0; i < imgs.size(); ++i) {
+            for (size_t j = 0; j < imgs[i].size(); ++j) {
+                LOG_TEE("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny);
+                clip_image_f32 * res = clip_image_f32_init();
+                normalize_image_u8_to_f32(imgs[i][j], res, ctx->image_mean, ctx->image_std);
+                res_imgs->data[idx++] = *res;
+                clip_image_f32_free(res);
+            }
+        }
+        return true;
+    }
+
+    bool pad_to_square = true;
+    if (!ctx->has_vision_encoder) {
+        LOG_TEE("This gguf file seems to have no vision encoder\n");
+        return false;
+    }
+    auto & params = ctx->vision_model.hparams;
+    // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
+    if (strcmp(params.mm_patch_merge_type, "spatial_unpad") == 0) {
+        pad_to_square = false;
+    }
+    // free the previous res_imgs if any set
+    if (res_imgs->size > 0) {
+        clip_image_f32_batch_free(res_imgs);
+    }
+    res_imgs->data = nullptr;
+    res_imgs->size = 0;
+
+    // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
+    // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
+
+    clip_image_u8 * temp = clip_image_u8_init(); // we will keep the input image data here temporarily
+    if (pad_to_square && img->nx != img->ny) {
+        int longer_side = std::max(img->nx, img->ny);
+        temp->nx = longer_side;
+        temp->ny = longer_side;
+        temp->buf.resize(3 * longer_side * longer_side);
+        const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255)
+
+        // fill with background color
+        for (size_t i = 0; i < temp->buf.size(); i++) {
+            temp->buf[i] = bc[i % 3];
+        }
+
+        // copy from the input image
+        for (int y = 0; y < img->ny; y++) {
+            for (int x = 0; x < img->nx; x++) {
+                const int i = 3 * (y * img->nx + x);
+                const int j = 3 * (y * temp->nx + x);
+                temp->buf[j]   = img->buf[i];
+                temp->buf[j+1] = img->buf[i+1];
+                temp->buf[j+2] = img->buf[i+2];
+            }
+        }
+    } else {
+        if (params.image_grid_pinpoints[0] != 0) {
+            // "spatial_unpad" with "anyres" processing for llava-1.6
+            std::vector<std::pair<int, int>> possible_resolutions;
+            for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) {
+                possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]});
+            }
+            std::pair<int, int> best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions);
+            // clip_image_save_to_bmp(*img, "input.bmp");
+            resize_and_pad_image(*img, *temp, best_resolution);  // we do not pad with mean-bg color anymore in llava-1.6
+            // clip_image_save_to_bmp(*temp, "resized.bmp");
+            // visually verify normalized image:
+            // normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std);
+            // {
+            //     clip_image_u8 * temp2 = clip_image_u8_init();
+            //     clip_image_convert_f32_to_u8(*res, *temp2);
+            //     clip_image_save_to_bmp(*temp2, "resized_normalized_f32.bmp");
+            //     clip_image_u8_free(temp2);
+            // }
+
+            std::vector<clip_image_u8 *> patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6)
+
+            clip_image_u8 *image_original_resize = clip_image_u8_init();
+            // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
+            bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
+            patches.insert(patches.begin(), image_original_resize);
+            // clip_image_f32_batch_init(patches.size());
+            res_imgs->size = patches.size();
+            res_imgs->data = new clip_image_f32[res_imgs->size];
+            int num=0;
+            for (auto& patch : patches) {
+                normalize_image_u8_to_f32(patch, &res_imgs->data[num], ctx->image_mean, ctx->image_std);
+                num++;
+            }
+
+            for (size_t i = 0; i < patches.size(); i++) {
+                // LOG_TEE("patch %d: %d %d\n", i, patches[i]->nx, patches[i]->ny);
+                clip_image_u8_free(patches[i]);
+            }
+
+            clip_image_u8_free(temp);
+
+            return true;
+        } else {
+            temp->nx = img->nx;
+            temp->ny = img->ny;
+            temp->buf.resize(img->buf.size());
+            memcpy(temp->buf.data(), img->buf.data(), temp->buf.size());
+        }
+    }
+
+    const int nx = temp->nx;
+    const int ny = temp->ny;
+    // clip_image_save_to_bmp(*temp, "resized_vanilla.bmp");
+
+    const int nx2 = ctx->vision_model.hparams.image_size;
+    const int ny2 = ctx->vision_model.hparams.image_size;
+    clip_image_f32 * res = clip_image_f32_init();
+    res->nx = nx2;
+    res->ny = ny2;
+    res->buf.resize(3 * nx2 * ny2);
+
+    const float scale = std::max(nx, ny) / (float)ctx->vision_model.hparams.image_size;
+
+    const int nx3 = int(nx / scale + 0.5f);
+    const int ny3 = int(ny / scale + 0.5f);
+
+    const auto & m3 = ctx->image_mean; // {0.48145466f, 0.4578275f, 0.40821073f};
+    const auto & s3 = ctx->image_std;  // {0.26862954f, 0.26130258f, 0.27577711f};
+
+    for (int y = 0; y < ny3; y++) {
+        for (int x = 0; x < nx3; x++) {
+            for (int c = 0; c < 3; c++) {
+                // linear interpolation
+                const float sx = (x + 0.5f) * scale - 0.5f;
+                const float sy = (y + 0.5f) * scale - 0.5f;
+
+                const int x0 = std::max(0, (int)std::floor(sx));
+                const int y0 = std::max(0, (int)std::floor(sy));
+
+                const int x1 = std::min(x0 + 1, nx - 1);
+                const int y1 = std::min(y0 + 1, ny - 1);
+
+                const float dx = sx - x0;
+                const float dy = sy - y0;
+
+                const int j00 = 3 * (y0 * nx + x0) + c;
+                const int j01 = 3 * (y0 * nx + x1) + c;
+                const int j10 = 3 * (y1 * nx + x0) + c;
+                const int j11 = 3 * (y1 * nx + x1) + c;
+
+                const float v00 = temp->buf[j00];
+                const float v01 = temp->buf[j01];
+                const float v10 = temp->buf[j10];
+                const float v11 = temp->buf[j11];
+
+                const float v0 = v00 * (1.0f - dx) + v01 * dx;
+                const float v1 = v10 * (1.0f - dx) + v11 * dx;
+
+                const float v = v0 * (1.0f - dy) + v1 * dy;
+
+                const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f);
+
+                const int i = 3 * (y * nx3 + x) + c;
+
+                res->buf[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c];
+            }
+        }
+    }
+    clip_image_u8_free(temp);
+
+    // {
+    //     clip_image_u8 * temp2 = clip_image_u8_init();
+    //     clip_image_convert_f32_to_u8(*res, *temp2);
+    //     clip_image_save_to_bmp(*temp2, "resized_normalized_f32_vanilla.bmp");
+    //     clip_image_u8_free(temp2);
+    // }
+    // res_imgs.push_back(res);
+
+    res_imgs->size = 1;
+    res_imgs->data = new clip_image_f32[res_imgs->size];
+    res_imgs->data[0] = *res;
+    clip_image_f32_free(res);
+
+    return true;
+}
+
+ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) {
+    return ctx->vision_model.image_newline;
+}
+
+void clip_free(clip_ctx * ctx) {
+    ggml_free(ctx->ctx_data);
+    gguf_free(ctx->ctx_gguf);
+
+    ggml_backend_buffer_free(ctx->params_buffer);
+    ggml_backend_free(ctx->backend);
+    ggml_gallocr_free(ctx->compute_alloc);
+    delete ctx;
+}
+
+size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
+    return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
+}
+
+int32_t clip_image_size(const struct clip_ctx * ctx) {
+    return ctx->vision_model.hparams.image_size;
+}
+
+int32_t clip_patch_size(const struct clip_ctx * ctx) {
+    return ctx->vision_model.hparams.patch_size;
+}
+
+int32_t clip_hidden_size(const struct clip_ctx * ctx) {
+    return ctx->vision_model.hparams.hidden_size;
+}
+
+const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
+    return ctx->vision_model.hparams.mm_patch_merge_type;
+}
+
+const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
+    return ctx->vision_model.hparams.image_grid_pinpoints;
+}
+
+int clip_n_patches(const struct clip_ctx * ctx) {
+    const auto & params = ctx->vision_model.hparams;
+
+    int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
+
+    if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
+        n_patches /= 4;
+    } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
+        if (ctx->minicpmv_version == 2) {
+            n_patches = 96;
+        }
+        else if (ctx->minicpmv_version == 3) {
+            n_patches = 64;
+        }
+    }
+
+    return n_patches;
+}
+
+static std::vector<std::vector<std::vector<float>>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector<std::vector<float>> & pos) {
+    assert(embed_dim % 2 == 0);
+    int H = pos.size();
+    int W = pos[0].size();
+
+    std::vector<float> omega(embed_dim / 2);
+    for (int i = 0; i < embed_dim / 2; ++i) {
+        omega[i] = 1.0 / pow(10000.0, static_cast<float>(i) / (embed_dim / 2));
+    }
+
+    std::vector<std::vector<std::vector<float>>> emb(H, std::vector<std::vector<float>>(W, std::vector<float>(embed_dim)));
+    for (int h = 0; h < H; ++h) {
+        for (int w = 0; w < W; ++w) {
+            for (int d = 0; d < embed_dim / 2; ++d) {
+                float out_value = pos[h][w] * omega[d];
+                emb[h][w][d] = sin(out_value);
+                emb[h][w][d + embed_dim / 2] = cos(out_value);
+            }
+        }
+    }
+
+    return emb;
+}
+
+static std::vector<std::vector<std::vector<float>>> get_2d_sincos_pos_embed_from_grid(int embed_dim, const std::vector<std::vector<std::vector<float>>> & grid) {
+    assert(embed_dim % 2 == 0);
+    std::vector<std::vector<std::vector<float>>> emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[0]); // (H, W, D/2)
+    std::vector<std::vector<std::vector<float>>> emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[1]); // (H, W, D/2)
+
+    int H = emb_h.size();
+    int W = emb_h[0].size();
+    std::vector<std::vector<std::vector<float>>> emb(H, std::vector<std::vector<float>>(W, std::vector<float>(embed_dim)));
+
+    for (int h = 0; h < H; ++h) {
+        for (int w = 0; w < W; ++w) {
+            for (int d = 0; d < embed_dim / 2; ++d) {
+                emb[h][w][d] = emb_h[h][w][d];
+                emb[h][w][d + embed_dim / 2] = emb_w[h][w][d];
+            }
+        }
+    }
+    return emb;
+}
+
+static std::vector<std::vector<float>> get_2d_sincos_pos_embed(int embed_dim, const std::pair<int, int> image_size) {
+    int grid_h_size = image_size.first;
+    int grid_w_size = image_size.second;
+
+    std::vector<float> grid_h(grid_h_size);
+    std::vector<float> grid_w(grid_w_size);
+
+    for (int i = 0; i < grid_h_size; ++i) {
+        grid_h[i] = static_cast<float>(i);
+    }
+    for (int i = 0; i < grid_w_size; ++i) {
+        grid_w[i] = static_cast<float>(i);
+    }
+
+    std::vector<std::vector<float>> grid(grid_h_size, std::vector<float>(grid_w_size));
+    for (int h = 0; h < grid_h_size; ++h) {
+        for (int w = 0; w < grid_w_size; ++w) {
+            grid[h][w] = grid_w[w];
+        }
+    }
+    std::vector<std::vector<std::vector<float>>> grid_2d = {grid, grid};
+    for (int h = 0; h < grid_h_size; ++h) {
+        for (int w = 0; w < grid_w_size; ++w) {
+            grid_2d[0][h][w] = grid_h[h];
+            grid_2d[1][h][w] = grid_w[w];
+        }
+    }
+
+    std::vector<std::vector<std::vector<float>>> pos_embed_3d = get_2d_sincos_pos_embed_from_grid(embed_dim, grid_2d);
+
+    int H = image_size.first;
+    int W = image_size.second;
+    std::vector<std::vector<float>> pos_embed_2d(H * W, std::vector<float>(embed_dim));
+    for (int h = 0; h < H; ++h) {
+        for (int w = 0; w < W; ++w) {
+            pos_embed_2d[w * H + h] = pos_embed_3d[h][w];
+        }
+    }
+
+    return pos_embed_2d;
+}
+
+bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) {
+    if (!ctx->has_vision_encoder) {
+        LOG_TEE("This gguf file seems to have no vision encoder\n");
+        return false;
+    }
+
+    clip_image_f32_batch imgs{};
+    imgs.size = 1;
+    imgs.data = img;
+    return clip_image_batch_encode(ctx, n_threads, &imgs, vec);
+}
+
+bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec) {
+    if (!ctx->has_vision_encoder) {
+        LOG_TEE("This gguf file seems to have no vision encoder\n");
+        return false;
+    }
+
+    int batch_size = imgs->size;
+    if (ctx->has_llava_projector) {
+        GGML_ASSERT(batch_size == 1); // TODO: support multiple images
+    }
+    if (ctx->has_minicpmv_projector) {
+        GGML_ASSERT(batch_size == 1);
+    }
+
+    // build the inference graph
+    ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true);
+    ggml_gallocr_alloc_graph(ctx->compute_alloc, gf);
+
+    // set inputs
+    const auto & model = ctx->vision_model;
+    const auto & hparams = model.hparams;
+
+    const int image_size = hparams.image_size;
+    int image_size_width  = image_size;
+    int image_size_height = image_size;
+    if (ctx->has_minicpmv_projector) {
+        image_size_width  = imgs->data[0].nx;
+        image_size_height = imgs->data[0].ny;
+    }
+    const int patch_size    = hparams.patch_size;
+    const int num_patches   = ((image_size_width / patch_size) * (image_size_height / patch_size));
+    const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
+    if(ctx->load_image_size==nullptr){
+        ctx->load_image_size= clip_image_size_init();
+    }
+    const int pos_w = ctx->load_image_size->width/patch_size;
+    const int pos_h = ctx->load_image_size->height/patch_size;
+
+    {
+        struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
+        float * data = (float *)malloc(ggml_nbytes(inp_raw));
+
+        for (size_t i = 0; i < imgs->size; i++) {
+            const int nx = imgs->data[i].nx;
+            const int ny = imgs->data[i].ny;
+            if (!ctx->has_minicpmv_projector) {
+                GGML_ASSERT(nx == image_size && ny == image_size);
+            }
+
+            const int n = nx * ny;
+
+            for (int b = 0; b < batch_size; b++) {
+                for (int k = 0; k < 3; k++) {
+                    for (int y = 0; y < ny; y++) {
+                        for (int x = 0; x < nx; x++) {
+                            data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
+                        }
+                    }
+                }
+            }
+        }
+        ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
+        free(data);
+    }
+    if (ctx->has_minicpmv_projector) {
+        {
+            // inspired from siglip:
+            //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
+            //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
+            struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
+            int* positions_data = (int*)malloc(ggml_nbytes(positions));
+            int bucket_coords_h[70];
+            int bucket_coords_w[70];
+            for (int i = 0; i < pos_h; i++){
+                bucket_coords_h[i] = std::floor(70.0*i/pos_h);
+            }
+            for (int i = 0; i < pos_w; i++){
+                bucket_coords_w[i] = std::floor(70.0*i/pos_w);
+            }
+            for (int i = 0, id = 0; i < pos_h; i++){
+                for (int j = 0; j < pos_w; j++){
+                    positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
+                }
+            }
+            ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
+            free(positions_data);
+        }
+
+        {
+            // inspired from resampler of Qwen-VL:
+            //    -> https://huggingface.co/Qwen/Qwen-VL/tree/main
+            //    -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23
+            struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed");
+            int embed_dim = 4096;
+            if (ctx->minicpmv_version == 2) {
+                embed_dim = 4096;
+            }
+            else if (ctx->minicpmv_version == 3) {
+                embed_dim = 3584;
+            }
+            auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
+
+            float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
+            for(int i=0;i<pos_w * pos_h;++i){
+                for(int j=0;j<embed_dim;++j){
+                    pos_embed_data[i*embed_dim+j]=pos_embed_t[i][j];
+                }
+            }
+
+            ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed));
+            free(pos_embed_data);
+        }
+    }
+    else{
+        {
+            if (ctx->has_class_embedding) {
+                struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
+
+                void* zero_mem = malloc(ggml_nbytes(embeddings));
+                memset(zero_mem, 0, ggml_nbytes(embeddings));
+                ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
+                free(zero_mem);
+            }
+        }
+
+        {
+            struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
+
+            int* positions_data = (int*)malloc(ggml_nbytes(positions));
+            for (int i = 0; i < num_positions; i++) {
+                positions_data[i] = i;
+            }
+            ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
+            free(positions_data);
+        }
+
+        {
+            struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
+            int* patches_data = (int*)malloc(ggml_nbytes(patches));
+            for (int i = 0; i < num_patches; i++) {
+                patches_data[i] = i + 1;
+            }
+            ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
+            free(patches_data);
+        }
+    }
+
+    if (ggml_backend_is_cpu(ctx->backend)) {
+        ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
+    }
+
+#ifdef GGML_USE_METAL
+    if (ggml_backend_is_metal(ctx->backend)) {
+        ggml_backend_metal_set_n_cb(ctx->backend, n_threads);
+    }
+#endif
+
+    ggml_backend_graph_compute(ctx->backend, gf);
+
+    // the last node is the embedding tensor
+    struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1];
+
+    // copy the embeddings to the location passed by the user
+    ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
+
+    return true;
+}
+
+bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) {
+    ggml_type type = GGML_TYPE_Q4_1;
+
+    assert(itype < GGML_TYPE_COUNT);
+    type = static_cast<ggml_type>(itype);
+
+    auto * ctx_clip = clip_model_load(fname_inp, 2);
+
+    const auto & ctx_src = ctx_clip->ctx_gguf;
+    const auto & ctx_data = ctx_clip->ctx_data;
+
+    auto * ctx_out = gguf_init_empty();
+    gguf_set_kv(ctx_out, ctx_src);
+    gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION);
+    gguf_set_val_u32(ctx_out, "general.file_type", itype);
+
+    auto fout = std::ofstream(fname_out, std::ios::binary);
+
+    const int n_tensors = gguf_get_n_tensors(ctx_src);
+
+    for (int i = 0; i < n_tensors; ++i) {
+        const char * name = gguf_get_tensor_name(ctx_src, i);
+        struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name);
+        gguf_add_tensor(ctx_out, cur);
+    }
+
+    const size_t meta_size = gguf_get_meta_size(ctx_out);
+    for (size_t i = 0; i < meta_size; ++i) {
+        fout.put(0);
+    }
+
+    // regexes of tensor names to be quantized
+    const std::vector<std::string> k_names = {
+        ".*weight",
+    };
+
+    std::vector<uint8_t> work(512);
+    std::vector<float> conv_buf(512);
+    size_t total_size_org = 0;
+    size_t total_size_new = 0;
+
+    for (int i = 0; i < n_tensors; ++i) {
+        const std::string name = gguf_get_tensor_name(ctx_src, i);
+        struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name.c_str());
+
+        enum ggml_type new_type;
+        void * new_data;
+        size_t new_size;
+
+        bool quantize = false;
+        for (const auto & s : k_names) {
+            if (std::regex_match(name, std::regex(s))) {
+                quantize = true;
+                break;
+            }
+        }
+
+        // quantize only 2D tensors
+        quantize &= (ggml_n_dims(cur) == 2);
+
+        if (quantize) {
+            new_type = type;
+            if (new_type >= GGML_TYPE_Q2_K && name.find("embd") != std::string::npos) {
+                new_type = GGML_TYPE_Q8_0; // ggml_get_rows needs non K type
+                // LOG_TEE("%s: quantizing %s to %s\n", __func__, name.c_str(), ggml_type_name(new_type));
+            }
+            const size_t n_elms = ggml_nelements(cur);
+            float * f32_data;
+
+            switch (cur->type) {
+            case GGML_TYPE_F32:
+                f32_data = (float *)cur->data;
+                break;
+            case GGML_TYPE_F16:
+                if (conv_buf.size() < n_elms) {
+                    conv_buf.resize(n_elms);
+                }
+                for (size_t j = 0; j < n_elms; ++j) {
+                    conv_buf[j] = ggml_fp16_to_fp32(((ggml_fp16_t *)cur->data)[j]);
+                }
+                f32_data = (float *)conv_buf.data();
+                break;
+            default:
+                LOG_TEE("Please use an input file in f32 or f16\n");
+                gguf_free(ctx_out);
+                return false;
+            }
+
+            if (work.size() < n_elms * 4) {
+                work.resize(n_elms * 4);
+            }
+            new_data = work.data();
+
+            new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, n_elms/cur->ne[0], cur->ne[0], nullptr);
+        } else {
+            new_type = cur->type;
+            new_data = cur->data;
+            new_size = ggml_nbytes(cur);
+        }
+        const size_t orig_size = ggml_nbytes(cur);
+        total_size_org += orig_size;
+        total_size_new += new_size;
+        gguf_set_tensor_type(ctx_out, name.c_str(), new_type);
+        gguf_set_tensor_data(ctx_out, name.c_str(), new_data, new_size);
+        fout.write((const char *)new_data, new_size);
+        size_t pad = GGML_PAD(new_size, gguf_get_alignment(ctx_out)) - new_size;
+        for (size_t j = 0; j < pad; ++j) {
+            fout.put(0);
+        }
+
+        LOG_TEE("%s: n_dims = %d | quantize=%d | size = %f MB -> %f MB\n", name.c_str(), ggml_n_dims(cur), quantize,
+               orig_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
+    }
+
+    // go back to beginning of file and write the updated metadata
+    fout.seekp(0, std::ios::beg);
+    std::vector<uint8_t> meta(meta_size);
+    gguf_get_meta_data(ctx_out, meta.data());
+    fout.write((const char *)meta.data(), meta_size);
+
+    fout.close();
+
+    clip_free(ctx_clip);
+    gguf_free(ctx_out);
+
+    {
+        LOG_TEE("%s: original  size = %8.2f MB\n", __func__, total_size_org / 1024.0 / 1024.0);
+        LOG_TEE("%s: quantized size = %8.2f MB\n", __func__, total_size_new / 1024.0 / 1024.0);
+    }
+
+    return true;
+}
+
+int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
+    if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
+        return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0];
+    }
+    if (ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
+        return ctx->vision_model.mm_model_peg_0_b->ne[0];
+    }
+    if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
+        return ctx->vision_model.mm_2_b->ne[0];
+    }
+    if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
+        return ctx->vision_model.mm_3_b->ne[0];
+    }
+    if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
+        if (ctx->minicpmv_version == 2) {
+            return 4096;
+        }
+        else if (ctx->minicpmv_version == 3) {
+            return 3584;
+        }
+    }
+
+    std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
+    throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
+}
+
+int clip_is_minicpmv(const struct clip_ctx * ctx) {
+    if (ctx->has_minicpmv_projector) {
+        return ctx->minicpmv_version;
+    }
+    return 0;
+}

+ 120 - 0
llama/clip.h

@@ -0,0 +1,120 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CLIP_H
+#define CLIP_H
+
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef LLAMA_SHARED
+#    if defined(_WIN32) && !defined(__MINGW32__)
+#        ifdef LLAMA_BUILD
+#            define CLIP_API __declspec(dllexport)
+#        else
+#            define CLIP_API __declspec(dllimport)
+#        endif
+#    else
+#        define CLIP_API __attribute__ ((visibility ("default")))
+#    endif
+#else
+#    define CLIP_API
+#endif
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct clip_ctx;
+
+struct clip_image_size {
+    int width;
+    int height;
+};
+
+struct clip_image_u8_batch {
+    struct clip_image_u8 * data;
+    size_t size;
+};
+
+struct clip_image_f32_batch {
+    struct clip_image_f32 * data;
+    size_t size;
+};
+
+CLIP_API struct clip_ctx * clip_model_load    (const char * fname, int verbosity);
+CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity);
+
+CLIP_API void clip_free(struct clip_ctx * ctx);
+
+CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
+
+CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx);
+CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx);
+CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx);
+
+// TODO: should be enum, not string
+CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
+
+CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
+
+CLIP_API int clip_n_patches    (const struct clip_ctx * ctx);
+CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
+
+CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
+CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
+
+CLIP_API struct clip_image_size * clip_image_size_init();
+CLIP_API struct clip_image_u8  * clip_image_u8_init ();
+CLIP_API struct clip_image_f32 * clip_image_f32_init();
+
+CLIP_API void clip_image_u8_free (struct clip_image_u8  * img);
+CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
+CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch  * batch);
+CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
+
+CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
+
+/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
+CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
+
+/** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */
+CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs );
+
+CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx);
+
+CLIP_API bool clip_image_encode      (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec);
+CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
+
+CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
+
+CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // CLIP_H

+ 3688 - 0
llama/common.cpp

@@ -0,0 +1,3688 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#if defined(_MSC_VER)
+#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
+#endif
+
+#include "common.h"
+// Change JSON_ASSERT from assert() to GGML_ASSERT:
+#define JSON_ASSERT GGML_ASSERT
+#include "json.hpp"
+#include "json-schema-to-grammar.h"
+#include "llama.h"
+
+#include <algorithm>
+#include <cinttypes>
+#include <cmath>
+#include <codecvt>
+#include <cstdarg>
+#include <cstring>
+#include <ctime>
+#include <fstream>
+#include <iostream>
+#include <iterator>
+#include <regex>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#if defined(__APPLE__) && defined(__MACH__)
+#include <sys/types.h>
+#include <sys/sysctl.h>
+#endif
+
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+#   define NOMINMAX
+#endif
+#include <locale>
+#include <windows.h>
+#include <fcntl.h>
+#include <io.h>
+#else
+#include <sys/ioctl.h>
+#include <sys/stat.h>
+#include <unistd.h>
+#endif
+#if defined(LLAMA_USE_CURL)
+#include <curl/curl.h>
+#include <curl/easy.h>
+#include <thread>
+#include <future>
+#endif
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+#if (defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL))
+#define GGML_USE_CUDA_SYCL
+#endif
+
+#if (defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL)) || defined(GGML_USE_VULKAN)
+#define GGML_USE_CUDA_SYCL_VULKAN
+#endif
+
+#if defined(LLAMA_USE_CURL)
+#ifdef __linux__
+#include <linux/limits.h>
+#elif defined(_WIN32)
+#define PATH_MAX MAX_PATH
+#else
+#include <sys/syslimits.h>
+#endif
+#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
+#endif // LLAMA_USE_CURL
+
+using json = nlohmann::ordered_json;
+
+//
+// Environment variable utils
+//
+
+template<typename T>
+static typename std::enable_if<std::is_same<T, std::string>::value, void>::type
+get_env(std::string name, T & target) {
+    char * value = std::getenv(name.c_str());
+    target = value ? std::string(value) : target;
+}
+
+template<typename T>
+static typename std::enable_if<!std::is_same<T, bool>::value && std::is_integral<T>::value, void>::type
+get_env(std::string name, T & target) {
+    char * value = std::getenv(name.c_str());
+    target = value ? std::stoi(value) : target;
+}
+
+template<typename T>
+static typename std::enable_if<std::is_floating_point<T>::value, void>::type
+get_env(std::string name, T & target) {
+    char * value = std::getenv(name.c_str());
+    target = value ? std::stof(value) : target;
+}
+
+template<typename T>
+static typename std::enable_if<std::is_same<T, bool>::value, void>::type
+get_env(std::string name, T & target) {
+    char * value = std::getenv(name.c_str());
+    if (value) {
+        std::string val(value);
+        target = val == "1" || val == "true";
+    }
+}
+
+//
+// CPU utils
+//
+
+int32_t cpu_get_num_physical_cores() {
+#ifdef __linux__
+    // enumerate the set of thread siblings, num entries is num cores
+    std::unordered_set<std::string> siblings;
+    for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) {
+        std::ifstream thread_siblings("/sys/devices/system/cpu/cpu"
+            + std::to_string(cpu) + "/topology/thread_siblings");
+        if (!thread_siblings.is_open()) {
+            break; // no more cpus
+        }
+        std::string line;
+        if (std::getline(thread_siblings, line)) {
+            siblings.insert(line);
+        }
+    }
+    if (!siblings.empty()) {
+        return static_cast<int32_t>(siblings.size());
+    }
+#elif defined(__APPLE__) && defined(__MACH__)
+    int32_t num_physical_cores;
+    size_t len = sizeof(num_physical_cores);
+    int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0);
+    if (result == 0) {
+        return num_physical_cores;
+    }
+    result = sysctlbyname("hw.physicalcpu", &num_physical_cores, &len, NULL, 0);
+    if (result == 0) {
+        return num_physical_cores;
+    }
+#elif defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later
+    // TODO: windows + arm64 + mingw64
+    unsigned int n_threads_win = std::thread::hardware_concurrency();
+    unsigned int default_threads = n_threads_win > 0 ? (n_threads_win <= 4 ? n_threads_win : n_threads_win / 2) : 4;
+
+    DWORD buffer_size = 0;
+    if (!GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &buffer_size)) {
+        if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) {
+            return default_threads;
+        }
+    }
+
+    std::vector<char> buffer(buffer_size);
+    if (!GetLogicalProcessorInformationEx(RelationProcessorCore, reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(buffer.data()), &buffer_size)) {
+        return default_threads;
+    }
+
+    int32_t num_physical_cores = 0;
+    PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX info = reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(buffer.data());
+    while (buffer_size > 0) {
+        if (info->Relationship == RelationProcessorCore) {
+            num_physical_cores += info->Processor.GroupCount;
+        }
+        buffer_size -= info->Size;
+        info = reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(reinterpret_cast<char*>(info) + info->Size);
+    }
+
+    return num_physical_cores > 0 ? num_physical_cores : default_threads;
+#endif
+    unsigned int n_threads = std::thread::hardware_concurrency();
+    return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
+}
+
+#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__)
+#include <pthread.h>
+
+static void cpuid(unsigned leaf, unsigned subleaf,
+                  unsigned *eax, unsigned *ebx, unsigned *ecx, unsigned *edx) {
+    __asm__("movq\t%%rbx,%%rsi\n\t"
+            "cpuid\n\t"
+            "xchgq\t%%rbx,%%rsi"
+            : "=a"(*eax), "=S"(*ebx), "=c"(*ecx), "=d"(*edx)
+            : "0"(leaf), "2"(subleaf));
+}
+
+static int pin_cpu(int cpu) {
+    cpu_set_t mask;
+    CPU_ZERO(&mask);
+    CPU_SET(cpu, &mask);
+    return pthread_setaffinity_np(pthread_self(), sizeof(mask), &mask);
+}
+
+static bool is_hybrid_cpu(void) {
+    unsigned eax, ebx, ecx, edx;
+    cpuid(7, 0, &eax, &ebx, &ecx, &edx);
+    return !!(edx & (1u << 15));
+}
+
+static bool is_running_on_efficiency_core(void) {
+    unsigned eax, ebx, ecx, edx;
+    cpuid(0x1a, 0, &eax, &ebx, &ecx, &edx);
+    int intel_atom = 0x20;
+    int core_type = (eax & 0xff000000u) >> 24;
+    return core_type == intel_atom;
+}
+
+static int cpu_count_math_cpus(int n_cpu) {
+    int result = 0;
+    for (int cpu = 0; cpu < n_cpu; ++cpu) {
+        if (pin_cpu(cpu)) {
+            return -1;
+        }
+        if (is_running_on_efficiency_core()) {
+            continue; // efficiency cores harm lockstep threading
+        }
+        ++cpu; // hyperthreading isn't useful for linear algebra
+        ++result;
+    }
+    return result;
+}
+
+#endif // __x86_64__ && __linux__
+
+/**
+ * Returns number of CPUs on system that are useful for math.
+ */
+int32_t cpu_get_num_math() {
+#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__)
+    int n_cpu = sysconf(_SC_NPROCESSORS_ONLN);
+    if (n_cpu < 1) {
+        return cpu_get_num_physical_cores();
+    }
+    if (is_hybrid_cpu()) {
+        cpu_set_t affinity;
+        if (!pthread_getaffinity_np(pthread_self(), sizeof(affinity), &affinity)) {
+            int result = cpu_count_math_cpus(n_cpu);
+            pthread_setaffinity_np(pthread_self(), sizeof(affinity), &affinity);
+            if (result > 0) {
+                return result;
+            }
+        }
+    }
+#endif
+    return cpu_get_num_physical_cores();
+}
+
+// Helper for setting process priority
+
+#if defined(_WIN32)
+
+bool set_process_priority(enum ggml_sched_priority prio) {
+    if (prio == GGML_SCHED_PRIO_NORMAL) {
+        return true;
+    }
+
+    DWORD p = NORMAL_PRIORITY_CLASS;
+    switch (prio) {
+        case GGML_SCHED_PRIO_NORMAL:   p = NORMAL_PRIORITY_CLASS;       break;
+        case GGML_SCHED_PRIO_MEDIUM:   p = ABOVE_NORMAL_PRIORITY_CLASS; break;
+        case GGML_SCHED_PRIO_HIGH:     p = HIGH_PRIORITY_CLASS;         break;
+        case GGML_SCHED_PRIO_REALTIME: p = REALTIME_PRIORITY_CLASS;     break;
+    }
+
+    if (!SetPriorityClass(GetCurrentProcess(), p)) {
+        fprintf(stderr, "warn: failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
+        return false;
+    }
+
+    return true;
+}
+
+#else // MacOS and POSIX
+#include <sys/types.h>
+#include <sys/resource.h>
+
+bool set_process_priority(enum ggml_sched_priority prio) {
+    if (prio == GGML_SCHED_PRIO_NORMAL) {
+        return true;
+    }
+
+    int p = 0;
+    switch (prio) {
+        case GGML_SCHED_PRIO_NORMAL:   p =  0;  break;
+        case GGML_SCHED_PRIO_MEDIUM:   p = -5;  break;
+        case GGML_SCHED_PRIO_HIGH:     p = -10; break;
+        case GGML_SCHED_PRIO_REALTIME: p = -20; break;
+    }
+
+    if (!setpriority(PRIO_PROCESS, 0, p)) {
+        fprintf(stderr, "warn: failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
+        return false;
+    }
+    return true;
+}
+
+#endif
+
+//
+// CLI argument parsing
+//
+
+void gpt_params_handle_model_default(gpt_params & params) {
+    if (!params.hf_repo.empty()) {
+        // short-hand to avoid specifying --hf-file -> default it to --model
+        if (params.hf_file.empty()) {
+            if (params.model.empty()) {
+                throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n");
+            }
+            params.hf_file = params.model;
+        } else if (params.model.empty()) {
+            params.model = fs_get_cache_file(string_split(params.hf_file, '/').back());
+        }
+    } else if (!params.model_url.empty()) {
+        if (params.model.empty()) {
+            auto f = string_split(params.model_url, '#').front();
+            f = string_split(f, '?').front();
+            params.model = fs_get_cache_file(string_split(f, '/').back());
+        }
+    } else if (params.model.empty()) {
+        params.model = DEFAULT_MODEL_PATH;
+    }
+}
+
+void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) {
+    int32_t n_set = 0;
+
+    if (cpuparams.n_threads < 0) {
+        // Assuming everything about cpuparams is invalid
+        if (role_model != nullptr) {
+            cpuparams = *role_model;
+        } else {
+            cpuparams.n_threads = cpu_get_num_math();
+        }
+    }
+
+    for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) {
+        if (cpuparams.cpumask[i]) {
+            n_set++;
+        }
+    }
+
+    if (n_set && n_set < cpuparams.n_threads) {
+        // Not enough set bits, may experience performance issues.
+        fprintf(stderr, "warn: Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
+    }
+}
+
+bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
+    bool invalid_param = false;
+    std::string arg;
+    const std::string arg_prefix = "--";
+    llama_sampling_params & sparams = params.sparams;
+
+    for (int i = 1; i < argc; i++) {
+        arg = argv[i];
+        if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
+            std::replace(arg.begin(), arg.end(), '_', '-');
+        }
+        if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) {
+            throw std::invalid_argument("error: unknown argument: " + arg);
+        }
+        if (invalid_param) {
+            throw std::invalid_argument("error: invalid parameter for argument: " + arg);
+        }
+    }
+
+    postprocess_cpu_params(params.cpuparams, nullptr);
+    postprocess_cpu_params(params.cpuparams_batch, &params.cpuparams);
+    postprocess_cpu_params(params.draft_cpuparams, &params.cpuparams);
+    postprocess_cpu_params(params.draft_cpuparams_batch, &params.cpuparams_batch);
+
+    if (params.prompt_cache_all && (params.interactive || params.interactive_first)) {
+        throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
+    }
+
+    gpt_params_handle_model_default(params);
+
+    if (params.hf_token.empty()) {
+        get_env("HF_TOKEN", params.hf_token);
+    }
+
+    if (params.escape) {
+        string_process_escapes(params.prompt);
+        string_process_escapes(params.input_prefix);
+        string_process_escapes(params.input_suffix);
+        string_process_escapes(sparams.cfg_negative_prompt);
+        for (auto & antiprompt : params.antiprompt) {
+            string_process_escapes(antiprompt);
+        }
+    }
+
+    if (!params.kv_overrides.empty()) {
+        params.kv_overrides.emplace_back();
+        params.kv_overrides.back().key[0] = 0;
+    }
+
+    return true;
+}
+
+void gpt_params_parse_from_env(gpt_params & params) {
+    // we only care about server-related params for now
+    get_env("LLAMA_ARG_MODEL",            params.model);
+    get_env("LLAMA_ARG_MODEL_URL",        params.model_url);
+    get_env("LLAMA_ARG_MODEL_ALIAS",      params.model_alias);
+    get_env("LLAMA_ARG_HF_REPO",          params.hf_repo);
+    get_env("LLAMA_ARG_HF_FILE",          params.hf_file);
+    get_env("LLAMA_ARG_THREADS",          params.cpuparams.n_threads);
+    get_env("LLAMA_ARG_CTX_SIZE",         params.n_ctx);
+    get_env("LLAMA_ARG_N_PARALLEL",       params.n_parallel);
+    get_env("LLAMA_ARG_BATCH",            params.n_batch);
+    get_env("LLAMA_ARG_UBATCH",           params.n_ubatch);
+    get_env("LLAMA_ARG_N_GPU_LAYERS",     params.n_gpu_layers);
+    get_env("LLAMA_ARG_THREADS_HTTP",     params.n_threads_http);
+    get_env("LLAMA_ARG_CHAT_TEMPLATE",    params.chat_template);
+    get_env("LLAMA_ARG_N_PREDICT",        params.n_predict);
+    get_env("LLAMA_ARG_ENDPOINT_METRICS", params.endpoint_metrics);
+    get_env("LLAMA_ARG_ENDPOINT_SLOTS",   params.endpoint_slots);
+    get_env("LLAMA_ARG_EMBEDDINGS",       params.embedding);
+    get_env("LLAMA_ARG_FLASH_ATTN",       params.flash_attn);
+    get_env("LLAMA_ARG_DEFRAG_THOLD",     params.defrag_thold);
+    get_env("LLAMA_ARG_CONT_BATCHING",    params.cont_batching);
+    get_env("LLAMA_ARG_HOST",             params.hostname);
+    get_env("LLAMA_ARG_PORT",             params.port);
+}
+
+bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
+    const auto params_org = params; // the example can modify the default params
+
+    try {
+        if (!gpt_params_parse_ex(argc, argv, params) || params.usage) {
+            params = params_org;
+            params.usage = true;
+            return false;
+        }
+    } catch (const std::invalid_argument & ex) {
+        fprintf(stderr, "%s\n", ex.what());
+        params = params_org;
+        return false;
+    }
+
+    return true;
+}
+
+bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THREADS]) {
+    size_t dash_loc = range.find('-');
+    if (dash_loc == std::string::npos) {
+        fprintf(stderr, "Format of CPU range is invalid! Expected [<start>]-[<end>].\n");
+        return false;
+    }
+
+    size_t start_i;
+    size_t end_i;
+
+    if (dash_loc == 0) {
+        start_i = 0;
+    } else {
+        start_i = std::stoull(range.substr(0, dash_loc));
+        if (start_i >= GGML_MAX_N_THREADS) {
+            fprintf(stderr, "Start index out of bounds!\n");
+            return false;
+        }
+    }
+
+    if (dash_loc == range.length() - 1) {
+        end_i = GGML_MAX_N_THREADS - 1;
+    } else {
+        end_i = std::stoull(range.substr(dash_loc + 1));
+        if (end_i >= GGML_MAX_N_THREADS) {
+            fprintf(stderr, "End index out of bounds!\n");
+            return false;
+        }
+    }
+
+    for (size_t i = start_i; i <= end_i; i++) {
+        boolmask[i] = true;
+    }
+
+    return true;
+}
+
+bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREADS]) {
+    // Discard potential 0x prefix
+    size_t start_i = 0;
+    if (mask.length() >= 2 && mask.substr(0, 2) == "0x") {
+        start_i = 2;
+    }
+
+    size_t num_digits = mask.length() - start_i;
+    if (num_digits > 128) num_digits = 128;
+
+    size_t end_i = num_digits + start_i;
+
+    for (size_t i = start_i, n = (num_digits*4 - 1); i < end_i; i++, n-=4) {
+        char c = mask.at(i);
+        int8_t id = c;
+
+        if ((c >= '0' && c <= '9')) {
+            id -= '0';
+        } else if (c >= 'a' && c <= 'f') {
+            id -= 'a' - 10;
+        } else if (c >= 'A' && c <= 'F') {
+            id -= 'A' - 10;
+        } else {
+            fprintf(stderr, "Invalid hex character '%c' at position %d\n", c, int32_t(i));
+            return false;
+        }
+
+        boolmask[  n  ] = boolmask[  n  ] || ((id & 8) != 0);
+        boolmask[n - 1] = boolmask[n - 1] || ((id & 4) != 0);
+        boolmask[n - 2] = boolmask[n - 2] || ((id & 2) != 0);
+        boolmask[n - 3] = boolmask[n - 3] || ((id & 1) != 0);
+    }
+
+    return true;
+}
+
+#define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; }
+
+bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
+    const char split_delim = ',';
+
+    llama_sampling_params & sparams = params.sparams;
+
+    if (arg == "-s" || arg == "--seed") {
+        CHECK_ARG
+        // TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context.
+        params.seed = std::stoul(argv[i]);
+        sparams.seed = std::stoul(argv[i]);
+        return true;
+    }
+    if (arg == "-t" || arg == "--threads") {
+        CHECK_ARG
+        params.cpuparams.n_threads = std::stoi(argv[i]);
+        if (params.cpuparams.n_threads <= 0) {
+            params.cpuparams.n_threads = std::thread::hardware_concurrency();
+        }
+        return true;
+    }
+    if (arg == "-C" || arg == "--cpu-mask") {
+        CHECK_ARG
+        std::string mask = argv[i];
+        params.cpuparams.mask_valid = true;
+        invalid_param = !parse_cpu_mask(mask, params.cpuparams.cpumask);
+        return true;
+    }
+    if (arg == "-Cr" || arg == "--cpu-range") {
+        CHECK_ARG
+        std::string range = argv[i];
+        params.cpuparams.mask_valid = true;
+        invalid_param = !parse_cpu_range(range, params.cpuparams.cpumask);
+        return true;
+    }
+    if (arg == "--prio") {
+        CHECK_ARG
+        params.cpuparams.priority = (enum ggml_sched_priority) std::stoul(argv[i]);
+        return true;
+    }
+    if (arg == "--cpu-strict") {
+        CHECK_ARG
+        params.cpuparams.strict_cpu = std::stoul(argv[i]);
+        return true;
+    }
+    if (arg == "--poll") {
+        CHECK_ARG
+        params.cpuparams.poll = std::stoul(argv[i]);
+        return true;
+    }
+    if (arg == "-tb" || arg == "--threads-batch") {
+        CHECK_ARG
+        params.cpuparams_batch.n_threads = std::stoi(argv[i]);
+        if (params.cpuparams_batch.n_threads <= 0) {
+            params.cpuparams_batch.n_threads = std::thread::hardware_concurrency();
+        }
+        return true;
+    }
+    if (arg == "-Cb" || arg == "--cpu-mask-batch") {
+        CHECK_ARG
+        std::string mask = argv[i];
+        params.cpuparams_batch.mask_valid = true;
+        invalid_param = !parse_cpu_mask(mask, params.cpuparams_batch.cpumask);
+        return true;
+    }
+    if (arg == "-Crb" || arg == "--cpu-range_batch") {
+        CHECK_ARG
+        std::string range = argv[i];
+        params.cpuparams_batch.mask_valid = true;
+        invalid_param = !parse_cpu_range(range, params.cpuparams_batch.cpumask);
+        return true;
+    }
+    if (arg == "--prio-batch") {
+        CHECK_ARG
+        params.cpuparams_batch.priority = (enum ggml_sched_priority) std::stoul(argv[i]);
+        return true;
+    }
+    if (arg == "--cpu-strict-batch") {
+        params.cpuparams_batch.strict_cpu = true;
+        return true;
+    }
+    if (arg == "--poll-batch") {
+        CHECK_ARG
+        params.cpuparams_batch.poll = std::stoul(argv[i]);
+        return true;
+    }
+    if (arg == "-td" || arg == "--threads-draft") {
+        CHECK_ARG
+        params.draft_cpuparams.n_threads = std::stoi(argv[i]);
+        if (params.draft_cpuparams.n_threads <= 0) {
+            params.draft_cpuparams.n_threads = std::thread::hardware_concurrency();
+        }
+        return true;
+    }
+        if (arg == "-Cd" || arg == "--cpu-mask-draft") {
+        CHECK_ARG
+        std::string mask = argv[i];
+        params.draft_cpuparams.mask_valid = true;
+        invalid_param = !parse_cpu_mask(mask, params.draft_cpuparams.cpumask);
+        return true;
+    }
+    if (arg == "-Crd" || arg == "--cpu-range-draft") {
+        CHECK_ARG
+        std::string range = argv[i];
+        params.draft_cpuparams.mask_valid = true;
+        invalid_param = !parse_cpu_range(range, params.draft_cpuparams.cpumask);
+        return true;
+    }
+    if (arg == "--prio-draft") {
+        CHECK_ARG
+        params.draft_cpuparams.priority = (enum ggml_sched_priority) std::stoul(argv[i]);
+        return true;
+    }
+    if (arg == "--cpu-strict-draft") {
+        params.draft_cpuparams.strict_cpu = true;
+        return true;
+    }
+    if (arg == "--poll-draft") {
+        CHECK_ARG
+        params.draft_cpuparams.poll = std::stoul(argv[i]);
+        return true;
+    }
+    if (arg == "-tbd" || arg == "--threads-batch-draft") {
+        CHECK_ARG
+        params.draft_cpuparams_batch.n_threads = std::stoi(argv[i]);
+        if (params.draft_cpuparams_batch.n_threads <= 0) {
+            params.draft_cpuparams_batch.n_threads = std::thread::hardware_concurrency();
+        }
+        return true;
+    }
+    if (arg == "-Crbd" || arg == "--cpu-range-batch-draft") {
+        CHECK_ARG
+        std::string range = argv[i];
+        params.draft_cpuparams_batch.mask_valid = true;
+        invalid_param = !parse_cpu_range(range, params.draft_cpuparams_batch.cpumask);
+        return true;
+    }
+    if (arg == "--prio-batch-draft") {
+        CHECK_ARG
+        params.draft_cpuparams_batch.priority = (enum ggml_sched_priority) std::stoul(argv[i]);
+        return true;
+    }
+    if (arg == "--cpu-strict-batch-draft") {
+        params.draft_cpuparams_batch.strict_cpu = true;
+        return true;
+    }
+    if (arg == "--poll-batch-draft") {
+        CHECK_ARG
+        params.draft_cpuparams_batch.poll = std::stoul(argv[i]);
+        return true;
+    }
+    if (arg == "-p" || arg == "--prompt") {
+        CHECK_ARG
+        params.prompt = argv[i];
+        return true;
+    }
+    if (arg == "-e" || arg == "--escape") {
+        params.escape = true;
+        return true;
+    }
+    if (arg == "--no-escape") {
+        params.escape = false;
+        return true;
+    }
+    if (arg == "--prompt-cache") {
+        CHECK_ARG
+        params.path_prompt_cache = argv[i];
+        return true;
+    }
+    if (arg == "--prompt-cache-all") {
+        params.prompt_cache_all = true;
+        return true;
+    }
+    if (arg == "--prompt-cache-ro") {
+        params.prompt_cache_ro = true;
+        return true;
+    }
+    if (arg == "-bf" || arg == "--binary-file") {
+        CHECK_ARG
+        std::ifstream file(argv[i], std::ios::binary);
+        if (!file) {
+            fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+            invalid_param = true;
+            return true;
+        }
+        // store the external file name in params
+        params.prompt_file = argv[i];
+        std::ostringstream ss;
+        ss << file.rdbuf();
+        params.prompt = ss.str();
+        fprintf(stderr, "Read %zu bytes from binary file %s\n", params.prompt.size(), argv[i]);
+        return true;
+    }
+    if (arg == "-f" || arg == "--file") {
+        CHECK_ARG
+        std::ifstream file(argv[i]);
+        if (!file) {
+            fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+            invalid_param = true;
+            return true;
+        }
+        // store the external file name in params
+        params.prompt_file = argv[i];
+        std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
+        if (!params.prompt.empty() && params.prompt.back() == '\n') {
+            params.prompt.pop_back();
+        }
+        return true;
+    }
+    if (arg == "--in-file") {
+        CHECK_ARG
+        std::ifstream file(argv[i]);
+        if (!file) {
+            fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+            invalid_param = true;
+            return true;
+        }
+        params.in_files.push_back(argv[i]);
+        return true;
+    }
+    if (arg == "-n" || arg == "--predict" || arg == "--n-predict") {
+        CHECK_ARG
+        params.n_predict = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--top-k") {
+        CHECK_ARG
+        sparams.top_k = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "-c" || arg == "--ctx-size") {
+        CHECK_ARG
+        params.n_ctx = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--grp-attn-n" || arg == "-gan") {
+        CHECK_ARG
+        params.grp_attn_n = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--grp-attn-w" || arg == "-gaw") {
+        CHECK_ARG
+        params.grp_attn_w = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--rope-freq-base") {
+        CHECK_ARG
+        params.rope_freq_base = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--rope-freq-scale") {
+        CHECK_ARG
+        params.rope_freq_scale = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--rope-scaling") {
+        CHECK_ARG
+        std::string value(argv[i]);
+        /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
+        else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
+        else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
+        else { invalid_param = true; }
+        return true;
+    }
+    if (arg == "--rope-scale") {
+        CHECK_ARG
+        params.rope_freq_scale = 1.0f / std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--yarn-orig-ctx") {
+        CHECK_ARG
+        params.yarn_orig_ctx = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--yarn-ext-factor") {
+        CHECK_ARG
+        params.yarn_ext_factor = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--yarn-attn-factor") {
+        CHECK_ARG
+        params.yarn_attn_factor = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--yarn-beta-fast") {
+        CHECK_ARG
+        params.yarn_beta_fast = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--yarn-beta-slow") {
+        CHECK_ARG
+        params.yarn_beta_slow = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--pooling") {
+        CHECK_ARG
+        std::string value(argv[i]);
+        /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
+        else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
+        else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
+        else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
+        else { invalid_param = true; }
+        return true;
+    }
+    if (arg == "--attention") {
+        CHECK_ARG
+        std::string value(argv[i]);
+        /**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
+        else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL; }
+        else { invalid_param = true; }
+        return true;
+    }
+    if (arg == "--defrag-thold" || arg == "-dt") {
+        CHECK_ARG
+        params.defrag_thold = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--samplers") {
+        CHECK_ARG
+        const auto sampler_names = string_split(argv[i], ';');
+        sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true);
+        return true;
+    }
+    if (arg == "--sampling-seq") {
+        CHECK_ARG
+        sparams.samplers_sequence = llama_sampling_types_from_chars(argv[i]);
+        return true;
+    }
+    if (arg == "--top-p") {
+        CHECK_ARG
+        sparams.top_p = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--min-p") {
+        CHECK_ARG
+        sparams.min_p = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--temp") {
+        CHECK_ARG
+        sparams.temp = std::stof(argv[i]);
+        sparams.temp = std::max(sparams.temp, 0.0f);
+        return true;
+    }
+    if (arg == "--tfs") {
+        CHECK_ARG
+        sparams.tfs_z = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--typical") {
+        CHECK_ARG
+        sparams.typical_p = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--repeat-last-n") {
+        CHECK_ARG
+        sparams.penalty_last_n = std::stoi(argv[i]);
+        sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
+        return true;
+    }
+    if (arg == "--repeat-penalty") {
+        CHECK_ARG
+        sparams.penalty_repeat = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--frequency-penalty") {
+        CHECK_ARG
+        sparams.penalty_freq = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--presence-penalty") {
+        CHECK_ARG
+        sparams.penalty_present = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--dynatemp-range") {
+        CHECK_ARG
+        sparams.dynatemp_range = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--dynatemp-exp") {
+        CHECK_ARG
+        sparams.dynatemp_exponent = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--mirostat") {
+        CHECK_ARG
+        sparams.mirostat = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--mirostat-lr") {
+        CHECK_ARG
+        sparams.mirostat_eta = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--mirostat-ent") {
+        CHECK_ARG
+        sparams.mirostat_tau = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "--cfg-negative-prompt") {
+        CHECK_ARG
+        sparams.cfg_negative_prompt = argv[i];
+        return true;
+    }
+    if (arg == "--cfg-negative-prompt-file") {
+        CHECK_ARG
+        std::ifstream file(argv[i]);
+        if (!file) {
+            fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+            invalid_param = true;
+            return true;
+        }
+        std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(sparams.cfg_negative_prompt));
+        if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') {
+            sparams.cfg_negative_prompt.pop_back();
+        }
+        return true;
+    }
+    if (arg == "--cfg-scale") {
+        CHECK_ARG
+        sparams.cfg_scale = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "-b" || arg == "--batch-size") {
+        CHECK_ARG
+        params.n_batch = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "-ub" || arg == "--ubatch-size") {
+        CHECK_ARG
+        params.n_ubatch = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--keep") {
+        CHECK_ARG
+        params.n_keep = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--draft") {
+        CHECK_ARG
+        params.n_draft = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--chunks") {
+        CHECK_ARG
+        params.n_chunks = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "-np" || arg == "--parallel") {
+        CHECK_ARG
+        params.n_parallel = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "-ns" || arg == "--sequences") {
+        CHECK_ARG
+        params.n_sequences = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--p-split" || arg == "-ps") {
+        CHECK_ARG
+        params.p_split = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "-m" || arg == "--model") {
+        CHECK_ARG
+        params.model = argv[i];
+        return true;
+    }
+    if (arg == "-md" || arg == "--model-draft") {
+        CHECK_ARG
+        params.model_draft = argv[i];
+        return true;
+    }
+    if (arg == "-a" || arg == "--alias") {
+        CHECK_ARG
+        params.model_alias = argv[i];
+        return true;
+    }
+    if (arg == "-mu" || arg == "--model-url") {
+        CHECK_ARG
+        params.model_url = argv[i];
+        return true;
+    }
+    if (arg == "-hft" || arg == "--hf-token") {
+        if (++i >= argc) {
+          invalid_param = true;
+          return true;
+        }
+        params.hf_token = argv[i];
+        return true;
+    }
+    if (arg == "-hfr" || arg == "--hf-repo") {
+        CHECK_ARG
+        params.hf_repo = argv[i];
+        return true;
+    }
+    if (arg == "-hff" || arg == "--hf-file") {
+        CHECK_ARG
+        params.hf_file = argv[i];
+        return true;
+    }
+    if (arg == "--lora") {
+        CHECK_ARG
+        params.lora_adapters.push_back({
+            std::string(argv[i]),
+            1.0,
+        });
+        return true;
+    }
+    if (arg == "--lora-scaled") {
+        CHECK_ARG
+        std::string lora_adapter = argv[i];
+        CHECK_ARG
+        params.lora_adapters.push_back({
+            lora_adapter,
+            std::stof(argv[i]),
+        });
+        return true;
+    }
+    if (arg == "--lora-init-without-apply") {
+        params.lora_init_without_apply = true;
+        return true;
+    }
+    if (arg == "--control-vector") {
+        CHECK_ARG
+        params.control_vectors.push_back({ 1.0f, argv[i], });
+        return true;
+    }
+    if (arg == "--control-vector-scaled") {
+        CHECK_ARG
+        const char* fname = argv[i];
+        CHECK_ARG
+        params.control_vectors.push_back({ std::stof(argv[i]), fname, });
+        return true;
+    }
+    if (arg == "--control-vector-layer-range") {
+        CHECK_ARG
+        params.control_vector_layer_start = std::stoi(argv[i]);
+        CHECK_ARG
+        params.control_vector_layer_end = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--mmproj") {
+        CHECK_ARG
+        params.mmproj = argv[i];
+        return true;
+    }
+    if (arg == "--image") {
+        CHECK_ARG
+        params.image.emplace_back(argv[i]);
+        return true;
+    }
+    if (arg == "-i" || arg == "--interactive") {
+        params.interactive = true;
+        return true;
+    }
+    if (arg == "-sp" || arg == "--special") {
+        params.special = true;
+        return true;
+    }
+    if (arg == "--embedding" || arg == "--embeddings") {
+        params.embedding = true;
+        return true;
+    }
+    if (arg == "--embd-normalize") {
+        CHECK_ARG
+        params.embd_normalize = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--embd-output-format") {
+        CHECK_ARG
+        params.embd_out = argv[i];
+        return true;
+    }
+    if (arg == "--embd-separator") {
+        CHECK_ARG
+        params.embd_sep = argv[i];
+        return true;
+    }
+    if (arg == "-if" || arg == "--interactive-first") {
+        params.interactive_first = true;
+        return true;
+    }
+    if (arg == "-cnv" || arg == "--conversation") {
+        params.conversation = true;
+        return true;
+    }
+    if (arg == "--infill") {
+        params.infill = true;
+        return true;
+    }
+    if (arg == "-dkvc" || arg == "--dump-kv-cache") {
+        params.dump_kv_cache = true;
+        return true;
+    }
+    if (arg == "-nkvo" || arg == "--no-kv-offload") {
+        params.no_kv_offload = true;
+        return true;
+    }
+    if (arg == "-ctk" || arg == "--cache-type-k") {
+        params.cache_type_k = argv[++i];
+        return true;
+    }
+    if (arg == "-ctv" || arg == "--cache-type-v") {
+        params.cache_type_v = argv[++i];
+        return true;
+    }
+    if (arg == "-mli" || arg == "--multiline-input") {
+        params.multiline_input = true;
+        return true;
+    }
+    if (arg == "--simple-io") {
+        params.simple_io = true;
+        return true;
+    }
+    if (arg == "-cb" || arg == "--cont-batching") {
+        params.cont_batching = true;
+        return true;
+    }
+    if (arg == "-nocb" || arg == "--no-cont-batching") {
+        params.cont_batching = false;
+        return true;
+    }
+    if (arg == "-fa" || arg == "--flash-attn") {
+        params.flash_attn = true;
+        return true;
+    }
+    if (arg == "-co" || arg == "--color") {
+        params.use_color = true;
+        return true;
+    }
+    if (arg == "--mlock") {
+        params.use_mlock = true;
+        return true;
+    }
+    if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") {
+        CHECK_ARG
+        params.n_gpu_layers = std::stoi(argv[i]);
+        if (!llama_supports_gpu_offload()) {
+            fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers option will be ignored\n");
+            fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
+        }
+        return true;
+    }
+    if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--n-gpu-layers-draft") {
+        CHECK_ARG
+        params.n_gpu_layers_draft = std::stoi(argv[i]);
+        if (!llama_supports_gpu_offload()) {
+            fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n");
+            fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
+        }
+        return true;
+    }
+    if (arg == "--main-gpu" || arg == "-mg") {
+        CHECK_ARG
+        params.main_gpu = std::stoi(argv[i]);
+#ifndef GGML_USE_CUDA_SYCL_VULKAN
+        fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting the main GPU has no effect.\n");
+#endif // GGML_USE_CUDA_SYCL_VULKAN
+        return true;
+    }
+    if (arg == "--split-mode" || arg == "-sm") {
+        CHECK_ARG
+        std::string arg_next = argv[i];
+        if (arg_next == "none") {
+            params.split_mode = LLAMA_SPLIT_MODE_NONE;
+        }
+        else if (arg_next == "layer") {
+            params.split_mode = LLAMA_SPLIT_MODE_LAYER;
+        }
+        else if (arg_next == "row") {
+#ifdef GGML_USE_SYCL
+            fprintf(stderr, "warning: The split mode value:[row] is not supported by llama.cpp with SYCL. It's developing.\nExit!\n");
+            exit(1);
+#endif // GGML_USE_SYCL
+            params.split_mode = LLAMA_SPLIT_MODE_ROW;
+        }
+        else {
+            invalid_param = true;
+            return true;
+        }
+#ifndef GGML_USE_CUDA_SYCL_VULKAN
+        fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting the split mode has no effect.\n");
+#endif // GGML_USE_CUDA_SYCL_VULKAN
+        return true;
+    }
+    if (arg == "--tensor-split" || arg == "-ts") {
+        CHECK_ARG
+        std::string arg_next = argv[i];
+
+        // split string by , and /
+        const std::regex regex{ R"([,/]+)" };
+        std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 };
+        std::vector<std::string> split_arg{ it, {} };
+        if (split_arg.size() >= llama_max_devices()) {
+            invalid_param = true;
+            return true;
+        }
+        for (size_t i = 0; i < llama_max_devices(); ++i) {
+            if (i < split_arg.size()) {
+                params.tensor_split[i] = std::stof(split_arg[i]);
+            }
+            else {
+                params.tensor_split[i] = 0.0f;
+            }
+        }
+#ifndef GGML_USE_CUDA_SYCL_VULKAN
+        fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting a tensor split has no effect.\n");
+#endif // GGML_USE_CUDA_SYCL_VULKAN
+        return true;
+    }
+    if (arg == "--rpc") {
+        CHECK_ARG
+        params.rpc_servers = argv[i];
+        return true;
+    }
+    if (arg == "--no-mmap") {
+        params.use_mmap = false;
+        return true;
+    }
+    if (arg == "--numa") {
+        CHECK_ARG
+        std::string value(argv[i]);
+        /**/ if (value == "distribute" || value == "") { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; }
+        else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; }
+        else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
+        else { invalid_param = true; }
+        return true;
+    }
+    if (arg == "-v" || arg == "--verbose") {
+        params.verbosity = 1;
+        return true;
+    }
+    if (arg == "--verbosity") {
+        CHECK_ARG
+        params.verbosity = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--verbose-prompt") {
+        params.verbose_prompt = true;
+        return true;
+    }
+    if (arg == "--no-display-prompt") {
+        params.display_prompt = false;
+        return true;
+    }
+    if (arg == "-r" || arg == "--reverse-prompt") {
+        CHECK_ARG
+        params.antiprompt.emplace_back(argv[i]);
+        return true;
+    }
+    if (arg == "-ld" || arg == "--logdir") {
+        CHECK_ARG
+        params.logdir = argv[i];
+
+        if (params.logdir.back() != DIRECTORY_SEPARATOR) {
+            params.logdir += DIRECTORY_SEPARATOR;
+        }
+        return true;
+    }
+    if (arg == "-lcs" || arg == "--lookup-cache-static") {
+        CHECK_ARG
+        params.lookup_cache_static = argv[i];
+        return true;
+    }
+    if (arg == "-lcd" || arg == "--lookup-cache-dynamic") {
+        CHECK_ARG
+        params.lookup_cache_dynamic = argv[i];
+        return true;
+    }
+    if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
+        CHECK_ARG
+        params.logits_file = argv[i];
+        return true;
+    }
+    if (arg == "--perplexity" || arg == "--all-logits") {
+        params.logits_all = true;
+        return true;
+    }
+    if (arg == "--ppl-stride") {
+        CHECK_ARG
+        params.ppl_stride = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--ppl-output-type") {
+        CHECK_ARG
+        params.ppl_output_type = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "-ptc" || arg == "--print-token-count") {
+        CHECK_ARG
+        params.n_print = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--check-tensors") {
+        params.check_tensors = true;
+        return true;
+    }
+    if (arg == "--hellaswag") {
+        params.hellaswag = true;
+        return true;
+    }
+    if (arg == "--hellaswag-tasks") {
+        CHECK_ARG
+        params.hellaswag_tasks = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--winogrande") {
+        params.winogrande = true;
+        return true;
+    }
+    if (arg == "--winogrande-tasks") {
+        CHECK_ARG
+        params.winogrande_tasks = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--multiple-choice") {
+        params.multiple_choice = true;
+        return true;
+    }
+    if (arg == "--multiple-choice-tasks") {
+        CHECK_ARG
+        params.multiple_choice_tasks = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--kl-divergence") {
+        params.kl_divergence = true;
+        return true;
+    }
+    if (arg == "--ignore-eos") {
+        params.ignore_eos = true;
+        return true;
+    }
+    if (arg == "--penalize-nl") {
+        sparams.penalize_nl = true;
+        return true;
+    }
+    if (arg == "-l" || arg == "--logit-bias") {
+        CHECK_ARG
+        std::stringstream ss(argv[i]);
+        llama_token key;
+        char sign;
+        std::string value_str;
+        try {
+            if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
+                sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
+            }
+            else {
+                throw std::exception();
+            }
+        }
+        catch (const std::exception&) {
+            invalid_param = true;
+            return true;
+        }
+        return true;
+    }
+    if (arg == "-h" || arg == "--help" || arg == "--usage"  ) {
+        params.usage = true;
+        return true;
+    }
+    if (arg == "--version") {
+        fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
+        fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET);
+        exit(0);
+    }
+    if (arg == "--in-prefix-bos") {
+        params.input_prefix_bos = true;
+        params.enable_chat_template = false;
+        return true;
+    }
+    if (arg == "--in-prefix") {
+        CHECK_ARG
+        params.input_prefix = argv[i];
+        params.enable_chat_template = false;
+        return true;
+    }
+    if (arg == "--in-suffix") {
+        CHECK_ARG
+        params.input_suffix = argv[i];
+        params.enable_chat_template = false;
+        return true;
+    }
+    if (arg == "--spm-infill") {
+        params.spm_infill = true;
+        return true;
+    }
+    if (arg == "--grammar") {
+        CHECK_ARG
+        sparams.grammar = argv[i];
+        return true;
+    }
+    if (arg == "--grammar-file") {
+        CHECK_ARG
+        std::ifstream file(argv[i]);
+        if (!file) {
+            fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+            invalid_param = true;
+            return true;
+        }
+        std::copy(
+            std::istreambuf_iterator<char>(file),
+            std::istreambuf_iterator<char>(),
+            std::back_inserter(sparams.grammar)
+        );
+        return true;
+    }
+    if (arg == "-j" || arg == "--json-schema") {
+        CHECK_ARG
+        sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
+        return true;
+    }
+    if (arg == "--override-kv") {
+        CHECK_ARG
+        if (!string_parse_kv_override(argv[i], params.kv_overrides)) {
+            fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
+            invalid_param = true;
+            return true;
+        }
+        return true;
+    }
+    if (arg == "--host") {
+        CHECK_ARG
+        params.hostname = argv[i];
+        return true;
+    }
+    if (arg == "--port") {
+        CHECK_ARG
+        params.port = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--path") {
+        CHECK_ARG
+        params.public_path = argv[i];
+        return true;
+    }
+    if (arg == "--api-key") {
+        CHECK_ARG
+        params.api_keys.push_back(argv[i]);
+        return true;
+    }
+    if (arg == "--api-key-file") {
+        CHECK_ARG
+        std::ifstream key_file(argv[i]);
+        if (!key_file) {
+            fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+            invalid_param = true;
+            return true;
+        }
+        std::string key;
+        while (std::getline(key_file, key)) {
+            if (!key.empty()) {
+                params.api_keys.push_back(key);
+            }
+        }
+        key_file.close();
+        return true;
+    }
+    if (arg == "--ssl-key-file") {
+        CHECK_ARG
+        params.ssl_file_key = argv[i];
+        return true;
+    }
+    if (arg == "--ssl-cert-file") {
+        CHECK_ARG
+        params.ssl_file_cert = argv[i];
+        return true;
+    }
+    if (arg == "--timeout" || arg == "-to") {
+        CHECK_ARG
+        params.timeout_read  = std::stoi(argv[i]);
+        params.timeout_write = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--threads-http") {
+        CHECK_ARG
+        params.n_threads_http = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "-spf" || arg == "--system-prompt-file") {
+        CHECK_ARG
+        std::ifstream file(argv[i]);
+        if (!file) {
+            fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+            invalid_param = true;
+            return true;
+        }
+        std::string system_prompt;
+        std::copy(
+                std::istreambuf_iterator<char>(file),
+                std::istreambuf_iterator<char>(),
+                std::back_inserter(system_prompt)
+                );
+        params.system_prompt = system_prompt;
+        return true;
+    }
+    if (arg == "--log-format") {
+        CHECK_ARG
+        if (std::strcmp(argv[i], "json") == 0) {
+            params.log_json = true;
+        } else if (std::strcmp(argv[i], "text") == 0) {
+            params.log_json = false;
+        } else {
+            invalid_param = true;
+            return true;
+        }
+        return true;
+    }
+    if (arg == "--no-slots") {
+        params.endpoint_slots = false;
+        return true;
+    }
+    if (arg == "--metrics") {
+        params.endpoint_metrics = true;
+        return true;
+    }
+    if (arg == "--slot-save-path") {
+        CHECK_ARG
+        params.slot_save_path = argv[i];
+        // if doesn't end with DIRECTORY_SEPARATOR, add it
+        if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) {
+            params.slot_save_path += DIRECTORY_SEPARATOR;
+        }
+        return true;
+    }
+    if (arg == "--chat-template") {
+        CHECK_ARG
+        if (!llama_chat_verify_template(argv[i])) {
+            fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]);
+            fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n");
+            invalid_param = true;
+            return true;
+        }
+        params.chat_template = argv[i];
+        return true;
+    }
+    if (arg == "--slot-prompt-similarity" || arg == "-sps") {
+        CHECK_ARG
+        params.slot_prompt_similarity = std::stof(argv[i]);
+        return true;
+    }
+    if (arg == "-pps") {
+        params.is_pp_shared = true;
+        return true;
+    }
+    if (arg == "-npp") {
+        CHECK_ARG
+        auto p = string_split<int>(argv[i], split_delim);
+        params.n_pp.insert(params.n_pp.end(), p.begin(), p.end());
+        return true;
+    }
+    if (arg == "-ntg") {
+        CHECK_ARG
+        auto p = string_split<int>(argv[i], split_delim);
+        params.n_tg.insert(params.n_tg.end(), p.begin(), p.end());
+        return true;
+    }
+    if (arg == "-npl") {
+        CHECK_ARG
+        auto p = string_split<int>(argv[i], split_delim);
+        params.n_pl.insert(params.n_pl.end(), p.begin(), p.end());
+        return true;
+    }
+    if (arg == "--context-file") {
+        CHECK_ARG
+        std::ifstream file(argv[i], std::ios::binary);
+        if (!file) {
+            fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+            invalid_param = true;
+            return true;
+        }
+        params.context_files.push_back(argv[i]);
+        return true;
+    }
+    if (arg == "--chunk-size") {
+        CHECK_ARG
+        params.chunk_size = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--chunk-separator") {
+        CHECK_ARG
+        params.chunk_separator = argv[i];
+        return true;
+    }
+    if (arg == "--junk") {
+        CHECK_ARG
+        params.n_junk = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--pos") {
+        CHECK_ARG
+        params.i_pos = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "-o" || arg == "--output" || arg == "--output-file") {
+        CHECK_ARG
+        params.out_file = argv[i];
+        params.cvector_outfile = argv[i];
+        params.lora_outfile = argv[i];
+        return true;
+    }
+    if (arg == "-ofreq" || arg == "--output-frequency") {
+        CHECK_ARG
+        params.n_out_freq = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--save-frequency") {
+        CHECK_ARG
+        params.n_save_freq = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--process-output") {
+        params.process_output = true;
+        return true;
+    }
+    if (arg == "--no-ppl") {
+        params.compute_ppl = false;
+        return true;
+    }
+    if (arg == "--chunk" || arg == "--from-chunk") {
+        CHECK_ARG
+        params.i_chunk = std::stoi(argv[i]);
+        return true;
+    }
+    // cvector params
+    if (arg == "--positive-file") {
+        CHECK_ARG
+        params.cvector_positive_file = argv[i];
+        return true;
+    }
+    if (arg == "--negative-file") {
+        CHECK_ARG
+        params.cvector_negative_file = argv[i];
+        return true;
+    }
+    if (arg == "--pca-batch") {
+        CHECK_ARG
+        params.n_pca_batch = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--pca-iter") {
+        CHECK_ARG
+        params.n_pca_iterations = std::stoi(argv[i]);
+        return true;
+    }
+    if (arg == "--method") {
+        CHECK_ARG
+        std::string value(argv[i]);
+        /**/ if (value == "pca") { params.cvector_dimre_method = DIMRE_METHOD_PCA; }
+        else if (value == "mean") { params.cvector_dimre_method = DIMRE_METHOD_MEAN; }
+        else { invalid_param = true; }
+        return true;
+    }
+    if (arg == "--no-warmup") {
+        params.warmup = false;
+        return true;
+    }
+#ifndef LOG_DISABLE_LOGS
+    // Parse args for logging parameters
+    if (log_param_single_parse(argv[i])) {
+        // Do nothing, log_param_single_parse automatically does it's thing
+        //  and returns if a match was found and parsed.
+        return true;
+    }
+    if (log_param_pair_parse( /*check_but_dont_parse*/ true, argv[i])) {
+        // We have a matching known parameter requiring an argument,
+        //  now we need to check if there is anything after this argv
+        //  and flag invalid_param or parse it.
+        CHECK_ARG
+        if (!log_param_pair_parse( /*check_but_dont_parse*/ false, argv[i - 1], argv[i])) {
+            invalid_param = true;
+            return true;
+        }
+        return true;
+    }
+    // End of Parse args for logging parameters
+#endif // LOG_DISABLE_LOGS
+
+    return false;
+}
+
+#ifdef __GNUC__
+#ifdef __MINGW32__
+#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
+#else
+#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
+#endif
+#else
+#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
+#endif
+
+void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
+    const llama_sampling_params & sparams = params.sparams;
+
+    std::string sampler_type_chars;
+    std::string sampler_type_names;
+    for (const auto sampler_type : sparams.samplers_sequence) {
+        sampler_type_chars += static_cast<char>(sampler_type);
+        sampler_type_names += llama_sampling_type_to_str(sampler_type) + ";";
+    }
+    sampler_type_names.pop_back();
+
+    struct option_info {
+        LLAMA_COMMON_ATTRIBUTE_FORMAT(4, 5)
+        option_info(const std::string & tags, const char * args, const char * desc, ...) : tags(tags), args(args), desc(desc) {
+            va_list args_list;
+            va_start(args_list, desc);
+            char buffer[1024];
+            vsnprintf(buffer, sizeof(buffer), desc, args_list);
+            va_end(args_list);
+            this->desc = buffer;
+        }
+
+        option_info(const std::string & grp) : grp(grp) {}
+
+        std::string tags;
+        std::string args;
+        std::string desc;
+        std::string grp;
+    };
+
+    std::vector<option_info> options;
+
+    // TODO: filter by tags
+
+    options.push_back({ "general" });
+    options.push_back({ "*",           "-h,    --help, --usage",        "print usage and exit" });
+    options.push_back({ "*",           "       --version",              "show version and build info" });
+    options.push_back({ "*",           "-v,    --verbose",              "print verbose information" });
+    options.push_back({ "*",           "       --verbosity N",          "set specific verbosity level (default: %d)", params.verbosity });
+    options.push_back({ "*",           "       --verbose-prompt",       "print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false" });
+    options.push_back({ "*",           "       --no-display-prompt",    "don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false" });
+    options.push_back({ "*",           "-co,   --color",                "colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false" });
+    options.push_back({ "*",           "-s,    --seed SEED",            "RNG seed (default: %d, use random seed for < 0)", params.seed });
+    options.push_back({ "*",           "-t,    --threads N",            "number of threads to use during generation (default: %d)", params.cpuparams.n_threads });
+    options.push_back({ "*",           "-tb,   --threads-batch N",      "number of threads to use during batch and prompt processing (default: same as --threads)" });
+    options.push_back({ "speculative", "-td,   --threads-draft N",      "number of threads to use during generation (default: same as --threads)" });
+    options.push_back({ "speculative", "-tbd,  --threads-batch-draft N","number of threads to use during batch and prompt processing (default: same as --threads-draft)" });
+
+#ifndef GGML_USE_OPENMP
+    // these options are available only with the internal threadpool
+    options.push_back({ "*",           "-C,    --cpu-mask M",            "CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")"});
+    options.push_back({ "*",           "-Cr,   --cpu-range lo-hi",       "range of CPUs for affinity. Complements --cpu-mask"});
+    options.push_back({ "*",           "       --cpu-strict <0|1>",      "use strict CPU placement (default: %u)\n", (unsigned) params.cpuparams.strict_cpu});
+    options.push_back({ "*",           "       --priority N",            "set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.cpuparams.priority});
+    options.push_back({ "*",           "       --poll <0...100>",        "use polling level to wait for work (0 - no polling, default: %u)\n", (unsigned) params.cpuparams.poll});
+
+    options.push_back({ "*",           "-Cb,   --cpu-mask-batch M",      "CPU affinity mask: arbitrarily long hex. Complements cpu-range-batch (default: same as --cpu-mask)"});
+    options.push_back({ "*",           "-Crb,  --cpu-range-batch lo-hi", "ranges of CPUs for affinity. Complements --cpu-mask-batch"});
+    options.push_back({ "*",           "       --cpu-strict-batch <0|1>","use strict CPU placement (default: same as --cpu-strict)"});
+    options.push_back({ "*",           "       --priority-batch N",      "set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: --priority)"});
+    options.push_back({ "*",           "       --poll-batch <0|1>",      "use polling to wait for work (default: same as --poll"});
+
+    options.push_back({ "speculative", "-Cd,   --cpu-mask-draft M",      "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)"});
+    options.push_back({ "speculative", "-Crd,  --cpu-range-draft lo-hi", "Ranges of CPUs for affinity. Complements --cpu-mask-draft"});
+    options.push_back({ "speculative", "       --cpu-strict-draft <0|1>","Use strict CPU placement for draft model (default: same as --cpu-strict)"});
+    options.push_back({ "speculative", "       --priority-draft N",      "Set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: same as --priority)"});
+    options.push_back({ "speculative", "       --poll-draft <0|1>",      "Use polling to wait for draft model work (default: same as --poll])"});
+
+    options.push_back({ "speculative", "-Cbd,  --cpu-mask-batch-draft M","Draft model CPU affinity mask. Complements cpu-range-draft-batch (default: same as --cpu-mask-draft)"});
+    options.push_back({ "speculative", "-Crbd, --cpu-range-batch-draft lo-hi",
+                                                                         "Ranges of CPUs for affinity. Complements --cpu-mask-draft-batch)"});
+    options.push_back({ "speculative", "       --cpu-strict-batch-draft <0|1>",
+                                                                         "Use strict CPU placement for draft model (default: --cpu-strict-draft)"});
+    options.push_back({ "speculative", "       --priority-batch-draft N","Set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: --priority-draft)"});
+    options.push_back({ "speculative", "       --poll-batch-draft <0|1>","Use polling to wait for draft model work (default: --poll-draft)"});
+#endif // GGML_USE_OPENMP
+
+    options.push_back({ "speculative", "       --draft N",              "number of tokens to draft for speculative decoding (default: %d)", params.n_draft });
+    options.push_back({ "speculative", "-ps,   --p-split N",            "speculative decoding split probability (default: %.1f)", (double)params.p_split });
+    options.push_back({ "*",           "-lcs,  --lookup-cache-static FNAME",
+                                                                        "path to static lookup cache to use for lookup decoding (not updated by generation)" });
+    options.push_back({ "*",           "-lcd,  --lookup-cache-dynamic FNAME",
+                                                                        "path to dynamic lookup cache to use for lookup decoding (updated by generation)" });
+
+    options.push_back({ "*",           "-c,    --ctx-size N",           "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx });
+    options.push_back({ "*",           "-n,    --predict N",            "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict });
+    options.push_back({ "*",           "-b,    --batch-size N",         "logical maximum batch size (default: %d)", params.n_batch });
+    options.push_back({ "*",           "-ub,   --ubatch-size N",        "physical maximum batch size (default: %d)", params.n_ubatch });
+    options.push_back({ "*",           "       --keep N",               "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep });
+    options.push_back({ "*",           "       --chunks N",             "max number of chunks to process (default: %d, -1 = all)", params.n_chunks });
+    options.push_back({ "*",           "-fa,   --flash-attn",           "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
+    options.push_back({ "*",           "-p,    --prompt PROMPT",        "prompt to start generation with\n"
+                                                                        "in conversation mode, this will be used as system prompt\n"
+                                                                        "(default: '%s')", params.prompt.c_str() });
+    options.push_back({ "*",           "-f,    --file FNAME",           "a file containing the prompt (default: none)" });
+    options.push_back({ "*",           "       --in-file FNAME",        "an input file (repeat to specify multiple files)" });
+    options.push_back({ "*",           "-bf,   --binary-file FNAME",    "binary file containing the prompt (default: none)" });
+    options.push_back({ "*",           "-e,    --escape",               "process escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\) (default: %s)", params.escape ? "true" : "false" });
+    options.push_back({ "*",           "       --no-escape",            "do not process escape sequences" });
+    options.push_back({ "main",        "-ptc,  --print-token-count N",  "print token count every N tokens (default: %d)", params.n_print });
+    options.push_back({ "main",        "       --prompt-cache FNAME",   "file to cache prompt state for faster startup (default: none)" });
+    options.push_back({ "main",        "       --prompt-cache-all",     "if specified, saves user input and generations to cache as well\n"
+                                                                        "not supported with --interactive or other interactive options" });
+    options.push_back({ "main",        "       --prompt-cache-ro",      "if specified, uses the prompt cache but does not update it" });
+    options.push_back({ "main",        "-r,    --reverse-prompt PROMPT",
+                                                                        "halt generation at PROMPT, return control in interactive mode\n"
+                                                                        "can be specified more than once for multiple prompts" });
+    options.push_back({ "main",        "-sp,   --special",              "special tokens output enabled (default: %s)", params.special ? "true" : "false" });
+    options.push_back({ "main",        "-cnv,  --conversation",         "run in conversation mode, does not print special tokens and suffix/prefix\n"
+                                                                        "if suffix/prefix are not specified, default chat template will be used\n"
+                                                                        "(default: %s)", params.conversation ? "true" : "false" });
+    options.push_back({ "main infill", "-i,    --interactive",          "run in interactive mode (default: %s)", params.interactive ? "true" : "false" });
+    options.push_back({ "main infill", "-if,   --interactive-first",    "run in interactive mode and wait for input right away (default: %s)", params.interactive_first ? "true" : "false" });
+    options.push_back({ "main infill", "-mli,  --multiline-input",      "allows you to write or paste multiple lines without ending each in '\\'" });
+    options.push_back({ "main infill", "       --in-prefix-bos",        "prefix BOS to user inputs, preceding the `--in-prefix` string" });
+    options.push_back({ "main infill", "       --in-prefix STRING",     "string to prefix user inputs with (default: empty)" });
+    options.push_back({ "main infill", "       --in-suffix STRING",     "string to suffix after user inputs with (default: empty)" });
+    options.push_back({ "main",        "       --no-warmup",            "skip warming up the model with an empty run" });
+    options.push_back({ "server infill",
+                                       "       --spm-infill",           "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" });
+
+    options.push_back({ "sampling" });
+    options.push_back({ "*",           "       --samplers SAMPLERS",    "samplers that will be used for generation in the order, separated by \';\'\n"
+                                                                        "(default: %s)", sampler_type_names.c_str() });
+    options.push_back({ "*",           "       --sampling-seq SEQUENCE",
+                                                                        "simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() });
+    options.push_back({ "*",           "       --ignore-eos",           "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" });
+    options.push_back({ "*",           "       --penalize-nl",          "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" });
+    options.push_back({ "*",           "       --temp N",               "temperature (default: %.1f)", (double)sparams.temp });
+    options.push_back({ "*",           "       --top-k N",              "top-k sampling (default: %d, 0 = disabled)", sparams.top_k });
+    options.push_back({ "*",           "       --top-p N",              "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p });
+    options.push_back({ "*",           "       --min-p N",              "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p });
+    options.push_back({ "*",           "       --tfs N",                "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z });
+    options.push_back({ "*",           "       --typical N",            "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typical_p });
+    options.push_back({ "*",           "       --repeat-last-n N",      "last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", sparams.penalty_last_n });
+    options.push_back({ "*",           "       --repeat-penalty N",     "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat });
+    options.push_back({ "*",           "       --presence-penalty N",   "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present });
+    options.push_back({ "*",           "       --frequency-penalty N",  "repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_freq });
+    options.push_back({ "*",           "       --dynatemp-range N",     "dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)sparams.dynatemp_range });
+    options.push_back({ "*",           "       --dynatemp-exp N",       "dynamic temperature exponent (default: %.1f)", (double)sparams.dynatemp_exponent });
+    options.push_back({ "*",           "       --mirostat N",           "use Mirostat sampling.\n"
+                                                                        "Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"
+                                                                        "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", sparams.mirostat });
+    options.push_back({ "*",           "       --mirostat-lr N",        "Mirostat learning rate, parameter eta (default: %.1f)", (double)sparams.mirostat_eta });
+    options.push_back({ "*",           "       --mirostat-ent N",       "Mirostat target entropy, parameter tau (default: %.1f)", (double)sparams.mirostat_tau });
+    options.push_back({ "*",           "       -l TOKEN_ID(+/-)BIAS",   "modifies the likelihood of token appearing in the completion,\n"
+                                                                        "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
+                                                                        "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
+    options.push_back({ "main",        "       --cfg-negative-prompt PROMPT",
+                                                                        "negative prompt to use for guidance (default: '%s')", sparams.cfg_negative_prompt.c_str() });
+    options.push_back({ "main",        "       --cfg-negative-prompt-file FNAME",
+                                                                        "negative prompt file to use for guidance" });
+    options.push_back({ "main",        "       --cfg-scale N",          "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale });
+    options.push_back({ "main",        "       --chat-template JINJA_TEMPLATE",
+                                                                        "set custom jinja chat template (default: template taken from model's metadata)\n"
+                                                                        "if suffix/prefix are specified, template will be disabled\n"
+                                                                        "only commonly used templates are accepted:\n"
+                                                                        "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
+    options.push_back({ "grammar" });
+    options.push_back({ "*",           "       --grammar GRAMMAR",      "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() });
+    options.push_back({ "*",           "       --grammar-file FNAME",   "file to read grammar from" });
+    options.push_back({ "*",           "-j,    --json-schema SCHEMA",
+                                                                        "JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\n"
+                                                                        "For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead" });
+
+    options.push_back({ "embedding" });
+    options.push_back({ "embedding",   "       --pooling {none,mean,cls,last}",
+                                                                        "pooling type for embeddings, use model default if unspecified" });
+    options.push_back({ "embedding",   "       --attention {causal,non-causal}",
+                                                                        "attention type for embeddings, use model default if unspecified" });
+
+    options.push_back({ "context hacking" });
+    options.push_back({ "*",           "       --rope-scaling {none,linear,yarn}",
+                                                                        "RoPE frequency scaling method, defaults to linear unless specified by the model" });
+    options.push_back({ "*",           "       --rope-scale N",         "RoPE context scaling factor, expands context by a factor of N" });
+    options.push_back({ "*",           "       --rope-freq-base N",     "RoPE base frequency, used by NTK-aware scaling (default: loaded from model)" });
+    options.push_back({ "*",           "       --rope-freq-scale N",    "RoPE frequency scaling factor, expands context by a factor of 1/N" });
+    options.push_back({ "*",           "       --yarn-orig-ctx N",      "YaRN: original context size of model (default: %d = model training context size)", params.yarn_orig_ctx });
+    options.push_back({ "*",           "       --yarn-ext-factor N",    "YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor });
+    options.push_back({ "*",           "       --yarn-attn-factor N",   "YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor });
+    options.push_back({ "*",           "       --yarn-beta-slow N",     "YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow });
+    options.push_back({ "*",           "       --yarn-beta-fast N",     "YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast });
+    options.push_back({ "*",           "-gan,  --grp-attn-n N",         "group-attention factor (default: %d)", params.grp_attn_n });
+    options.push_back({ "*",           "-gaw,  --grp-attn-w N",         "group-attention width (default: %.1f)", (double)params.grp_attn_w });
+    options.push_back({ "*",           "-dkvc, --dump-kv-cache",        "verbose print of the KV cache" });
+    options.push_back({ "*",           "-nkvo, --no-kv-offload",        "disable KV offload" });
+    options.push_back({ "*",           "-ctk,  --cache-type-k TYPE",    "KV cache data type for K (default: %s)", params.cache_type_k.c_str() });
+    options.push_back({ "*",           "-ctv,  --cache-type-v TYPE",    "KV cache data type for V (default: %s)", params.cache_type_v.c_str() });
+
+    options.push_back({ "perplexity" });
+    options.push_back({ "perplexity",  "       --all-logits",           "return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false" });
+    options.push_back({ "perplexity",  "       --hellaswag",            "compute HellaSwag score over random tasks from datafile supplied with -f" });
+    options.push_back({ "perplexity",  "       --hellaswag-tasks N",    "number of tasks to use when computing the HellaSwag score (default: %zu)", params.hellaswag_tasks });
+    options.push_back({ "perplexity",  "       --winogrande",           "compute Winogrande score over random tasks from datafile supplied with -f" });
+    options.push_back({ "perplexity",  "       --winogrande-tasks N",   "number of tasks to use when computing the Winogrande score (default: %zu)", params.winogrande_tasks });
+    options.push_back({ "perplexity",  "       --multiple-choice",      "compute multiple choice score over random tasks from datafile supplied with -f" });
+    options.push_back({ "perplexity",  "       --multiple-choice-tasks N",
+                                                                        "number of tasks to use when computing the multiple choice score (default: %zu)", params.multiple_choice_tasks });
+    options.push_back({ "perplexity",  "       --kl-divergence",        "computes KL-divergence to logits provided via --kl-divergence-base" });
+    options.push_back({ "perplexity",  "       --ppl-stride N",         "stride for perplexity calculation (default: %d)", params.ppl_stride });
+    options.push_back({ "perplexity",  "       --ppl-output-type {0,1}",
+                                                                        "output type for perplexity calculation (default: %d)", params.ppl_output_type });
+
+    options.push_back({ "parallel" });
+    options.push_back({ "*",           "-dt,   --defrag-thold N",       "KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold });
+    options.push_back({ "*",           "-np,   --parallel N",           "number of parallel sequences to decode (default: %d)", params.n_parallel });
+    options.push_back({ "*",           "-ns,   --sequences N",          "number of sequences to decode (default: %d)", params.n_sequences });
+    options.push_back({ "*",           "-cb,   --cont-batching",        "enable continuous batching (a.k.a dynamic batching) (default: %s)", params.cont_batching ? "enabled" : "disabled" });
+    options.push_back({ "*",           "-nocb, --no-cont-batching",     "disable continuous batching" });
+
+    options.push_back({ "multi-modality" });
+    options.push_back({ "*",           "       --mmproj FILE",          "path to a multimodal projector file for LLaVA. see examples/llava/README.md" });
+    options.push_back({ "*",           "       --image FILE",           "path to an image file. use with multimodal models. Specify multiple times for batching" });
+
+    options.push_back({ "backend" });
+    options.push_back({ "*",           "       --rpc SERVERS",          "comma separated list of RPC servers" });
+
+    if (llama_supports_mlock()) {
+        options.push_back({ "*",           "       --mlock",                "force system to keep model in RAM rather than swapping or compressing" });
+    }
+    if (llama_supports_mmap()) {
+        options.push_back({ "*",           "       --no-mmap",              "do not memory-map model (slower load but may reduce pageouts if not using mlock)" });
+    }
+    options.push_back({ "*",           "       --numa TYPE",            "attempt optimizations that help on some NUMA systems\n"
+                                                                        "  - distribute: spread execution evenly over all nodes\n"
+                                                                        "  - isolate: only spawn threads on CPUs on the node that execution started on\n"
+                                                                        "  - numactl: use the CPU map provided by numactl\n"
+                                                                        "if run without this previously, it is recommended to drop the system page cache before using this\n"
+                                                                        "see https://github.com/ggerganov/llama.cpp/issues/1437" });
+
+    if (llama_supports_gpu_offload()) {
+        options.push_back({ "*",           "-ngl,  --gpu-layers N",
+                                                                        "number of layers to store in VRAM" });
+        options.push_back({ "*",           "-ngld, --gpu-layers-draft N",
+                                                                        "number of layers to store in VRAM for the draft model" });
+        options.push_back({ "*",           "-sm,   --split-mode SPLIT_MODE",
+                                                                        "how to split the model across multiple GPUs, one of:\n"
+                                                                        "  - none: use one GPU only\n"
+                                                                        "  - layer (default): split layers and KV across GPUs\n"
+                                                                        "  - row: split rows across GPUs" });
+        options.push_back({ "*",           "-ts,   --tensor-split SPLIT",
+                                                                        "fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1" });
+        options.push_back({ "*",           "-mg,   --main-gpu i",       "the GPU to use for the model (with split-mode = none),\n"
+                                                                        "or for intermediate results and KV (with split-mode = row) (default: %d)", params.main_gpu });
+    }
+
+    options.push_back({ "model" });
+    options.push_back({ "*",           "       --check-tensors",        "check model tensor data for invalid values (default: %s)", params.check_tensors ? "true" : "false" });
+    options.push_back({ "*",           "       --override-kv KEY=TYPE:VALUE",
+                                                                        "advanced option to override model metadata by key. may be specified multiple times.\n"
+                                                                        "types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false" });
+    options.push_back({ "*",           "       --lora FNAME",           "apply LoRA adapter (can be repeated to use multiple adapters)" });
+    options.push_back({ "*",           "       --lora-scaled FNAME S",  "apply LoRA adapter with user defined scaling S (can be repeated to use multiple adapters)" });
+    options.push_back({ "*",           "       --control-vector FNAME", "add a control vector\n"
+                                                                        "note: this argument can be repeated to add multiple control vectors" });
+    options.push_back({ "*",           "       --control-vector-scaled FNAME SCALE",
+                                                                        "add a control vector with user defined scaling SCALE\n"
+                                                                        "note: this argument can be repeated to add multiple scaled control vectors" });
+    options.push_back({ "*",           "       --control-vector-layer-range START END",
+                                                                        "layer range to apply the control vector(s) to, start and end inclusive" });
+    options.push_back({ "*",           "-m,    --model FNAME",          "model path (default: models/$filename with filename from --hf-file\n"
+                                                                        "or --model-url if set, otherwise %s)", DEFAULT_MODEL_PATH });
+    options.push_back({ "*",           "-md,   --model-draft FNAME",    "draft model for speculative decoding (default: unused)" });
+    options.push_back({ "*",           "-mu,   --model-url MODEL_URL",  "model download url (default: unused)" });
+    options.push_back({ "*",           "-hfr,  --hf-repo REPO",         "Hugging Face model repository (default: unused)" });
+    options.push_back({ "*",           "-hff,  --hf-file FILE",         "Hugging Face model file (default: unused)" });
+    options.push_back({ "*",           "-hft,  --hf-token TOKEN",       "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
+
+    options.push_back({ "retrieval" });
+    options.push_back({ "retrieval",   "       --context-file FNAME",   "file to load context from (repeat to specify multiple files)" });
+    options.push_back({ "retrieval",   "       --chunk-size N",         "minimum length of embedded text chunks (default: %d)", params.chunk_size });
+    options.push_back({ "retrieval",   "       --chunk-separator STRING",
+                                                                        "separator between chunks (default: '%s')", params.chunk_separator.c_str() });
+
+    options.push_back({ "passkey" });
+    options.push_back({ "passkey",     "       --junk N",               "number of times to repeat the junk text (default: %d)", params.n_junk });
+    options.push_back({ "passkey",     "       --pos N",                "position of the passkey in the junk text (default: %d)", params.i_pos });
+
+    options.push_back({ "imatrix" });
+    options.push_back({ "imatrix",     "-o,    --output FNAME",         "output file (default: '%s')", params.out_file.c_str() });
+    options.push_back({ "imatrix",     "       --output-frequency N",   "output the imatrix every N iterations (default: %d)", params.n_out_freq });
+    options.push_back({ "imatrix",     "       --save-frequency N",     "save an imatrix copy every N iterations (default: %d)", params.n_save_freq });
+    options.push_back({ "imatrix",     "       --process-output",       "collect data for the output tensor (default: %s)", params.process_output ? "true" : "false" });
+    options.push_back({ "imatrix",     "       --no-ppl",               "do not compute perplexity (default: %s)", params.compute_ppl ? "true" : "false" });
+    options.push_back({ "imatrix",     "       --chunk N",              "start processing the input from chunk N (default: %d)", params.i_chunk });
+
+    options.push_back({ "bench" });
+    options.push_back({ "bench",       "-pps",                          "is the prompt shared across parallel sequences (default: %s)", params.is_pp_shared ? "true" : "false" });
+    options.push_back({ "bench",       "-npp n0,n1,...",                "number of prompt tokens" });
+    options.push_back({ "bench",       "-ntg n0,n1,...",                "number of text generation tokens" });
+    options.push_back({ "bench",       "-npl n0,n1,...",                "number of parallel prompts" });
+
+    options.push_back({ "embedding" });
+    options.push_back({ "embedding",   "       --embd-normalize",       "normalisation for embendings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize });
+    options.push_back({ "embedding",   "       --embd-output-format",   "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix" });
+    options.push_back({ "embedding",   "       --embd-separator",       "separator of embendings (default \\n) for example \"<#sep#>\"" });
+
+    options.push_back({ "server" });
+    options.push_back({ "server",      "       --host HOST",            "ip address to listen (default: %s)", params.hostname.c_str() });
+    options.push_back({ "server",      "       --port PORT",            "port to listen (default: %d)", params.port });
+    options.push_back({ "server",      "       --path PATH",            "path to serve static files from (default: %s)", params.public_path.c_str() });
+    options.push_back({ "server",      "       --embedding(s)",         "restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled" });
+    options.push_back({ "server",      "       --api-key KEY",          "API key to use for authentication (default: none)" });
+    options.push_back({ "server",      "       --api-key-file FNAME",   "path to file containing API keys (default: none)" });
+    options.push_back({ "server",      "       --ssl-key-file FNAME",   "path to file a PEM-encoded SSL private key" });
+    options.push_back({ "server",      "       --ssl-cert-file FNAME",  "path to file a PEM-encoded SSL certificate" });
+    options.push_back({ "server",      "       --timeout N",            "server read/write timeout in seconds (default: %d)", params.timeout_read });
+    options.push_back({ "server",      "       --threads-http N",       "number of threads used to process HTTP requests (default: %d)", params.n_threads_http });
+    options.push_back({ "server",      "       --system-prompt-file FNAME",
+                                                                        "set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications" });
+    options.push_back({ "server",      "       --log-format {text,json}",
+                                                                        "log output format: json or text (default: json)" });
+    options.push_back({ "server",      "       --metrics",              "enable prometheus compatible metrics endpoint (default: %s)", params.endpoint_metrics ? "enabled" : "disabled" });
+    options.push_back({ "server",      "       --no-slots",             "disables slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled" });
+    options.push_back({ "server",      "       --slot-save-path PATH",  "path to save slot kv cache (default: disabled)" });
+    options.push_back({ "server",      "       --chat-template JINJA_TEMPLATE",
+                                                                        "set custom jinja chat template (default: template taken from model's metadata)\n"
+                                                                        "only commonly used templates are accepted:\n"
+                                                                        "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
+    options.push_back({ "server",      "-sps,  --slot-prompt-similarity SIMILARITY",
+                                                                        "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity });
+    options.push_back({ "server",      "       --lora-init-without-apply",     "load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"});
+
+#ifndef LOG_DISABLE_LOGS
+    options.push_back({ "logging" });
+    options.push_back({ "*",           "       --simple-io",            "use basic IO for better compatibility in subprocesses and limited consoles" });
+    options.push_back({ "*",           "-ld,   --logdir LOGDIR",        "path under which to save YAML logs (no logging if unset)" });
+    options.push_back({ "logging",     "       --log-test",             "Run simple logging test" });
+    options.push_back({ "logging",     "       --log-disable",          "Disable trace logs" });
+    options.push_back({ "logging",     "       --log-enable",           "Enable trace logs" });
+    options.push_back({ "logging",     "       --log-file FNAME",       "Specify a log filename (without extension)" });
+    options.push_back({ "logging",     "       --log-new",              "Create a separate new log file on start. "
+                                                                        "Each log file will have unique name: \"<name>.<ID>.log\"" });
+    options.push_back({ "logging",     "       --log-append",           "Don't truncate the old log file." });
+#endif // LOG_DISABLE_LOGS
+
+    options.push_back({ "cvector" });
+    options.push_back({ "cvector",     "-o,    --output FNAME",         "output file (default: '%s')", params.cvector_outfile.c_str() });
+    options.push_back({ "cvector",     "       --positive-file FNAME",  "positive prompts file, one prompt per line (default: '%s')", params.cvector_positive_file.c_str() });
+    options.push_back({ "cvector",     "       --negative-file FNAME",  "negative prompts file, one prompt per line (default: '%s')", params.cvector_negative_file.c_str() });
+    options.push_back({ "cvector",     "       --pca-batch N",          "batch size used for PCA. Larger batch runs faster, but uses more memory (default: %d)", params.n_pca_batch });
+    options.push_back({ "cvector",     "       --pca-iter N",           "number of iterations used for PCA (default: %d)", params.n_pca_iterations });
+    options.push_back({ "cvector",     "       --method {pca,mean}",    "dimensionality reduction method to be used (default: pca)" });
+
+    options.push_back({ "export-lora" });
+    options.push_back({ "export-lora", "-m,    --model",                "model path from which to load base model (default '%s')", params.model.c_str() });
+    options.push_back({ "export-lora", "       --lora FNAME",           "path to LoRA adapter  (can be repeated to use multiple adapters)" });
+    options.push_back({ "export-lora", "       --lora-scaled FNAME S",  "path to LoRA adapter with user defined scaling S  (can be repeated to use multiple adapters)" });
+    options.push_back({ "export-lora", "-o,    --output FNAME",         "output file (default: '%s')", params.lora_outfile.c_str() });
+
+    printf("usage: %s [options]\n", argv[0]);
+
+    for (const auto & o : options) {
+        if (!o.grp.empty()) {
+            printf("\n%s:\n\n", o.grp.c_str());
+            continue;
+        }
+        printf("  %-32s", o.args.c_str());
+        if (o.args.length() > 30) {
+            printf("\n%34s", "");
+        }
+
+        const auto desc = o.desc;
+        size_t start = 0;
+        size_t end = desc.find('\n');
+        while (end != std::string::npos) {
+            printf("%s\n%34s", desc.substr(start, end - start).c_str(), "");
+            start = end + 1;
+            end = desc.find('\n', start);
+        }
+
+        printf("%s\n", desc.substr(start).c_str());
+    }
+    printf("\n");
+}
+
+std::string gpt_params_get_system_info(const gpt_params & params) {
+    std::ostringstream os;
+
+    os << "system_info: n_threads = " << params.cpuparams.n_threads;
+    if (params.cpuparams_batch.n_threads != -1) {
+        os << " (n_threads_batch = " << params.cpuparams_batch.n_threads << ")";
+    }
+#if defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later
+    // TODO: windows + arm64 + mingw64
+    DWORD logicalProcessorCount = GetActiveProcessorCount(ALL_PROCESSOR_GROUPS);
+    os << " / " << logicalProcessorCount << " | " << llama_print_system_info();
+#else
+    os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info();
+#endif
+
+    return os.str();
+}
+
+//
+// String utils
+//
+
+std::vector<std::string> string_split(std::string input, char separator) {
+    std::vector<std::string> parts;
+    size_t separator_pos = input.find(separator);
+    while (separator_pos != std::string::npos) {
+        std::string part = input.substr(0, separator_pos);
+        parts.emplace_back(part);
+        input = input.substr(separator_pos + 1);
+        separator_pos = input.find(separator);
+    }
+    parts.emplace_back(input);
+    return parts;
+}
+
+std::string string_strip(const std::string & str) {
+    size_t start = 0;
+    size_t end = str.size();
+    while (start < end && std::isspace(str[start])) {
+        start++;
+    }
+    while (end > start && std::isspace(str[end - 1])) {
+        end--;
+    }
+    return str.substr(start, end - start);
+}
+
+std::string string_get_sortable_timestamp() {
+    using clock = std::chrono::system_clock;
+
+    const clock::time_point current_time = clock::now();
+    const time_t as_time_t = clock::to_time_t(current_time);
+    char timestamp_no_ns[100];
+    std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t));
+
+    const int64_t ns = std::chrono::duration_cast<std::chrono::nanoseconds>(
+        current_time.time_since_epoch() % 1000000000).count();
+    char timestamp_ns[11];
+    snprintf(timestamp_ns, 11, "%09" PRId64, ns);
+
+    return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns);
+}
+
+void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
+    if (search.empty()) {
+        return;
+    }
+    std::string builder;
+    builder.reserve(s.length());
+    size_t pos = 0;
+    size_t last_pos = 0;
+    while ((pos = s.find(search, last_pos)) != std::string::npos) {
+        builder.append(s, last_pos, pos - last_pos);
+        builder.append(replace);
+        last_pos = pos + search.length();
+    }
+    builder.append(s, last_pos, std::string::npos);
+    s = std::move(builder);
+}
+
+void string_process_escapes(std::string & input) {
+    std::size_t input_len = input.length();
+    std::size_t output_idx = 0;
+
+    for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) {
+        if (input[input_idx] == '\\' && input_idx + 1 < input_len) {
+            switch (input[++input_idx]) {
+                case 'n':  input[output_idx++] = '\n'; break;
+                case 'r':  input[output_idx++] = '\r'; break;
+                case 't':  input[output_idx++] = '\t'; break;
+                case '\'': input[output_idx++] = '\''; break;
+                case '\"': input[output_idx++] = '\"'; break;
+                case '\\': input[output_idx++] = '\\'; break;
+                case 'x':
+                    // Handle \x12, etc
+                    if (input_idx + 2 < input_len) {
+                        const char x[3] = { input[input_idx + 1], input[input_idx + 2], 0 };
+                        char *err_p = nullptr;
+                        const long val = std::strtol(x, &err_p, 16);
+                        if (err_p == x + 2) {
+                            input_idx += 2;
+                            input[output_idx++] = char(val);
+                            break;
+                        }
+                    }
+                    // fall through
+                default:   input[output_idx++] = '\\';
+                           input[output_idx++] = input[input_idx]; break;
+            }
+        } else {
+            input[output_idx++] = input[input_idx];
+        }
+    }
+
+    input.resize(output_idx);
+}
+
+bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
+    const char * sep = strchr(data, '=');
+    if (sep == nullptr || sep - data >= 128) {
+        fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data);
+        return false;
+    }
+    llama_model_kv_override kvo;
+    std::strncpy(kvo.key, data, sep - data);
+    kvo.key[sep - data] = 0;
+    sep++;
+    if (strncmp(sep, "int:", 4) == 0) {
+        sep += 4;
+        kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
+        kvo.val_i64 = std::atol(sep);
+    } else if (strncmp(sep, "float:", 6) == 0) {
+        sep += 6;
+        kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
+        kvo.val_f64 = std::atof(sep);
+    } else if (strncmp(sep, "bool:", 5) == 0) {
+        sep += 5;
+        kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
+        if (std::strcmp(sep, "true") == 0) {
+            kvo.val_bool = true;
+        } else if (std::strcmp(sep, "false") == 0) {
+            kvo.val_bool = false;
+        } else {
+            fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data);
+            return false;
+        }
+    } else if (strncmp(sep, "str:", 4) == 0) {
+        sep += 4;
+        kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
+        if (strlen(sep) > 127) {
+            fprintf(stderr, "%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
+            return false;
+        }
+        strncpy(kvo.val_str, sep, 127);
+        kvo.val_str[127] = '\0';
+    } else {
+        fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data);
+        return false;
+    }
+    overrides.emplace_back(std::move(kvo));
+    return true;
+}
+
+//
+// Filesystem utils
+//
+
+// Validate if a filename is safe to use
+// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
+bool fs_validate_filename(const std::string & filename) {
+    if (!filename.length()) {
+        // Empty filename invalid
+        return false;
+    }
+    if (filename.length() > 255) {
+        // Limit at common largest possible filename on Linux filesystems
+        // to avoid unnecessary further validation
+        // (On systems with smaller limits it will be caught by the OS)
+        return false;
+    }
+
+    std::u32string filename_utf32;
+    try {
+        std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
+        filename_utf32 = converter.from_bytes(filename);
+
+        // If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
+        // or invalid encodings were encountered. Reject such attempts
+        std::string filename_reencoded = converter.to_bytes(filename_utf32);
+        if (filename_reencoded != filename) {
+            return false;
+        }
+    } catch (const std::exception &) {
+        return false;
+    }
+
+    // Check for forbidden codepoints:
+    // - Control characters
+    // - Unicode equivalents of illegal characters
+    // - UTF-16 surrogate pairs
+    // - UTF-8 replacement character
+    // - Byte order mark (BOM)
+    // - Illegal characters: / \ : * ? " < > |
+    for (char32_t c : filename_utf32) {
+        if (c <= 0x1F // Control characters (C0)
+            || c == 0x7F // Control characters (DEL)
+            || (c >= 0x80 && c <= 0x9F) // Control characters (C1)
+            || c == 0xFF0E // Fullwidth Full Stop (period equivalent)
+            || c == 0x2215 // Division Slash (forward slash equivalent)
+            || c == 0x2216 // Set Minus (backslash equivalent)
+            || (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
+            || c == 0xFFFD // Replacement Character (UTF-8)
+            || c == 0xFEFF // Byte Order Mark (BOM)
+            || c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
+            || c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
+            return false;
+        }
+    }
+
+    // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
+    // Unicode and other whitespace is not affected, only 0x20 space
+    if (filename.front() == ' ' || filename.back() == ' ' || filename.back() == '.') {
+        return false;
+    }
+
+    // Reject any ".." (currently stricter than necessary, it should be fine to just check for == ".." instead)
+    if (filename.find("..") != std::string::npos) {
+        return false;
+    }
+
+    // Reject "."
+    if (filename == ".") {
+        return false;
+    }
+
+    return true;
+}
+
+// returns true if successful, false otherwise
+bool fs_create_directory_with_parents(const std::string & path) {
+#ifdef _WIN32
+    std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
+    std::wstring wpath = converter.from_bytes(path);
+
+    // if the path already exists, check whether it's a directory
+    const DWORD attributes = GetFileAttributesW(wpath.c_str());
+    if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) {
+        return true;
+    }
+
+    size_t pos_slash = 0;
+
+    // process path from front to back, procedurally creating directories
+    while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
+        const std::wstring subpath = wpath.substr(0, pos_slash);
+        const wchar_t * test = subpath.c_str();
+
+        const bool success = CreateDirectoryW(test, NULL);
+        if (!success) {
+            const DWORD error = GetLastError();
+
+            // if the path already exists, ensure that it's a directory
+            if (error == ERROR_ALREADY_EXISTS) {
+                const DWORD attributes = GetFileAttributesW(subpath.c_str());
+                if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) {
+                    return false;
+                }
+            } else {
+                return false;
+            }
+        }
+
+        pos_slash += 1;
+    }
+
+    return true;
+#else
+    // if the path already exists, check whether it's a directory
+    struct stat info;
+    if (stat(path.c_str(), &info) == 0) {
+        return S_ISDIR(info.st_mode);
+    }
+
+    size_t pos_slash = 1; // skip leading slashes for directory creation
+
+    // process path from front to back, procedurally creating directories
+    while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) {
+        const std::string subpath = path.substr(0, pos_slash);
+        struct stat info;
+
+        // if the path already exists, ensure that it's a directory
+        if (stat(subpath.c_str(), &info) == 0) {
+            if (!S_ISDIR(info.st_mode)) {
+                return false;
+            }
+        } else {
+            // create parent directories
+            const int ret = mkdir(subpath.c_str(), 0755);
+            if (ret != 0) {
+                return false;
+            }
+        }
+
+        pos_slash += 1;
+    }
+
+    return true;
+#endif // _WIN32
+}
+
+std::string fs_get_cache_directory() {
+    std::string cache_directory = "";
+    auto ensure_trailing_slash = [](std::string p) {
+        // Make sure to add trailing slash
+        if (p.back() != DIRECTORY_SEPARATOR) {
+            p += DIRECTORY_SEPARATOR;
+        }
+        return p;
+    };
+    if (getenv("LLAMA_CACHE")) {
+        cache_directory = std::getenv("LLAMA_CACHE");
+    } else {
+#ifdef __linux__
+        if (std::getenv("XDG_CACHE_HOME")) {
+            cache_directory = std::getenv("XDG_CACHE_HOME");
+        } else {
+            cache_directory = std::getenv("HOME") + std::string("/.cache/");
+        }
+#elif defined(__APPLE__)
+        cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
+#elif defined(_WIN32)
+        cache_directory = std::getenv("LOCALAPPDATA");
+#endif // __linux__
+        cache_directory = ensure_trailing_slash(cache_directory);
+        cache_directory += "llama.cpp";
+    }
+    return ensure_trailing_slash(cache_directory);
+}
+
+std::string fs_get_cache_file(const std::string & filename) {
+    GGML_ASSERT(filename.find(DIRECTORY_SEPARATOR) == std::string::npos);
+    std::string cache_directory = fs_get_cache_directory();
+    const bool success = fs_create_directory_with_parents(cache_directory);
+    if (!success) {
+        throw std::runtime_error("failed to create cache directory: " + cache_directory);
+    }
+    return cache_directory + filename;
+}
+
+
+//
+// Model utils
+//
+struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
+    llama_init_result iparams;
+    auto mparams = llama_model_params_from_gpt_params(params);
+
+    llama_model * model = nullptr;
+
+    if (!params.hf_repo.empty() && !params.hf_file.empty()) {
+        model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
+    } else if (!params.model_url.empty()) {
+        model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
+    } else {
+        model = llama_load_model_from_file(params.model.c_str(), mparams);
+    }
+
+    if (model == NULL) {
+        fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
+        return iparams;
+    }
+
+    auto cparams = llama_context_params_from_gpt_params(params);
+
+    llama_context * lctx = llama_new_context_with_model(model, cparams);
+    if (lctx == NULL) {
+        fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
+        llama_free_model(model);
+        return iparams;
+    }
+
+    if (!params.control_vectors.empty()) {
+        if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
+        if (params.control_vector_layer_end   <= 0) params.control_vector_layer_end   = llama_n_layer(model);
+
+        const auto cvec = llama_control_vector_load(params.control_vectors);
+        if (cvec.n_embd == -1) {
+            llama_free(lctx);
+            llama_free_model(model);
+            return iparams;
+        }
+
+        int err = llama_control_vector_apply(lctx,
+                                             cvec.data.data(),
+                                             cvec.data.size(),
+                                             cvec.n_embd,
+                                             params.control_vector_layer_start,
+                                             params.control_vector_layer_end);
+        if (err) {
+            llama_free(lctx);
+            llama_free_model(model);
+            return iparams;
+        }
+    }
+
+    // load and optionally apply lora adapters
+    for (auto & la : params.lora_adapters) {
+        llama_lora_adapter_container loaded_la;
+        loaded_la.path = la.path;
+        loaded_la.scale = la.scale;
+        loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
+        if (loaded_la.adapter == nullptr) {
+            fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
+            llama_free(lctx);
+            llama_free_model(model);
+            return iparams;
+        }
+        iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters
+    }
+    if (!params.lora_init_without_apply) {
+        llama_lora_adapters_apply(lctx, iparams.lora_adapters);
+    }
+
+    if (params.ignore_eos) {
+        params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
+    }
+
+    if (params.warmup) {
+        LOG("warming up the model with an empty run\n");
+
+        std::vector<llama_token> tmp;
+        llama_token bos = llama_token_bos(model);
+        llama_token eos = llama_token_eos(model);
+        // some models (e.g. T5) don't have a BOS token
+        if (bos != -1) {
+            tmp.push_back(bos);
+        }
+        tmp.push_back(eos);
+
+        if (llama_model_has_encoder(model)) {
+            llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
+            llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
+            if (decoder_start_token_id == -1) {
+                decoder_start_token_id = bos;
+            }
+            tmp.clear();
+            tmp.push_back(decoder_start_token_id);
+        }
+        if (llama_model_has_decoder(model)) {
+            llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
+        }
+        llama_kv_cache_clear(lctx);
+        llama_synchronize(lctx);
+        llama_reset_timings(lctx);
+    }
+
+    iparams.model   = model;
+    iparams.context = lctx;
+    return iparams;
+}
+
+void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lora_adapter_container> & lora_adapters) {
+    llama_lora_adapter_clear(ctx);
+    for (auto & la : lora_adapters) {
+        if (la.scale != 0.0f) {
+            llama_lora_adapter_set(ctx, la.adapter, la.scale);
+        }
+    }
+}
+
+struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
+    auto mparams = llama_model_default_params();
+
+    if (params.n_gpu_layers != -1) {
+        mparams.n_gpu_layers = params.n_gpu_layers;
+    }
+    mparams.rpc_servers     = params.rpc_servers.c_str();
+    mparams.main_gpu        = params.main_gpu;
+    mparams.split_mode      = params.split_mode;
+    mparams.tensor_split    = params.tensor_split;
+    mparams.use_mmap        = params.use_mmap;
+    mparams.use_mlock       = params.use_mlock;
+    mparams.check_tensors   = params.check_tensors;
+    if (params.kv_overrides.empty()) {
+        mparams.kv_overrides = NULL;
+    } else {
+        GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
+        mparams.kv_overrides = params.kv_overrides.data();
+    }
+
+    return mparams;
+}
+
+static ggml_type kv_cache_type_from_str(const std::string & s) {
+    if (s == "f32") {
+        return GGML_TYPE_F32;
+    }
+    if (s == "f16") {
+        return GGML_TYPE_F16;
+    }
+    if (s == "q8_0") {
+        return GGML_TYPE_Q8_0;
+    }
+    if (s == "q4_0") {
+        return GGML_TYPE_Q4_0;
+    }
+    if (s == "q4_1") {
+        return GGML_TYPE_Q4_1;
+    }
+    if (s == "iq4_nl") {
+        return GGML_TYPE_IQ4_NL;
+    }
+    if (s == "q5_0") {
+        return GGML_TYPE_Q5_0;
+    }
+    if (s == "q5_1") {
+        return GGML_TYPE_Q5_1;
+    }
+
+    throw std::runtime_error("Invalid cache type: " + s);
+}
+
+struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
+    auto cparams = llama_context_default_params();
+
+    cparams.n_ctx             = params.n_ctx;
+    cparams.n_seq_max         = params.n_parallel;
+    cparams.n_batch           = params.n_batch;
+    cparams.n_ubatch          = params.n_ubatch;
+    cparams.n_threads         = params.cpuparams.n_threads;
+    cparams.n_threads_batch   = params.cpuparams_batch.n_threads == -1 ?
+                                    params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
+    cparams.seed              = params.seed;
+    cparams.logits_all        = params.logits_all;
+    cparams.embeddings        = params.embedding;
+    cparams.rope_scaling_type = params.rope_scaling_type;
+    cparams.rope_freq_base    = params.rope_freq_base;
+    cparams.rope_freq_scale   = params.rope_freq_scale;
+    cparams.yarn_ext_factor   = params.yarn_ext_factor;
+    cparams.yarn_attn_factor  = params.yarn_attn_factor;
+    cparams.yarn_beta_fast    = params.yarn_beta_fast;
+    cparams.yarn_beta_slow    = params.yarn_beta_slow;
+    cparams.yarn_orig_ctx     = params.yarn_orig_ctx;
+    cparams.pooling_type      = params.pooling_type;
+    cparams.attention_type    = params.attention_type;
+    cparams.defrag_thold      = params.defrag_thold;
+    cparams.cb_eval           = params.cb_eval;
+    cparams.cb_eval_user_data = params.cb_eval_user_data;
+    cparams.offload_kqv       = !params.no_kv_offload;
+    cparams.flash_attn        = params.flash_attn;
+
+    cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
+    cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
+
+    return cparams;
+}
+
+struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params) {
+    struct ggml_threadpool_params tpp;
+
+    ggml_threadpool_params_init(&tpp, params.n_threads); // setup the defaults
+
+    if (params.mask_valid) {
+        std::memcpy(&tpp.cpumask, &params.cpumask, GGML_MAX_N_THREADS);
+    }
+
+    tpp.prio       = params.priority;
+    tpp.poll       = params.poll;
+    tpp.strict_cpu = params.strict_cpu;
+
+    return tpp;
+}
+
+#ifdef LLAMA_USE_CURL
+
+static bool starts_with(const std::string & str, const std::string & prefix) {
+    // While we wait for C++20's std::string::starts_with...
+    return str.rfind(prefix, 0) == 0;
+}
+
+static bool llama_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
+
+    // Initialize libcurl
+    std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
+    if (!curl) {
+        fprintf(stderr, "%s: error initializing libcurl\n", __func__);
+        return false;
+    }
+
+    bool force_download = false;
+
+    // Set the URL, allow to follow http redirection
+    curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
+    curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
+
+    // Check if hf-token or bearer-token was specified
+    if (!hf_token.empty()) {
+      std::string auth_header = "Authorization: Bearer ";
+      auth_header += hf_token.c_str();
+      struct curl_slist *http_headers = NULL;
+      http_headers = curl_slist_append(http_headers, auth_header.c_str());
+      curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers);
+    }
+
+#if defined(_WIN32)
+    // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
+    //   operating system. Currently implemented under MS-Windows.
+    curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
+#endif
+
+    // Check if the file already exists locally
+    struct stat model_file_info;
+    auto file_exists = (stat(path.c_str(), &model_file_info) == 0);
+
+    // If the file exists, check its JSON metadata companion file.
+    std::string metadata_path = path + ".json";
+    nlohmann::json metadata;
+    std::string etag;
+    std::string last_modified;
+
+    if (file_exists) {
+        // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
+        std::ifstream metadata_in(metadata_path);
+        if (metadata_in.good()) {
+            try {
+                metadata_in >> metadata;
+                fprintf(stderr, "%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
+                if (metadata.contains("url") && metadata.at("url").is_string()) {
+                    auto previous_url = metadata.at("url").get<std::string>();
+                    if (previous_url != url) {
+                        fprintf(stderr, "%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str());
+                        return false;
+                    }
+                }
+                if (metadata.contains("etag") && metadata.at("etag").is_string()) {
+                    etag = metadata.at("etag");
+                }
+                if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
+                    last_modified = metadata.at("lastModified");
+                }
+            } catch (const nlohmann::json::exception & e) {
+                fprintf(stderr, "%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
+                return false;
+            }
+        }
+    } else {
+        fprintf(stderr, "%s: no previous model file found %s\n", __func__, path.c_str());
+    }
+
+    // Send a HEAD request to retrieve the etag and last-modified headers
+    struct llama_load_model_from_url_headers {
+        std::string etag;
+        std::string last_modified;
+    };
+    llama_load_model_from_url_headers headers;
+    {
+        typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
+        auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
+            llama_load_model_from_url_headers *headers = (llama_load_model_from_url_headers *) userdata;
+
+            static std::regex header_regex("([^:]+): (.*)\r\n");
+            static std::regex etag_regex("ETag", std::regex_constants::icase);
+            static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
+
+            std::string header(buffer, n_items);
+            std::smatch match;
+            if (std::regex_match(header, match, header_regex)) {
+                const std::string & key = match[1];
+                const std::string & value = match[2];
+                if (std::regex_match(key, match, etag_regex)) {
+                    headers->etag = value;
+                } else if (std::regex_match(key, match, last_modified_regex)) {
+                    headers->last_modified = value;
+                }
+            }
+            return n_items;
+        };
+
+        curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
+        curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
+        curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
+        curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
+
+        CURLcode res = curl_easy_perform(curl.get());
+        if (res != CURLE_OK) {
+            fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
+            return false;
+        }
+
+        long http_code = 0;
+        curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
+        if (http_code != 200) {
+            // HEAD not supported, we don't know if the file has changed
+            // force trigger downloading
+            force_download = true;
+            fprintf(stderr, "%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
+        }
+    }
+
+    bool should_download = !file_exists || force_download;
+    if (!should_download) {
+        if (!etag.empty() && etag != headers.etag) {
+            fprintf(stderr, "%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
+            should_download = true;
+        } else if (!last_modified.empty() && last_modified != headers.last_modified) {
+            fprintf(stderr, "%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
+            should_download = true;
+        }
+    }
+    if (should_download) {
+        std::string path_temporary = path + ".downloadInProgress";
+        if (file_exists) {
+            fprintf(stderr, "%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
+            if (remove(path.c_str()) != 0) {
+                fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path.c_str());
+                return false;
+            }
+        }
+
+        // Set the output file
+
+        struct FILE_deleter {
+            void operator()(FILE * f) const {
+                fclose(f);
+            }
+        };
+
+        std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
+        if (!outfile) {
+            fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path.c_str());
+            return false;
+        }
+
+        typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
+        auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
+            return fwrite(data, size, nmemb, (FILE *)fd);
+        };
+        curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
+        curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
+        curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
+
+        //  display download progress
+        curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
+
+        // helper function to hide password in URL
+        auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
+            std::size_t protocol_pos = url.find("://");
+            if (protocol_pos == std::string::npos) {
+                return url;  // Malformed URL
+            }
+
+            std::size_t at_pos = url.find('@', protocol_pos + 3);
+            if (at_pos == std::string::npos) {
+                return url;  // No password in URL
+            }
+
+            return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
+        };
+
+        // start the download
+        fprintf(stderr, "%s: downloading from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
+                llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
+        auto res = curl_easy_perform(curl.get());
+        if (res != CURLE_OK) {
+            fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
+            return false;
+        }
+
+        long http_code = 0;
+        curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
+        if (http_code < 200 || http_code >= 400) {
+            fprintf(stderr, "%s: invalid http status code received: %ld\n", __func__, http_code);
+            return false;
+        }
+
+        // Causes file to be closed explicitly here before we rename it.
+        outfile.reset();
+
+        // Write the updated JSON metadata file.
+        metadata.update({
+            {"url", url},
+            {"etag", headers.etag},
+            {"lastModified", headers.last_modified}
+        });
+        std::ofstream(metadata_path) << metadata.dump(4);
+        fprintf(stderr, "%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
+
+        if (rename(path_temporary.c_str(), path.c_str()) != 0) {
+            fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
+            return false;
+        }
+    }
+
+    return true;
+}
+
+struct llama_model * llama_load_model_from_url(
+        const char * model_url,
+        const char * path_model,
+        const char * hf_token,
+        const struct llama_model_params & params) {
+    // Basic validation of the model_url
+    if (!model_url || strlen(model_url) == 0) {
+        fprintf(stderr, "%s: invalid model_url\n", __func__);
+        return NULL;
+    }
+
+    if (!llama_download_file(model_url, path_model, hf_token)) {
+        return NULL;
+    }
+
+    // check for additional GGUFs split to download
+    int n_split = 0;
+    {
+        struct gguf_init_params gguf_params = {
+            /*.no_alloc = */ true,
+            /*.ctx      = */ NULL,
+        };
+        auto * ctx_gguf = gguf_init_from_file(path_model, gguf_params);
+        if (!ctx_gguf) {
+            fprintf(stderr, "\n%s:  failed to load input GGUF from %s\n", __func__, path_model);
+            return NULL;
+        }
+
+        auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
+        if (key_n_split >= 0) {
+            n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
+        }
+
+        gguf_free(ctx_gguf);
+    }
+
+    if (n_split > 1) {
+        char split_prefix[PATH_MAX] = {0};
+        char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0};
+
+        // Verify the first split file format
+        // and extract split URL and PATH prefixes
+        {
+            if (!llama_split_prefix(split_prefix, sizeof(split_prefix), path_model, 0, n_split)) {
+                fprintf(stderr, "\n%s: unexpected model file name: %s"
+                                " n_split=%d\n", __func__, path_model, n_split);
+                return NULL;
+            }
+
+            if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url, 0, n_split)) {
+                fprintf(stderr, "\n%s: unexpected model url: %s"
+                                " n_split=%d\n", __func__, model_url, n_split);
+                return NULL;
+            }
+        }
+
+        // Prepare download in parallel
+        std::vector<std::future<bool>> futures_download;
+        for (int idx = 1; idx < n_split; idx++) {
+            futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool {
+                char split_path[PATH_MAX] = {0};
+                llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
+
+                char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
+                llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
+
+                return llama_download_file(split_url, split_path, hf_token);
+            }, idx));
+        }
+
+        // Wait for all downloads to complete
+        for (auto & f : futures_download) {
+            if (!f.get()) {
+                return NULL;
+            }
+        }
+    }
+
+    return llama_load_model_from_file(path_model, params);
+}
+
+struct llama_model * llama_load_model_from_hf(
+        const char * repo,
+        const char * model,
+        const char * path_model,
+        const char * hf_token,
+        const struct llama_model_params & params) {
+    // construct hugging face model url:
+    //
+    //  --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
+    //    https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
+    //
+    //  --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
+    //    https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
+    //
+
+    std::string model_url = "https://huggingface.co/";
+    model_url += repo;
+    model_url += "/resolve/main/";
+    model_url += model;
+
+    return llama_load_model_from_url(model_url.c_str(), path_model, hf_token, params);
+}
+
+#else
+
+struct llama_model * llama_load_model_from_url(
+        const char * /*model_url*/,
+        const char * /*path_model*/,
+        const char * /*hf_token*/,
+        const struct llama_model_params & /*params*/) {
+    fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
+    return nullptr;
+}
+
+struct llama_model * llama_load_model_from_hf(
+        const char * /*repo*/,
+        const char * /*model*/,
+        const char * /*path_model*/,
+        const char * /*hf_token*/,
+        const struct llama_model_params & /*params*/) {
+    fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
+    return nullptr;
+}
+
+#endif // LLAMA_USE_CURL
+
+//
+// Batch utils
+//
+
+void llama_batch_clear(struct llama_batch & batch) {
+    batch.n_tokens = 0;
+}
+
+void llama_batch_add(
+                 struct llama_batch & batch,
+                        llama_token   id,
+                          llama_pos   pos,
+    const std::vector<llama_seq_id> & seq_ids,
+                               bool   logits) {
+    batch.token   [batch.n_tokens] = id;
+    batch.pos     [batch.n_tokens] = pos;
+    batch.n_seq_id[batch.n_tokens] = seq_ids.size();
+    for (size_t i = 0; i < seq_ids.size(); ++i) {
+        batch.seq_id[batch.n_tokens][i] = seq_ids[i];
+    }
+    batch.logits  [batch.n_tokens] = logits;
+
+    batch.n_tokens++;
+}
+
+//
+// Vocab utils
+//
+
+std::vector<llama_token> llama_tokenize(
+  const struct llama_context * ctx,
+           const std::string & text,
+                        bool   add_special,
+                        bool   parse_special) {
+    return llama_tokenize(llama_get_model(ctx), text, add_special, parse_special);
+}
+
+std::vector<llama_token> llama_tokenize(
+    const struct llama_model * model,
+           const std::string & text,
+                        bool   add_special,
+                        bool   parse_special) {
+    // upper limit for the number of tokens
+    int n_tokens = text.length() + 2 * add_special;
+    std::vector<llama_token> result(n_tokens);
+    n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
+    if (n_tokens < 0) {
+        result.resize(-n_tokens);
+        int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
+        GGML_ASSERT(check == -n_tokens);
+    } else {
+        result.resize(n_tokens);
+    }
+    return result;
+}
+
+std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
+    std::string piece;
+    piece.resize(piece.capacity());  // using string internal cache, 15 bytes + '\n'
+    const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
+    if (n_chars < 0) {
+        piece.resize(-n_chars);
+        int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
+        GGML_ASSERT(check == -n_chars);
+    }
+    else {
+        piece.resize(n_chars);
+    }
+
+    return piece;
+}
+
+std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
+    std::string text;
+    text.resize(std::max(text.capacity(), tokens.size()));
+    int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+    if (n_chars < 0) {
+        text.resize(-n_chars);
+        n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+        GGML_ASSERT(n_chars <= (int32_t)text.size());  // whitespace trimming is performed after per-token detokenization
+    }
+
+    text.resize(n_chars);
+
+    // NOTE: the original tokenizer decodes bytes after collecting the pieces.
+    return text;
+}
+
+//
+// Chat template utils
+//
+
+bool llama_chat_verify_template(const std::string & tmpl) {
+    llama_chat_message chat[] = {{"user", "test"}};
+    int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
+    return res >= 0;
+}
+
+std::string llama_chat_apply_template(const struct llama_model * model,
+        const std::string & tmpl,
+        const std::vector<llama_chat_msg> & msgs,
+        bool add_ass) {
+    int alloc_size = 0;
+    bool fallback = false; // indicate if we must fallback to default chatml
+    std::vector<llama_chat_message> chat;
+    for (auto & msg : msgs) {
+        chat.push_back({msg.role.c_str(), msg.content.c_str()});
+        alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
+    }
+
+    const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
+    std::vector<char> buf(alloc_size);
+
+    // run the first time to get the total output length
+    int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
+
+    // error: chat template is not supported
+    if (res < 0) {
+        if (ptr_tmpl != nullptr) {
+            // if the custom "tmpl" is not supported, we throw an error
+            // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
+            throw std::runtime_error("this custom template is not supported");
+        } else {
+            // If the built-in template is not supported, we default to chatml
+            res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
+            fallback = true;
+        }
+    }
+
+    // if it turns out that our buffer is too small, we resize it
+    if ((size_t) res > buf.size()) {
+        buf.resize(res);
+        res = llama_chat_apply_template(
+            fallback ? nullptr : model,
+            fallback ? "chatml" : ptr_tmpl,
+            chat.data(), chat.size(), add_ass, buf.data(), buf.size());
+    }
+
+    std::string formatted_chat(buf.data(), res);
+    return formatted_chat;
+}
+
+std::string llama_chat_format_single(const struct llama_model * model,
+        const std::string & tmpl,
+        const std::vector<llama_chat_msg> & past_msg,
+        const llama_chat_msg & new_msg,
+        bool add_ass) {
+    std::ostringstream ss;
+    auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false);
+    std::vector<llama_chat_msg> chat_new(past_msg);
+    // if the past_msg ends with a newline, we must preserve it in the formatted version
+    if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
+        ss << "\n";
+    };
+    // format chat with new_msg
+    chat_new.push_back(new_msg);
+    auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
+    // get the diff part
+    ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
+    return ss.str();
+}
+
+std::string llama_chat_format_example(const struct llama_model * model,
+        const std::string & tmpl) {
+    std::vector<llama_chat_msg> msgs = {
+        {"system",    "You are a helpful assistant"},
+        {"user",      "Hello"},
+        {"assistant", "Hi there"},
+        {"user",      "How are you?"},
+    };
+    return llama_chat_apply_template(model, tmpl, msgs, true);
+}
+
+//
+// KV cache utils
+//
+
+void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) {
+    static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
+
+    printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
+        view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
+
+    llama_kv_cache_view_cell * c_curr = view.cells;
+    llama_seq_id * cs_curr = view.cells_sequences;
+
+    for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
+        if (i % row_size == 0) {
+            printf("\n%5d: ", i);
+        }
+        int seq_count = 0;
+        for (int j = 0; j < view.n_seq_max; j++) {
+            if (cs_curr[j] >= 0) { seq_count++; }
+        }
+        putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
+    }
+
+    printf("\n=== Done dumping\n");
+}
+
+void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) {
+    static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
+
+    printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
+        view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
+
+    std::unordered_map<llama_seq_id, size_t> seqs;
+    llama_kv_cache_view_cell * c_curr = view.cells;
+    llama_seq_id * cs_curr = view.cells_sequences;
+
+    for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
+        for (int j = 0; j < view.n_seq_max; j++) {
+            if (cs_curr[j] < 0) { continue; }
+            if (seqs.find(cs_curr[j]) == seqs.end()) {
+                if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
+                const size_t sz = seqs.size();
+                seqs[cs_curr[j]] = sz;
+            }
+        }
+        if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
+    }
+
+    printf("=== Sequence legend: ");
+    for (const auto & it : seqs) {
+        printf("%zu=%d, ", it.second, it.first);
+    }
+    printf("'+'=other sequence ids");
+
+    c_curr = view.cells;
+    cs_curr = view.cells_sequences;
+    for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
+        if (i % row_size == 0) {
+            printf("\n%5d: ", i);
+        }
+        for (int j = 0; j < view.n_seq_max; j++) {
+            if (cs_curr[j] >= 0) {
+                const auto & it = seqs.find(cs_curr[j]);
+                putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
+            } else {
+                putchar('.');
+            }
+        }
+        putchar(' ');
+    }
+
+    printf("\n=== Done dumping\n");
+}
+
+//
+// Embedding utils
+//
+
+void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm) {
+    double sum = 0.0;
+
+    switch (embd_norm) {
+        case -1: // no normalisation
+            sum = 1.0;
+            break;
+        case 0: // max absolute
+            for (int i = 0; i < n; i++) {
+                if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
+            }
+            sum /= 32760.0; // make an int16 range
+            break;
+        case 2: // euclidean
+            for (int i = 0; i < n; i++) {
+                sum += inp[i] * inp[i];
+            }
+            sum = std::sqrt(sum);
+            break;
+        default: // p-norm (euclidean is p-norm p=2)
+            for (int i = 0; i < n; i++) {
+                sum += std::pow(std::abs(inp[i]), embd_norm);
+            }
+            sum = std::pow(sum, 1.0 / embd_norm);
+            break;
+    }
+
+    const float norm = sum > 0.0 ? 1.0 / sum : 0.0f;
+
+    for (int i = 0; i < n; i++) {
+        out[i] = inp[i] * norm;
+    }
+}
+
+float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n){
+    double sum  = 0.0;
+    double sum1 = 0.0;
+    double sum2 = 0.0;
+
+    for (int i = 0; i < n; i++) {
+        sum  += embd1[i] * embd2[i];
+        sum1 += embd1[i] * embd1[i];
+        sum2 += embd2[i] * embd2[i];
+    }
+
+    // Handle the case where one or both vectors are zero vectors
+    if (sum1 == 0.0 || sum2 == 0.0) {
+        if (sum1 == 0.0 && sum2 == 0.0) {
+            return 1.0f; // two zero vectors are similar
+        }
+        return 0.0f;
+    }
+
+    return sum / (sqrt(sum1) * sqrt(sum2));
+}
+
+//
+// Control vector utils
+//
+
+static llama_control_vector_data llama_control_vector_load_one(const llama_control_vector_load_info & load_info) {
+    llama_control_vector_data result = { -1, {} };
+
+    ggml_context * ctx = nullptr;
+    struct gguf_init_params meta_gguf_params = {
+        /* .no_alloc = */ false,
+        /* .ctx      = */ &ctx,
+    };
+    struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
+    if (!ctx_gguf) {
+        fprintf(stderr, "%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str());
+        return result;
+    }
+
+    int32_t n_tensors = gguf_get_n_tensors(ctx_gguf);
+    if (n_tensors == 0) {
+        fprintf(stderr, "%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str());
+    }
+
+    for (int i = 0; i < n_tensors; i++) {
+        std::string name = gguf_get_tensor_name(ctx_gguf, i);
+
+        int layer_idx = -1;
+
+        // split on '.'
+        size_t dotpos = name.find('.');
+        if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") {
+            try {
+                layer_idx = std::stoi(name.substr(dotpos + 1));
+            } catch (...) {
+                layer_idx = -1;
+            }
+        }
+        if (layer_idx < 0) {
+            fprintf(stderr, "%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
+            result.n_embd = -1;
+            break;
+        } else if (layer_idx == 0) {
+            fprintf(stderr, "%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
+            result.n_embd = -1;
+            break;
+        }
+
+        struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
+        if (tensor->type != GGML_TYPE_F32) {
+            fprintf(stderr, "%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str());
+            result.n_embd = -1;
+            break;
+        }
+        if (ggml_n_dims(tensor) != 1) {
+            fprintf(stderr, "%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str());
+            result.n_embd = -1;
+            break;
+        }
+
+        if (result.n_embd == -1) {
+            result.n_embd = ggml_nelements(tensor);
+        } else if (ggml_nelements(tensor) != result.n_embd) {
+            fprintf(stderr, "%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str());
+            result.n_embd = -1;
+            break;
+        }
+
+        // extend if necessary - do not store data for layer 0 (it's not used)
+        result.data.resize(std::max(result.data.size(), static_cast<size_t>(result.n_embd * layer_idx)), 0.0f);
+
+        const float * src = (const float *) tensor->data;
+        float * dst = result.data.data() + result.n_embd * (layer_idx - 1);  // layer 1 at [0]
+        for (int j = 0; j < result.n_embd; j++) {
+            dst[j] += src[j] * load_info.strength;  // allows multiple directions for same layer in same file
+        }
+
+    }
+
+    if (result.n_embd == -1) {
+        fprintf(stderr, "%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str());
+        result.data.clear();
+    }
+
+    gguf_free(ctx_gguf);
+    ggml_free(ctx);
+
+    return result;
+}
+
+llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos) {
+    llama_control_vector_data result = { -1, {} };
+
+    for (const auto & info : load_infos) {
+        auto cur = llama_control_vector_load_one(info);
+
+        if (cur.n_embd == -1) {
+            result.n_embd = -1;
+            break;
+        }
+        if (result.n_embd != -1 && result.n_embd != cur.n_embd) {
+            fprintf(stderr, "%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str());
+            result.n_embd = -1;
+            break;
+        }
+
+        if (result.n_embd == -1) {
+            result = std::move(cur);
+        } else {
+            result.data.resize(std::max(result.data.size(), cur.data.size()), 0.0f);  // extend if necessary
+            for (size_t i = 0; i < cur.data.size(); i++) {
+                result.data[i] += cur.data[i];
+            }
+        }
+    }
+
+    if (result.n_embd == -1) {
+        fprintf(stderr, "%s: no valid control vector files passed\n", __func__);
+        result.data.clear();
+    }
+
+    return result;
+}
+
+//
+// YAML utils
+//
+
+void yaml_dump_vector_float(FILE * stream, const char * prop_name, const std::vector<float> & data) {
+    if (data.empty()) {
+        fprintf(stream, "%s:\n", prop_name);
+        return;
+    }
+
+    fprintf(stream, "%s: [", prop_name);
+    for (size_t i = 0; i < data.size() - 1; ++i) {
+        fprintf(stream, "%e, ", data[i]);
+    }
+    fprintf(stream, "%e]\n", data.back());
+}
+
+void yaml_dump_vector_int(FILE * stream, const char * prop_name, const std::vector<int> & data) {
+    if (data.empty()) {
+        fprintf(stream, "%s:\n", prop_name);
+        return;
+    }
+
+    fprintf(stream, "%s: [", prop_name);
+    for (size_t i = 0; i < data.size() - 1; ++i) {
+        fprintf(stream, "%d, ", data[i]);
+    }
+    fprintf(stream, "%d]\n", data.back());
+}
+
+void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data) {
+    std::string data_str(data == NULL ? "" : data);
+
+    if (data_str.empty()) {
+        fprintf(stream, "%s:\n", prop_name);
+        return;
+    }
+
+    size_t pos_start = 0;
+    size_t pos_found = 0;
+
+    if (std::isspace(data_str[0]) || std::isspace(data_str.back())) {
+        data_str = std::regex_replace(data_str, std::regex("\n"), "\\n");
+        data_str = std::regex_replace(data_str, std::regex("\""), "\\\"");
+        data_str = std::regex_replace(data_str, std::regex(R"(\\[^n"])"), R"(\$&)");
+        data_str = "\"" + data_str + "\"";
+        fprintf(stream, "%s: %s\n", prop_name, data_str.c_str());
+        return;
+    }
+
+    if (data_str.find('\n') == std::string::npos) {
+        fprintf(stream, "%s: %s\n", prop_name, data_str.c_str());
+        return;
+    }
+
+    fprintf(stream, "%s: |\n", prop_name);
+    while ((pos_found = data_str.find('\n', pos_start)) != std::string::npos) {
+        fprintf(stream, "  %s\n", data_str.substr(pos_start, pos_found-pos_start).c_str());
+        pos_start = pos_found + 1;
+    }
+}
+
+void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx,
+                               const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
+    const llama_sampling_params & sparams = params.sparams;
+
+    fprintf(stream, "build_commit: %s\n",        LLAMA_COMMIT);
+    fprintf(stream, "build_number: %d\n",        LLAMA_BUILD_NUMBER);
+    fprintf(stream, "cpu_has_arm_fma: %s\n",     ggml_cpu_has_arm_fma()     ? "true" : "false");
+    fprintf(stream, "cpu_has_avx: %s\n",         ggml_cpu_has_avx()         ? "true" : "false");
+    fprintf(stream, "cpu_has_avx_vnni: %s\n",    ggml_cpu_has_avx_vnni()    ? "true" : "false");
+    fprintf(stream, "cpu_has_avx2: %s\n",        ggml_cpu_has_avx2()        ? "true" : "false");
+    fprintf(stream, "cpu_has_avx512: %s\n",      ggml_cpu_has_avx512()      ? "true" : "false");
+    fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false");
+    fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false");
+    fprintf(stream, "cpu_has_cuda: %s\n",        ggml_cpu_has_cuda()        ? "true" : "false");
+    fprintf(stream, "cpu_has_vulkan: %s\n",      ggml_cpu_has_vulkan()      ? "true" : "false");
+    fprintf(stream, "cpu_has_kompute: %s\n",     ggml_cpu_has_kompute()     ? "true" : "false");
+    fprintf(stream, "cpu_has_fma: %s\n",         ggml_cpu_has_fma()         ? "true" : "false");
+    fprintf(stream, "cpu_has_gpublas: %s\n",     ggml_cpu_has_gpublas()     ? "true" : "false");
+    fprintf(stream, "cpu_has_neon: %s\n",        ggml_cpu_has_neon()        ? "true" : "false");
+    fprintf(stream, "cpu_has_sve: %s\n",         ggml_cpu_has_sve()         ? "true" : "false");
+    fprintf(stream, "cpu_has_f16c: %s\n",        ggml_cpu_has_f16c()        ? "true" : "false");
+    fprintf(stream, "cpu_has_fp16_va: %s\n",     ggml_cpu_has_fp16_va()     ? "true" : "false");
+    fprintf(stream, "cpu_has_wasm_simd: %s\n",   ggml_cpu_has_wasm_simd()   ? "true" : "false");
+    fprintf(stream, "cpu_has_blas: %s\n",        ggml_cpu_has_blas()        ? "true" : "false");
+    fprintf(stream, "cpu_has_sse3: %s\n",        ggml_cpu_has_sse3()        ? "true" : "false");
+    fprintf(stream, "cpu_has_vsx: %s\n",         ggml_cpu_has_vsx()         ? "true" : "false");
+    fprintf(stream, "cpu_has_matmul_int8: %s\n", ggml_cpu_has_matmul_int8() ? "true" : "false");
+
+#ifdef NDEBUG
+    fprintf(stream, "debug: false\n");
+#else
+    fprintf(stream, "debug: true\n");
+#endif // NDEBUG
+
+    fprintf(stream, "model_desc: %s\n", model_desc);
+    fprintf(stream, "n_vocab: %d  # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx)));
+
+#ifdef __OPTIMIZE__
+    fprintf(stream, "optimize: true\n");
+#else
+    fprintf(stream, "optimize: false\n");
+#endif // __OPTIMIZE__
+
+    fprintf(stream, "time: %s\n", timestamp.c_str());
+
+    fprintf(stream, "\n");
+    fprintf(stream, "###############\n");
+    fprintf(stream, "# User Inputs #\n");
+    fprintf(stream, "###############\n");
+    fprintf(stream, "\n");
+
+    fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str());
+    fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch);
+    yaml_dump_string_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str());
+    fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale);
+    fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
+    fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
+    fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
+    fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
+    fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
+    fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
+    yaml_dump_string_multiline(stream, "grammar", sparams.grammar.c_str());
+    fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
+    fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
+    fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
+
+    const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
+    const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
+    fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
+
+    yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
+    fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false");
+    yaml_dump_string_multiline(stream, "in_suffix", params.input_prefix.c_str());
+    fprintf(stream, "interactive: %s # default: false\n", params.interactive ? "true" : "false");
+    fprintf(stream, "interactive_first: %s # default: false\n", params.interactive_first ? "true" : "false");
+    fprintf(stream, "keep: %d # default: 0\n", params.n_keep);
+    fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str());
+
+    fprintf(stream, "logit_bias:\n");
+    for (std::pair<llama_token, float> lb : sparams.logit_bias) {
+        if (ignore_eos && lb.first == logit_bias_eos->first) {
+            continue;
+        }
+        fprintf(stream, "  %d: %f", lb.first, lb.second);
+    }
+
+    fprintf(stream, "lora:\n");
+    for (auto & la : params.lora_adapters) {
+        if (la.scale == 1.0f) {
+            fprintf(stream, "  - %s\n", la.path.c_str());
+        }
+    }
+    fprintf(stream, "lora_scaled:\n");
+    for (auto & la : params.lora_adapters) {
+        if (la.scale != 1.0f) {
+            fprintf(stream, "  - %s: %f\n", la.path.c_str(), la.scale);
+        }
+    }
+    fprintf(stream, "lora_init_without_apply: %s # default: false\n", params.lora_init_without_apply ? "true" : "false");
+    fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
+    fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep);
+    fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
+    fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
+    fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
+    fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
+    fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH);
+    fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
+    fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
+    fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers);
+    fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict);
+    fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs);
+    fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false");
+    fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false");
+    fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
+    fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
+    fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
+    yaml_dump_string_multiline(stream, "prompt", params.prompt.c_str());
+    fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
+    fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
+    fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
+    yaml_dump_vector_int(stream, "prompt_tokens", prompt_tokens);
+    fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat);
+
+    fprintf(stream, "reverse_prompt:\n");
+    for (std::string ap : params.antiprompt) {
+        size_t pos = 0;
+        while ((pos = ap.find('\n', pos)) != std::string::npos) {
+            ap.replace(pos, 1, "\\n");
+            pos += 1;
+        }
+
+        fprintf(stream, "  - %s\n", ap.c_str());
+    }
+
+    fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base);
+    fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale);
+    fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
+    fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
+    fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
+    fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
+    fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
+
+    const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
+    yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector);
+
+    fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z);
+    fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency());
+    fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
+    fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
+    fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
+    fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
+    fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
+    fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
+}

+ 514 - 0
llama/common.h

@@ -0,0 +1,514 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// Various helper functions and utilities
+
+#pragma once
+
+#include "llama.h"
+
+#include "sampling.h"
+
+#define LOG_NO_FILE_LINE_FUNCTION
+#include "log.h"
+
+#include <cmath>
+#include <string>
+#include <vector>
+#include <random>
+#include <thread>
+#include <unordered_map>
+#include <tuple>
+
+#ifdef _WIN32
+#define DIRECTORY_SEPARATOR '\\'
+#else
+#define DIRECTORY_SEPARATOR '/'
+#endif // _WIN32
+
+#define die(msg)          do { fputs("error: " msg "\n", stderr);                exit(1); } while (0)
+#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
+
+#define print_build_info() do {                                                                     \
+    fprintf(stderr, "%s: build = %d (%s)\n",      __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);      \
+    fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);    \
+} while(0)
+
+#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
+
+struct llama_lora_adapter_info {
+    std::string path;
+    float scale;
+};
+
+struct llama_lora_adapter_container : llama_lora_adapter_info {
+    struct llama_lora_adapter * adapter;
+};
+
+// build info
+extern int LLAMA_BUILD_NUMBER;
+extern char const * LLAMA_COMMIT;
+extern char const * LLAMA_COMPILER;
+extern char const * LLAMA_BUILD_TARGET;
+
+struct llama_control_vector_load_info;
+
+//
+// CPU utils
+//
+
+int32_t cpu_get_num_physical_cores();
+int32_t cpu_get_num_math();
+
+//
+// CLI argument parsing
+//
+
+// dimensionality reduction methods, used by cvector-generator
+enum dimre_method {
+    DIMRE_METHOD_PCA,
+    DIMRE_METHOD_MEAN,
+};
+
+struct cpu_params {
+    int      n_threads                   = -1;
+    bool     cpumask[GGML_MAX_N_THREADS] = {false}; // CPU affinity mask.
+    bool     mask_valid                  = false;   // Default: any CPU
+    enum ggml_sched_priority  priority   = GGML_SCHED_PRIO_NORMAL;  // Scheduling prio : (0 - normal, 1 - medium, 2 - high, 3 - realtime)
+    bool     strict_cpu                  = false;   // Use strict CPU placement
+    uint32_t poll                        = 50;      // Polling (busywait) level (0 - no polling, 100 - mostly polling)
+};
+
+struct gpt_params {
+    uint32_t seed                 = LLAMA_DEFAULT_SEED; // RNG seed
+
+    int32_t n_predict             =    -1; // new tokens to predict
+    int32_t n_ctx                 =     0; // context size
+    int32_t n_batch               =  2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
+    int32_t n_ubatch              =   512; // physical batch size for prompt processing (must be >=32 to use BLAS)
+    int32_t n_keep                =     0; // number of tokens to keep from initial prompt
+    int32_t n_draft               =     5; // number of tokens to draft during speculative decoding
+    int32_t n_chunks              =    -1; // max number of chunks to process (-1 = unlimited)
+    int32_t n_parallel            =     1; // number of parallel sequences to decode
+    int32_t n_sequences           =     1; // number of sequences to decode
+    float   p_split               =  0.1f; // speculative decoding split probability
+    int32_t n_gpu_layers          =    -1; // number of layers to store in VRAM (-1 - use default)
+    int32_t n_gpu_layers_draft    =    -1; // number of layers to store in VRAM for the draft model (-1 - use default)
+    int32_t main_gpu              =     0; // the GPU that is used for scratch and small tensors
+    float   tensor_split[128]     =   {0}; // how split tensors should be distributed across GPUs
+    int32_t grp_attn_n            =     1; // group-attention factor
+    int32_t grp_attn_w            =   512; // group-attention width
+    int32_t n_print               =    -1; // print token count every n tokens (-1 = disabled)
+    float   rope_freq_base        =  0.0f; // RoPE base frequency
+    float   rope_freq_scale       =  0.0f; // RoPE frequency scaling factor
+    float   yarn_ext_factor       = -1.0f; // YaRN extrapolation mix factor
+    float   yarn_attn_factor      =  1.0f; // YaRN magnitude scaling factor
+    float   yarn_beta_fast        = 32.0f; // YaRN low correction dim
+    float   yarn_beta_slow        =  1.0f; // YaRN high correction dim
+    int32_t yarn_orig_ctx         =     0; // YaRN original context length
+    float   defrag_thold          = -1.0f; // KV cache defragmentation threshold
+
+    struct cpu_params cpuparams;
+    struct cpu_params cpuparams_batch;
+    struct cpu_params draft_cpuparams;
+    struct cpu_params draft_cpuparams_batch;
+
+    ggml_backend_sched_eval_callback cb_eval = nullptr;
+    void * cb_eval_user_data                 = nullptr;
+
+    ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
+
+    enum llama_split_mode        split_mode        = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
+    enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
+    enum llama_pooling_type      pooling_type      = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
+    enum llama_attention_type    attention_type    = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
+
+    // // sampling parameters
+    struct llama_sampling_params sparams;
+
+    std::string model                = ""; // model path
+    std::string model_draft          = ""; // draft model for speculative decoding
+    std::string model_alias          = "unknown"; // model alias
+    std::string model_url            = ""; // model url to download
+    std::string hf_token             = ""; // HF token
+    std::string hf_repo              = ""; // HF repo
+    std::string hf_file              = ""; // HF file
+    std::string prompt               = "";
+    std::string prompt_file          = ""; // store the external prompt file name
+    std::string path_prompt_cache    = ""; // path to file for saving/loading prompt eval state
+    std::string input_prefix         = ""; // string to prefix user inputs with
+    std::string input_suffix         = ""; // string to suffix user inputs with
+    std::string logdir               = ""; // directory in which to save YAML log files
+    std::string lookup_cache_static  = ""; // path of static ngram cache file for lookup decoding
+    std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding
+    std::string logits_file          = ""; // file for saving *all* logits
+    std::string rpc_servers          = ""; // comma separated list of RPC servers
+
+    std::vector<std::string> in_files;   // all input files
+    std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
+    std::vector<llama_model_kv_override> kv_overrides;
+
+    bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
+    std::vector<llama_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale
+
+    std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale
+
+    int32_t verbosity                  = 0;
+    int32_t control_vector_layer_start = -1; // layer range for control vector
+    int32_t control_vector_layer_end   = -1; // layer range for control vector
+
+    int32_t ppl_stride      = 0;     // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
+    int32_t ppl_output_type = 0;     // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
+                                     //                                       (which is more convenient to use for plotting)
+                                     //
+    bool   hellaswag        = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
+    size_t hellaswag_tasks  = 400;   // number of tasks to use when computing the HellaSwag score
+
+    bool   winogrande       = false; // compute Winogrande score over random tasks from datafile supplied in prompt
+    size_t winogrande_tasks = 0;     // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed
+
+    bool   multiple_choice  = false;  // compute TruthfulQA score over random tasks from datafile supplied in prompt
+    size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed
+
+    bool   kl_divergence    = false; // compute KL divergence
+
+    bool usage             = false; // print usage
+    bool use_color         = false; // use color to distinguish generations and inputs
+    bool special           = false; // enable special token output
+    bool interactive       = false; // interactive mode
+    bool interactive_first = false; // wait for user input immediately
+    bool conversation      = false; // conversation mode (does not print special tokens and suffix/prefix)
+    bool prompt_cache_all  = false; // save user input and generations to prompt cache
+    bool prompt_cache_ro   = false; // open the prompt cache read-only and do not update it
+
+    bool escape            = true;  // escape "\n", "\r", "\t", "\'", "\"", and "\\"
+    bool multiline_input   = false; // reverse the usage of `\`
+    bool simple_io         = false; // improves compatibility with subprocesses and limited consoles
+    bool cont_batching     = true;  // insert new sequences for decoding on-the-fly
+    bool flash_attn        = false; // flash attention
+
+    bool input_prefix_bos  = false; // prefix BOS to user inputs, preceding input_prefix
+    bool ignore_eos        = false; // ignore generated EOS tokens
+    bool logits_all        = false; // return logits for all tokens in the batch
+    bool use_mmap          = true;  // use mmap for faster loads
+    bool use_mlock         = false; // use mlock to keep model in memory
+    bool verbose_prompt    = false; // print prompt tokens before generation
+    bool display_prompt    = true;  // print prompt before generation
+    bool infill            = false; // use infill mode
+    bool dump_kv_cache     = false; // dump the KV cache contents for debugging purposes
+    bool no_kv_offload     = false; // disable KV offloading
+    bool warmup            = true;  // warmup run
+    bool check_tensors     = false; // validate tensor data
+
+    std::string cache_type_k = "f16"; // KV cache data type for the K
+    std::string cache_type_v = "f16"; // KV cache data type for the V
+
+    // multimodal models (see examples/llava)
+    std::string mmproj = "";        // path to multimodal projector
+    std::vector<std::string> image; // path to image file(s)
+
+    // embedding
+    bool embedding         = false; // get only sentence embedding
+    int32_t embd_normalize = 2;     // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
+    std::string embd_out   = "";    // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
+    std::string embd_sep   = "\n";  // separator of embendings
+
+    // server params
+    int32_t port           = 8080;         // server listens on this network port
+    int32_t timeout_read   = 600;          // http read timeout in seconds
+    int32_t timeout_write  = timeout_read; // http write timeout in seconds
+    int     n_threads_http = -1;           // number of threads to process HTTP requests (TODO: support threadpool)
+
+    std::string hostname      = "127.0.0.1";
+    std::string public_path   = "";
+    std::string chat_template = "";
+    std::string system_prompt = "";
+    bool enable_chat_template = true;
+
+    std::vector<std::string> api_keys;
+
+    std::string ssl_file_key  = "";
+    std::string ssl_file_cert = "";
+
+    bool endpoint_slots   = true;
+    bool endpoint_metrics = false;
+
+    bool log_json = false;
+
+    std::string slot_save_path;
+
+    float slot_prompt_similarity = 0.5f;
+
+    // batched-bench params
+    bool is_pp_shared = false;
+
+    std::vector<int32_t> n_pp;
+    std::vector<int32_t> n_tg;
+    std::vector<int32_t> n_pl;
+
+    // retrieval params
+    std::vector<std::string> context_files; // context files to embed
+
+    int32_t chunk_size = 64; // chunk size for context embedding
+
+    std::string chunk_separator = "\n"; // chunk separator for context embedding
+
+    // passkey params
+    int32_t n_junk = 250; // number of times to repeat the junk text
+    int32_t i_pos  = -1;  // position of the passkey in the junk text
+
+    // imatrix params
+    std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file
+
+    int32_t n_out_freq  = 10; // output the imatrix every n_out_freq iterations
+    int32_t n_save_freq =  0; // save the imatrix every n_save_freq iterations
+    int32_t i_chunk     =  0; // start processing from this chunk
+
+    bool process_output = false; // collect data for the output tensor
+    bool compute_ppl    = true;  // whether to compute perplexity
+
+    // cvector-generator params
+    int n_pca_batch = 100;
+    int n_pca_iterations = 1000;
+    dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
+    std::string cvector_outfile       = "control_vector.gguf";
+    std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
+    std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
+
+    bool spm_infill = false; // suffix/prefix/middle pattern for infill
+
+    std::string lora_outfile = "ggml-lora-merged-f16.gguf";
+};
+
+void gpt_params_parse_from_env(gpt_params & params);
+void gpt_params_handle_model_default(gpt_params & params);
+
+bool gpt_params_parse_ex   (int argc, char ** argv, gpt_params & params);
+bool gpt_params_parse      (int argc, char ** argv, gpt_params & params);
+bool gpt_params_find_arg   (int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param);
+void gpt_params_print_usage(int argc, char ** argv, const gpt_params & params);
+
+std::string gpt_params_get_system_info(const gpt_params & params);
+
+bool parse_cpu_range(const std::string& range, bool(&boolmask)[GGML_MAX_N_THREADS]);
+bool parse_cpu_mask(const std::string& mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
+void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model = nullptr);
+bool set_process_priority(enum ggml_sched_priority prio);
+
+//
+// String utils
+//
+
+std::vector<std::string> string_split(std::string input, char separator);
+
+std::string string_strip(const std::string & str);
+std::string string_get_sortable_timestamp();
+
+void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
+
+template<class T>
+static std::vector<T> string_split(const std::string & str, char delim) {
+    std::vector<T> values;
+    std::istringstream str_stream(str);
+    std::string token;
+    while (std::getline(str_stream, token, delim)) {
+        T value;
+        std::istringstream token_stream(token);
+        token_stream >> value;
+        values.push_back(value);
+    }
+    return values;
+}
+
+bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
+void string_process_escapes(std::string & input);
+
+//
+// Filesystem utils
+//
+
+bool fs_validate_filename(const std::string & filename);
+bool fs_create_directory_with_parents(const std::string & path);
+
+std::string fs_get_cache_directory();
+std::string fs_get_cache_file(const std::string & filename);
+
+//
+// Model utils
+//
+
+struct llama_init_result {
+    struct llama_model   * model   = nullptr;
+    struct llama_context * context = nullptr;
+    std::vector<llama_lora_adapter_container> lora_adapters;
+};
+
+struct llama_init_result    llama_init_from_gpt_params(gpt_params & params);
+
+struct llama_model_params     llama_model_params_from_gpt_params    (const gpt_params & params);
+struct llama_context_params   llama_context_params_from_gpt_params  (const gpt_params & params);
+struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
+
+struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params);
+struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params);
+
+// clear LoRA adapters from context, then apply new list of adapters
+void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lora_adapter_container> & lora_adapters);
+
+// Batch utils
+
+void llama_batch_clear(struct llama_batch & batch);
+
+void llama_batch_add(
+                 struct llama_batch & batch,
+                        llama_token   id,
+                          llama_pos   pos,
+    const std::vector<llama_seq_id> & seq_ids,
+                               bool   logits);
+
+//
+// Vocab utils
+//
+
+// tokenizes a string into a vector of tokens
+// should work similar to Python's `tokenizer.encode`
+std::vector<llama_token> llama_tokenize(
+  const struct llama_context * ctx,
+           const std::string & text,
+                        bool   add_special,
+                        bool   parse_special = false);
+
+std::vector<llama_token> llama_tokenize(
+    const struct llama_model * model,
+           const std::string & text,
+                        bool   add_special,
+                        bool   parse_special = false);
+
+// tokenizes a token into a piece, optionally renders special/control tokens
+// should work similar to Python's `tokenizer.id_to_piece`
+std::string llama_token_to_piece(
+        const struct llama_context * ctx,
+                       llama_token   token,
+                       bool          special = true);
+
+// detokenizes a vector of tokens into a string
+// should work similar to Python's `tokenizer.decode`
+// optionally renders special/control tokens
+std::string llama_detokenize(
+                         llama_context * ctx,
+        const std::vector<llama_token> & tokens,
+                                  bool   special = true);
+
+//
+// Chat template utils
+//
+
+// same with llama_chat_message, but uses std::string
+struct llama_chat_msg {
+    std::string role;
+    std::string content;
+};
+
+// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
+bool llama_chat_verify_template(const std::string & tmpl);
+
+// CPP wrapper for llama_chat_apply_template
+// If the built-in template is not supported, we default to chatml
+// If the custom "tmpl" is not supported, we throw an error
+std::string llama_chat_apply_template(const struct llama_model * model,
+        const std::string & tmpl,
+        const std::vector<llama_chat_msg> & chat,
+        bool add_ass);
+
+// Format single message, while taking into account the position of that message in chat history
+std::string llama_chat_format_single(const struct llama_model * model,
+        const std::string & tmpl,
+        const std::vector<llama_chat_msg> & past_msg,
+        const llama_chat_msg & new_msg,
+        bool add_ass);
+
+// Returns an example of formatted chat
+std::string llama_chat_format_example(const struct llama_model * model,
+        const std::string & tmpl);
+
+//
+// KV cache utils
+//
+
+// Dump the KV cache view with the number of sequences per cell.
+void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80);
+
+// Dump the KV cache view showing individual sequences in each cell (long output).
+void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
+
+//
+// Embedding utils
+//
+
+void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
+
+float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n);
+
+//
+// Control vector utils
+//
+
+struct llama_control_vector_data {
+    int n_embd;
+
+    // stores data for layers [1, n_layer] where n_layer = data.size() / n_embd
+    std::vector<float> data;
+};
+
+struct llama_control_vector_load_info {
+    float strength;
+
+    std::string fname;
+};
+
+// Load control vectors, scale each by strength, and add them together.
+// On error, returns {-1, empty}
+llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos);
+
+//
+// Split utils
+//
+
+static const char * const LLM_KV_SPLIT_NO            = "split.no";
+static const char * const LLM_KV_SPLIT_COUNT         = "split.count";
+static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
+
+//
+// YAML utils
+//
+
+void yaml_dump_vector_float    (FILE * stream, const char * prop_name, const std::vector<float> & data);
+void yaml_dump_vector_int      (FILE * stream, const char * prop_name, const std::vector<int> & data);
+void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data);
+
+void yaml_dump_non_result_info(
+    FILE * stream, const gpt_params & params, const llama_context * lctx,
+    const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);

+ 2206 - 0
llama/ggml-aarch64.c

@@ -0,0 +1,2206 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+#include "ggml-quants.h"
+#include "ggml-impl.h"
+
+#include <math.h>
+#include <string.h>
+#include <assert.h>
+#include <float.h>
+#include <stdlib.h> // for qsort
+#include <stdio.h>  // for GGML_ASSERT
+
+#include "ggml-aarch64.h"
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic ignored "-Woverlength-strings"
+#elif defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+#define UNUSED GGML_UNUSED
+
+// Functions to create the interleaved data layout formats
+
+// interleave 4 block_q4_0s in blocks of blck_size_interleave
+// returns an interleaved block_q4_0x4
+// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks
+// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave
+//
+// - in                  : an array of block_q4_0 pointers
+// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of
+//                         blck_size_interleave bytes
+// - xor_mask            : the mask to convert the nibbles in block_q4_0 quants bytes
+//                         from bias offset form to pure sign form (this saves subtract
+//                         operations durin unpacking)
+//
+static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
+    block_q4_0x4 out;
+
+    for (int i = 0; i < 4; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    for (int i = 0; i < QK4_0 * 2; i++) {
+        int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave;
+        int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave;
+        src_offset += (i % blck_size_interleave);
+
+        out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
+    }
+
+    return out;
+}
+
+// interleave 8 block_q4_0s in blocks of blck_size_interleave
+// returns an interleaved block_q4_0x8
+// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
+// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
+static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
+    block_q4_0x8 out;
+
+    for (int i = 0; i < 8; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    for (int i = 0; i < QK4_0 * 4; i++) {
+        int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave;
+        int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave;
+        src_offset += (i % blck_size_interleave);
+
+        out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
+    }
+
+    return out;
+}
+
+void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(QK8_0 == 32);
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t srcv[4][8];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
+            for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
+
+            for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
+            for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
+            for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
+
+            const float amax = vmaxvq_f32(amaxv[0]);
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < 8; j++) {
+            float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]);
+            int32x4_t vi = vcvtnq_s32_f32(v);
+            y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[1][j], id[1]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0);
+            y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1);
+            y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2);
+            y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[2][j], id[2]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0);
+            y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1);
+            y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2);
+            y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[3][j], id[3]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0);
+            y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1);
+            y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2);
+            y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3);
+        }
+    }
+#else
+    // scalar
+    const int blck_size_interleave = 4;
+    float srcv[4][QK8_0];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            float amax = 0.0f; // absolute max
+
+            for (int j = 0; j < QK8_0; j++) {
+                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
+                amax = MAX(amax, fabsf(srcv[row_iter][j]));
+            }
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < QK8_0 * 4; j++) {
+            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
+            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
+            src_offset += (j % blck_size_interleave);
+
+            float x0 = srcv[src_id][src_offset] * id[src_id];
+            y[i].qs[j] = roundf(x0);
+        }
+    }
+#endif
+}
+
+void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(QK8_0 == 32);
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t srcv[4][8];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
+            for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
+
+            for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
+            for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
+            for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
+
+            const float amax = vmaxvq_f32(amaxv[0]);
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < 4; j++) {
+            float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]);
+            int32x4_t vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3);
+            v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[1][2 * j], id[1]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3);
+            v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[2][2 * j], id[2]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3);
+            v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[3][2 * j], id[3]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3);
+            v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3);
+        }
+    }
+#else
+    // scalar
+    const int blck_size_interleave = 8;
+    float srcv[4][QK8_0];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            float amax = 0.0f; // absolute max
+
+            for (int j = 0; j < QK8_0; j++) {
+                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
+                amax = MAX(amax, fabsf(srcv[row_iter][j]));
+            }
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < QK8_0 * 4; j++) {
+            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
+            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
+            src_offset += (j % blck_size_interleave);
+
+            float x0 = srcv[src_id][src_offset] * id[src_id];
+            y[i].qs[j] = roundf(x0);
+        }
+    }
+#endif
+}
+
+void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
+    assert(nrow == 4);
+    UNUSED(nrow);
+    if (blck_size_interleave == 4) {
+        quantize_q8_0_4x4(x, vy, n_per_row);
+    } else if (blck_size_interleave == 8) {
+        quantize_q8_0_4x8(x, vy, n_per_row);
+    } else {
+        assert(false);
+    }
+}
+
+static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) {
+    assert(n_per_row % QK4_0 == 0);
+    const int nb = n_per_row / QK4_0;
+
+    void * out_ptr = NULL;
+    if (nrows_interleaved == 8) {
+        out_ptr = (block_q4_0x8 *) dst;
+    }
+    else if (nrows_interleaved == 4) {
+        out_ptr = (block_q4_0x4 *) dst;
+    }
+    assert(nrows_interleaved <= 8);
+    block_q4_0 dst_tmp[8];
+
+    for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) {
+
+        for (int64_t x = 0; x < nb; x++) {
+
+            for (int i  = 0; i < nrows_interleaved; i++ ) {
+                quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0);
+            }
+
+            if (nrows_interleaved == 8) {
+                *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88);
+                out_ptr = (block_q4_0x8 *) out_ptr + 1;
+            }
+            else if (nrows_interleaved == 4) {
+                *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88);
+                out_ptr = (block_q4_0x4 *) out_ptr + 1;
+            }
+        }
+    }
+
+    return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0));
+}
+
+size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    UNUSED(quant_weights);
+    return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4);
+}
+
+size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    UNUSED(quant_weights);
+    return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8);
+}
+
+size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    UNUSED(quant_weights);
+    return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
+}
+
+void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE)
+    if (ggml_sve_cnt_b == QK8_0) {
+        GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
+                    "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
+    }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) &&
+                "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance");
+#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+    const void * b_ptr = vx;
+    const void * a_ptr = vy;
+    float * res_ptr = s;
+
+    __asm__ __volatile__(
+        "movi v31.16b, #0x4\n"
+        "movi v30.16b, #0xf0\n"
+        "add %x[b_ptr], %x[b_ptr], #0x8\n"
+        "1:"  // Column loop
+        "add x22, %x[a_ptr], #0x2\n"
+        "movi v29.16b, #0x0\n"
+        "mov x21, %x[nb]\n"
+        "2:"  // Block loop
+        "ldr q28, [%x[b_ptr], #0x0]\n"
+        "ldr q27, [x22, #0x0]\n"
+        "movi v26.4s, #0x0\n"
+        "sub x20, x22, #0x2\n"
+        "ldr q25, [x22, #0x10]\n"
+        "ldr q24, [%x[b_ptr], #0x10]\n"
+        "sub x21, x21, #0x1\n"
+        "add x22, x22, #0x22\n"
+        "ldr q23, [%x[b_ptr], #0x20]\n"
+        "ldr q22, [%x[b_ptr], #0x30]\n"
+        "ld1r { v21.8h }, [x20]\n"
+        "ldr q20, [%x[b_ptr], #-0x8]\n"
+        "sshl v16.16b, v28.16b, v31.16b\n"
+        "and v28.16b, v28.16b, v30.16b\n"
+        "sshl v19.16b, v24.16b, v31.16b\n"
+        "and v24.16b, v24.16b, v30.16b\n"
+        "add %x[b_ptr], %x[b_ptr], #0x48\n"
+        "sshl v18.16b, v23.16b, v31.16b\n"
+        "and v23.16b, v23.16b, v30.16b\n"
+        ".inst 0x4f9be21a  // sdot v26.4s, v16.16b, v27.4b[0]\n"
+        "sshl v17.16b, v22.16b, v31.16b\n"
+        "and v22.16b, v22.16b, v30.16b\n"
+        "fcvtl v21.4s, v21.4h\n"
+        "fcvtl v16.4s, v20.4h\n"
+        ".inst 0x4f99e39a  // sdot v26.4s, v28.16b, v25.4b[0]\n"
+        "fmul v16.4s, v16.4s, v21.4s\n"
+        ".inst 0x4fbbe27a  // sdot v26.4s, v19.16b, v27.4b[1]\n"
+        ".inst 0x4fb9e31a  // sdot v26.4s, v24.16b, v25.4b[1]\n"
+        ".inst 0x4f9bea5a  // sdot v26.4s, v18.16b, v27.4b[2]\n"
+        ".inst 0x4f99eafa  // sdot v26.4s, v23.16b, v25.4b[2]\n"
+        ".inst 0x4fbbea3a  // sdot v26.4s, v17.16b, v27.4b[3]\n"
+        ".inst 0x4fb9eada  // sdot v26.4s, v22.16b, v25.4b[3]\n"
+        "scvtf v26.4s, v26.4s, #0x4\n"
+        "fmla v29.4s, v26.4s, v16.4s\n"
+        "cbnz x21, 2b\n"
+        "sub %x[nc], %x[nc], #0x4\n"
+        "str q29, [%x[res_ptr], #0x0]\n"
+        "add %x[res_ptr], %x[res_ptr], #0x10\n"
+        "cbnz %x[nc], 1b\n"
+        : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
+        : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
+        : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"
+    );
+#else
+    float sumf[4];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+                    }
+                    sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+#endif
+}
+
+void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE)
+    if (ggml_sve_cnt_b == QK8_0) {
+        GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
+                    "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
+    }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+    const void * b_ptr = vx;
+    const void * a_ptr = vy;
+    float * res_ptr = s;
+
+    __asm__ __volatile__(
+        "movi v2.16b, #0x4\n"
+        "movi v1.16b, #0xf0\n"
+        "add %x[b_ptr], %x[b_ptr], #0x8\n"
+        "1:"  // Column loop
+        "add x23, %x[a_ptr], #0x2\n"
+        "movi v0.16b, #0x0\n"
+        "mov x22, %x[nb]\n"
+        "2:"  // Block loop
+        "ldr q31, [%x[b_ptr], #0x0]\n"
+        "ldr q30, [%x[b_ptr], #0x10]\n"
+        "mov x21, x23\n"
+        "movi v29.4s, #0x0\n"
+        "ldr q28, [%x[b_ptr], #0x20]\n"
+        "ldr q27, [%x[b_ptr], #0x30]\n"
+        "movi v26.4s, #0x0\n"
+        "sub x20, x23, #0x2\n"
+        "ld1r { v25.8h }, [x20]\n"
+        "ldr q24, [%x[b_ptr], #-0x8]\n"
+        "sub x22, x22, #0x1\n"
+        "add x23, x23, #0x22\n"
+        "ld1r { v23.2d }, [x21], #0x8\n"
+        "sshl v22.16b, v31.16b, v2.16b\n"
+        "sshl v16.16b, v30.16b, v2.16b\n"
+        "add %x[b_ptr], %x[b_ptr], #0x48\n"
+        "ld1r { v21.2d }, [x21], #0x8\n"
+        "sshl v20.16b, v28.16b, v2.16b\n"
+        "sshl v19.16b, v27.16b, v2.16b\n"
+        "ld1r { v18.2d }, [x21], #0x8\n"
+        "ld1r { v17.2d }, [x21], #0x8\n"
+        "and v31.16b, v31.16b, v1.16b\n"
+        "and v30.16b, v30.16b, v1.16b\n"
+        ".inst 0x4e9796dd  // sdot v29.4s, v22.16b, v23.16b\n"
+        ".inst 0x4e97961a  // sdot v26.4s, v16.16b, v23.16b\n"
+        "and v28.16b, v28.16b, v1.16b\n"
+        "and v27.16b, v27.16b, v1.16b\n"
+        "fcvtl v25.4s, v25.4h\n"
+        "fcvtl v16.4s, v24.4h\n"
+        ".inst 0x4e95969d  // sdot v29.4s, v20.16b, v21.16b\n"
+        ".inst 0x4e95967a  // sdot v26.4s, v19.16b, v21.16b\n"
+        "fmul v16.4s, v16.4s, v25.4s\n"
+        ".inst 0x4e9297fd  // sdot v29.4s, v31.16b, v18.16b\n"
+        ".inst 0x4e9297da  // sdot v26.4s, v30.16b, v18.16b\n"
+        ".inst 0x4e91979d  // sdot v29.4s, v28.16b, v17.16b\n"
+        ".inst 0x4e91977a  // sdot v26.4s, v27.16b, v17.16b\n"
+        "addp v29.4s, v29.4s, v26.4s\n"
+        "scvtf v29.4s, v29.4s, #0x4\n"
+        "fmla v0.4s, v29.4s, v16.4s\n"
+        "cbnz x22, 2b\n"
+        "sub %x[nc], %x[nc], #0x4\n"
+        "str q0, [%x[res_ptr], #0x0]\n"
+        "add %x[res_ptr], %x[res_ptr], #0x10\n"
+        "cbnz %x[nc], 1b\n"
+        : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
+        : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
+        : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
+    );
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+    GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
+                "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
+                "performance");
+#else
+    float sumf[4];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+                    }
+                    sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+#endif
+}
+
+void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+    if (ggml_sve_cnt_b == QK8_0) {
+        const void * b_ptr = vx;
+        const void * a_ptr = vy;
+        float * res_ptr = s;
+
+        __asm__ __volatile__(
+            "ptrue p0.b\n"
+            "add %x[b_ptr], %x[b_ptr], #0x10\n"
+            "1:"  // Column loop
+            "add x22, %x[a_ptr], #0x2\n"
+            "mov z31.b, #0x0\n"
+            "mov x21, %x[nb]\n"
+            "2:"  // Block loop
+            "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n"
+            "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n"
+            "mov z28.s, #0x0\n"
+            "mov z27.s, #0x0\n"
+            "ld1rd { z26.d }, p0/Z, [x22]\n"
+            "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n"
+            "sub x20, x22, #0x2\n"
+            "sub x21, x21, #0x1\n"
+            "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n"
+            "ld1rd { z23.d }, p0/Z, [x22, #8]\n"
+            "lsl z22.b, z30.b, #0x4\n"
+            "lsl z16.b, z29.b, #0x4\n"
+            "and z30.b, z30.b, #0xf0\n"
+            "and z29.b, z29.b, #0xf0\n"
+            "ld1rd { z21.d }, p0/Z, [x22, #16]\n"
+            "ld1rd { z20.d }, p0/Z, [x22, #24]\n"
+            "lsl z19.b, z25.b, #0x4\n"
+            "and z25.b, z25.b, #0xf0\n"
+            "ld1rh { z17.h }, p0/Z, [x20]\n"
+            "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n"
+            "sdot z28.s, z22.b, z26.b\n"
+            "sdot z27.s, z16.b, z26.b\n"
+            "lsl z16.b, z24.b, #0x4\n"
+            "add x22, x22, #0x22\n"
+            "and z24.b, z24.b, #0xf0\n"
+            "add %x[b_ptr], %x[b_ptr], #0x90\n"
+            "fcvt z17.s, p0/m, z17.h\n"
+            "fcvt z18.s, p0/m, z18.h\n"
+            "sdot z28.s, z19.b, z23.b\n"
+            "sdot z27.s, z16.b, z23.b\n"
+            "fmul z18.s, z18.s, z17.s\n"
+            "sdot z28.s, z30.b, z21.b\n"
+            "sdot z27.s, z29.b, z21.b\n"
+            "sdot z28.s, z25.b, z20.b\n"
+            "sdot z27.s, z24.b, z20.b\n"
+            "uzp1 z17.s, z28.s, z27.s\n"
+            "uzp2 z16.s, z28.s, z27.s\n"
+            "add z17.s, z17.s, z16.s\n"
+            "asr z17.s, z17.s, #0x4\n"
+            "scvtf z17.s, p0/m, z17.s\n"
+            "fmla z31.s, p0/M, z17.s, z18.s\n"
+            "cbnz x21, 2b\n"
+            "sub %x[nc], %x[nc], #0x8\n"
+            "st1w { z31.s }, p0, [%x[res_ptr]]\n"
+            "add %x[res_ptr], %x[res_ptr], #0x20\n"
+            "cbnz %x[nc], 1b\n"
+            : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
+            : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
+            : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+        );
+        return;
+    }
+    else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+        GGML_ASSERT((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
+                    "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal "
+                    "performance");
+    }
+    else if (ggml_cpu_has_neon()) {
+        GGML_ASSERT(((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) || ggml_cpu_has_matmul_int8()) &&
+                    "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 "
+                    "quantization format for optimal performance");
+    }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    GGML_ASSERT(ggml_cpu_has_sve() &&
+                "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance");
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+    GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
+                "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
+                "performance");
+#else
+    float sumf[8];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+                    }
+                    sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+#endif
+}
+
+void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+    if (ggml_sve_cnt_b == QK8_0) {
+        GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
+                    "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
+    }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) &&
+                "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance");
+#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+    const void * b_ptr = vx;
+    const void * a_ptr = vy;
+    float * res_ptr = s;
+    size_t res_stride = bs * sizeof(float);
+
+    __asm__ __volatile__(
+        "mov x10, %x[nr]\n"
+        "mov x9, #0x88\n"
+        "cmp x10, #0x10\n"
+        "mul x9, %x[nb], x9\n"
+        "blt 4f\n"
+        "1:"  // Row loop
+        "add x28, %x[b_ptr], #0x8\n"
+        "mov x27, %x[nc]\n"
+        "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
+        "2:"  // Column loop
+        "add x25, %x[a_ptr], #0x8\n"
+        "movi v15.16b, #0x0\n"
+        "movi v19.16b, #0x0\n"
+        "mov x24, %x[nb]\n"
+        "add x23, x25, x9\n"
+        "movi v18.16b, #0x0\n"
+        "movi v14.16b, #0x0\n"
+        "add x22, x23, x9\n"
+        "movi v11.16b, #0x0\n"
+        "movi v13.16b, #0x0\n"
+        "add x21, x22, x9\n"
+        "movi v23.16b, #0x0\n"
+        "movi v16.16b, #0x0\n"
+        "movi v25.16b, #0x0\n"
+        "movi v7.16b, #0x0\n"
+        "movi v0.16b, #0x0\n"
+        "movi v4.16b, #0x0\n"
+        "movi v5.16b, #0x0\n"
+        "movi v21.16b, #0x0\n"
+        "movi v8.16b, #0x0\n"
+        "movi v1.16b, #0x0\n"
+        "3:"  // Block loop
+        "ldr q3, [x28, #0x0]\n"
+        "ldr q31, [x25, #0x0]\n"
+        "movi v28.16b, #0x4\n"
+        "movi v10.4s, #0x0\n"
+        "ldr q22, [x28, #0x10]\n"
+        "ldr q6, [x25, #0x10]\n"
+        "movi v29.4s, #0x0\n"
+        "movi v9.4s, #0x0\n"
+        "ldr q27, [x28, #0x20]\n"
+        "ldr q30, [x28, #0x30]\n"
+        "movi v20.4s, #0x0\n"
+        "movi v24.16b, #0xf0\n"
+        "ldr d2, [x25, #-0x8]\n"
+        "ldr d26, [x23, #-0x8]\n"
+        "sshl v12.16b, v3.16b, v28.16b\n"
+        "sub x20, x28, #0x8\n"
+        "ldr d17, [x20, #0x0]\n"
+        "and v3.16b, v3.16b, v24.16b\n"
+        "subs x24, x24, #0x1\n"
+        "add x28, x28, #0x48\n"
+        ".inst 0x4f9fe18a  // sdot v10.4s, v12.16b, v31.4b[0]\n"
+        ".inst 0x4fbfe19d  // sdot v29.4s, v12.16b, v31.4b[1]\n"
+        ".inst 0x4f9fe989  // sdot v9.4s, v12.16b, v31.4b[2]\n"
+        ".inst 0x4fbfe994  // sdot v20.4s, v12.16b, v31.4b[3]\n"
+        "sshl v31.16b, v22.16b, v28.16b\n"
+        "and v22.16b, v22.16b, v24.16b\n"
+        "fcvtl v17.4s, v17.4h\n"
+        "fcvtl v2.4s, v2.4h\n"
+        "fcvtl v26.4s, v26.4h\n"
+        ".inst 0x4f86e3ea  // sdot v10.4s, v31.16b, v6.4b[0]\n"
+        ".inst 0x4fa6e3fd  // sdot v29.4s, v31.16b, v6.4b[1]\n"
+        ".inst 0x4f86ebe9  // sdot v9.4s, v31.16b, v6.4b[2]\n"
+        ".inst 0x4fa6ebf4  // sdot v20.4s, v31.16b, v6.4b[3]\n"
+        "sshl v6.16b, v27.16b, v28.16b\n"
+        "sshl v28.16b, v30.16b, v28.16b\n"
+        "and v27.16b, v27.16b, v24.16b\n"
+        "and v30.16b, v30.16b, v24.16b\n"
+        "ldr q24, [x25, #0x20]\n"
+        ".inst 0x4f98e0ca  // sdot v10.4s, v6.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\n"
+        ".inst 0x4f98e8c9  // sdot v9.4s, v6.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e8d4  // sdot v20.4s, v6.16b, v24.4b[3]\n"
+        "ldr q24, [x25, #0x30]\n"
+        ".inst 0x4f98e38a  // sdot v10.4s, v28.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e39d  // sdot v29.4s, v28.16b, v24.4b[1]\n"
+        ".inst 0x4f98eb89  // sdot v9.4s, v28.16b, v24.4b[2]\n"
+        ".inst 0x4fb8eb94  // sdot v20.4s, v28.16b, v24.4b[3]\n"
+        "ldr q24, [x25, #0x40]\n"
+        ".inst 0x4f98e06a  // sdot v10.4s, v3.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\n"
+        ".inst 0x4f98e869  // sdot v9.4s, v3.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e874  // sdot v20.4s, v3.16b, v24.4b[3]\n"
+        "ldr q24, [x25, #0x50]\n"
+        ".inst 0x4f98e2ca  // sdot v10.4s, v22.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e2dd  // sdot v29.4s, v22.16b, v24.4b[1]\n"
+        ".inst 0x4f98eac9  // sdot v9.4s, v22.16b, v24.4b[2]\n"
+        ".inst 0x4fb8ead4  // sdot v20.4s, v22.16b, v24.4b[3]\n"
+        "ldr q24, [x25, #0x60]\n"
+        ".inst 0x4f98e36a  // sdot v10.4s, v27.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\n"
+        ".inst 0x4f98eb69  // sdot v9.4s, v27.16b, v24.4b[2]\n"
+        ".inst 0x4fb8eb74  // sdot v20.4s, v27.16b, v24.4b[3]\n"
+        "ldr q24, [x25, #0x70]\n"
+        "add x25, x25, #0x88\n"
+        ".inst 0x4f98e3ca  // sdot v10.4s, v30.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e3dd  // sdot v29.4s, v30.16b, v24.4b[1]\n"
+        ".inst 0x4f98ebc9  // sdot v9.4s, v30.16b, v24.4b[2]\n"
+        ".inst 0x4fb8ebd4  // sdot v20.4s, v30.16b, v24.4b[3]\n"
+        "fmul v24.4s, v17.4s, v2.s[0]\n"
+        "scvtf v10.4s, v10.4s, #0x4\n"
+        "scvtf v29.4s, v29.4s, #0x4\n"
+        "scvtf v9.4s, v9.4s, #0x4\n"
+        "scvtf v20.4s, v20.4s, #0x4\n"
+        "fmla v15.4s, v10.4s, v24.4s\n"
+        "ldr q24, [x23, #0x0]\n"
+        "fmul v10.4s, v17.4s, v2.s[1]\n"
+        "fmla v19.4s, v29.4s, v10.4s\n"
+        "ldr q10, [x23, #0x10]\n"
+        "fmul v29.4s, v17.4s, v2.s[2]\n"
+        "fmul v2.4s, v17.4s, v2.s[3]\n"
+        "fmla v18.4s, v9.4s, v29.4s\n"
+        "movi v9.4s, #0x0\n"
+        "movi v29.4s, #0x0\n"
+        ".inst 0x4f98e189  // sdot v9.4s, v12.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e19d  // sdot v29.4s, v12.16b, v24.4b[1]\n"
+        "fmla v14.4s, v20.4s, v2.4s\n"
+        "movi v20.4s, #0x0\n"
+        "movi v2.4s, #0x0\n"
+        ".inst 0x4f98e994  // sdot v20.4s, v12.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e982  // sdot v2.4s, v12.16b, v24.4b[3]\n"
+        "ldr q24, [x23, #0x20]\n"
+        ".inst 0x4f8ae3e9  // sdot v9.4s, v31.16b, v10.4b[0]\n"
+        ".inst 0x4faae3fd  // sdot v29.4s, v31.16b, v10.4b[1]\n"
+        ".inst 0x4f8aebf4  // sdot v20.4s, v31.16b, v10.4b[2]\n"
+        ".inst 0x4faaebe2  // sdot v2.4s, v31.16b, v10.4b[3]\n"
+        "ldr q10, [x23, #0x30]\n"
+        ".inst 0x4f98e0c9  // sdot v9.4s, v6.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\n"
+        ".inst 0x4f98e8d4  // sdot v20.4s, v6.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e8c2  // sdot v2.4s, v6.16b, v24.4b[3]\n"
+        "ldr q24, [x23, #0x40]\n"
+        ".inst 0x4f8ae389  // sdot v9.4s, v28.16b, v10.4b[0]\n"
+        ".inst 0x4faae39d  // sdot v29.4s, v28.16b, v10.4b[1]\n"
+        ".inst 0x4f8aeb94  // sdot v20.4s, v28.16b, v10.4b[2]\n"
+        ".inst 0x4faaeb82  // sdot v2.4s, v28.16b, v10.4b[3]\n"
+        "ldr q10, [x23, #0x50]\n"
+        ".inst 0x4f98e069  // sdot v9.4s, v3.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\n"
+        ".inst 0x4f98e874  // sdot v20.4s, v3.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e862  // sdot v2.4s, v3.16b, v24.4b[3]\n"
+        "ldr q24, [x23, #0x60]\n"
+        ".inst 0x4f8ae2c9  // sdot v9.4s, v22.16b, v10.4b[0]\n"
+        ".inst 0x4faae2dd  // sdot v29.4s, v22.16b, v10.4b[1]\n"
+        ".inst 0x4f8aead4  // sdot v20.4s, v22.16b, v10.4b[2]\n"
+        ".inst 0x4faaeac2  // sdot v2.4s, v22.16b, v10.4b[3]\n"
+        "ldr q10, [x23, #0x70]\n"
+        "add x23, x23, #0x88\n"
+        ".inst 0x4f98e369  // sdot v9.4s, v27.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\n"
+        ".inst 0x4f98eb74  // sdot v20.4s, v27.16b, v24.4b[2]\n"
+        ".inst 0x4fb8eb62  // sdot v2.4s, v27.16b, v24.4b[3]\n"
+        "ldr q24, [x22, #0x0]\n"
+        ".inst 0x4f8ae3c9  // sdot v9.4s, v30.16b, v10.4b[0]\n"
+        ".inst 0x4faae3dd  // sdot v29.4s, v30.16b, v10.4b[1]\n"
+        ".inst 0x4f8aebd4  // sdot v20.4s, v30.16b, v10.4b[2]\n"
+        ".inst 0x4faaebc2  // sdot v2.4s, v30.16b, v10.4b[3]\n"
+        "fmul v10.4s, v17.4s, v26.s[0]\n"
+        "scvtf v9.4s, v9.4s, #0x4\n"
+        "scvtf v29.4s, v29.4s, #0x4\n"
+        "scvtf v20.4s, v20.4s, #0x4\n"
+        "scvtf v2.4s, v2.4s, #0x4\n"
+        "fmla v11.4s, v9.4s, v10.4s\n"
+        "ldr q9, [x22, #0x10]\n"
+        "fmul v10.4s, v17.4s, v26.s[1]\n"
+        "fmla v13.4s, v29.4s, v10.4s\n"
+        "ldr d29, [x22, #-0x8]\n"
+        "fmul v10.4s, v17.4s, v26.s[2]\n"
+        "fmul v26.4s, v17.4s, v26.s[3]\n"
+        "fcvtl v29.4s, v29.4h\n"
+        "fmla v23.4s, v20.4s, v10.4s\n"
+        "movi v20.4s, #0x0\n"
+        "movi v10.4s, #0x0\n"
+        "fmla v16.4s, v2.4s, v26.4s\n"
+        "movi v26.4s, #0x0\n"
+        "movi v2.4s, #0x0\n"
+        ".inst 0x4f98e194  // sdot v20.4s, v12.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e18a  // sdot v10.4s, v12.16b, v24.4b[1]\n"
+        ".inst 0x4f98e99a  // sdot v26.4s, v12.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e982  // sdot v2.4s, v12.16b, v24.4b[3]\n"
+        "ldr q24, [x22, #0x20]\n"
+        ".inst 0x4f89e3f4  // sdot v20.4s, v31.16b, v9.4b[0]\n"
+        ".inst 0x4fa9e3ea  // sdot v10.4s, v31.16b, v9.4b[1]\n"
+        ".inst 0x4f89ebfa  // sdot v26.4s, v31.16b, v9.4b[2]\n"
+        ".inst 0x4fa9ebe2  // sdot v2.4s, v31.16b, v9.4b[3]\n"
+        "ldr q9, [x22, #0x30]\n"
+        ".inst 0x4f98e0d4  // sdot v20.4s, v6.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e0ca  // sdot v10.4s, v6.16b, v24.4b[1]\n"
+        ".inst 0x4f98e8da  // sdot v26.4s, v6.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e8c2  // sdot v2.4s, v6.16b, v24.4b[3]\n"
+        "ldr q24, [x22, #0x40]\n"
+        ".inst 0x4f89e394  // sdot v20.4s, v28.16b, v9.4b[0]\n"
+        ".inst 0x4fa9e38a  // sdot v10.4s, v28.16b, v9.4b[1]\n"
+        ".inst 0x4f89eb9a  // sdot v26.4s, v28.16b, v9.4b[2]\n"
+        ".inst 0x4fa9eb82  // sdot v2.4s, v28.16b, v9.4b[3]\n"
+        "ldr q9, [x22, #0x50]\n"
+        ".inst 0x4f98e074  // sdot v20.4s, v3.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e06a  // sdot v10.4s, v3.16b, v24.4b[1]\n"
+        ".inst 0x4f98e87a  // sdot v26.4s, v3.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e862  // sdot v2.4s, v3.16b, v24.4b[3]\n"
+        "ldr q24, [x22, #0x60]\n"
+        ".inst 0x4f89e2d4  // sdot v20.4s, v22.16b, v9.4b[0]\n"
+        ".inst 0x4fa9e2ca  // sdot v10.4s, v22.16b, v9.4b[1]\n"
+        ".inst 0x4f89eada  // sdot v26.4s, v22.16b, v9.4b[2]\n"
+        ".inst 0x4fa9eac2  // sdot v2.4s, v22.16b, v9.4b[3]\n"
+        "ldr q9, [x22, #0x70]\n"
+        "add x22, x22, #0x88\n"
+        ".inst 0x4f98e374  // sdot v20.4s, v27.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e36a  // sdot v10.4s, v27.16b, v24.4b[1]\n"
+        ".inst 0x4f98eb7a  // sdot v26.4s, v27.16b, v24.4b[2]\n"
+        ".inst 0x4fb8eb62  // sdot v2.4s, v27.16b, v24.4b[3]\n"
+        "ldr q24, [x21, #0x0]\n"
+        ".inst 0x4f89e3d4  // sdot v20.4s, v30.16b, v9.4b[0]\n"
+        ".inst 0x4fa9e3ca  // sdot v10.4s, v30.16b, v9.4b[1]\n"
+        ".inst 0x4f89ebda  // sdot v26.4s, v30.16b, v9.4b[2]\n"
+        ".inst 0x4fa9ebc2  // sdot v2.4s, v30.16b, v9.4b[3]\n"
+        "fmul v9.4s, v17.4s, v29.s[0]\n"
+        "scvtf v20.4s, v20.4s, #0x4\n"
+        "scvtf v10.4s, v10.4s, #0x4\n"
+        "scvtf v26.4s, v26.4s, #0x4\n"
+        "scvtf v2.4s, v2.4s, #0x4\n"
+        "fmla v25.4s, v20.4s, v9.4s\n"
+        "ldr q9, [x21, #0x10]\n"
+        "fmul v20.4s, v17.4s, v29.s[1]\n"
+        "fmla v7.4s, v10.4s, v20.4s\n"
+        "ldr d20, [x21, #-0x8]\n"
+        "fmul v10.4s, v17.4s, v29.s[2]\n"
+        "fmul v29.4s, v17.4s, v29.s[3]\n"
+        "fcvtl v20.4s, v20.4h\n"
+        "fmla v0.4s, v26.4s, v10.4s\n"
+        "movi v26.4s, #0x0\n"
+        "movi v10.4s, #0x0\n"
+        "fmla v4.4s, v2.4s, v29.4s\n"
+        "movi v2.4s, #0x0\n"
+        "movi v29.4s, #0x0\n"
+        ".inst 0x4f98e19a  // sdot v26.4s, v12.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e18a  // sdot v10.4s, v12.16b, v24.4b[1]\n"
+        ".inst 0x4f98e982  // sdot v2.4s, v12.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e99d  // sdot v29.4s, v12.16b, v24.4b[3]\n"
+        "ldr q12, [x21, #0x20]\n"
+        "fmul v24.4s, v17.4s, v20.s[0]\n"
+        ".inst 0x4f89e3fa  // sdot v26.4s, v31.16b, v9.4b[0]\n"
+        ".inst 0x4fa9e3ea  // sdot v10.4s, v31.16b, v9.4b[1]\n"
+        ".inst 0x4f89ebe2  // sdot v2.4s, v31.16b, v9.4b[2]\n"
+        ".inst 0x4fa9ebfd  // sdot v29.4s, v31.16b, v9.4b[3]\n"
+        "ldr q9, [x21, #0x30]\n"
+        "fmul v31.4s, v17.4s, v20.s[1]\n"
+        ".inst 0x4f8ce0da  // sdot v26.4s, v6.16b, v12.4b[0]\n"
+        ".inst 0x4face0ca  // sdot v10.4s, v6.16b, v12.4b[1]\n"
+        ".inst 0x4f8ce8c2  // sdot v2.4s, v6.16b, v12.4b[2]\n"
+        ".inst 0x4face8dd  // sdot v29.4s, v6.16b, v12.4b[3]\n"
+        "ldr q12, [x21, #0x40]\n"
+        "fmul v6.4s, v17.4s, v20.s[2]\n"
+        "fmul v20.4s, v17.4s, v20.s[3]\n"
+        ".inst 0x4f89e39a  // sdot v26.4s, v28.16b, v9.4b[0]\n"
+        ".inst 0x4fa9e38a  // sdot v10.4s, v28.16b, v9.4b[1]\n"
+        ".inst 0x4f89eb82  // sdot v2.4s, v28.16b, v9.4b[2]\n"
+        ".inst 0x4fa9eb9d  // sdot v29.4s, v28.16b, v9.4b[3]\n"
+        "ldr q9, [x21, #0x50]\n"
+        ".inst 0x4f8ce07a  // sdot v26.4s, v3.16b, v12.4b[0]\n"
+        ".inst 0x4face06a  // sdot v10.4s, v3.16b, v12.4b[1]\n"
+        ".inst 0x4f8ce862  // sdot v2.4s, v3.16b, v12.4b[2]\n"
+        ".inst 0x4face87d  // sdot v29.4s, v3.16b, v12.4b[3]\n"
+        "ldr q12, [x21, #0x60]\n"
+        ".inst 0x4f89e2da  // sdot v26.4s, v22.16b, v9.4b[0]\n"
+        ".inst 0x4fa9e2ca  // sdot v10.4s, v22.16b, v9.4b[1]\n"
+        ".inst 0x4f89eac2  // sdot v2.4s, v22.16b, v9.4b[2]\n"
+        ".inst 0x4fa9eadd  // sdot v29.4s, v22.16b, v9.4b[3]\n"
+        "ldr q17, [x21, #0x70]\n"
+        "add x21, x21, #0x88\n"
+        ".inst 0x4f8ce37a  // sdot v26.4s, v27.16b, v12.4b[0]\n"
+        ".inst 0x4face36a  // sdot v10.4s, v27.16b, v12.4b[1]\n"
+        ".inst 0x4f8ceb62  // sdot v2.4s, v27.16b, v12.4b[2]\n"
+        ".inst 0x4faceb7d  // sdot v29.4s, v27.16b, v12.4b[3]\n"
+        ".inst 0x4f91e3da  // sdot v26.4s, v30.16b, v17.4b[0]\n"
+        ".inst 0x4fb1e3ca  // sdot v10.4s, v30.16b, v17.4b[1]\n"
+        ".inst 0x4f91ebc2  // sdot v2.4s, v30.16b, v17.4b[2]\n"
+        ".inst 0x4fb1ebdd  // sdot v29.4s, v30.16b, v17.4b[3]\n"
+        "scvtf v26.4s, v26.4s, #0x4\n"
+        "scvtf v10.4s, v10.4s, #0x4\n"
+        "fmla v5.4s, v26.4s, v24.4s\n"
+        "scvtf v2.4s, v2.4s, #0x4\n"
+        "scvtf v29.4s, v29.4s, #0x4\n"
+        "fmla v21.4s, v10.4s, v31.4s\n"
+        "fmla v8.4s, v2.4s, v6.4s\n"
+        "fmla v1.4s, v29.4s, v20.4s\n"
+        "bgt 3b\n"
+        "mov x20, %x[res_ptr]\n"
+        "subs x27, x27, #0x4\n"
+        "add %x[res_ptr], %x[res_ptr], #0x10\n"
+        "str q15, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q19, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q18, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q14, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q11, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q13, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q23, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q16, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q25, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q7, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q0, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q4, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q5, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q21, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q8, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q1, [x20, #0x0]\n"
+        "bne 2b\n"
+        "mov x20, #0x4\n"
+        "sub x10, x10, #0x10\n"
+        "cmp x10, #0x10\n"
+        "mov %x[res_ptr], x26\n"
+        "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
+        "bge 1b\n"
+        "4:"  // Row loop skip
+        "cbz x10, 9f\n"
+        "5:"  // Row tail: Row loop
+        "add x24, %x[b_ptr], #0x8\n"
+        "mov x23, %x[nc]\n"
+        "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
+        "6:"  // Row tail: Column loop
+        "movi v15.16b, #0x0\n"
+        "movi v19.16b, #0x0\n"
+        "add x25, %x[a_ptr], #0x8\n"
+        "mov x21, %x[nb]\n"
+        "movi v18.16b, #0x0\n"
+        "movi v14.16b, #0x0\n"
+        "7:"  // Row tail: Block loop
+        "ldr q7, [x24, #0x0]\n"
+        "ldr q5, [x25, #0x0]\n"
+        "movi v9.16b, #0x4\n"
+        "movi v4.4s, #0x0\n"
+        "ldr q3, [x24, #0x10]\n"
+        "ldr q2, [x25, #0x10]\n"
+        "movi v1.4s, #0x0\n"
+        "movi v0.4s, #0x0\n"
+        "ldr q13, [x24, #0x20]\n"
+        "ldr q31, [x25, #0x20]\n"
+        "movi v30.4s, #0x0\n"
+        "movi v29.16b, #0xf0\n"
+        "ldr q28, [x24, #0x30]\n"
+        "ldr q27, [x25, #0x30]\n"
+        "sshl v20.16b, v7.16b, v9.16b\n"
+        "sub x20, x24, #0x8\n"
+        "ldr q26, [x25, #0x40]\n"
+        "ldr q25, [x25, #0x50]\n"
+        "sshl v17.16b, v3.16b, v9.16b\n"
+        "and v7.16b, v7.16b, v29.16b\n"
+        "ldr q24, [x25, #0x60]\n"
+        "ldr q16, [x25, #0x70]\n"
+        "sshl v22.16b, v13.16b, v9.16b\n"
+        "and v3.16b, v3.16b, v29.16b\n"
+        "ldr d21, [x20, #0x0]\n"
+        "ldr d12, [x25, #-0x8]\n"
+        ".inst 0x4f85e284  // sdot v4.4s, v20.16b, v5.4b[0]\n"
+        ".inst 0x4fa5e281  // sdot v1.4s, v20.16b, v5.4b[1]\n"
+        ".inst 0x4f85ea80  // sdot v0.4s, v20.16b, v5.4b[2]\n"
+        ".inst 0x4fa5ea9e  // sdot v30.4s, v20.16b, v5.4b[3]\n"
+        "sshl v9.16b, v28.16b, v9.16b\n"
+        "subs x21, x21, #0x1\n"
+        "and v13.16b, v13.16b, v29.16b\n"
+        "and v28.16b, v28.16b, v29.16b\n"
+        "add x25, x25, #0x88\n"
+        "add x24, x24, #0x48\n"
+        "fcvtl v21.4s, v21.4h\n"
+        "fcvtl v12.4s, v12.4h\n"
+        ".inst 0x4f82e224  // sdot v4.4s, v17.16b, v2.4b[0]\n"
+        ".inst 0x4fa2e221  // sdot v1.4s, v17.16b, v2.4b[1]\n"
+        ".inst 0x4f82ea20  // sdot v0.4s, v17.16b, v2.4b[2]\n"
+        ".inst 0x4fa2ea3e  // sdot v30.4s, v17.16b, v2.4b[3]\n"
+        "fmul v11.4s, v21.4s, v12.s[0]\n"
+        "fmul v23.4s, v21.4s, v12.s[1]\n"
+        "fmul v17.4s, v21.4s, v12.s[2]\n"
+        ".inst 0x4f9fe2c4  // sdot v4.4s, v22.16b, v31.4b[0]\n"
+        "fmul v6.4s, v21.4s, v12.s[3]\n"
+        ".inst 0x4fbfe2c1  // sdot v1.4s, v22.16b, v31.4b[1]\n"
+        ".inst 0x4f9feac0  // sdot v0.4s, v22.16b, v31.4b[2]\n"
+        ".inst 0x4fbfeade  // sdot v30.4s, v22.16b, v31.4b[3]\n"
+        ".inst 0x4f9be124  // sdot v4.4s, v9.16b, v27.4b[0]\n"
+        ".inst 0x4fbbe121  // sdot v1.4s, v9.16b, v27.4b[1]\n"
+        ".inst 0x4f9be920  // sdot v0.4s, v9.16b, v27.4b[2]\n"
+        ".inst 0x4fbbe93e  // sdot v30.4s, v9.16b, v27.4b[3]\n"
+        ".inst 0x4f9ae0e4  // sdot v4.4s, v7.16b, v26.4b[0]\n"
+        ".inst 0x4fbae0e1  // sdot v1.4s, v7.16b, v26.4b[1]\n"
+        ".inst 0x4f9ae8e0  // sdot v0.4s, v7.16b, v26.4b[2]\n"
+        ".inst 0x4fbae8fe  // sdot v30.4s, v7.16b, v26.4b[3]\n"
+        ".inst 0x4f99e064  // sdot v4.4s, v3.16b, v25.4b[0]\n"
+        ".inst 0x4fb9e061  // sdot v1.4s, v3.16b, v25.4b[1]\n"
+        ".inst 0x4f99e860  // sdot v0.4s, v3.16b, v25.4b[2]\n"
+        ".inst 0x4fb9e87e  // sdot v30.4s, v3.16b, v25.4b[3]\n"
+        ".inst 0x4f98e1a4  // sdot v4.4s, v13.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e1a1  // sdot v1.4s, v13.16b, v24.4b[1]\n"
+        ".inst 0x4f98e9a0  // sdot v0.4s, v13.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e9be  // sdot v30.4s, v13.16b, v24.4b[3]\n"
+        ".inst 0x4f90e384  // sdot v4.4s, v28.16b, v16.4b[0]\n"
+        ".inst 0x4fb0e381  // sdot v1.4s, v28.16b, v16.4b[1]\n"
+        ".inst 0x4f90eb80  // sdot v0.4s, v28.16b, v16.4b[2]\n"
+        ".inst 0x4fb0eb9e  // sdot v30.4s, v28.16b, v16.4b[3]\n"
+        "scvtf v4.4s, v4.4s, #0x4\n"
+        "scvtf v1.4s, v1.4s, #0x4\n"
+        "scvtf v0.4s, v0.4s, #0x4\n"
+        "fmla v15.4s, v4.4s, v11.4s\n"
+        "scvtf v30.4s, v30.4s, #0x4\n"
+        "fmla v19.4s, v1.4s, v23.4s\n"
+        "fmla v18.4s, v0.4s, v17.4s\n"
+        "fmla v14.4s, v30.4s, v6.4s\n"
+        "bgt 7b\n"
+        "mov x20, %x[res_ptr]\n"
+        "cmp x10, #0x1\n"
+        "str q15, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "ble 8f\n"
+        "cmp x10, #0x2\n"
+        "str q19, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "ble 8f\n"
+        "cmp x10, #0x3\n"
+        "str q18, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "ble 8f\n"
+        "str q14, [x20, #0x0]\n"
+        "8:"  // Row tail: Accumulator store skip
+        "subs x23, x23, #0x4\n"
+        "add %x[res_ptr], %x[res_ptr], #0x10\n"
+        "bne 6b\n"
+        "subs x10, x10, #0x4\n"
+        "add %x[a_ptr], %x[a_ptr], x9\n"
+        "mov %x[res_ptr], x22\n"
+        "bgt 5b\n"
+        "9:"  // Row tail: Row loop skip
+        : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+        : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+        : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+    );
+#else
+    float sumf[4][4];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+                            }
+                            sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+#endif
+}
+
+void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+    if (ggml_sve_cnt_b == QK8_0) {
+        GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
+                    "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
+    }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+    const void * b_ptr = vx;
+    const void * a_ptr = vy;
+    float * res_ptr = s;
+    size_t res_stride = bs * sizeof(float);
+
+    __asm__ __volatile__(
+        "mov x10, %x[nr]\n"
+        "mov x9, #0x88\n"
+        "cmp x10, #0x10\n"
+        "mul x9, %x[nb], x9\n"
+        "blt 4f\n"
+        "1:"  // Row loop
+        "add x28, %x[b_ptr], #0x8\n"
+        "mov x27, %x[nc]\n"
+        "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
+        "2:"  // Column loop
+        "add x25, %x[a_ptr], #0x8\n"
+        "movi v2.16b, #0x0\n"
+        "movi v10.16b, #0x0\n"
+        "mov x24, %x[nb]\n"
+        "add x23, x25, x9\n"
+        "movi v12.16b, #0x0\n"
+        "movi v28.16b, #0x0\n"
+        "add x22, x23, x9\n"
+        "movi v11.16b, #0x0\n"
+        "movi v13.16b, #0x0\n"
+        "add x21, x22, x9\n"
+        "movi v22.16b, #0x0\n"
+        "movi v23.16b, #0x0\n"
+        "movi v25.16b, #0x0\n"
+        "movi v5.16b, #0x0\n"
+        "movi v7.16b, #0x0\n"
+        "movi v4.16b, #0x0\n"
+        "movi v6.16b, #0x0\n"
+        "movi v30.16b, #0x0\n"
+        "movi v24.16b, #0x0\n"
+        "movi v14.16b, #0x0\n"
+        "3:"  // Block loop
+        "ldr q21, [x28, #0x0]\n"
+        "ldr q16, [x28, #0x10]\n"
+        "movi v1.16b, #0x4\n"
+        "movi v19.4s, #0x0\n"
+        "ldr q27, [x25, #0x0]\n"
+        "ldr q15, [x25, #0x10]\n"
+        "movi v26.4s, #0x0\n"
+        "movi v18.4s, #0x0\n"
+        "ldr q29, [x28, #0x20]\n"
+        "ldr q3, [x28, #0x30]\n"
+        "movi v17.4s, #0x0\n"
+        "movi v0.16b, #0xf0\n"
+        "ldr d20, [x25, #-0x8]\n"
+        "ldr d9, [x23, #-0x8]\n"
+        "sshl v8.16b, v21.16b, v1.16b\n"
+        "sshl v31.16b, v16.16b, v1.16b\n"
+        "and v21.16b, v21.16b, v0.16b\n"
+        "and v16.16b, v16.16b, v0.16b\n"
+        "sub x20, x28, #0x8\n"
+        "subs x24, x24, #0x1\n"
+        "add x28, x28, #0x48\n"
+        ".inst 0x4e88a773  // smmla v19.4s, v27.16b, v8.16b\n"
+        ".inst 0x4e9fa77a  // smmla v26.4s, v27.16b, v31.16b\n"
+        "ldr q27, [x25, #0x20]\n"
+        ".inst 0x4e88a5f2  // smmla v18.4s, v15.16b, v8.16b\n"
+        ".inst 0x4e9fa5f1  // smmla v17.4s, v15.16b, v31.16b\n"
+        "sshl v15.16b, v29.16b, v1.16b\n"
+        "sshl v1.16b, v3.16b, v1.16b\n"
+        "and v29.16b, v29.16b, v0.16b\n"
+        "and v3.16b, v3.16b, v0.16b\n"
+        "ldr q0, [x25, #0x30]\n"
+        "fcvtl v20.4s, v20.4h\n"
+        ".inst 0x4e8fa773  // smmla v19.4s, v27.16b, v15.16b\n"
+        "fcvtl v9.4s, v9.4h\n"
+        ".inst 0x4e81a77a  // smmla v26.4s, v27.16b, v1.16b\n"
+        "ldr q27, [x25, #0x40]\n"
+        ".inst 0x4e8fa412  // smmla v18.4s, v0.16b, v15.16b\n"
+        ".inst 0x4e81a411  // smmla v17.4s, v0.16b, v1.16b\n"
+        "ldr q0, [x25, #0x50]\n"
+        ".inst 0x4e95a773  // smmla v19.4s, v27.16b, v21.16b\n"
+        ".inst 0x4e90a77a  // smmla v26.4s, v27.16b, v16.16b\n"
+        "ldr q27, [x25, #0x60]\n"
+        ".inst 0x4e95a412  // smmla v18.4s, v0.16b, v21.16b\n"
+        ".inst 0x4e90a411  // smmla v17.4s, v0.16b, v16.16b\n"
+        "ldr q0, [x25, #0x70]\n"
+        "add x25, x25, #0x88\n"
+        ".inst 0x4e9da773  // smmla v19.4s, v27.16b, v29.16b\n"
+        ".inst 0x4e83a77a  // smmla v26.4s, v27.16b, v3.16b\n"
+        "ldr d27, [x20, #0x0]\n"
+        ".inst 0x4e9da412  // smmla v18.4s, v0.16b, v29.16b\n"
+        ".inst 0x4e83a411  // smmla v17.4s, v0.16b, v3.16b\n"
+        "fcvtl v27.4s, v27.4h\n"
+        "uzp1 v0.2d, v19.2d, v26.2d\n"
+        "uzp2 v26.2d, v19.2d, v26.2d\n"
+        "fmul v19.4s, v27.4s, v20.s[0]\n"
+        "scvtf v0.4s, v0.4s, #0x4\n"
+        "scvtf v26.4s, v26.4s, #0x4\n"
+        "fmla v2.4s, v0.4s, v19.4s\n"
+        "ldr q19, [x23, #0x0]\n"
+        "uzp1 v0.2d, v18.2d, v17.2d\n"
+        "uzp2 v18.2d, v18.2d, v17.2d\n"
+        "fmul v17.4s, v27.4s, v20.s[1]\n"
+        "scvtf v0.4s, v0.4s, #0x4\n"
+        "scvtf v18.4s, v18.4s, #0x4\n"
+        "fmla v10.4s, v26.4s, v17.4s\n"
+        "ldr q17, [x23, #0x10]\n"
+        "fmul v26.4s, v27.4s, v20.s[2]\n"
+        "fmul v20.4s, v27.4s, v20.s[3]\n"
+        "fmla v12.4s, v0.4s, v26.4s\n"
+        "ldr d0, [x22, #-0x8]\n"
+        "ldr d26, [x21, #-0x8]\n"
+        "fcvtl v0.4s, v0.4h\n"
+        "fmla v28.4s, v18.4s, v20.4s\n"
+        "movi v20.4s, #0x0\n"
+        "movi v18.4s, #0x0\n"
+        ".inst 0x4e88a674  // smmla v20.4s, v19.16b, v8.16b\n"
+        ".inst 0x4e9fa672  // smmla v18.4s, v19.16b, v31.16b\n"
+        "ldr q19, [x23, #0x20]\n"
+        "fcvtl v26.4s, v26.4h\n"
+        ".inst 0x4e8fa674  // smmla v20.4s, v19.16b, v15.16b\n"
+        ".inst 0x4e81a672  // smmla v18.4s, v19.16b, v1.16b\n"
+        "ldr q19, [x23, #0x40]\n"
+        ".inst 0x4e95a674  // smmla v20.4s, v19.16b, v21.16b\n"
+        ".inst 0x4e90a672  // smmla v18.4s, v19.16b, v16.16b\n"
+        "ldr q19, [x23, #0x60]\n"
+        ".inst 0x4e9da674  // smmla v20.4s, v19.16b, v29.16b\n"
+        ".inst 0x4e83a672  // smmla v18.4s, v19.16b, v3.16b\n"
+        "uzp1 v19.2d, v20.2d, v18.2d\n"
+        "scvtf v19.4s, v19.4s, #0x4\n"
+        "uzp2 v20.2d, v20.2d, v18.2d\n"
+        "fmul v18.4s, v27.4s, v9.s[0]\n"
+        "scvtf v20.4s, v20.4s, #0x4\n"
+        "fmla v11.4s, v19.4s, v18.4s\n"
+        "ldr q18, [x22, #0x0]\n"
+        "fmul v19.4s, v27.4s, v9.s[1]\n"
+        "fmla v13.4s, v20.4s, v19.4s\n"
+        "movi v19.4s, #0x0\n"
+        "movi v20.4s, #0x0\n"
+        ".inst 0x4e88a633  // smmla v19.4s, v17.16b, v8.16b\n"
+        ".inst 0x4e9fa634  // smmla v20.4s, v17.16b, v31.16b\n"
+        "ldr q17, [x23, #0x30]\n"
+        ".inst 0x4e8fa633  // smmla v19.4s, v17.16b, v15.16b\n"
+        ".inst 0x4e81a634  // smmla v20.4s, v17.16b, v1.16b\n"
+        "ldr q17, [x23, #0x50]\n"
+        ".inst 0x4e95a633  // smmla v19.4s, v17.16b, v21.16b\n"
+        ".inst 0x4e90a634  // smmla v20.4s, v17.16b, v16.16b\n"
+        "ldr q17, [x23, #0x70]\n"
+        "add x23, x23, #0x88\n"
+        ".inst 0x4e9da633  // smmla v19.4s, v17.16b, v29.16b\n"
+        ".inst 0x4e83a634  // smmla v20.4s, v17.16b, v3.16b\n"
+        "uzp1 v17.2d, v19.2d, v20.2d\n"
+        "scvtf v17.4s, v17.4s, #0x4\n"
+        "uzp2 v20.2d, v19.2d, v20.2d\n"
+        "fmul v19.4s, v27.4s, v9.s[2]\n"
+        "fmul v9.4s, v27.4s, v9.s[3]\n"
+        "scvtf v20.4s, v20.4s, #0x4\n"
+        "fmla v22.4s, v17.4s, v19.4s\n"
+        "ldr q17, [x22, #0x10]\n"
+        "movi v19.4s, #0x0\n"
+        ".inst 0x4e88a653  // smmla v19.4s, v18.16b, v8.16b\n"
+        "fmla v23.4s, v20.4s, v9.4s\n"
+        "movi v20.4s, #0x0\n"
+        "movi v9.4s, #0x0\n"
+        ".inst 0x4e9fa654  // smmla v20.4s, v18.16b, v31.16b\n"
+        "ldr q18, [x22, #0x20]\n"
+        ".inst 0x4e88a629  // smmla v9.4s, v17.16b, v8.16b\n"
+        ".inst 0x4e8fa653  // smmla v19.4s, v18.16b, v15.16b\n"
+        ".inst 0x4e81a654  // smmla v20.4s, v18.16b, v1.16b\n"
+        "ldr q18, [x22, #0x40]\n"
+        ".inst 0x4e95a653  // smmla v19.4s, v18.16b, v21.16b\n"
+        ".inst 0x4e90a654  // smmla v20.4s, v18.16b, v16.16b\n"
+        "ldr q18, [x22, #0x60]\n"
+        ".inst 0x4e9da653  // smmla v19.4s, v18.16b, v29.16b\n"
+        ".inst 0x4e83a654  // smmla v20.4s, v18.16b, v3.16b\n"
+        "movi v18.4s, #0x0\n"
+        ".inst 0x4e9fa632  // smmla v18.4s, v17.16b, v31.16b\n"
+        "ldr q17, [x22, #0x30]\n"
+        ".inst 0x4e8fa629  // smmla v9.4s, v17.16b, v15.16b\n"
+        ".inst 0x4e81a632  // smmla v18.4s, v17.16b, v1.16b\n"
+        "ldr q17, [x22, #0x50]\n"
+        ".inst 0x4e95a629  // smmla v9.4s, v17.16b, v21.16b\n"
+        ".inst 0x4e90a632  // smmla v18.4s, v17.16b, v16.16b\n"
+        "ldr q17, [x22, #0x70]\n"
+        "add x22, x22, #0x88\n"
+        ".inst 0x4e9da629  // smmla v9.4s, v17.16b, v29.16b\n"
+        ".inst 0x4e83a632  // smmla v18.4s, v17.16b, v3.16b\n"
+        "uzp1 v17.2d, v19.2d, v20.2d\n"
+        "uzp2 v20.2d, v19.2d, v20.2d\n"
+        "fmul v19.4s, v27.4s, v0.s[0]\n"
+        "scvtf v17.4s, v17.4s, #0x4\n"
+        "scvtf v20.4s, v20.4s, #0x4\n"
+        "fmla v25.4s, v17.4s, v19.4s\n"
+        "ldr q19, [x21, #0x0]\n"
+        "fmul v17.4s, v27.4s, v0.s[1]\n"
+        "fmla v5.4s, v20.4s, v17.4s\n"
+        "ldr q17, [x21, #0x10]\n"
+        "uzp1 v20.2d, v9.2d, v18.2d\n"
+        "uzp2 v9.2d, v9.2d, v18.2d\n"
+        "fmul v18.4s, v27.4s, v0.s[2]\n"
+        "fmul v0.4s, v27.4s, v0.s[3]\n"
+        "scvtf v20.4s, v20.4s, #0x4\n"
+        "scvtf v9.4s, v9.4s, #0x4\n"
+        "fmla v7.4s, v20.4s, v18.4s\n"
+        "movi v20.4s, #0x0\n"
+        "movi v18.4s, #0x0\n"
+        ".inst 0x4e88a674  // smmla v20.4s, v19.16b, v8.16b\n"
+        ".inst 0x4e9fa672  // smmla v18.4s, v19.16b, v31.16b\n"
+        "ldr q19, [x21, #0x20]\n"
+        "fmla v4.4s, v9.4s, v0.4s\n"
+        "movi v9.4s, #0x0\n"
+        "movi v0.4s, #0x0\n"
+        ".inst 0x4e88a629  // smmla v9.4s, v17.16b, v8.16b\n"
+        "fmul v8.4s, v27.4s, v26.s[0]\n"
+        ".inst 0x4e9fa620  // smmla v0.4s, v17.16b, v31.16b\n"
+        "ldr q17, [x21, #0x30]\n"
+        ".inst 0x4e8fa674  // smmla v20.4s, v19.16b, v15.16b\n"
+        "fmul v31.4s, v27.4s, v26.s[1]\n"
+        ".inst 0x4e81a672  // smmla v18.4s, v19.16b, v1.16b\n"
+        "ldr q19, [x21, #0x40]\n"
+        ".inst 0x4e8fa629  // smmla v9.4s, v17.16b, v15.16b\n"
+        "fmul v15.4s, v27.4s, v26.s[2]\n"
+        "fmul v27.4s, v27.4s, v26.s[3]\n"
+        ".inst 0x4e81a620  // smmla v0.4s, v17.16b, v1.16b\n"
+        "ldr q1, [x21, #0x50]\n"
+        ".inst 0x4e95a674  // smmla v20.4s, v19.16b, v21.16b\n"
+        ".inst 0x4e90a672  // smmla v18.4s, v19.16b, v16.16b\n"
+        "ldr q26, [x21, #0x60]\n"
+        ".inst 0x4e95a429  // smmla v9.4s, v1.16b, v21.16b\n"
+        ".inst 0x4e90a420  // smmla v0.4s, v1.16b, v16.16b\n"
+        "ldr q21, [x21, #0x70]\n"
+        "add x21, x21, #0x88\n"
+        ".inst 0x4e9da754  // smmla v20.4s, v26.16b, v29.16b\n"
+        ".inst 0x4e83a752  // smmla v18.4s, v26.16b, v3.16b\n"
+        ".inst 0x4e9da6a9  // smmla v9.4s, v21.16b, v29.16b\n"
+        ".inst 0x4e83a6a0  // smmla v0.4s, v21.16b, v3.16b\n"
+        "uzp1 v29.2d, v20.2d, v18.2d\n"
+        "uzp2 v21.2d, v20.2d, v18.2d\n"
+        "scvtf v29.4s, v29.4s, #0x4\n"
+        "uzp1 v18.2d, v9.2d, v0.2d\n"
+        "uzp2 v16.2d, v9.2d, v0.2d\n"
+        "scvtf v21.4s, v21.4s, #0x4\n"
+        "fmla v6.4s, v29.4s, v8.4s\n"
+        "scvtf v18.4s, v18.4s, #0x4\n"
+        "scvtf v16.4s, v16.4s, #0x4\n"
+        "fmla v30.4s, v21.4s, v31.4s\n"
+        "fmla v24.4s, v18.4s, v15.4s\n"
+        "fmla v14.4s, v16.4s, v27.4s\n"
+        "bgt 3b\n"
+        "mov x20, %x[res_ptr]\n"
+        "subs x27, x27, #0x4\n"
+        "add %x[res_ptr], %x[res_ptr], #0x10\n"
+        "str q2, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q10, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q12, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q28, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q11, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q13, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q22, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q23, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q25, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q5, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q7, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q4, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q6, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q30, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q24, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "str q14, [x20, #0x0]\n"
+        "bne 2b\n"
+        "mov x20, #0x4\n"
+        "sub x10, x10, #0x10\n"
+        "cmp x10, #0x10\n"
+        "mov %x[res_ptr], x26\n"
+        "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
+        "bge 1b\n"
+        "4:"  // Row loop skip
+        "cbz x10, 9f\n"
+        "5:"  // Row tail: Row loop
+        "add x24, %x[b_ptr], #0x8\n"
+        "mov x23, %x[nc]\n"
+        "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
+        "6:"  // Row tail: Column loop
+        "movi v2.16b, #0x0\n"
+        "movi v10.16b, #0x0\n"
+        "add x25, %x[a_ptr], #0x8\n"
+        "mov x21, %x[nb]\n"
+        "movi v12.16b, #0x0\n"
+        "movi v28.16b, #0x0\n"
+        "7:"  // Row tail: Block loop
+        "ldr q6, [x24, #0x0]\n"
+        "ldr q5, [x24, #0x10]\n"
+        "movi v17.16b, #0x4\n"
+        "movi v8.4s, #0x0\n"
+        "ldr q4, [x25, #0x0]\n"
+        "ldr q13, [x25, #0x10]\n"
+        "movi v27.4s, #0x0\n"
+        "movi v0.4s, #0x0\n"
+        "ldr q31, [x24, #0x20]\n"
+        "ldr q14, [x24, #0x30]\n"
+        "movi v29.4s, #0x0\n"
+        "movi v22.16b, #0xf0\n"
+        "ldr q11, [x25, #0x20]\n"
+        "ldr q23, [x25, #0x30]\n"
+        "sshl v21.16b, v6.16b, v17.16b\n"
+        "sshl v16.16b, v5.16b, v17.16b\n"
+        "ldr q20, [x25, #0x40]\n"
+        "ldr q26, [x25, #0x50]\n"
+        "and v6.16b, v6.16b, v22.16b\n"
+        "and v5.16b, v5.16b, v22.16b\n"
+        "ldr q25, [x25, #0x60]\n"
+        "ldr q3, [x25, #0x70]\n"
+        "sshl v19.16b, v31.16b, v17.16b\n"
+        "sshl v18.16b, v14.16b, v17.16b\n"
+        "ldr d17, [x25, #-0x8]\n"
+        ".inst 0x4e95a488  // smmla v8.4s, v4.16b, v21.16b\n"
+        ".inst 0x4e90a49b  // smmla v27.4s, v4.16b, v16.16b\n"
+        "and v31.16b, v31.16b, v22.16b\n"
+        ".inst 0x4e95a5a0  // smmla v0.4s, v13.16b, v21.16b\n"
+        ".inst 0x4e90a5bd  // smmla v29.4s, v13.16b, v16.16b\n"
+        "and v14.16b, v14.16b, v22.16b\n"
+        "sub x20, x24, #0x8\n"
+        "ldr d16, [x20, #0x0]\n"
+        "subs x21, x21, #0x1\n"
+        "add x25, x25, #0x88\n"
+        "fcvtl v17.4s, v17.4h\n"
+        "add x24, x24, #0x48\n"
+        ".inst 0x4e93a568  // smmla v8.4s, v11.16b, v19.16b\n"
+        ".inst 0x4e92a57b  // smmla v27.4s, v11.16b, v18.16b\n"
+        ".inst 0x4e93a6e0  // smmla v0.4s, v23.16b, v19.16b\n"
+        ".inst 0x4e92a6fd  // smmla v29.4s, v23.16b, v18.16b\n"
+        "fcvtl v16.4s, v16.4h\n"
+        ".inst 0x4e86a688  // smmla v8.4s, v20.16b, v6.16b\n"
+        ".inst 0x4e85a69b  // smmla v27.4s, v20.16b, v5.16b\n"
+        "fmul v23.4s, v16.4s, v17.s[0]\n"
+        "fmul v21.4s, v16.4s, v17.s[1]\n"
+        "fmul v1.4s, v16.4s, v17.s[2]\n"
+        "fmul v20.4s, v16.4s, v17.s[3]\n"
+        ".inst 0x4e86a740  // smmla v0.4s, v26.16b, v6.16b\n"
+        ".inst 0x4e85a75d  // smmla v29.4s, v26.16b, v5.16b\n"
+        ".inst 0x4e9fa728  // smmla v8.4s, v25.16b, v31.16b\n"
+        ".inst 0x4e8ea73b  // smmla v27.4s, v25.16b, v14.16b\n"
+        ".inst 0x4e9fa460  // smmla v0.4s, v3.16b, v31.16b\n"
+        ".inst 0x4e8ea47d  // smmla v29.4s, v3.16b, v14.16b\n"
+        "uzp1 v19.2d, v8.2d, v27.2d\n"
+        "uzp2 v18.2d, v8.2d, v27.2d\n"
+        "scvtf v19.4s, v19.4s, #0x4\n"
+        "uzp1 v17.2d, v0.2d, v29.2d\n"
+        "uzp2 v16.2d, v0.2d, v29.2d\n"
+        "scvtf v18.4s, v18.4s, #0x4\n"
+        "fmla v2.4s, v19.4s, v23.4s\n"
+        "scvtf v17.4s, v17.4s, #0x4\n"
+        "scvtf v16.4s, v16.4s, #0x4\n"
+        "fmla v10.4s, v18.4s, v21.4s\n"
+        "fmla v12.4s, v17.4s, v1.4s\n"
+        "fmla v28.4s, v16.4s, v20.4s\n"
+        "bgt 7b\n"
+        "mov x20, %x[res_ptr]\n"
+        "cmp x10, #0x1\n"
+        "str q2, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "ble 8f\n"
+        "cmp x10, #0x2\n"
+        "str q10, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "ble 8f\n"
+        "cmp x10, #0x3\n"
+        "str q12, [x20, #0x0]\n"
+        "add x20, x20, %x[res_stride]\n"
+        "ble 8f\n"
+        "str q28, [x20, #0x0]\n"
+        "8:"  // Row tail: Accumulator store skip
+        "subs x23, x23, #0x4\n"
+        "add %x[res_ptr], %x[res_ptr], #0x10\n"
+        "bne 6b\n"
+        "subs x10, x10, #0x4\n"
+        "add %x[a_ptr], %x[a_ptr], x9\n"
+        "mov %x[res_ptr], x22\n"
+        "bgt 5b\n"
+        "9:"  // Row tail: Row loop skip
+        : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+        : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+        : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+    );
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+    GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
+                "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
+                "performance");
+#else
+    float sumf[4][4];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+                            }
+                            sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+#endif
+}
+
+void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+    if (ggml_sve_cnt_b == QK8_0) {
+        const void * b_ptr = vx;
+        const void * a_ptr = vy;
+        float * res_ptr = s;
+        size_t res_stride = bs * sizeof(float);
+
+        __asm__ __volatile__(
+            "mov x20, #0x4\n"
+            "mov x13, %x[nr]\n"
+            "mov z28.s, #-0x4\n"
+            "mov x12, #0x88\n"
+            "ptrue p1.b\n"
+            "whilelt p0.s, XZR, x20\n"
+            "cmp x13, #0x10\n"
+            "mul x12, %x[nb], x12\n"
+            "blt 4f\n"
+            "1:"  // Row loop
+            "add x11, %x[b_ptr], #0x10\n"
+            "mov x10, %x[nc]\n"
+            "add x9, %x[res_ptr], %x[res_stride], LSL #4\n"
+            "2:"  // Column loop
+            "add x28, %x[a_ptr], #0x8\n"
+            "mov z24.b, #0x0\n"
+            "mov z15.b, #0x0\n"
+            "mov x27, %x[nb]\n"
+            "add x26, x28, x12\n"
+            "mov z12.b, #0x0\n"
+            "mov z0.b, #0x0\n"
+            "add x25, x26, x12\n"
+            "mov z13.b, #0x0\n"
+            "mov z1.b, #0x0\n"
+            "add x24, x25, x12\n"
+            "mov z20.b, #0x0\n"
+            "mov z25.b, #0x0\n"
+            "mov z11.b, #0x0\n"
+            "mov z16.b, #0x0\n"
+            "mov z19.b, #0x0\n"
+            "mov z26.b, #0x0\n"
+            "mov z8.b, #0x0\n"
+            "mov z29.b, #0x0\n"
+            "mov z27.b, #0x0\n"
+            "mov z10.b, #0x0\n"
+            "3:"  // Block loop
+            "ld1b { z30.b }, p1/Z, [x11]\n"
+            "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n"
+            "mov z18.s, #0x0\n"
+            "mov z7.s, #0x0\n"
+            "ld1rqb { z3.b }, p1/Z, [x28]\n"
+            "ld1rqb { z5.b }, p1/Z, [x28, #16]\n"
+            "mov z9.s, #0x0\n"
+            "mov z22.s, #0x0\n"
+            "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n"
+            "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n"
+            "sub x20, x11, #0x10\n"
+            "sub x23, x28, #0x8\n"
+            "lsl z31.b, z30.b, #0x4\n"
+            "lsl z6.b, z21.b, #0x4\n"
+            "ld1h { z23.s }, p1/Z, [x20]\n"
+            "sub x22, x26, #0x8\n"
+            "and z30.b, z30.b, #0xf0\n"
+            "and z21.b, z21.b, #0xf0\n"
+            "sub x21, x25, #0x8\n"
+            "sub x20, x24, #0x8\n"
+            "lsl z14.b, z4.b, #0x4\n"
+            "lsl z2.b, z17.b, #0x4\n"
+            "subs x27, x27, #0x1\n"
+            "add x11, x11, #0x90\n"
+            ".inst 0x451f9872  // smmla z18.s, z3.b, z31.b\n"
+            ".inst 0x45069867  // smmla z7.s, z3.b, z6.b\n"
+            "ld1rqb { z3.b }, p1/Z, [x28, #32]\n"
+            "and z4.b, z4.b, #0xf0\n"
+            ".inst 0x451f98a9  // smmla z9.s, z5.b, z31.b\n"
+            ".inst 0x450698b6  // smmla z22.s, z5.b, z6.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x28, #48]\n"
+            "and z17.b, z17.b, #0xf0\n"
+            "fcvt z23.s, p1/m, z23.h\n"
+            ".inst 0x450e9872  // smmla z18.s, z3.b, z14.b\n"
+            ".inst 0x45029867  // smmla z7.s, z3.b, z2.b\n"
+            "ld1rqb { z3.b }, p1/Z, [x28, #64]\n"
+            ".inst 0x450e98a9  // smmla z9.s, z5.b, z14.b\n"
+            ".inst 0x450298b6  // smmla z22.s, z5.b, z2.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x28, #80]\n"
+            "fscale z23.s, p1/m, z23.s, z28.s\n"
+            ".inst 0x451e9872  // smmla z18.s, z3.b, z30.b\n"
+            ".inst 0x45159867  // smmla z7.s, z3.b, z21.b\n"
+            "ld1rqb { z3.b }, p1/Z, [x28, #96]\n"
+            ".inst 0x451e98a9  // smmla z9.s, z5.b, z30.b\n"
+            ".inst 0x451598b6  // smmla z22.s, z5.b, z21.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x28, #112]\n"
+            "add x28, x28, #0x88\n"
+            ".inst 0x45049872  // smmla z18.s, z3.b, z4.b\n"
+            ".inst 0x45119867  // smmla z7.s, z3.b, z17.b\n"
+            "ld1h { z3.s }, p0/Z, [x23]\n"
+            ".inst 0x450498a9  // smmla z9.s, z5.b, z4.b\n"
+            ".inst 0x451198b6  // smmla z22.s, z5.b, z17.b\n"
+            "fcvt z3.s, p1/m, z3.h\n"
+            "uzp1 z5.d, z18.d, z7.d\n"
+            "uzp2 z18.d, z18.d, z7.d\n"
+            "mov z3.q, z3.q[0]\n"
+            "uzp1 z7.d, z9.d, z22.d\n"
+            "uzp2 z22.d, z9.d, z22.d\n"
+            "fmul z9.s, z23.s, z3.s[0]\n"
+            "scvtf z5.s, p1/m, z5.s\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "scvtf z7.s, p1/m, z7.s\n"
+            "scvtf z22.s, p1/m, z22.s\n"
+            "fmla z24.s, p1/M, z5.s, z9.s\n"
+            "ld1rqb { z5.b }, p1/Z, [x26]\n"
+            "fmul z9.s, z23.s, z3.s[1]\n"
+            "fmla z15.s, p1/M, z18.s, z9.s\n"
+            "ld1rqb { z18.b }, p1/Z, [x26, #16]\n"
+            "fmul z9.s, z23.s, z3.s[2]\n"
+            "fmul z3.s, z23.s, z3.s[3]\n"
+            "fmla z12.s, p1/M, z7.s, z9.s\n"
+            "mov z9.s, #0x0\n"
+            "ld1h { z7.s }, p0/Z, [x22]\n"
+            ".inst 0x451f98a9  // smmla z9.s, z5.b, z31.b\n"
+            "fmla z0.s, p1/M, z22.s, z3.s\n"
+            "mov z22.s, #0x0\n"
+            "ld1h { z3.s }, p0/Z, [x21]\n"
+            ".inst 0x450698b6  // smmla z22.s, z5.b, z6.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x26, #32]\n"
+            "fcvt z7.s, p1/m, z7.h\n"
+            "fcvt z3.s, p1/m, z3.h\n"
+            ".inst 0x450e98a9  // smmla z9.s, z5.b, z14.b\n"
+            ".inst 0x450298b6  // smmla z22.s, z5.b, z2.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x26, #64]\n"
+            "mov z7.q, z7.q[0]\n"
+            "mov z3.q, z3.q[0]\n"
+            ".inst 0x451e98a9  // smmla z9.s, z5.b, z30.b\n"
+            ".inst 0x451598b6  // smmla z22.s, z5.b, z21.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x26, #96]\n"
+            ".inst 0x450498a9  // smmla z9.s, z5.b, z4.b\n"
+            ".inst 0x451198b6  // smmla z22.s, z5.b, z17.b\n"
+            "uzp1 z5.d, z9.d, z22.d\n"
+            "scvtf z5.s, p1/m, z5.s\n"
+            "uzp2 z22.d, z9.d, z22.d\n"
+            "fmul z9.s, z23.s, z7.s[0]\n"
+            "scvtf z22.s, p1/m, z22.s\n"
+            "fmla z13.s, p1/M, z5.s, z9.s\n"
+            "ld1rqb { z9.b }, p1/Z, [x25]\n"
+            "fmul z5.s, z23.s, z7.s[1]\n"
+            "fmla z1.s, p1/M, z22.s, z5.s\n"
+            "mov z5.s, #0x0\n"
+            "mov z22.s, #0x0\n"
+            ".inst 0x451f9a45  // smmla z5.s, z18.b, z31.b\n"
+            ".inst 0x45069a56  // smmla z22.s, z18.b, z6.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x26, #48]\n"
+            ".inst 0x450e9a45  // smmla z5.s, z18.b, z14.b\n"
+            ".inst 0x45029a56  // smmla z22.s, z18.b, z2.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x26, #80]\n"
+            ".inst 0x451e9a45  // smmla z5.s, z18.b, z30.b\n"
+            ".inst 0x45159a56  // smmla z22.s, z18.b, z21.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x26, #112]\n"
+            "add x26, x26, #0x88\n"
+            ".inst 0x45049a45  // smmla z5.s, z18.b, z4.b\n"
+            ".inst 0x45119a56  // smmla z22.s, z18.b, z17.b\n"
+            "uzp1 z18.d, z5.d, z22.d\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "uzp2 z22.d, z5.d, z22.d\n"
+            "fmul z5.s, z23.s, z7.s[2]\n"
+            "fmul z7.s, z23.s, z7.s[3]\n"
+            "scvtf z22.s, p1/m, z22.s\n"
+            "fmla z20.s, p1/M, z18.s, z5.s\n"
+            "ld1rqb { z18.b }, p1/Z, [x25, #16]\n"
+            "ld1h { z5.s }, p0/Z, [x20]\n"
+            "fcvt z5.s, p1/m, z5.h\n"
+            "fmla z25.s, p1/M, z22.s, z7.s\n"
+            "mov z22.s, #0x0\n"
+            "mov z7.s, #0x0\n"
+            ".inst 0x451f9936  // smmla z22.s, z9.b, z31.b\n"
+            ".inst 0x45069927  // smmla z7.s, z9.b, z6.b\n"
+            "ld1rqb { z9.b }, p1/Z, [x25, #32]\n"
+            "mov z5.q, z5.q[0]\n"
+            ".inst 0x450e9936  // smmla z22.s, z9.b, z14.b\n"
+            ".inst 0x45029927  // smmla z7.s, z9.b, z2.b\n"
+            "ld1rqb { z9.b }, p1/Z, [x25, #64]\n"
+            ".inst 0x451e9936  // smmla z22.s, z9.b, z30.b\n"
+            ".inst 0x45159927  // smmla z7.s, z9.b, z21.b\n"
+            "ld1rqb { z9.b }, p1/Z, [x25, #96]\n"
+            ".inst 0x45049936  // smmla z22.s, z9.b, z4.b\n"
+            ".inst 0x45119927  // smmla z7.s, z9.b, z17.b\n"
+            "uzp1 z9.d, z22.d, z7.d\n"
+            "scvtf z9.s, p1/m, z9.s\n"
+            "uzp2 z22.d, z22.d, z7.d\n"
+            "fmul z7.s, z23.s, z3.s[0]\n"
+            "scvtf z22.s, p1/m, z22.s\n"
+            "fmla z11.s, p1/M, z9.s, z7.s\n"
+            "ld1rqb { z9.b }, p1/Z, [x24]\n"
+            "fmul z7.s, z23.s, z3.s[1]\n"
+            "fmla z16.s, p1/M, z22.s, z7.s\n"
+            "mov z22.s, #0x0\n"
+            "mov z7.s, #0x0\n"
+            ".inst 0x451f9a56  // smmla z22.s, z18.b, z31.b\n"
+            ".inst 0x45069a47  // smmla z7.s, z18.b, z6.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x25, #48]\n"
+            ".inst 0x450e9a56  // smmla z22.s, z18.b, z14.b\n"
+            ".inst 0x45029a47  // smmla z7.s, z18.b, z2.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x25, #80]\n"
+            ".inst 0x451e9a56  // smmla z22.s, z18.b, z30.b\n"
+            ".inst 0x45159a47  // smmla z7.s, z18.b, z21.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x25, #112]\n"
+            "add x25, x25, #0x88\n"
+            ".inst 0x45049a56  // smmla z22.s, z18.b, z4.b\n"
+            ".inst 0x45119a47  // smmla z7.s, z18.b, z17.b\n"
+            "uzp1 z18.d, z22.d, z7.d\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "uzp2 z7.d, z22.d, z7.d\n"
+            "fmul z22.s, z23.s, z3.s[2]\n"
+            "fmul z3.s, z23.s, z3.s[3]\n"
+            "scvtf z7.s, p1/m, z7.s\n"
+            "fmla z19.s, p1/M, z18.s, z22.s\n"
+            "ld1rqb { z18.b }, p1/Z, [x24, #16]\n"
+            "fmul z22.s, z23.s, z5.s[0]\n"
+            "fmla z26.s, p1/M, z7.s, z3.s\n"
+            "mov z3.s, #0x0\n"
+            "mov z7.s, #0x0\n"
+            ".inst 0x451f9923  // smmla z3.s, z9.b, z31.b\n"
+            ".inst 0x45069927  // smmla z7.s, z9.b, z6.b\n"
+            "ld1rqb { z9.b }, p1/Z, [x24, #32]\n"
+            ".inst 0x450e9923  // smmla z3.s, z9.b, z14.b\n"
+            ".inst 0x45029927  // smmla z7.s, z9.b, z2.b\n"
+            "mov z9.s, #0x0\n"
+            ".inst 0x451f9a49  // smmla z9.s, z18.b, z31.b\n"
+            "mov z31.s, #0x0\n"
+            ".inst 0x45069a5f  // smmla z31.s, z18.b, z6.b\n"
+            "ld1rqb { z6.b }, p1/Z, [x24, #48]\n"
+            "ld1rqb { z18.b }, p1/Z, [x24, #64]\n"
+            ".inst 0x450e98c9  // smmla z9.s, z6.b, z14.b\n"
+            "fmul z14.s, z23.s, z5.s[1]\n"
+            ".inst 0x450298df  // smmla z31.s, z6.b, z2.b\n"
+            "ld1rqb { z6.b }, p1/Z, [x24, #80]\n"
+            "fmul z2.s, z23.s, z5.s[2]\n"
+            "fmul z23.s, z23.s, z5.s[3]\n"
+            ".inst 0x451e9a43  // smmla z3.s, z18.b, z30.b\n"
+            ".inst 0x45159a47  // smmla z7.s, z18.b, z21.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x24, #96]\n"
+            ".inst 0x451e98c9  // smmla z9.s, z6.b, z30.b\n"
+            ".inst 0x451598df  // smmla z31.s, z6.b, z21.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x24, #112]\n"
+            "add x24, x24, #0x88\n"
+            ".inst 0x450498a3  // smmla z3.s, z5.b, z4.b\n"
+            ".inst 0x451198a7  // smmla z7.s, z5.b, z17.b\n"
+            ".inst 0x45049a49  // smmla z9.s, z18.b, z4.b\n"
+            ".inst 0x45119a5f  // smmla z31.s, z18.b, z17.b\n"
+            "uzp1 z18.d, z3.d, z7.d\n"
+            "uzp2 z5.d, z3.d, z7.d\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "uzp1 z6.d, z9.d, z31.d\n"
+            "uzp2 z9.d, z9.d, z31.d\n"
+            "scvtf z5.s, p1/m, z5.s\n"
+            "fmla z8.s, p1/M, z18.s, z22.s\n"
+            "scvtf z6.s, p1/m, z6.s\n"
+            "scvtf z9.s, p1/m, z9.s\n"
+            "fmla z29.s, p1/M, z5.s, z14.s\n"
+            "fmla z27.s, p1/M, z6.s, z2.s\n"
+            "fmla z10.s, p1/M, z9.s, z23.s\n"
+            "bgt 3b\n"
+            "mov x20, %x[res_ptr]\n"
+            "subs x10, x10, #0x8\n"
+            "add %x[res_ptr], %x[res_ptr], #0x20\n"
+            "st1w { z24.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z15.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z12.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z0.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z13.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z1.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z20.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z25.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z11.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z16.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z19.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z26.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z8.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z29.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z27.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z10.s }, p1, [x20]\n"
+            "bne 2b\n"
+            "mov x20, #0x4\n"
+            "sub x13, x13, #0x10\n"
+            "cmp x13, #0x10\n"
+            "mov %x[res_ptr], x9\n"
+            "madd %x[a_ptr], x20, x12, %x[a_ptr]\n"
+            "bge 1b\n"
+            "4:"  // Row loop skip
+            "cbz x13, 9f\n"
+            "5:"  // Row tail: Row loop
+            "add x25, %x[b_ptr], #0x10\n"
+            "mov x24, %x[nc]\n"
+            "add x23, %x[res_ptr], %x[res_stride], LSL #2\n"
+            "6:"  // Row tail: Column loop
+            "mov z24.b, #0x0\n"
+            "mov z15.b, #0x0\n"
+            "add x28, %x[a_ptr], #0x8\n"
+            "mov x22, %x[nb]\n"
+            "mov z12.b, #0x0\n"
+            "mov z0.b, #0x0\n"
+            "7:"  // Row tail: Block loop
+            "ld1b { z3.b }, p1/Z, [x25]\n"
+            "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n"
+            "mov z2.s, #0x0\n"
+            "mov z25.s, #0x0\n"
+            "ld1rqb { z26.b }, p1/Z, [x28]\n"
+            "ld1rqb { z21.b }, p1/Z, [x28, #16]\n"
+            "mov z27.s, #0x0\n"
+            "mov z19.s, #0x0\n"
+            "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n"
+            "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n"
+            "sub x21, x25, #0x10\n"
+            "sub x20, x28, #0x8\n"
+            "lsl z20.b, z3.b, #0x4\n"
+            "lsl z4.b, z6.b, #0x4\n"
+            "ld1rqb { z10.b }, p1/Z, [x28, #32]\n"
+            "ld1rqb { z23.b }, p1/Z, [x28, #48]\n"
+            "and z3.b, z3.b, #0xf0\n"
+            "and z6.b, z6.b, #0xf0\n"
+            "ld1rqb { z11.b }, p1/Z, [x28, #64]\n"
+            "ld1rqb { z7.b }, p1/Z, [x28, #80]\n"
+            "lsl z8.b, z29.b, #0x4\n"
+            "lsl z14.b, z16.b, #0x4\n"
+            "ld1rqb { z18.b }, p1/Z, [x28, #96]\n"
+            "ld1rqb { z30.b }, p1/Z, [x28, #112]\n"
+            ".inst 0x45149b42  // smmla z2.s, z26.b, z20.b\n"
+            ".inst 0x45049b59  // smmla z25.s, z26.b, z4.b\n"
+            "and z29.b, z29.b, #0xf0\n"
+            "ld1h { z17.s }, p1/Z, [x21]\n"
+            ".inst 0x45149abb  // smmla z27.s, z21.b, z20.b\n"
+            ".inst 0x45049ab3  // smmla z19.s, z21.b, z4.b\n"
+            "and z16.b, z16.b, #0xf0\n"
+            "ld1h { z4.s }, p0/Z, [x20]\n"
+            "subs x22, x22, #0x1\n"
+            "add x28, x28, #0x88\n"
+            "fcvt z17.s, p1/m, z17.h\n"
+            "add x25, x25, #0x90\n"
+            ".inst 0x45089942  // smmla z2.s, z10.b, z8.b\n"
+            ".inst 0x450e9959  // smmla z25.s, z10.b, z14.b\n"
+            "fcvt z4.s, p1/m, z4.h\n"
+            ".inst 0x45089afb  // smmla z27.s, z23.b, z8.b\n"
+            ".inst 0x450e9af3  // smmla z19.s, z23.b, z14.b\n"
+            "fscale z17.s, p1/m, z17.s, z28.s\n"
+            "mov z4.q, z4.q[0]\n"
+            ".inst 0x45039962  // smmla z2.s, z11.b, z3.b\n"
+            ".inst 0x45069979  // smmla z25.s, z11.b, z6.b\n"
+            "fmul z23.s, z17.s, z4.s[0]\n"
+            "fmul z9.s, z17.s, z4.s[1]\n"
+            "fmul z21.s, z17.s, z4.s[2]\n"
+            "fmul z4.s, z17.s, z4.s[3]\n"
+            ".inst 0x450398fb  // smmla z27.s, z7.b, z3.b\n"
+            ".inst 0x450698f3  // smmla z19.s, z7.b, z6.b\n"
+            ".inst 0x451d9a42  // smmla z2.s, z18.b, z29.b\n"
+            ".inst 0x45109a59  // smmla z25.s, z18.b, z16.b\n"
+            ".inst 0x451d9bdb  // smmla z27.s, z30.b, z29.b\n"
+            ".inst 0x45109bd3  // smmla z19.s, z30.b, z16.b\n"
+            "uzp1 z31.d, z2.d, z25.d\n"
+            "uzp2 z13.d, z2.d, z25.d\n"
+            "scvtf z31.s, p1/m, z31.s\n"
+            "uzp1 z17.d, z27.d, z19.d\n"
+            "uzp2 z18.d, z27.d, z19.d\n"
+            "scvtf z13.s, p1/m, z13.s\n"
+            "fmla z24.s, p1/M, z31.s, z23.s\n"
+            "scvtf z17.s, p1/m, z17.s\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "fmla z15.s, p1/M, z13.s, z9.s\n"
+            "fmla z12.s, p1/M, z17.s, z21.s\n"
+            "fmla z0.s, p1/M, z18.s, z4.s\n"
+            "bgt 7b\n"
+            "mov x20, %x[res_ptr]\n"
+            "cmp x13, #0x1\n"
+            "st1w { z24.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "cmp x13, #0x2\n"
+            "st1w { z15.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "cmp x13, #0x3\n"
+            "st1w { z12.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "st1w { z0.s }, p1, [x20]\n"
+            "8:"  // Row tail: Accumulator store skip
+            "subs x24, x24, #0x8\n"
+            "add %x[res_ptr], %x[res_ptr], #0x20\n"
+            "bne 6b\n"
+            "subs x13, x13, #0x4\n"
+            "add %x[a_ptr], %x[a_ptr], x12\n"
+            "mov %x[res_ptr], x23\n"
+            "bgt 5b\n"
+            "9:"  // Row tail: Row loop skip
+            : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+            : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+            : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+        );
+        return;
+    }
+    else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+        GGML_ASSERT((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
+                    "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal "
+                    "performance");
+    }
+    else if (ggml_cpu_has_neon()) {
+        GGML_ASSERT(((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) || ggml_cpu_has_matmul_int8()) &&
+                    "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 "
+                    "quantization format for optimal performance");
+    }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    GGML_ASSERT(ggml_cpu_has_sve() &&
+                "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance");
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+    GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
+                "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
+                "performance");
+#else
+    float sumf[4][8];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+                            }
+                            sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+#endif
+}

+ 65 - 0
llama/ggml-aarch64.h

@@ -0,0 +1,65 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
+#pragma once
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+
+#include "ggml.h"
+
+// GGML internal header
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Quantization
+void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave);
+
+// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
+size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+
+// GEMV
+void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+
+// GEMM
+void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+
+#ifdef __cplusplus
+}
+#endif
+

+ 1062 - 0
llama/ggml-alloc.c

@@ -0,0 +1,1062 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ggml-alloc.h"
+#include "ggml-backend-impl.h"
+#include "ggml.h"
+#include "ggml-impl.h"
+#include <assert.h>
+#include <limits.h>
+#include <stdarg.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MAX_FREE_BLOCKS 256
+
+//#define GGML_ALLOCATOR_DEBUG
+
+//#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__)
+#define AT_PRINTF(...)
+
+
+static bool ggml_is_view(const struct ggml_tensor * t) {
+    return t->view_src != NULL;
+}
+
+static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
+    if (a->type != b->type) {
+        return false;
+    }
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        if (a->ne[i] != b->ne[i]) {
+            return false;
+        }
+        if (a->nb[i] != b->nb[i]) {
+            return false;
+        }
+    }
+    return true;
+}
+
+static bool ggml_op_can_inplace(enum ggml_op op) {
+    switch (op) {
+        case GGML_OP_SCALE:
+        case GGML_OP_DIAG_MASK_ZERO:
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_ADD:
+        case GGML_OP_ADD1:
+        case GGML_OP_SUB:
+        case GGML_OP_MUL:
+        case GGML_OP_DIV:
+        case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+        case GGML_OP_LOG:
+        case GGML_OP_UNARY:
+        case GGML_OP_ROPE:
+        case GGML_OP_RMS_NORM:
+        case GGML_OP_SOFT_MAX:
+            return true;
+
+        default:
+            return false;
+    }
+}
+
+static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
+    assert(alignment && !(alignment & (alignment - 1))); // power of 2
+    size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment;
+    return offset + align;
+}
+
+// tallocr
+
+struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer) {
+    void * base = ggml_backend_buffer_get_base(buffer);
+    size_t align = ggml_backend_buffer_get_alignment(buffer);
+
+    assert(align && !(align & (align - 1))); // power of 2
+
+    struct ggml_tallocr talloc = (struct ggml_tallocr) {
+        /*.buffer    = */ buffer,
+        /*.base      = */ base,
+        /*.alignment = */ align,
+        /*.offset    = */ aligned_offset(base, 0, align),
+    };
+    return talloc;
+}
+
+void ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor) {
+    size_t size = ggml_backend_buffer_get_alloc_size(talloc->buffer, tensor);
+    size = GGML_PAD(size, talloc->alignment);
+
+    if (talloc->offset + size > ggml_backend_buffer_get_size(talloc->buffer)) {
+        fprintf(stderr, "%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n",
+                __func__, tensor->name, size, ggml_backend_buffer_get_size(talloc->buffer) - talloc->offset);
+        GGML_ABORT("not enough space in the buffer");
+    }
+
+    void * addr = (char *)ggml_backend_buffer_get_base(talloc->buffer) + talloc->offset;
+    talloc->offset += size;
+
+    assert(((uintptr_t)addr % talloc->alignment) == 0);
+
+    ggml_backend_tensor_alloc(talloc->buffer, tensor, addr);
+}
+
+// dynamic tensor allocator
+
+struct free_block {
+    size_t offset;
+    size_t size;
+};
+
+struct ggml_dyn_tallocr {
+    size_t alignment;
+    int n_free_blocks;
+    struct free_block free_blocks[MAX_FREE_BLOCKS];
+    size_t max_size;
+
+#ifdef GGML_ALLOCATOR_DEBUG
+    struct {
+        const struct ggml_tensor * tensor;
+        size_t offset;
+    } allocated_tensors[1024];
+#endif
+};
+
+#ifdef GGML_ALLOCATOR_DEBUG
+static void add_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, const struct ggml_tensor * tensor) {
+    for (int i = 0; i < 1024; i++) {
+        if (alloc->allocated_tensors[i].tensor == NULL) {
+            alloc->allocated_tensors[i].tensor = tensor;
+            alloc->allocated_tensors[i].offset = offset;
+            return;
+        }
+    }
+    GGML_ABORT("out of allocated_tensors");
+}
+static void remove_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, const struct ggml_tensor * tensor) {
+    for (int i = 0; i < 1024; i++) {
+        if (alloc->allocated_tensors[i].offset == offset) {
+            alloc->allocated_tensors[i].tensor = NULL;
+            return;
+        }
+    }
+    GGML_ABORT("tried to free tensor %s not found\n", tensor->name);
+}
+#endif
+
+static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t size, const struct ggml_tensor * tensor) {
+    size = aligned_offset(NULL, size, alloc->alignment);
+
+    AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
+
+    size_t max_avail = 0;
+
+    // find the best fitting free block besides the last block
+    int best_fit_block = -1;
+    size_t best_fit_size = SIZE_MAX;
+    for (int i = 0; i < alloc->n_free_blocks - 1; i++) {
+        struct free_block * block = &alloc->free_blocks[i];
+        max_avail = MAX(max_avail, block->size);
+        if (block->size >= size && block->size <= best_fit_size) {
+            best_fit_block = i;
+            best_fit_size = block->size;
+        }
+    }
+
+    if (best_fit_block == -1) {
+        // the last block is our last resort
+        struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
+        max_avail = MAX(max_avail, block->size);
+        if (block->size >= size) {
+            best_fit_block = alloc->n_free_blocks - 1;
+        } else {
+            // this should never happen
+            fprintf(stderr, "%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n",
+                    __func__, size, max_avail);
+            GGML_ABORT("not enough space in the buffer");
+        }
+    }
+
+    struct free_block * block = &alloc->free_blocks[best_fit_block];
+    size_t offset = block->offset;
+    block->offset = offset + size;
+    block->size -= size;
+    if (block->size == 0) {
+        // remove block if empty
+        alloc->n_free_blocks--;
+        for (int j = best_fit_block; j < alloc->n_free_blocks; j++) {
+            alloc->free_blocks[j] = alloc->free_blocks[j+1];
+        }
+    }
+
+    AT_PRINTF("block %d, offset %zu\n", best_fit_block, offset);
+
+#ifdef GGML_ALLOCATOR_DEBUG
+    add_allocated_tensor(alloc, offset, tensor);
+    size_t cur_max = offset + size;
+    if (cur_max > alloc->max_size) {
+        // sort allocated_tensors by offset
+        for (int i = 0; i < 1024; i++) {
+            for (int j = i + 1; j < 1024; j++) {
+                if (alloc->allocated_tensors[i].offset > alloc->allocated_tensors[j].offset) {
+                    const struct ggml_tensor * tmp_tensor = alloc->allocated_tensors[i].tensor;
+                    size_t tmp_offset = alloc->allocated_tensors[i].offset;
+                    alloc->allocated_tensors[i].tensor = alloc->allocated_tensors[j].tensor;
+                    alloc->allocated_tensors[i].offset = alloc->allocated_tensors[j].offset;
+                    alloc->allocated_tensors[j].tensor = tmp_tensor;
+                    alloc->allocated_tensors[j].offset = tmp_offset;
+                }
+            }
+        }
+        fprintf(stderr, "max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
+        for (int i = 0; i < 1024; i++) {
+            if (alloc->allocated_tensors[i].tensor) {
+                fprintf(stderr, "%s [%zx-%zx] (%.2f MB) ", alloc->allocated_tensors[i].tensor->name,
+                    alloc->allocated_tensors[i].offset,
+                    alloc->allocated_tensors[i].offset + ggml_nbytes(alloc->allocated_tensors[i].tensor),
+                    ggml_nbytes(alloc->allocated_tensors[i].tensor) / 1024.0 / 1024.0);
+            }
+        }
+        fprintf(stderr, "\n");
+    }
+#endif
+
+    alloc->max_size = MAX(alloc->max_size, offset + size);
+
+    return offset;
+
+    GGML_UNUSED(tensor);
+}
+
+// this is a very naive implementation, but for our case the number of free blocks should be very small
+static void ggml_dyn_tallocr_free_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, size_t size, const struct ggml_tensor * tensor) {
+    size = aligned_offset(NULL, size, alloc->alignment);
+
+    AT_PRINTF("%s: freeing %s at %zu (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, offset, size, alloc->n_free_blocks);
+
+#ifdef GGML_ALLOCATOR_DEBUG
+    remove_allocated_tensor(alloc, offset, tensor);
+#endif
+
+    // see if we can merge with an existing block
+    for (int i = 0; i < alloc->n_free_blocks; i++) {
+        struct free_block * block = &alloc->free_blocks[i];
+        // check if ptr is at the end of the block
+        if (block->offset + block->size == offset) {
+            block->size += size;
+            // check if we can merge with the next block
+            if (i < alloc->n_free_blocks - 1 && block->offset + block->size == alloc->free_blocks[i+1].offset) {
+                block->size += alloc->free_blocks[i+1].size;
+                alloc->n_free_blocks--;
+                for (int j = i+1; j < alloc->n_free_blocks; j++) {
+                    alloc->free_blocks[j] = alloc->free_blocks[j+1];
+                }
+            }
+            return;
+        }
+        // check if ptr is at the beginning of the block
+        if (offset + size == block->offset) {
+            block->offset = offset;
+            block->size += size;
+            // check if we can merge with the previous block
+            if (i > 0 && alloc->free_blocks[i-1].offset + alloc->free_blocks[i-1].size == block->offset) {
+                alloc->free_blocks[i-1].size += block->size;
+                alloc->n_free_blocks--;
+                for (int j = i; j < alloc->n_free_blocks; j++) {
+                    alloc->free_blocks[j] = alloc->free_blocks[j+1];
+                }
+            }
+            return;
+        }
+    }
+    // otherwise, add a new block
+    GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks");
+    // insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster)
+    int insert_pos = 0;
+    while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].offset < offset) {
+        insert_pos++;
+    }
+    // shift all blocks from insert_pos onward to make room for the new block
+    for (int i = alloc->n_free_blocks; i > insert_pos; i--) {
+        alloc->free_blocks[i] = alloc->free_blocks[i-1];
+    }
+    // insert the new block
+    alloc->free_blocks[insert_pos].offset = offset;
+    alloc->free_blocks[insert_pos].size = size;
+    alloc->n_free_blocks++;
+
+    GGML_UNUSED(tensor);
+}
+
+static void ggml_dyn_tallocr_reset(struct ggml_dyn_tallocr * alloc) {
+    alloc->n_free_blocks = 1;
+    alloc->free_blocks[0].offset = 0;
+    alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
+    alloc->max_size = 0;
+}
+
+static struct ggml_dyn_tallocr * ggml_dyn_tallocr_new(size_t alignment) {
+    struct ggml_dyn_tallocr * alloc = (struct ggml_dyn_tallocr *)malloc(sizeof(struct ggml_dyn_tallocr));
+
+    *alloc = (struct ggml_dyn_tallocr) {
+        /*.alignment     = */ alignment,
+        /*.n_free_blocks = */ 0,
+        /*.free_blocks   = */ {{0}},
+        /*.max_size      = */ 0,
+#ifdef GGML_ALLOCATOR_DEBUG
+        /*.allocated_tensors = */ {{0}},
+#endif
+    };
+
+    ggml_dyn_tallocr_reset(alloc);
+
+    return alloc;
+}
+
+static void ggml_dyn_tallocr_free(struct ggml_dyn_tallocr * alloc) {
+    free(alloc);
+}
+
+static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc) {
+    return alloc->max_size;
+}
+
+
+/////////////////////////////////////
+
+// graph allocator
+
+struct hash_node {
+    int n_children;
+    int n_views;
+    int buffer_id;
+    size_t offset; // offset within the buffer
+    bool allocated;
+};
+
+struct tensor_alloc {
+    int buffer_id;
+    size_t offset;
+    size_t size_max; // 0 = pre-allocated, unused, or view
+};
+
+struct leaf_alloc {
+    int buffer_id;
+    struct tensor_alloc leaf;
+};
+
+struct node_alloc {
+    struct tensor_alloc dst;
+    struct tensor_alloc src[GGML_MAX_SRC];
+};
+
+struct ggml_gallocr {
+    ggml_backend_buffer_type_t * bufts; // [n_buffers]
+    ggml_backend_buffer_t * buffers; // [n_buffers]
+    struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers]
+    int n_buffers;
+
+    struct ggml_hash_set hash_set;
+    struct hash_node * hash_values; // [hash_set.size]
+
+    struct node_alloc * node_allocs; // [n_nodes]
+    int n_nodes;
+
+    struct leaf_alloc * leaf_allocs; // [n_leafs]
+    int n_leafs;
+};
+
+ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs) {
+    ggml_gallocr_t galloc = (ggml_gallocr_t)calloc(1, sizeof(struct ggml_gallocr));
+    GGML_ASSERT(galloc != NULL);
+
+    galloc->bufts = calloc(n_bufs, sizeof(ggml_backend_buffer_type_t));
+    GGML_ASSERT(galloc->bufts != NULL);
+
+    galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t));
+    GGML_ASSERT(galloc->buffers != NULL);
+
+    galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *));
+    GGML_ASSERT(galloc->buf_tallocs != NULL);
+
+    for (int i = 0; i < n_bufs; i++) {
+        galloc->bufts[i] = bufts[i];
+        galloc->buffers[i] = NULL;
+
+        // check if the same buffer type is used multiple times and reuse the same allocator
+        for (int j = 0; j < i; j++) {
+            if (bufts[i] == bufts[j]) {
+                galloc->buf_tallocs[i] = galloc->buf_tallocs[j];
+                break;
+            }
+        }
+
+        if (galloc->buf_tallocs[i] == NULL) {
+            size_t alignment = ggml_backend_buft_get_alignment(bufts[i]);
+            galloc->buf_tallocs[i] = ggml_dyn_tallocr_new(alignment);
+        }
+    }
+    galloc->n_buffers = n_bufs;
+
+    return galloc;
+}
+
+ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft) {
+    return ggml_gallocr_new_n(&buft, 1);
+}
+
+void ggml_gallocr_free(ggml_gallocr_t galloc) {
+    if (galloc == NULL) {
+        return;
+    }
+
+    for (int i = 0; i < galloc->n_buffers; i++) {
+        if (galloc->buffers != NULL) {
+            // skip if already freed
+            bool freed = false;
+            for (int j = 0; j < i; j++) {
+                if (galloc->buffers[j] == galloc->buffers[i]) {
+                    freed = true;
+                    break;
+                }
+            }
+            if (!freed) {
+                ggml_backend_buffer_free(galloc->buffers[i]);
+            }
+        }
+        if (galloc->buf_tallocs != NULL) {
+            // skip if already freed
+            bool freed = false;
+            for (int j = 0; j < i; j++) {
+                if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) {
+                    freed = true;
+                    break;
+                }
+            }
+            if (!freed) {
+                ggml_dyn_tallocr_free(galloc->buf_tallocs[i]);
+            }
+        }
+    }
+
+    ggml_hash_set_free(&galloc->hash_set);
+    free(galloc->hash_values);
+    free(galloc->bufts);
+    free(galloc->buffers);
+    free(galloc->buf_tallocs);
+    free(galloc->node_allocs);
+    free(galloc->leaf_allocs);
+    free(galloc);
+}
+
+typedef struct ggml_gallocr * ggml_gallocr_t;
+
+static struct hash_node * ggml_gallocr_hash_get(ggml_gallocr_t galloc, struct ggml_tensor * t) {
+    size_t i = ggml_hash_find_or_insert(&galloc->hash_set, t);
+    return &galloc->hash_values[i];
+}
+
+static bool ggml_gallocr_is_own(ggml_gallocr_t galloc, struct ggml_tensor * t) {
+    return ggml_gallocr_hash_get(galloc, t)->allocated;
+}
+
+static void ggml_gallocr_set_node_offset(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id, size_t offset) {
+    struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
+    hn->buffer_id = buffer_id;
+    hn->offset = offset;
+    hn->allocated = true;
+}
+
+static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor * t) {
+    return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
+}
+
+static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
+    struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
+
+    if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) {
+        hn->allocated = true;
+        assert(hn->offset == 0);
+
+        // try to reuse a parent's buffer (inplace)
+        if (ggml_op_can_inplace(node->op)) {
+            for (int i = 0; i < GGML_MAX_SRC; i++) {
+                struct ggml_tensor * parent = node->src[i];
+                if (parent == NULL) {
+                    continue;
+                }
+
+                // if the node's data is external, then we cannot re-use it
+                if (!ggml_gallocr_is_own(galloc, parent)) {
+                    AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
+                    continue;
+                }
+
+                // outputs cannot be reused
+                if (parent->flags & GGML_TENSOR_FLAG_OUTPUT || (parent->view_src != NULL && parent->view_src->flags & GGML_TENSOR_FLAG_OUTPUT)) {
+                    AT_PRINTF("not reusing parent %s for %s as it is an output\n", parent->name, node->name);
+                    continue;
+                }
+
+                if (!ggml_are_same_layout(node, parent)) {
+                    AT_PRINTF("not reusing parent %s for %s as layouts are different\n", parent->name, node->name);
+                    continue;
+                }
+
+                struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
+                if (p_hn->n_children == 1 && p_hn->n_views == 0) {
+                    if (ggml_is_view(parent)) {
+                        struct ggml_tensor * view_src = parent->view_src;
+                        struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
+                        if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
+                            AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
+                            assert(view_src_hn->offset == p_hn->offset);
+                            hn->buffer_id = p_hn->buffer_id;
+                            hn->offset = p_hn->offset;
+                            p_hn->allocated = false; // avoid freeing the parent
+                            view_src_hn->allocated = false;
+                            return;
+                        }
+                    } else {
+                        AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
+                        hn->buffer_id = p_hn->buffer_id;
+                        hn->offset = p_hn->offset;
+                        p_hn->allocated = false; // avoid freeing the parent
+                        return;
+                    }
+                }
+            }
+        }
+        // allocate tensor from the buffer
+        struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id];
+        ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id];
+        size_t size = ggml_backend_buft_get_alloc_size(buft, node);
+        size_t offset = ggml_dyn_tallocr_alloc(alloc, size, node);
+        hn->buffer_id = buffer_id;
+        hn->offset = offset;
+        return;
+    }
+}
+
+static void ggml_gallocr_free_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
+    // graph outputs are never freed
+    if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
+        AT_PRINTF("not freeing output %s\n", node->name);
+        return;
+    }
+
+    struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
+    size_t offset = hn->offset;
+    int buffer_id = hn->buffer_id;
+    struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id];
+    ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id];
+    size_t size = ggml_backend_buft_get_alloc_size(buft, node);
+    ggml_dyn_tallocr_free_tensor(alloc, offset, size, node);
+    hn->allocated = false;
+}
+
+static int get_node_buffer_id(const int * node_buffer_ids, int i) {
+    return node_buffer_ids ? node_buffer_ids[i] : 0;
+}
+
+static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) {
+    // clear hash tables
+    ggml_hash_set_reset(&galloc->hash_set);
+    memset(galloc->hash_values, 0, sizeof(struct hash_node) * galloc->hash_set.size);
+
+    // allocate leafs
+    // these may be tensors that the application is not using in the graph, but may still want to allocate for other purposes
+    for (int i = 0; i < graph->n_leafs; i++) {
+        struct ggml_tensor * leaf = graph->leafs[i];
+        ggml_gallocr_allocate_node(galloc, leaf, get_node_buffer_id(leaf_buffer_ids, i));
+    }
+
+    // count number of children and views
+    // allocate other graph inputs and leafs first to avoid overwriting them
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+
+        // TODO: better way to add external dependencies
+        // GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to
+        // control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node
+        // itself is never used and should not be considered a dependency
+        if (ggml_is_view(node) && node->op != GGML_OP_NONE) {
+            struct ggml_tensor * view_src = node->view_src;
+            ggml_gallocr_hash_get(galloc, view_src)->n_views += 1;
+        }
+
+        if (node->flags & GGML_TENSOR_FLAG_INPUT) {
+            ggml_gallocr_allocate_node(galloc, graph->nodes[i], get_node_buffer_id(node_buffer_ids, i));
+        }
+
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                continue;
+            }
+
+            ggml_gallocr_hash_get(galloc, src)->n_children += 1;
+
+            // allocate explicit inputs
+            if (src->flags & GGML_TENSOR_FLAG_INPUT) {
+                ggml_gallocr_allocate_node(galloc, src, get_node_buffer_id(node_buffer_ids, i));
+            }
+        }
+    }
+
+    // allocate tensors
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        int buffer_id = get_node_buffer_id(node_buffer_ids, i);
+
+        // allocate parents (only leafs need to be allocated at this point)
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * parent = node->src[j];
+            if (parent == NULL) {
+                continue;
+            }
+            ggml_gallocr_allocate_node(galloc, parent, buffer_id);
+        }
+
+        // allocate node
+        ggml_gallocr_allocate_node(galloc, node, buffer_id);
+
+        AT_PRINTF("exec: %s (%s) <= ", ggml_op_desc(node), node->name);
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * parent = node->src[j];
+            if (parent == NULL) {
+                continue;
+            }
+            AT_PRINTF("%s", parent->name);
+            if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
+                AT_PRINTF(", ");
+            }
+        }
+        AT_PRINTF("\n");
+
+        // update parents
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * parent = node->src[j];
+            if (parent == NULL) {
+                continue;
+            }
+            struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
+            p_hn->n_children -= 1;
+
+            AT_PRINTF("parent %s: %d children, %d views, allocated: %d\n",
+                parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated);
+
+            if (p_hn->n_children == 0 && p_hn->n_views == 0) {
+                if (ggml_is_view(parent)) {
+                    struct ggml_tensor * view_src = parent->view_src;
+                    struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
+                    view_src_hn->n_views -= 1;
+                    AT_PRINTF("view_src %s: %d children, %d views\n",
+                        view_src->name, view_src_hn->n_children, view_src_hn->n_views);
+                    if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src_hn->allocated) {
+                        ggml_gallocr_free_node(galloc, view_src);
+                    }
+                }
+                else if (p_hn->allocated) {
+                    ggml_gallocr_free_node(galloc, parent);
+                }
+            }
+            AT_PRINTF("\n");
+        }
+    }
+}
+
+bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) {
+    size_t min_hash_size = graph->n_nodes + graph->n_leafs;
+    // add 25% margin to avoid hash collisions
+    min_hash_size += min_hash_size / 4;
+
+    // initialize hash table
+    if (galloc->hash_set.size < min_hash_size) {
+        ggml_hash_set_free(&galloc->hash_set);
+        galloc->hash_set = ggml_hash_set_new(min_hash_size);
+        GGML_ASSERT(galloc->hash_set.keys != NULL);
+
+        free(galloc->hash_values);
+        galloc->hash_values = malloc(sizeof(struct hash_node) * galloc->hash_set.size);
+        GGML_ASSERT(galloc->hash_values != NULL);
+    }
+
+    // reset allocators
+    for (int i = 0; i < galloc->n_buffers; i++) {
+        ggml_dyn_tallocr_reset(galloc->buf_tallocs[i]);
+    }
+
+    // allocate in hash table
+    ggml_gallocr_alloc_graph_impl(galloc, graph, node_buffer_ids, leaf_buffer_ids);
+
+    // set the node_allocs from the hash table
+    if (galloc->n_nodes < graph->n_nodes) {
+        free(galloc->node_allocs);
+        galloc->node_allocs = calloc(graph->n_nodes, sizeof(struct node_alloc));
+        GGML_ASSERT(galloc->node_allocs != NULL);
+    }
+    galloc->n_nodes = graph->n_nodes;
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        struct node_alloc * node_alloc = &galloc->node_allocs[i];
+        if (node->view_src || node->data) {
+            node_alloc->dst.buffer_id = -1;
+            node_alloc->dst.offset = SIZE_MAX;
+            node_alloc->dst.size_max = 0;
+        } else {
+            struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
+            node_alloc->dst.buffer_id = hn->buffer_id;
+            node_alloc->dst.offset    = hn->offset;
+            node_alloc->dst.size_max  = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);
+        }
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (!src || src->view_src || src->data) {
+                node_alloc->src[j].buffer_id = -1;
+                node_alloc->src[j].offset = SIZE_MAX;
+                node_alloc->src[j].size_max = 0;
+            } else {
+                struct hash_node * hn = ggml_gallocr_hash_get(galloc, src);
+                node_alloc->src[j].buffer_id = hn->buffer_id;
+                node_alloc->src[j].offset   = hn->offset;
+                node_alloc->src[j].size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], src);
+            }
+        }
+    }
+    if (galloc->n_leafs < graph->n_leafs) {
+        free(galloc->leaf_allocs);
+        galloc->leaf_allocs = calloc(graph->n_leafs, sizeof(galloc->leaf_allocs[0]));
+        GGML_ASSERT(galloc->leaf_allocs != NULL);
+    }
+    galloc->n_leafs = graph->n_leafs;
+    for (int i = 0; i < graph->n_leafs; i++) {
+        struct ggml_tensor * leaf = graph->leafs[i];
+        struct hash_node * hn = ggml_gallocr_hash_get(galloc, leaf);
+        galloc->leaf_allocs[i].buffer_id = hn->buffer_id;
+        if (leaf->view_src || leaf->data) {
+            galloc->leaf_allocs[i].leaf.buffer_id = -1;
+            galloc->leaf_allocs[i].leaf.offset = SIZE_MAX;
+            galloc->leaf_allocs[i].leaf.size_max = 0;
+        } else {
+            galloc->leaf_allocs[i].leaf.buffer_id = hn->buffer_id;
+            galloc->leaf_allocs[i].leaf.offset = hn->offset;
+            galloc->leaf_allocs[i].leaf.size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], leaf);
+        }
+    }
+
+    // reallocate buffers if needed
+    for (int i = 0; i < galloc->n_buffers; i++) {
+        // if the buffer type is used multiple times, we reuse the same buffer
+        for (int j = 0; j < i; j++) {
+            if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) {
+                galloc->buffers[i] = galloc->buffers[j];
+                break;
+            }
+        }
+
+        size_t cur_size = galloc->buffers[i] ? ggml_backend_buffer_get_size(galloc->buffers[i]) : 0;
+        size_t new_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]);
+
+        // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views
+        if (new_size > cur_size || galloc->buffers[i] == NULL) {
+#ifndef NDEBUG
+            fprintf(stderr, "%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
+#endif
+
+            ggml_backend_buffer_free(galloc->buffers[i]);
+            galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size);
+            if (galloc->buffers[i] == NULL) {
+                fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
+                return false;
+            }
+            ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
+        }
+    }
+
+    return true;
+}
+
+bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
+    return ggml_gallocr_reserve_n(galloc, graph, NULL, NULL);
+}
+
+static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * tensor, struct tensor_alloc * tensor_alloc) {
+    int buffer_id = tensor_alloc->buffer_id;
+    assert(tensor->data || tensor->view_src || ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max);
+
+    if (tensor->view_src != NULL) {
+        if (tensor->buffer == NULL) {
+            assert(tensor_alloc->offset == SIZE_MAX);
+            if (tensor->view_src->buffer == NULL) {
+                // this tensor was allocated without ggml-backend
+                return;
+            }
+            ggml_backend_view_init(tensor);
+        }
+    } else {
+        if (tensor->data == NULL) {
+            assert(tensor_alloc->offset != SIZE_MAX);
+            assert(ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max);
+            void * base = ggml_backend_buffer_get_base(galloc->buffers[buffer_id]);
+            void * addr = (char *)base + tensor_alloc->offset;
+            ggml_backend_tensor_alloc(galloc->buffers[buffer_id], tensor, addr);
+        } else {
+            if (tensor->buffer == NULL) {
+                // this tensor was allocated without ggml-backend
+                return;
+            }
+        }
+    }
+}
+
+static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) {
+    size_t node_size = (node->data || node->view_src) ? 0 : ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);
+    return talloc->size_max >= node_size;
+}
+
+static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph * graph) {
+    if (galloc->n_nodes != graph->n_nodes) {
+#ifndef NDEBUG
+        fprintf(stderr, "%s: graph has different number of nodes\n", __func__);
+#endif
+        return true;
+    }
+
+    if (galloc->n_leafs != graph->n_leafs) {
+#ifndef NDEBUG
+        fprintf(stderr, "%s: graph has different number of leafs\n", __func__);
+#endif
+        return true;
+    }
+
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        struct node_alloc * node_alloc = &galloc->node_allocs[i];
+
+        if (!ggml_gallocr_node_needs_realloc(galloc, node, &node_alloc->dst)) {
+#ifndef NDEBUG
+            fprintf(stderr, "%s: node %s is not valid\n", __func__, node->name);
+#endif
+            return true;
+        }
+
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                continue;
+            }
+            if (!ggml_gallocr_node_needs_realloc(galloc, src, &node_alloc->src[j])) {
+#ifndef NDEBUG
+                fprintf(stderr, "%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name);
+#endif
+                return true;
+            }
+        }
+    }
+
+    return false;
+}
+
+bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph) {
+    if (ggml_gallocr_needs_realloc(galloc, graph)) {
+        if (galloc->n_buffers == 1) {
+#ifndef NDEBUG
+            fprintf(stderr, "%s: reallocating buffers automatically\n", __func__);
+#endif
+            if (!ggml_gallocr_reserve(galloc, graph)) {
+                return false;
+            }
+        } else {
+#ifndef NDEBUG
+            fprintf(stderr, "%s: cannot reallocate multi buffer graph automatically, call reserve\n", __func__);
+#endif
+            return false;
+        }
+    }
+
+    // reset buffers
+    for (int i = 0; i < galloc->n_buffers; i++) {
+        if (galloc->buffers[i] != NULL) {
+            ggml_backend_buffer_reset(galloc->buffers[i]);
+        }
+    }
+
+    // allocate the graph tensors from the previous assignments
+    // leafs
+    for (int i = 0; i < graph->n_leafs; i++) {
+        struct ggml_tensor * leaf = graph->leafs[i];
+        struct leaf_alloc * leaf_alloc = &galloc->leaf_allocs[i];
+        ggml_gallocr_init_tensor(galloc, leaf, &leaf_alloc->leaf);
+    }
+    // nodes
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        struct node_alloc * node_alloc = &galloc->node_allocs[i];
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                continue;
+            }
+            ggml_gallocr_init_tensor(galloc, src, &node_alloc->src[j]);
+        }
+        ggml_gallocr_init_tensor(galloc, node, &node_alloc->dst);
+    }
+
+    return true;
+}
+
+size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
+    GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers);
+
+    if (galloc->buffers[buffer_id] == NULL) {
+        return 0;
+    }
+
+    for (int i = 0; i < buffer_id; i++) {
+        if (galloc->buffers[i] == galloc->buffers[buffer_id]) {
+            // this buffer is the same as a previous one due to the same buffer type being used multiple times
+            // only return the buffer size the first time it appears to avoid double counting
+            return 0;
+        }
+    }
+
+    return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]);
+}
+
+// utils
+
+static bool alloc_tensor_range(struct ggml_context * ctx,
+        struct ggml_tensor * first, struct ggml_tensor * last,
+        ggml_backend_buffer_type_t buft, size_t size,
+        ggml_backend_buffer_t ** buffers, size_t * n_buffers) {
+    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
+    if (buffer == NULL) {
+#ifndef NDEBUG
+        fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size);
+#endif
+        for (size_t i = 0; i < *n_buffers; i++) {
+            ggml_backend_buffer_free((*buffers)[i]);
+        }
+        free(*buffers);
+        return false;
+    }
+
+    struct ggml_tallocr tallocr = ggml_tallocr_new(buffer);
+
+    for (struct ggml_tensor * t = first; t != last; t = ggml_get_next_tensor(ctx, t)) {
+        if (t->data == NULL) {
+            if (t->view_src == NULL) {
+                ggml_tallocr_alloc(&tallocr, t);
+            } else if (t->buffer == NULL) {
+                ggml_backend_view_init(t);
+            }
+        } else {
+            if (t->view_src != NULL && t->buffer == NULL) {
+                // view of a pre-allocated tensor
+                ggml_backend_view_init(t);
+            }
+        }
+    }
+
+    *buffers = realloc(*buffers, sizeof(ggml_backend_buffer_t) * (*n_buffers + 1));
+    (*buffers)[(*n_buffers)++] = buffer;
+
+    return true;
+}
+
+ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
+    GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
+
+    size_t alignment = ggml_backend_buft_get_alignment(buft);
+    size_t max_size = ggml_backend_buft_get_max_size(buft);
+
+    ggml_backend_buffer_t * buffers = NULL;
+    size_t n_buffers = 0;
+
+    size_t cur_buf_size = 0;
+    struct ggml_tensor * first = ggml_get_first_tensor(ctx);
+    for (struct ggml_tensor * t = first; t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+        size_t this_size = 0;
+        if (t->data == NULL && t->view_src == NULL) {
+            this_size = GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
+        }
+
+        if (this_size > max_size) {
+            fprintf(stderr, "%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n",
+                    __func__, t->name,
+                    ggml_backend_buft_name(buft),
+                    this_size, max_size);
+            for (size_t i = 0; i < n_buffers; i++) {
+                ggml_backend_buffer_free(buffers[i]);
+            }
+            free(buffers);
+            return NULL;
+        }
+
+        if ((cur_buf_size + this_size) > max_size) {
+            // allocate tensors in the current buffer
+            if (!alloc_tensor_range(ctx, first, t, buft, cur_buf_size, &buffers, &n_buffers)) {
+                return NULL;
+            }
+            first = t;
+            cur_buf_size = this_size;
+        } else {
+            cur_buf_size += this_size;
+        }
+    }
+
+    // allocate remaining tensors
+    if (cur_buf_size > 0) {
+        if (!alloc_tensor_range(ctx, first, NULL, buft, cur_buf_size, &buffers, &n_buffers)) {
+            return NULL;
+        }
+    }
+
+    if (n_buffers == 0) {
+#ifndef NDEBUG
+        fprintf(stderr, "%s: all tensors in the context are already allocated\n", __func__);
+#endif
+        return NULL;
+    }
+
+    ggml_backend_buffer_t buffer;
+    if (n_buffers == 1) {
+        buffer = buffers[0];
+    } else {
+        buffer = ggml_backend_multi_buffer_alloc_buffer(buffers, n_buffers);
+    }
+    free(buffers);
+    return buffer;
+}
+
+ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
+    return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
+}

+ 102 - 0
llama/ggml-alloc.h

@@ -0,0 +1,102 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
+typedef struct      ggml_backend_buffer * ggml_backend_buffer_t;
+typedef struct             ggml_backend * ggml_backend_t;
+
+// Tensor allocator
+struct ggml_tallocr {
+    ggml_backend_buffer_t buffer;
+    void * base;
+    size_t alignment;
+    size_t offset;
+};
+
+GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer);
+GGML_API void                ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor);
+
+// Graph allocator
+/*
+  Example usage:
+    ggml_gallocr_t galloc = ggml_gallocr_new(ggml_bacckend_cpu_buffer_type());
+
+    // optional: create a worst-case graph and reserve the buffers to avoid reallocations
+    ggml_gallocr_reserve(galloc, build_graph(max_batch));
+
+    // allocate the graph
+    struct ggml_cgraph * graph = build_graph(batch);
+    ggml_gallocr_alloc_graph(galloc, graph);
+
+    printf("compute buffer size: %zu bytes\n", ggml_gallocr_get_buffer_size(galloc, 0));
+
+    // evaluate the graph
+    ggml_backend_graph_compute(backend, graph);
+*/
+
+// special tensor flags for use with the graph allocator:
+//   ggml_set_input(): all input tensors are allocated at the beginning of the graph in non-overlapping addresses
+//   ggml_set_output(): output tensors are never freed and never overwritten
+
+typedef struct ggml_gallocr * ggml_gallocr_t;
+
+GGML_API ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft);
+GGML_API ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs);
+GGML_API void           ggml_gallocr_free(ggml_gallocr_t galloc);
+
+// pre-allocate buffers from a measure graph - does not allocate or modify the graph
+// call with a worst-case graph to avoid buffer reallocations
+// not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed
+// returns false if the buffer allocation failed
+GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
+GGML_API bool ggml_gallocr_reserve_n(
+    ggml_gallocr_t galloc,
+    struct ggml_cgraph * graph,
+    const int * node_buffer_ids,
+    const int * leaf_buffer_ids);
+
+// automatic reallocation if the topology changes when using a single buffer
+// returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers)
+GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
+
+GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);
+
+// Utils
+// Create a buffer and allocate all the tensors in a ggml_context
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend);
+
+#ifdef  __cplusplus
+}
+#endif

+ 179 - 0
llama/ggml-backend-impl.h

@@ -0,0 +1,179 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+// ggml-backend internal header
+
+#include "ggml-backend.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    //
+    // Backend buffer
+    //
+
+    // buffer type
+    typedef void * ggml_backend_buffer_type_context_t;
+
+    struct ggml_backend_buffer_type_i {
+        const char *          (*GGML_CALL get_name)        (ggml_backend_buffer_type_t buft);
+        // allocate a buffer of this type
+        ggml_backend_buffer_t (*GGML_CALL alloc_buffer)    (ggml_backend_buffer_type_t buft, size_t size);
+        // tensor alignment
+        size_t                (*GGML_CALL get_alignment)   (ggml_backend_buffer_type_t buft);
+        // max buffer size that can be allocated
+        size_t                (*GGML_CALL get_max_size)    (ggml_backend_buffer_type_t buft);
+        // data size needed to allocate the tensor, including padding
+        size_t                (*GGML_CALL get_alloc_size)  (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
+        // check if tensor data is in host memory
+        bool                  (*GGML_CALL is_host)         (ggml_backend_buffer_type_t buft);
+    };
+
+    struct ggml_backend_buffer_type {
+        struct ggml_backend_buffer_type_i  iface;
+        ggml_backend_buffer_type_context_t context;
+    };
+
+    // buffer
+    typedef void * ggml_backend_buffer_context_t;
+
+    struct ggml_backend_buffer_i {
+        const char * (*GGML_CALL get_name)   (ggml_backend_buffer_t buffer);
+        void         (*GGML_CALL free_buffer)(ggml_backend_buffer_t buffer);
+        void *       (*GGML_CALL get_base)   (ggml_backend_buffer_t buffer);
+        void         (*GGML_CALL init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+        void         (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+        void         (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+        bool         (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
+        void         (*GGML_CALL clear)      (ggml_backend_buffer_t buffer, uint8_t value);
+        void         (*GGML_CALL reset)      (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
+    };
+
+    struct ggml_backend_buffer {
+        struct ggml_backend_buffer_i  iface;
+        ggml_backend_buffer_type_t    buft;
+        ggml_backend_buffer_context_t context;
+        size_t size;
+        enum ggml_backend_buffer_usage usage;
+    };
+
+    GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init(
+                   ggml_backend_buffer_type_t      buft,
+            struct ggml_backend_buffer_i           iface,
+                   ggml_backend_buffer_context_t   context,
+                   size_t                          size);
+
+    // do not use directly, use ggml_backend_tensor_copy instead
+    bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst);
+
+    // buffer that contains a collection of buffers
+    GGML_CALL ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers);
+    GGML_CALL bool                  ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer);
+    GGML_CALL void                  ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
+
+    //
+    // Backend
+    //
+
+    typedef void * ggml_backend_context_t;
+
+    struct ggml_backend_i {
+        const char * (*GGML_CALL get_name)(ggml_backend_t backend);
+
+        void (*GGML_CALL free)(ggml_backend_t backend);
+
+        // buffer allocation
+        ggml_backend_buffer_type_t (*GGML_CALL get_default_buffer_type)(ggml_backend_t backend);
+
+        // (optional) asynchronous tensor data access
+        void (*GGML_CALL set_tensor_async)(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+        void (*GGML_CALL get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+        bool (*GGML_CALL cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);
+
+        // (optional) complete all pending operations
+        void (*GGML_CALL synchronize)(ggml_backend_t backend);
+
+        // compute graph with a plan (not used currently)
+        // create a new plan for a graph
+        ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
+        void                      (*GGML_CALL graph_plan_free)   (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+        // update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology
+        void                      (*GGML_CALL graph_plan_update) (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph);
+        // compute the graph with the plan
+        enum ggml_status          (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+
+        // compute graph without a plan (async)
+        enum ggml_status (*GGML_CALL graph_compute)     (ggml_backend_t backend, struct ggml_cgraph * cgraph);
+
+        // check if the backend can compute an operation
+        bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
+
+        // check if the backend can use tensors allocated in a buffer type
+        bool (*GGML_CALL supports_buft)(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
+
+        // check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer
+        // these should be expensive operations with large batch sizes that may benefit from running on this backend
+        // even if the weight has to be copied from the CPU temporarily
+        bool (*GGML_CALL offload_op)(ggml_backend_t backend, const struct ggml_tensor * op);
+
+        // (optional) event synchronization
+        // create a new event that can record events on this backend instance
+        ggml_backend_event_t (*GGML_CALL event_new)         (ggml_backend_t backend);
+        void                 (*GGML_CALL event_free)        (ggml_backend_event_t event);
+        // record an event on the backend instance that created it
+        void                 (*GGML_CALL event_record)      (ggml_backend_event_t event);
+        // wait for an event on on a different backend instance
+        void                 (*GGML_CALL event_wait)        (ggml_backend_t backend, ggml_backend_event_t event);
+        // block until an event is recorded
+        void                 (*GGML_CALL event_synchronize) (ggml_backend_event_t event);
+    };
+
+    struct ggml_backend {
+        ggml_guid_t guid;
+
+        struct ggml_backend_i iface;
+        ggml_backend_context_t context;
+    };
+
+    struct ggml_backend_event {
+        ggml_backend_t backend;
+        void * context;
+    };
+
+    //
+    // Backend registry
+    //
+
+    typedef ggml_backend_t (*GGML_CALL ggml_backend_init_fn)(const char * params, void * user_data);
+
+    GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
+
+#ifdef  __cplusplus
+}
+#endif

+ 2288 - 0
llama/ggml-backend.c

@@ -0,0 +1,2288 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ggml-backend-impl.h"
+#include "ggml-alloc.h"
+#include "ggml-impl.h"
+
+#include <assert.h>
+#include <limits.h>
+#include <stdarg.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+// backend buffer type
+
+const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name(buft);
+}
+
+GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    return buft->iface.alloc_buffer(buft, size);
+}
+
+size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_alignment(buft);
+}
+
+size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
+    // get_max_size is optional, defaults to SIZE_MAX
+    if (buft->iface.get_max_size) {
+        return buft->iface.get_max_size(buft);
+    }
+    return SIZE_MAX;
+}
+
+GGML_CALL size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) {
+    // get_alloc_size is optional, defaults to ggml_nbytes
+    if (buft->iface.get_alloc_size) {
+        size_t size = buft->iface.get_alloc_size(buft, tensor);
+        assert(size >= ggml_nbytes(tensor));
+        return size;
+    }
+    return ggml_nbytes(tensor);
+}
+
+bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) {
+    if (buft->iface.is_host) {
+        return buft->iface.is_host(buft);
+    }
+    return false;
+}
+
+// backend buffer
+
+GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init(
+               ggml_backend_buffer_type_t      buft,
+        struct ggml_backend_buffer_i           iface,
+               ggml_backend_buffer_context_t   context,
+               size_t                          size) {
+    ggml_backend_buffer_t buffer = malloc(sizeof(struct ggml_backend_buffer));
+
+    (*buffer) = (struct ggml_backend_buffer) {
+        /* .interface = */ iface,
+        /* .buft      = */ buft,
+        /* .context   = */ context,
+        /* .size      = */ size,
+        /* .usage     = */ GGML_BACKEND_BUFFER_USAGE_ANY
+    };
+
+    return buffer;
+}
+
+const char * ggml_backend_buffer_name(ggml_backend_buffer_t buffer) {
+    return buffer->iface.get_name(buffer);
+}
+
+void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
+    if (buffer == NULL) {
+        return;
+    }
+
+    if (buffer->iface.free_buffer != NULL) {
+        buffer->iface.free_buffer(buffer);
+    }
+
+// TODO: this needs to be freed in cuda and hipblas backends because
+// the cuda backend implementation compiled with msvc
+#if !defined(GGML_USE_CUDA) && !defined(GGML_USE_HIPBLAS)
+    free(buffer);
+#endif
+}
+
+size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
+    return buffer->size;
+}
+
+void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
+    void * base = buffer->iface.get_base(buffer);
+
+    GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL");
+
+    return base;
+}
+
+GGML_CALL void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    // init_tensor is optional
+    if (buffer->iface.init_tensor) {
+        buffer->iface.init_tensor(buffer, tensor);
+    }
+}
+
+size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer) {
+    return ggml_backend_buft_get_alignment(ggml_backend_buffer_get_type(buffer));
+}
+
+size_t ggml_backend_buffer_get_max_size(ggml_backend_buffer_t buffer) {
+    return ggml_backend_buft_get_max_size(ggml_backend_buffer_get_type(buffer));
+}
+
+size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor);
+}
+
+void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    buffer->iface.clear(buffer, value);
+}
+
+bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) {
+    return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer));
+}
+
+void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
+    buffer->usage = usage;
+
+    // FIXME: add a generic callback to the buffer interface
+    if (ggml_backend_buffer_is_multi_buffer(buffer)) {
+        ggml_backend_multi_buffer_set_usage(buffer, usage);
+    }
+}
+
+enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage(ggml_backend_buffer_t buffer) {
+    return buffer->usage;
+}
+
+ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) {
+    return buffer->buft;
+}
+
+void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) {
+    if (buffer->iface.reset) {
+        buffer->iface.reset(buffer);
+    }
+}
+
+bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst) {
+    ggml_backend_buffer_t dst_buf = dst->view_src ? dst->view_src->buffer : dst->buffer;
+    if (dst_buf->iface.cpy_tensor) {
+        return dst_buf->iface.cpy_tensor(dst_buf, src, dst);
+    }
+    return false;
+}
+
+// backend
+
+ggml_guid_t ggml_backend_guid(ggml_backend_t backend) {
+    if (backend == NULL) {
+        return NULL;
+    }
+    return backend->guid;
+}
+
+const char * ggml_backend_name(ggml_backend_t backend) {
+    if (backend == NULL) {
+        return "NULL";
+    }
+    return backend->iface.get_name(backend);
+}
+
+void ggml_backend_free(ggml_backend_t backend) {
+    if (backend == NULL) {
+        return;
+    }
+
+    backend->iface.free(backend);
+}
+
+ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
+    return backend->iface.get_default_buffer_type(backend);
+}
+
+ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) {
+    return ggml_backend_buft_alloc_buffer(ggml_backend_get_default_buffer_type(backend), size);
+}
+
+size_t ggml_backend_get_alignment(ggml_backend_t backend) {
+    return ggml_backend_buft_get_alignment(ggml_backend_get_default_buffer_type(backend));
+}
+
+size_t ggml_backend_get_max_size(ggml_backend_t backend) {
+    return ggml_backend_buft_get_max_size(ggml_backend_get_default_buffer_type(backend));
+}
+
+void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+
+    if (backend->iface.set_tensor_async == NULL) {
+        ggml_backend_tensor_set(tensor, data, offset, size);
+    } else {
+        backend->iface.set_tensor_async(backend, tensor, data, offset, size);
+    }
+}
+
+void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+
+    if (backend->iface.get_tensor_async == NULL) {
+        ggml_backend_tensor_get(tensor, data, offset, size);
+    } else {
+        backend->iface.get_tensor_async(backend, tensor, data, offset, size);
+    }
+}
+
+GGML_CALL void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+    GGML_ASSERT(buf != NULL && "tensor buffer not set");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+
+    if (!size) {
+        return;
+    }
+
+    buf->iface.set_tensor(buf, tensor, data, offset, size);
+}
+
+GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+    GGML_ASSERT(buf != NULL && "tensor buffer not set");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+
+    if (!size) {
+        return;
+    }
+
+    buf->iface.get_tensor(buf, tensor, data, offset, size);
+}
+
+void ggml_backend_synchronize(ggml_backend_t backend) {
+    if (backend->iface.synchronize == NULL) {
+        return;
+    }
+
+    backend->iface.synchronize(backend);
+}
+
+ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    GGML_ASSERT(backend->iface.graph_plan_create != NULL);
+
+    return backend->iface.graph_plan_create(backend, cgraph);
+}
+
+void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    GGML_ASSERT(backend->iface.graph_plan_free != NULL);
+
+    backend->iface.graph_plan_free(backend, plan);
+}
+
+enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    GGML_ASSERT(backend->iface.graph_plan_compute != NULL);
+
+    return backend->iface.graph_plan_compute(backend, plan);
+}
+
+enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    enum ggml_status err = ggml_backend_graph_compute_async(backend, cgraph);
+    ggml_backend_synchronize(backend);
+    return err;
+}
+
+enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    return backend->iface.graph_compute(backend, cgraph);
+}
+
+bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+    return backend->iface.supports_op(backend, op);
+}
+
+bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+    return backend->iface.supports_buft(backend, buft);
+}
+
+bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+    if (backend->iface.offload_op != NULL) {
+        return backend->iface.offload_op(backend, op);
+    }
+    return false;
+}
+
+// backend copy
+
+static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
+    if (a->type != b->type) {
+        return false;
+    }
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        if (a->ne[i] != b->ne[i]) {
+            return false;
+        }
+        if (a->nb[i] != b->nb[i]) {
+            return false;
+        }
+    }
+    return true;
+}
+
+void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
+
+    if (src == dst) {
+        return;
+    }
+
+    if (ggml_backend_buffer_is_host(src->buffer)) {
+        ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
+    } else if (ggml_backend_buffer_is_host(dst->buffer)) {
+        ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
+    } else if (!ggml_backend_buffer_copy_tensor(src, dst)) {
+#ifndef NDEBUG
+        fprintf(stderr, "%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer));
+#endif
+        size_t nbytes = ggml_nbytes(src);
+        void * data = malloc(nbytes);
+        ggml_backend_tensor_get(src, data, 0, nbytes);
+        ggml_backend_tensor_set(dst, data, 0, nbytes);
+        free(data);
+    }
+}
+
+void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
+
+    if (src == dst) {
+        return;
+    }
+
+    if (backend_dst->iface.cpy_tensor_async != NULL) {
+        if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) {
+            return;
+        }
+    }
+
+    // an async copy would normally happen after all the queued operations on both backends are completed
+    // to simulate the same behavior, we need to synchronize both backends first, and do a blocking copy
+    ggml_backend_synchronize(backend_src);
+    ggml_backend_synchronize(backend_dst);
+    ggml_backend_tensor_copy(src, dst);
+}
+
+// events
+
+ggml_backend_event_t ggml_backend_event_new(ggml_backend_t backend) {
+    if (backend->iface.event_new == NULL) {
+        return NULL;
+    }
+    return backend->iface.event_new(backend);
+}
+
+void ggml_backend_event_free(ggml_backend_event_t event) {
+    if (event == NULL) {
+        return;
+    }
+    event->backend->iface.event_free(event);
+}
+
+void ggml_backend_event_record(ggml_backend_event_t event) {
+    GGML_ASSERT(event->backend->iface.event_record != NULL);
+
+    event->backend->iface.event_record(event);
+}
+
+void ggml_backend_event_synchronize(ggml_backend_event_t event) {
+    GGML_ASSERT(event->backend->iface.event_synchronize != NULL);
+
+    event->backend->iface.event_synchronize(event);
+}
+
+void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
+    GGML_ASSERT(backend->iface.event_wait != NULL);
+
+    backend->iface.event_wait(backend, event);
+}
+
+// backend registry
+
+#define GGML_REG_MAX_BACKENDS 64
+
+struct ggml_backend_reg {
+    char name[128];
+    ggml_backend_init_fn init_fn;
+    ggml_backend_buffer_type_t default_buffer_type;
+    void * user_data;
+};
+
+static struct ggml_backend_reg ggml_backend_registry[GGML_REG_MAX_BACKENDS];
+static size_t ggml_backend_registry_count = 0;
+
+GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data);
+
+GGML_CALL static void ggml_backend_registry_init(void) {
+    static bool initialized = false;
+
+    if (initialized) {
+        return;
+    }
+
+    initialized = true;
+
+    ggml_backend_register("CPU", ggml_backend_reg_cpu_init, ggml_backend_cpu_buffer_type(), NULL);
+
+    // add forward decls here to avoid including the backend headers
+#ifdef GGML_USE_CUDA
+    extern GGML_CALL void ggml_backend_cuda_reg_devices(void);
+    ggml_backend_cuda_reg_devices();
+#endif
+
+#ifdef GGML_USE_SYCL
+    extern void ggml_backend_sycl_reg_devices(void);
+    ggml_backend_sycl_reg_devices();
+#endif
+
+#ifdef GGML_USE_METAL
+    extern GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data);
+    extern GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
+    ggml_backend_register("Metal", ggml_backend_reg_metal_init, ggml_backend_metal_buffer_type(), NULL);
+#endif
+
+#ifdef GGML_USE_VULKAN
+    extern GGML_CALL int ggml_backend_vk_reg_devices(void);
+    ggml_backend_vk_reg_devices();
+#endif
+
+#ifdef GGML_USE_KOMPUTE
+    extern GGML_CALL void ggml_backend_kompute_reg_devices(void);
+    ggml_backend_kompute_reg_devices();
+#endif
+
+#ifdef GGML_USE_CANN
+    extern GGML_CALL int ggml_backend_cann_reg_devices(void);
+    ggml_backend_cann_reg_devices();
+#endif
+}
+
+GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
+    GGML_ASSERT(ggml_backend_registry_count < GGML_REG_MAX_BACKENDS);
+
+    size_t id = ggml_backend_registry_count;
+
+    ggml_backend_registry[id] = (struct ggml_backend_reg) {
+        /* .name                = */ {0},
+        /* .fn                  = */ init_fn,
+        /* .default_buffer_type = */ default_buffer_type,
+        /* .user_data           = */ user_data,
+    };
+
+    snprintf(ggml_backend_registry[id].name, sizeof(ggml_backend_registry[id].name), "%s", name);
+
+#ifndef NDEBUG
+    fprintf(stderr, "%s: registered backend %s\n", __func__, name);
+#endif
+
+    ggml_backend_registry_count++;
+}
+
+size_t ggml_backend_reg_get_count(void) {
+    ggml_backend_registry_init();
+
+    return ggml_backend_registry_count;
+}
+
+size_t ggml_backend_reg_find_by_name(const char * name) {
+    ggml_backend_registry_init();
+
+    for (size_t i = 0; i < ggml_backend_registry_count; i++) {
+        // TODO: case insensitive in a portable way
+        if (strcmp(ggml_backend_registry[i].name, name) == 0) {
+            return i;
+        }
+    }
+
+    // not found
+    return SIZE_MAX;
+}
+
+// init from backend:params string
+ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str) {
+    ggml_backend_registry_init();
+
+    const char * params = strchr(backend_str, ':');
+    char backend_name[128];
+    if (params == NULL) {
+        snprintf(backend_name, sizeof(backend_name), "%s", backend_str);
+        params = "";
+    } else {
+        snprintf(backend_name, sizeof(backend_name), "%.*s", (int)(params - backend_str), backend_str);
+        params++;
+    }
+
+    size_t backend_i = ggml_backend_reg_find_by_name(backend_name);
+
+    if (backend_i == SIZE_MAX) {
+        fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
+        return NULL;
+    }
+
+    return ggml_backend_reg_init_backend(backend_i, params);
+}
+
+const char * ggml_backend_reg_get_name(size_t i) {
+    ggml_backend_registry_init();
+
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_registry[i].name;
+}
+
+ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params) {
+    ggml_backend_registry_init();
+
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_registry[i].init_fn(params, ggml_backend_registry[i].user_data);
+}
+
+ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i) {
+    ggml_backend_registry_init();
+
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_registry[i].default_buffer_type;
+}
+
+ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size) {
+    ggml_backend_registry_init();
+
+    GGML_ASSERT(i < ggml_backend_registry_count);
+    return ggml_backend_buft_alloc_buffer(ggml_backend_registry[i].default_buffer_type, size);
+}
+
+// backend CPU
+
+static const size_t TENSOR_ALIGNMENT = 32; // required for mmap as gguf only guarantees 32-byte alignment
+
+GGML_CALL static const char * ggml_backend_cpu_buffer_name(ggml_backend_buffer_t buffer) {
+    return "CPU";
+
+    GGML_UNUSED(buffer);
+}
+
+GGML_CALL static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
+    uintptr_t data = (uintptr_t)buffer->context;
+
+    // align the buffer
+    if (data % TENSOR_ALIGNMENT != 0) {
+        data = GGML_PAD(data, TENSOR_ALIGNMENT);
+    }
+
+    return (void *)data;
+}
+
+GGML_CALL static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    free(buffer->context);
+}
+
+GGML_CALL static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    memcpy((char *)tensor->data + offset, data, size);
+
+    GGML_UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    memcpy(data, (const char *)tensor->data + offset, size);
+
+    GGML_UNUSED(buffer);
+}
+
+GGML_CALL static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
+    if (ggml_backend_buffer_is_host(src->buffer)) {
+        memcpy(dst->data, src->data, ggml_nbytes(src));
+        return true;
+    }
+    return false;
+
+    GGML_UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    memset(buffer->context, value, buffer->size);
+}
+
+static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
+    /* .get_name        = */ ggml_backend_cpu_buffer_name,
+    /* .free_buffer     = */ ggml_backend_cpu_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_cpu_buffer_get_base,
+    /* .init_tensor     = */ NULL, // no initialization required
+    /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,
+    /* .cpy_tensor      = */ ggml_backend_cpu_buffer_cpy_tensor,
+    /* .clear           = */ ggml_backend_cpu_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+// for buffers from ptr, free is not called
+static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
+    /* .get_name        = */ ggml_backend_cpu_buffer_name,
+    /* .free_buffer     = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
+    /* .get_base        = */ ggml_backend_cpu_buffer_get_base,
+    /* .init_tensor     = */ NULL, // no initialization required
+    /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,
+    /* .cpy_tensor      = */ ggml_backend_cpu_buffer_cpy_tensor,
+    /* .clear           = */ ggml_backend_cpu_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+GGML_CALL static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    return "CPU";
+
+    GGML_UNUSED(buft);
+}
+
+GGML_CALL static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    size += TENSOR_ALIGNMENT;   // malloc may return an address that is not aligned
+    void * data = malloc(size); // TODO: use GGML_ALIGNED_MALLOC (move to ggml-impl.h)
+    if (data == NULL) {
+        fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
+        return NULL;
+    }
+
+    return ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size);
+}
+
+GGML_CALL static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    return TENSOR_ALIGNMENT;
+
+    GGML_UNUSED(buft);
+}
+
+GGML_CALL static bool ggml_backend_cpu_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+    return true;
+
+    GGML_UNUSED(buft);
+}
+
+GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = {
+        /* .iface = */ {
+            /* .get_name         = */ ggml_backend_cpu_buffer_type_get_name,
+            /* .alloc_buffer     = */ ggml_backend_cpu_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_cpu_buffer_type_get_alignment,
+            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
+            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .is_host          = */ ggml_backend_cpu_buffer_type_is_host,
+        },
+        /* .context = */ NULL,
+    };
+
+    return &ggml_backend_cpu_buffer_type;
+}
+
+#ifdef GGML_USE_CPU_HBM
+
+// buffer type HBM
+
+#include <hbwmalloc.h>
+
+GGML_CALL static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    return "CPU_HBM";
+
+    GGML_UNUSED(buft);
+}
+
+GGML_CALL static const char * ggml_backend_cpu_hbm_buffer_get_name(ggml_backend_buffer_t buf) {
+    return "CPU_HBM";
+
+    GGML_UNUSED(buf);
+}
+
+GGML_CALL static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    hbw_free(buffer->context);
+}
+
+GGML_CALL static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    //void * ptr = hbw_malloc(size);
+    void * ptr;
+    int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);
+    if (result != 0) {
+        fprintf(stderr, "failed to allocate HBM buffer of size %zu\n", size);
+        return NULL;
+    }
+
+    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
+    buffer->buft = buft;
+    buffer->iface.get_name = ggml_backend_cpu_hbm_buffer_get_name;
+    buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer;
+
+    return buffer;
+}
+
+ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = {
+        /* .iface    = */ {
+            /* .get_name         = */ ggml_backend_cpu_hbm_buffer_type_get_name,
+            /* .alloc_buffer     = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_cpu_buffer_type_get_alignment,
+            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
+            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .is_host          = */ ggml_backend_cpu_buffer_type_is_host,
+        },
+        /* .context  = */ NULL,
+    };
+
+    return &ggml_backend_cpu_buffer_type_hbm;
+}
+#endif
+
+struct ggml_backend_cpu_context {
+    int                 n_threads;
+    ggml_threadpool_t   threadpool;
+
+    void *              work_data;
+    size_t              work_size;
+
+    ggml_abort_callback abort_callback;
+    void *              abort_callback_data;
+};
+
+GGML_CALL static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
+    return "CPU";
+
+    GGML_UNUSED(backend);
+}
+
+GGML_CALL static void ggml_backend_cpu_free(ggml_backend_t backend) {
+    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+    free(cpu_ctx->work_data);
+    free(cpu_ctx);
+    free(backend);
+}
+
+GGML_CALL static ggml_backend_buffer_type_t ggml_backend_cpu_get_default_buffer_type(ggml_backend_t backend) {
+    return ggml_backend_cpu_buffer_type();
+
+    GGML_UNUSED(backend);
+}
+
+struct ggml_backend_plan_cpu {
+    struct ggml_cplan cplan;
+    struct ggml_cgraph cgraph;
+};
+
+GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
+    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+
+    struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu));
+
+    cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
+    cpu_plan->cgraph = *cgraph; // FIXME: deep copy
+
+    if (cpu_plan->cplan.work_size > 0) {
+        cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
+        if (cpu_plan->cplan.work_data == NULL) {
+            free(cpu_plan);
+            return NULL;
+        }
+    }
+
+    cpu_plan->cplan.abort_callback      = cpu_ctx->abort_callback;
+    cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
+
+    return cpu_plan;
+}
+
+GGML_CALL static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
+
+    free(cpu_plan->cplan.work_data);
+    free(cpu_plan);
+
+    GGML_UNUSED(backend);
+}
+
+GGML_CALL static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
+
+    return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
+
+    GGML_UNUSED(backend);
+}
+
+GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+
+    struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
+
+    if (cpu_ctx->work_size < cplan.work_size) {
+        free(cpu_ctx->work_data);
+        cpu_ctx->work_data = malloc(cplan.work_size);
+        if (cpu_ctx->work_data == NULL) {
+            cpu_ctx->work_size = 0;
+            return GGML_STATUS_ALLOC_FAILED;
+        }
+        cpu_ctx->work_size = cplan.work_size;
+    }
+    cplan.work_data = cpu_ctx->work_data;
+
+    cplan.abort_callback      = cpu_ctx->abort_callback;
+    cplan.abort_callback_data = cpu_ctx->abort_callback_data;
+
+    return ggml_graph_compute(cgraph, &cplan);
+}
+
+GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+    switch (op->op) {
+        case GGML_OP_CPY:
+            return
+                op->type != GGML_TYPE_IQ2_XXS &&
+                op->type != GGML_TYPE_IQ2_XS  &&
+                op->type != GGML_TYPE_IQ1_S   &&
+                op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
+        case GGML_OP_MUL_MAT:
+            return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
+        default:
+            return true;
+    }
+
+    GGML_UNUSED(backend);
+}
+
+GGML_CALL static bool ggml_backend_cpu_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+    return ggml_backend_buft_is_host(buft);
+
+    GGML_UNUSED(backend);
+}
+
+static struct ggml_backend_i cpu_backend_i = {
+    /* .get_name                = */ ggml_backend_cpu_name,
+    /* .free                    = */ ggml_backend_cpu_free,
+    /* .get_default_buffer_type = */ ggml_backend_cpu_get_default_buffer_type,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_async        = */ NULL,
+    /* .synchronize             = */ NULL,
+    /* .graph_plan_create       = */ ggml_backend_cpu_graph_plan_create,
+    /* .graph_plan_free         = */ ggml_backend_cpu_graph_plan_free,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ ggml_backend_cpu_graph_plan_compute,
+    /* .graph_compute           = */ ggml_backend_cpu_graph_compute,
+    /* .supports_op             = */ ggml_backend_cpu_supports_op,
+    /* .supports_buft           = */ ggml_backend_cpu_supports_buft,
+    /* .offload_op              = */ NULL,
+    /* .event_new               = */ NULL,
+    /* .event_free              = */ NULL,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+    /* .event_synchronize       = */ NULL,
+};
+
+static ggml_guid_t ggml_backend_cpu_guid(void) {
+    static ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 };
+    return &guid;
+}
+
+ggml_backend_t ggml_backend_cpu_init(void) {
+    struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));
+    if (ctx == NULL) {
+        return NULL;
+    }
+
+    ctx->n_threads           = GGML_DEFAULT_N_THREADS;
+    ctx->threadpool          = NULL;
+    ctx->work_data           = NULL;
+    ctx->work_size           = 0;
+    ctx->abort_callback      = NULL;
+    ctx->abort_callback_data = NULL;
+
+    ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));
+    if (cpu_backend == NULL) {
+        free(ctx);
+        return NULL;
+    }
+
+    *cpu_backend = (struct ggml_backend) {
+        /* .guid      = */ ggml_backend_cpu_guid(),
+        /* .interface = */ cpu_backend_i,
+        /* .context   = */ ctx
+    };
+    return cpu_backend;
+}
+
+GGML_CALL bool ggml_backend_is_cpu(ggml_backend_t backend) {
+    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cpu_guid());
+}
+
+void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
+    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
+
+    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
+    ctx->n_threads = n_threads;
+}
+
+void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) {
+    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
+
+    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
+
+    if (ctx->threadpool && ctx->threadpool != threadpool) {
+        // already had a different threadpool, pause/suspend it before switching
+        ggml_threadpool_pause(ctx->threadpool);
+    }
+    ctx->threadpool = threadpool;
+}
+
+void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) {
+    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
+
+    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
+    ctx->abort_callback = abort_callback;
+    ctx->abort_callback_data = abort_callback_data;
+}
+
+GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
+    GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned");
+    return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
+}
+
+GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data) {
+    return ggml_backend_cpu_init();
+
+    GGML_UNUSED(params);
+    GGML_UNUSED(user_data);
+}
+
+// multi-buffer buffer
+
+struct ggml_backend_multi_buffer_context {
+    ggml_backend_buffer_t * buffers;
+    size_t n_buffers;
+};
+
+typedef struct ggml_backend_multi_buffer_context * ggml_backend_multi_buffer_context_t;
+
+GGML_CALL static const char * ggml_backend_multi_buffer_get_name(ggml_backend_buffer_t buffer) {
+    ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context;
+
+    return ctx->buffers[0]->iface.get_name(ctx->buffers[0]);
+}
+
+GGML_CALL static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context;
+    for (size_t i = 0; i < ctx->n_buffers; i++) {
+        ggml_backend_buffer_free(ctx->buffers[i]);
+    }
+
+    free(ctx->buffers);
+    free(ctx);
+}
+
+GGML_CALL static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context;
+    for (size_t i = 0; i < ctx->n_buffers; i++) {
+        ggml_backend_buffer_clear(ctx->buffers[i], value);
+    }
+}
+
+static struct ggml_backend_buffer_i ggml_backend_multi_buffer_context_interface(void) {
+    static struct ggml_backend_buffer_i multi_backend_buffer_i = {
+        /* .get_name        = */ ggml_backend_multi_buffer_get_name,
+        /* .free_buffer     = */ ggml_backend_multi_buffer_free_buffer,
+        /* .get_base        = */ NULL,
+        /* .init_tensor     = */ NULL,
+        /* .set_tensor      = */ NULL,
+        /* .get_tensor      = */ NULL,
+        /* .cpy_tensor      = */ NULL,
+        /* .clear           = */ ggml_backend_multi_buffer_clear,
+        /* .reset           = */ NULL,
+    };
+
+    return multi_backend_buffer_i;
+}
+
+GGML_CALL ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers) {
+    ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) malloc(sizeof(struct ggml_backend_multi_buffer_context));
+    ctx->n_buffers = n_buffers;
+    ctx->buffers = (ggml_backend_buffer_t *) malloc(n_buffers * sizeof(ggml_backend_buffer_t));
+
+    GGML_ASSERT(ctx->buffers != NULL);
+
+    size_t total_size = 0;
+    for (size_t i = 0; i < n_buffers; i++) {
+        ctx->buffers[i] = buffers[i];
+        total_size += ggml_backend_buffer_get_size(buffers[i]);
+    }
+
+    return ggml_backend_buffer_init(buffers[0]->buft, ggml_backend_multi_buffer_context_interface(), ctx, total_size);
+}
+
+GGML_CALL bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) {
+    return buffer->iface.get_name == ggml_backend_multi_buffer_get_name;
+}
+
+GGML_CALL void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
+    GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer));
+    ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context;
+    for (size_t i = 0; i < ctx->n_buffers; i++) {
+        ggml_backend_buffer_set_usage(ctx->buffers[i], usage);
+    }
+}
+
+// creates a copy of the tensor with the same memory layout
+static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) {
+    struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor);
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        dup->nb[i] = tensor->nb[i];
+    }
+    return dup;
+}
+
+static bool ggml_is_view_op(enum ggml_op op) {
+    return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
+}
+
+// scheduler
+
+#ifndef GGML_SCHED_MAX_BACKENDS
+#define GGML_SCHED_MAX_BACKENDS 16
+#endif
+
+#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
+#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC
+#endif
+
+#ifndef GGML_SCHED_MAX_COPIES
+#define GGML_SCHED_MAX_COPIES 4
+#endif
+
+struct ggml_backend_sched_split {
+    int backend_id;
+    int i_start;
+    int i_end;
+    struct ggml_tensor * inputs[GGML_SCHED_MAX_SPLIT_INPUTS];
+    int n_inputs;
+    // graph view of this split
+    struct ggml_cgraph graph;
+};
+
+struct ggml_backend_sched {
+    bool is_reset; // true if the scheduler has been reset since the last graph split
+    bool is_alloc;
+
+    int n_backends;
+
+    ggml_backend_t backends[GGML_SCHED_MAX_BACKENDS];
+    ggml_backend_buffer_type_t bufts[GGML_SCHED_MAX_BACKENDS];
+    ggml_gallocr_t galloc;
+
+    // hash map of the nodes in the graph
+    struct ggml_hash_set  hash_set;
+    int                 * hv_tensor_backend_ids; // [hash_set.size]
+    struct ggml_tensor ** hv_tensor_copies;      // [hash_set.size][n_backends][n_copies]
+
+    int * node_backend_ids; // [graph_size]
+    int * leaf_backend_ids; // [graph_size]
+
+    int * prev_node_backend_ids; // [graph_size]
+    int * prev_leaf_backend_ids; // [graph_size]
+
+    // copy of the graph with modified inputs
+    struct ggml_cgraph graph;
+
+    // graph splits
+    struct ggml_backend_sched_split * splits;
+    int n_splits;
+    int splits_capacity;
+
+    // pipeline parallelism support
+    int n_copies;
+    int cur_copy;
+    ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES];
+    struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS];
+    int n_graph_inputs;
+
+    struct ggml_context * ctx;
+
+    ggml_backend_sched_eval_callback callback_eval;
+    void * callback_eval_user_data;
+
+    char * context_buffer;
+    size_t context_buffer_size;
+
+    bool debug;
+};
+
+#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)
+#define tensor_backend_id(tensor) sched->hv_tensor_backend_ids[hash_id(tensor)]
+#define tensor_id_copy(id, backend_id, copy_id) sched->hv_tensor_copies[(id) * sched->n_backends * sched->n_copies + (backend_id) * sched->n_copies + (copy_id)]
+#define tensor_copy(tensor, backend_id, copy_id) tensor_id_copy(hash_id(tensor), backend_id, copy_id)
+
+// returns the priority of the backend, lower id is higher priority
+static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backend_t backend) {
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (sched->backends[i] == backend) {
+            return i;
+        }
+    }
+    return -1;
+}
+
+static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor, const struct ggml_tensor * op) {
+    ggml_backend_buffer_t buffer = tensor->buffer;
+    if (buffer == NULL) {
+        return -1;
+    }
+
+    // find highest prio backend that supports the buffer type and the op
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (ggml_backend_supports_buft(sched->backends[i], buffer->buft) &&
+            ggml_backend_supports_op(sched->backends[i], op)) {
+            return i;
+        }
+    }
+
+#ifndef NDEBUG
+    fprintf(stderr, "%s: warning: no backend supports op %s with a weight with buffer type %s used in tensor %s, the weight will need to be copied\n",
+        __func__, ggml_op_desc(tensor), ggml_backend_buffer_name(buffer), tensor->name);
+#endif
+
+    return -1;
+}
+
+#if 0
+#define GGML_SCHED_MAX_SPLITS_DEBUG 4096
+static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only
+#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
+#define GET_CAUSE(node) causes[hash_id(node)]
+#else
+#define SET_CAUSE(node, ...)
+#define GET_CAUSE(node) ""
+#endif
+
+// returns the backend that should be used for the node based on the current locations
+static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) {
+    // TODO: use supports_op to check if the backend supports the op
+
+    // assign pre-allocated nodes to their backend
+    int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor, tensor);
+    if (cur_backend_id != -1) {
+        SET_CAUSE(tensor, "1.dst");
+        return cur_backend_id;
+    }
+
+    // view_src
+    if (tensor->view_src != NULL) {
+        cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src, tensor);
+        if (cur_backend_id != -1) {
+            SET_CAUSE(tensor, "1.vsrc");
+            return cur_backend_id;
+        }
+    }
+
+    // graph input
+    if (tensor->flags & GGML_TENSOR_FLAG_INPUT) {
+        cur_backend_id = sched->n_backends - 1; // last backend (assumed CPU)
+        SET_CAUSE(tensor, "1.inp");
+        return cur_backend_id;
+    }
+
+    // operations with weights are preferably run on the same backend as the weights
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        const struct ggml_tensor * src = tensor->src[i];
+        if (src == NULL) {
+            continue;
+        }
+        if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
+            int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
+            // check if a backend with higher prio wants to offload the op
+            if (src_backend_id == sched->n_backends - 1) {
+                for (int b = 0; b < src_backend_id; b++) {
+                    if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
+                        SET_CAUSE(tensor, "1.off");
+                        return b;
+                    }
+                }
+            }
+            SET_CAUSE(tensor, "1.wgt%d", i);
+            return src_backend_id;
+        }
+    }
+
+    return -1;
+}
+
+static char * fmt_size(size_t size) {
+    static char buffer[128];
+    if (size >= 1024*1024) {
+        snprintf(buffer, sizeof(buffer), "%zuM", size/1024/1024);
+    } else {
+        snprintf(buffer, sizeof(buffer), "%zuK", size/1024);
+    }
+    return buffer;
+}
+
+static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    int cur_split = 0;
+    for (int i = 0; i < graph->n_nodes; i++) {
+        if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
+            ggml_backend_t split_backend = sched->backends[sched->splits[cur_split].backend_id];
+            fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend),
+                sched->splits[cur_split].n_inputs);
+            for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
+                fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name,
+                    fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
+            }
+            fprintf(stderr, "\n");
+            cur_split++;
+        }
+        struct ggml_tensor * node = graph->nodes[i];
+        if (ggml_is_view_op(node->op)) {
+            continue;
+        }
+        ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
+        fprintf(stderr, "node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
+            fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                continue;
+            }
+            ggml_backend_t src_backend = ggml_backend_sched_get_tensor_backend(sched, src);
+            fprintf(stderr, " %20.20s (%5.5s) [%5.5s %8.8s]", src->name,
+                fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
+        }
+        fprintf(stderr, "\n");
+    }
+}
+
+static bool ggml_backend_sched_buffer_supported(ggml_backend_sched_t sched, struct ggml_tensor * t, int backend_id) {
+    ggml_backend_buffer_t buf = t->view_src ? t->view_src->buffer : t->buffer;
+    ggml_backend_buffer_type_t buft = NULL;
+
+    if (buf) {
+        // the tensor is already allocated
+        buft = buf->buft;
+    } else {
+        // see if the tensor already has a backend assigned, and use the buffer type of that backend
+        int tensor_backend_id = tensor_backend_id(t);
+        if (tensor_backend_id == -1 && t->view_src) {
+            tensor_backend_id = tensor_backend_id(t->view_src);
+        }
+        if (tensor_backend_id != -1) {
+            buft = sched->bufts[tensor_backend_id];
+        }
+    }
+
+    return buft != NULL && ggml_backend_supports_buft(sched->backends[backend_id], buft);
+}
+
+static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, struct ggml_tensor * node, int cur_backend_id, int * node_backend_id) {
+    if (ggml_backend_supports_op(sched->backends[cur_backend_id], node)) {
+        *node_backend_id = cur_backend_id;
+        SET_CAUSE(node, "2.sup");
+    }
+}
+
+// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
+static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    // reset splits
+    sched->n_splits = 0;
+    sched->n_graph_inputs = 0;
+    sched->is_reset = false;
+
+    struct ggml_init_params params = {
+        /* .mem_size =   */ sched->context_buffer_size,
+        /* .mem_buffer = */ sched->context_buffer,
+        /* .no_alloc =   */ true
+    };
+
+    ggml_free(sched->ctx);
+
+    sched->ctx = ggml_init(params);
+    if (sched->ctx == NULL) {
+        GGML_ABORT("%s: failed to initialize context\n", __func__);
+    }
+
+    // pass 1: assign backends to ops with pre-allocated inputs
+    for (int i = 0; i < graph->n_leafs; i++) {
+        struct ggml_tensor * leaf = graph->leafs[i];
+        int * leaf_backend_id = &tensor_backend_id(leaf);
+        // do not overwrite user assignments
+        if (*leaf_backend_id == -1) {
+            *leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf);
+        }
+    }
+
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        int * node_backend_id = &tensor_backend_id(node);
+        // do not overwrite user assignments
+        if (*node_backend_id == -1) {
+            *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node);
+
+#if 0
+            // src
+            if (node->op == GGML_OP_NONE) {
+                continue;
+            }
+
+            for (int j = 0; j < GGML_MAX_SRC; j++) {
+                struct ggml_tensor * src = node->src[j];
+                if (src == NULL) {
+                    continue;
+                }
+                int * src_backend_id = &tensor_backend_id(src);
+                if (*src_backend_id == -1) {
+                    *src_backend_id = ggml_backend_sched_backend_id_from_cur(sched, src);
+                }
+            }
+#endif
+        }
+    }
+
+    // pass 2: expand current backend assignments
+    // assign the same backend to adjacent nodes
+    // expand gpu backends (i.e. non last prio) up and down, ignoring cpu (the lowest priority backend)
+    // thus, cpu will never be used unless weights are on cpu, or there are no gpu ops between cpu ops
+    // ops unsupported by the backend being expanded will be left unassigned so that they can be assigned later when the locations of its inputs are known
+    // expand gpu down
+    {
+        int cur_backend_id = -1;
+        for (int i = 0; i < graph->n_nodes; i++) {
+            struct ggml_tensor * node = graph->nodes[i];
+            if (ggml_is_view_op(node->op)) {
+                continue;
+            }
+            int * node_backend_id = &tensor_backend_id(node);
+            if (*node_backend_id != -1) {
+                if (*node_backend_id == sched->n_backends - 1) {
+                    // skip cpu (lowest prio backend)
+                    cur_backend_id = -1;
+                } else {
+                    cur_backend_id = *node_backend_id;
+                }
+            } else if (cur_backend_id != -1) {
+                ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
+            }
+        }
+    }
+    // expand gpu up
+    {
+        int cur_backend_id = -1;
+        for (int i = graph->n_nodes - 1; i >= 0; i--) {
+            struct ggml_tensor * node = graph->nodes[i];
+            if (ggml_is_view_op(node->op)) {
+                continue;
+            }
+            int * node_backend_id = &tensor_backend_id(node);
+            if (*node_backend_id != -1) {
+                if (*node_backend_id == sched->n_backends - 1) {
+                    // skip cpu (lowest prio backend)
+                    cur_backend_id = -1;
+                } else {
+                    cur_backend_id = *node_backend_id;
+                }
+            } else if (cur_backend_id != -1) {
+                ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
+            }
+        }
+    }
+    // expand rest down
+    {
+        int cur_backend_id = -1;
+        for (int i = 0; i < graph->n_nodes; i++) {
+            struct ggml_tensor * node = graph->nodes[i];
+            if (ggml_is_view_op(node->op)) {
+                continue;
+            }
+            int * node_backend_id = &tensor_backend_id(node);
+            if (*node_backend_id != -1) {
+                cur_backend_id = *node_backend_id;
+            } else if (cur_backend_id != -1) {
+                ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
+            }
+        }
+    }
+    // expand rest up
+    {
+        int cur_backend_id = -1;
+        for (int i = graph->n_nodes - 1; i >= 0; i--) {
+            struct ggml_tensor * node = graph->nodes[i];
+            if (ggml_is_view_op(node->op)) {
+                continue;
+            }
+            int * node_backend_id = &tensor_backend_id(node);
+            if (*node_backend_id != -1) {
+                cur_backend_id = *node_backend_id;
+            } else if (cur_backend_id != -1) {
+                ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
+            }
+        }
+    }
+
+    // pass 3: upgrade nodes to higher prio backends with compatible buffer types
+    // if the tensor is already in the same buffer type (*) as another higher priority backend, we should move it there
+    // however, we also need to verify that the sources are in compatible buffer types
+    // (*) the actual requirement is more relaxed, the buffer type of the backend should be supported by all the users of this tensor further down the graph
+    // however, this is slow to verify, so we have a more strict requirement that the buffer type is the same
+    // this is not uncommon since multiple backends can use host memory, with the same buffer type (eg. BLAS and CPU)
+    // additionally, set remaining unassigned nodes to the backend with the most supported inputs
+    // only nodes that could not be assigned during expansion due to the backend not supporting the op should be unassigned at this point
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        if (ggml_is_view_op(node->op)) {
+            continue;
+        }
+        int * node_backend_id = &tensor_backend_id(node);
+        if (*node_backend_id == -1) {
+            // unassigned node: find the backend with the most supported inputs
+            int n_supported_best = -1;
+            for (int b = 0; b < sched->n_backends; b++) {
+                if (ggml_backend_supports_op(sched->backends[b], node)) {
+                    int n_supported = 0;
+                    for (int j = 0; j < GGML_MAX_SRC; j++) {
+                        struct ggml_tensor * src = node->src[j];
+                        if (src == NULL) {
+                            continue;
+                        }
+                        if ((tensor_backend_id(src) != -1 || tensor_backend_id(src->view_src) != -1) && ggml_backend_sched_buffer_supported(sched, src, b)) {
+                            n_supported++;
+                        }
+                    }
+                    if (n_supported > n_supported_best) {
+                        n_supported_best = n_supported;
+                        *node_backend_id = b;
+                        SET_CAUSE(node, "3.best");
+                    }
+                }
+            }
+        } else {
+            // assigned node: upgrade to higher prio backend if possible
+            for (int b = 0; b < *node_backend_id; b++) {
+                if (sched->bufts[b] == sched->bufts[*node_backend_id] && ggml_backend_supports_op(sched->backends[b], node)) {
+                    bool supported = true;
+                    for (int j = 0; j < GGML_MAX_SRC; j++) {
+                        struct ggml_tensor * src = node->src[j];
+                        if (src == NULL) {
+                            continue;
+                        }
+                        if (!ggml_backend_sched_buffer_supported(sched, src, b)) {
+                            supported = false;
+                            break;
+                        }
+                    }
+                    if (supported) {
+                        *node_backend_id = b;
+                        SET_CAUSE(node, "3.upg");
+                        break;
+                    }
+                }
+            }
+        }
+    }
+
+    // pass 4: assign backends to remaining src from dst and view_src
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        int * cur_backend_id = &tensor_backend_id(node);
+        if (node->view_src != NULL && *cur_backend_id == -1) {
+            *cur_backend_id = tensor_backend_id(node->view_src);
+            SET_CAUSE(node, "4.vsrc");
+        }
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                continue;
+            }
+            int * src_backend_id = &tensor_backend_id(src);
+            if (*src_backend_id == -1) {
+                if (src->view_src != NULL) {
+                    // views are always on the same backend as the source
+                    *src_backend_id = tensor_backend_id(src->view_src);
+                    SET_CAUSE(src, "4.vsrc");
+                } else {
+                    *src_backend_id = *cur_backend_id;
+                    SET_CAUSE(src, "4.cur");
+                }
+            }
+        }
+    }
+
+    // pass 5: split graph, find tensors that need to be copied
+    {
+        int i_split = 0;
+        struct ggml_backend_sched_split * split = &sched->splits[0];
+        // find the backend of the first split, skipping view ops
+        int i = 0;
+        for (; i < graph->n_nodes; i++) {
+            struct ggml_tensor * node = graph->nodes[i];
+            if (!ggml_is_view_op(node->op)) {
+                split->backend_id = tensor_backend_id(node);
+                break;
+            }
+        }
+        split->i_start = 0;
+        split->n_inputs = 0;
+        int cur_backend_id = split->backend_id;
+        for (; i < graph->n_nodes; i++) {
+            struct ggml_tensor * node = graph->nodes[i];
+
+            if (ggml_is_view_op(node->op)) {
+                continue;
+            }
+
+            const int node_backend_id = tensor_backend_id(node);
+
+            assert(node_backend_id != -1); // all nodes should be assigned by now
+
+            // check if we should start a new split based on the sources of the current node
+            bool need_new_split = false;
+            if (node_backend_id == cur_backend_id && split->n_inputs > 0) {
+                for (int j = 0; j < GGML_MAX_SRC; j++) {
+                    struct ggml_tensor * src = node->src[j];
+                    if (src == NULL) {
+                        continue;
+                    }
+                    // check if a weight is on a different backend
+                    // by starting a new split, the memory of the previously offloaded weights can be reused
+                    if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
+                        int src_backend_id = tensor_backend_id(src);
+                        if (src_backend_id != cur_backend_id) {
+                            need_new_split = true;
+                            break;
+                        }
+                    }
+                    // check if the split has too many inputs
+                    // FIXME: count the number of inputs instead of only checking when full
+                    if (split->n_inputs == GGML_SCHED_MAX_SPLIT_INPUTS) {
+                        const size_t id = hash_id(src);
+                        int src_backend_id = sched->hv_tensor_backend_ids[id];
+                        bool supported = ggml_backend_sched_buffer_supported(sched, src, cur_backend_id);
+                        if (src_backend_id != cur_backend_id && tensor_id_copy(id, cur_backend_id, 0) == NULL && !supported) {
+                            //printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name);
+                            need_new_split = true;
+                            break;
+                        }
+                    }
+                }
+            }
+
+            if (node_backend_id != cur_backend_id || need_new_split) {
+                split->i_end = i;
+                i_split++;
+                if (i_split >= sched->splits_capacity) {
+                    sched->splits_capacity *= 2;
+                    sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split));
+                    GGML_ASSERT(sched->splits != NULL);
+                }
+                split = &sched->splits[i_split];
+                split->backend_id = node_backend_id;
+                split->i_start = i;
+                split->n_inputs = 0;
+                cur_backend_id = node_backend_id;
+            }
+
+            // find inputs that are not on the same backend
+            for (int j = 0; j < GGML_MAX_SRC; j++) {
+                struct ggml_tensor * src = node->src[j];
+                if (src == NULL) {
+                    continue;
+                }
+
+                size_t src_id = hash_id(src);
+                const int src_backend_id = sched->hv_tensor_backend_ids[src_id];
+                assert(src_backend_id != -1); // all inputs should be assigned by now
+
+                if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) {
+                    if (tensor_id_copy(src_id, src_backend_id, 0) == NULL) {
+                        ggml_backend_t backend = sched->backends[src_backend_id];
+                        for (int c = 0; c < sched->n_copies; c++) {
+                            struct ggml_tensor * tensor_copy;
+                            if (c == sched->cur_copy) {
+                                tensor_copy = src; // use the original tensor as the current copy
+                            } else {
+                                tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
+                                ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c);
+                            }
+                            if (sched->n_copies > 1) {
+                                ggml_set_input(tensor_copy);
+                                ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor
+                            }
+                            tensor_id_copy(src_id, src_backend_id, c) = tensor_copy;
+                            SET_CAUSE(tensor_copy, "4.cpy");
+                        }
+                        int n_graph_inputs = sched->n_graph_inputs++;
+                        GGML_ASSERT(n_graph_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);
+                        sched->graph_inputs[n_graph_inputs] = src;
+                    }
+                }
+
+                if (src_backend_id != cur_backend_id && !ggml_backend_sched_buffer_supported(sched, src, cur_backend_id)) {
+                    // create a copy of the input in the split's backend
+                    if (tensor_id_copy(src_id, cur_backend_id, 0) == NULL) {
+                        ggml_backend_t backend = sched->backends[cur_backend_id];
+                        for (int c = 0; c < sched->n_copies; c++) {
+                            struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
+                            ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c);
+                            if (sched->n_copies > 1) {
+                                ggml_set_input(tensor_copy);
+                                ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor
+                            }
+                            tensor_id_copy(src_id, cur_backend_id, c) = tensor_copy;
+                            SET_CAUSE(tensor_copy, "4.cpy");
+                        }
+                        int n_inputs = split->n_inputs++;
+                        GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);
+                        split->inputs[n_inputs] = src;
+                    }
+                    node->src[j] = tensor_id_copy(src_id, cur_backend_id, sched->cur_copy);
+                }
+            }
+        }
+        split->i_end = graph->n_nodes;
+        sched->n_splits = i_split + 1;
+    }
+
+    if (sched->debug) {
+        ggml_backend_sched_print_assignments(sched, graph);
+    }
+
+    // swap node_backend_ids and leaf _backend_ids with prevs
+    {
+        int * tmp = sched->node_backend_ids;
+        sched->node_backend_ids = sched->prev_node_backend_ids;
+        sched->prev_node_backend_ids = tmp;
+
+        tmp = sched->leaf_backend_ids;
+        sched->leaf_backend_ids = sched->prev_leaf_backend_ids;
+        sched->prev_leaf_backend_ids = tmp;
+    }
+
+    int graph_size = graph->n_nodes + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2;
+    if (sched->graph.size < graph_size) {
+        sched->graph.size = graph_size;
+        sched->graph.nodes = realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *));
+        sched->graph.leafs = realloc(sched->graph.leafs, graph_size * sizeof(struct ggml_tensor *));
+        GGML_ASSERT(sched->graph.nodes != NULL);
+        GGML_ASSERT(sched->graph.leafs != NULL);
+    }
+    sched->graph.n_nodes = 0;
+    sched->graph.n_leafs = 0;
+
+    struct ggml_cgraph * graph_copy = &sched->graph;
+
+    for (int i = 0; i < sched->n_splits; i++) {
+        struct ggml_backend_sched_split * split = &sched->splits[i];
+        split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
+
+        // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
+        for (int j = 0; j < split->n_inputs; j++) {
+            assert(graph_copy->size > (graph_copy->n_nodes + 1));
+
+            struct ggml_tensor * input = split->inputs[j];
+            const size_t input_id = hash_id(input);
+            struct ggml_tensor * input_cpy = tensor_id_copy(input_id, split->backend_id, sched->cur_copy);
+
+            // add a dependency to the input source so that it is not freed before the copy is done
+            struct ggml_tensor * input_dep = ggml_view_tensor(sched->ctx, input);
+            input_dep->src[0] = input;
+            sched->node_backend_ids[graph_copy->n_nodes] = sched->hv_tensor_backend_ids[input_id];
+            graph_copy->nodes[graph_copy->n_nodes++] = input_dep;
+
+            // add a dependency to the input copy so that it is allocated at the start of the split
+            sched->node_backend_ids[graph_copy->n_nodes] = split->backend_id;
+            graph_copy->nodes[graph_copy->n_nodes++] = input_cpy;
+        }
+
+        for (int j = split->i_start; j < split->i_end; j++) {
+            assert(graph_copy->size > graph_copy->n_nodes);
+            sched->node_backend_ids[graph_copy->n_nodes] = tensor_backend_id(graph->nodes[j]);
+            graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j];
+        }
+    }
+
+    if (sched->n_copies > 1) {
+        // add input copies as leafs so that they are allocated first
+        for (int i = 0; i < sched->n_graph_inputs; i++) {
+            struct ggml_tensor * input = sched->graph_inputs[i];
+            size_t id = hash_id(input);
+            int backend_id = tensor_backend_id(input);
+            for (int c = 0; c < sched->n_copies; c++) {
+                struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c);
+                sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;
+                graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;
+            }
+        }
+
+        for (int i = 0; i < sched->n_splits; i++) {
+            struct ggml_backend_sched_split * split = &sched->splits[i];
+            int backend_id = split->backend_id;
+            for (int j = 0; j < split->n_inputs; j++) {
+                struct ggml_tensor * input = split->inputs[j];
+                size_t id = hash_id(input);
+                for (int c = 0; c < sched->n_copies; c++) {
+                    struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c);
+                    sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;
+                    graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;
+                }
+            }
+        }
+    }
+
+    // add leafs from the original graph
+    for (int i = 0; i < graph->n_leafs; i++) {
+        struct ggml_tensor * leaf = graph->leafs[i];
+        sched->leaf_backend_ids[graph_copy->n_leafs] = tensor_backend_id(leaf);
+        graph_copy->leafs[graph_copy->n_leafs++] = leaf;
+    }
+}
+
+static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
+    bool backend_ids_changed = false;
+    for (int i = 0; i < sched->graph.n_nodes; i++) {
+        if (sched->node_backend_ids[i] != sched->prev_node_backend_ids[i] &&
+            sched->bufts[sched->node_backend_ids[i]] != sched->bufts[sched->prev_node_backend_ids[i]]) {
+            backend_ids_changed = true;
+            break;
+        }
+    }
+    if (!backend_ids_changed) {
+        for (int i = 0; i < sched->graph.n_leafs; i++) {
+            if (sched->leaf_backend_ids[i] != sched->prev_leaf_backend_ids[i] &&
+                sched->bufts[sched->leaf_backend_ids[i]] != sched->bufts[sched->prev_leaf_backend_ids[i]]) {
+                backend_ids_changed = true;
+                break;
+            }
+        }
+    }
+
+    // allocate graph
+    if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
+        // the re-allocation may cause the split inputs to be moved to a different address
+        ggml_backend_sched_synchronize(sched);
+#ifndef NDEBUG
+        fprintf(stderr, "%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
+#endif
+        ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids);
+        if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
+            fprintf(stderr, "%s: failed to allocate graph\n", __func__);
+            return false;
+        }
+    }
+
+    return true;
+}
+
+static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
+    struct ggml_backend_sched_split * splits = sched->splits;
+
+    for (int i = 0; i < sched->n_splits; i++) {
+        struct ggml_backend_sched_split * split = &splits[i];
+        int split_backend_id = split->backend_id;
+        ggml_backend_t split_backend = sched->backends[split_backend_id];
+
+        // copy the input tensors to the split backend
+        for (int j = 0; j < split->n_inputs; j++) {
+            ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]);
+            struct ggml_tensor * input = split->inputs[j];
+            struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy);
+
+            if (input->flags & GGML_TENSOR_FLAG_INPUT) {
+                // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
+                if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
+                    ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
+                } else {
+                    ggml_backend_synchronize(split_backend);
+                }
+                ggml_backend_tensor_copy(input, input_cpy);
+            } else {
+                // wait for the split backend to finish using the input before overwriting it
+                if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
+                    ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);
+                } else {
+                    ggml_backend_synchronize(split_backend);
+                }
+                // try async copy, but if not possible, we can still use a sync copy without synchronizing the dst backend, since we handle the synchronization here with multiple copies and events
+                // TODO: add public function to facilitate this, since applications do not have direct access to the backend interface
+                if (!split_backend->iface.cpy_tensor_async || !split_backend->iface.cpy_tensor_async(input_backend, split_backend, input, input_cpy)) {
+                    ggml_backend_synchronize(input_backend);
+                    if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
+                        ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
+                    } else {
+                        ggml_backend_synchronize(split_backend);
+                    }
+                    ggml_backend_tensor_copy(input, input_cpy);
+                }
+            }
+        }
+
+        if (!sched->callback_eval) {
+            enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
+            if (ec != GGML_STATUS_SUCCESS) {
+                return ec;
+            }
+        } else {
+            // similar to ggml_backend_compare_graph_backend
+            for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {
+                struct ggml_tensor * t = split->graph.nodes[j0];
+
+                // check if the user needs data from this node
+                bool need = sched->callback_eval(t, true, sched->callback_eval_user_data);
+
+                int j1 = j0;
+
+                // determine the range [j0, j1] of nodes that can be computed together
+                while (!need && j1 < split->graph.n_nodes - 1) {
+                    t = split->graph.nodes[++j1];
+                    need = sched->callback_eval(t, true, sched->callback_eval_user_data);
+                }
+
+                struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
+
+                enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv);
+                if (ec != GGML_STATUS_SUCCESS) {
+                    return ec;
+                }
+
+                // TODO: pass backend to the callback, then the user can decide if they want to synchronize
+                ggml_backend_synchronize(split_backend);
+
+                if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
+                    break;
+                }
+
+                j0 = j1;
+            }
+        }
+
+        // record the event of this copy
+        if (split->n_inputs > 0) {
+            if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
+                ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy]);
+            }
+        }
+    }
+
+    sched->cur_copy = (sched->cur_copy + 1) % sched->n_copies;
+
+    return GGML_STATUS_SUCCESS;
+}
+
+ggml_backend_sched_t ggml_backend_sched_new(
+        ggml_backend_t * backends,
+        ggml_backend_buffer_type_t * bufts,
+        int n_backends,
+        size_t graph_size,
+        bool parallel) {
+    GGML_ASSERT(n_backends > 0);
+    GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
+    GGML_ASSERT(ggml_backend_is_cpu(backends[n_backends - 1])); // last backend must be CPU
+
+    struct ggml_backend_sched * sched = calloc(1, sizeof(struct ggml_backend_sched));
+
+    sched->debug = getenv("GGML_SCHED_DEBUG") != NULL;
+    sched->n_backends = n_backends;
+    sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1;
+
+    // initialize hash table
+    // FIXME: needs to be size*2 to account for leafs (do it in graph_split instead)
+    sched->hash_set    = ggml_hash_set_new(graph_size);
+    sched->hv_tensor_backend_ids = malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));
+    sched->hv_tensor_copies      = malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));
+
+    const size_t ggml_sched_max_splits = graph_size; // at most there is one split for each node in the graph
+    const size_t nodes_size = graph_size + ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2;
+    sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0]));
+    sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0]));
+    sched->prev_node_backend_ids = calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0]));
+    sched->prev_leaf_backend_ids = calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0]));
+
+    sched->context_buffer_size = ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false);
+    sched->context_buffer = malloc(sched->context_buffer_size);
+
+    const int initial_splits_capacity = 16;
+    sched->splits = calloc(initial_splits_capacity, sizeof(sched->splits[0]));
+    sched->splits_capacity = initial_splits_capacity;
+
+    for (int b = 0; b < n_backends; b++) {
+        sched->backends[b] = backends[b];
+        sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]);
+        GGML_ASSERT(ggml_backend_supports_buft(backends[b], sched->bufts[b]));
+        if (sched->n_copies > 1) {
+            for (int c = 0; c < sched->n_copies; c++) {
+                sched->events[b][c] = ggml_backend_event_new(backends[b]);
+            }
+        }
+    }
+
+    sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);
+
+    ggml_backend_sched_reset(sched);
+
+    return sched;
+}
+
+void ggml_backend_sched_free(ggml_backend_sched_t sched) {
+    if (sched == NULL) {
+        return;
+    }
+    for (int b = 0; b < sched->n_backends; b++) {
+        for (int c = 0; c < sched->n_copies; c++) {
+            ggml_backend_event_free(sched->events[b][c]);
+        }
+    }
+    ggml_gallocr_free(sched->galloc);
+    ggml_free(sched->ctx);
+    ggml_hash_set_free(&sched->hash_set);
+    free(sched->splits);
+    free(sched->hv_tensor_backend_ids);
+    free(sched->hv_tensor_copies);
+    free(sched->node_backend_ids);
+    free(sched->leaf_backend_ids);
+    free(sched->prev_node_backend_ids);
+    free(sched->prev_leaf_backend_ids);
+    free(sched->context_buffer);
+    free(sched->graph.nodes);
+    free(sched->graph.leafs);
+    free(sched);
+}
+
+void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
+    // reset state for the next run
+    if (!sched->is_reset) {
+        ggml_hash_set_reset(&sched->hash_set);
+        memset(sched->hv_tensor_backend_ids, -1, sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));
+        memset(sched->hv_tensor_copies,       0, sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));
+        sched->is_reset = true;
+    }
+    sched->is_alloc = false;
+}
+
+bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
+    GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
+
+    ggml_backend_sched_split_graph(sched, measure_graph);
+
+    if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) {
+        return false;
+    }
+
+    ggml_backend_sched_reset(sched);
+    ggml_backend_sched_synchronize(sched);
+
+    return true;
+}
+
+bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);
+
+    ggml_backend_sched_split_graph(sched, graph);
+
+
+    if (!ggml_backend_sched_alloc_splits(sched)) {
+        return false;
+    }
+
+    sched->is_alloc = true;
+
+    return true;
+}
+
+enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    enum ggml_status err = ggml_backend_sched_graph_compute_async(sched, graph);
+    ggml_backend_sched_synchronize(sched);
+    return err;
+}
+
+enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    if (!sched->is_reset && !sched->is_alloc) {
+        ggml_backend_sched_reset(sched);
+    }
+
+    if (!sched->is_alloc) {
+        if (!ggml_backend_sched_alloc_graph(sched, graph)) {
+            return GGML_STATUS_ALLOC_FAILED;
+        }
+    }
+
+    return ggml_backend_sched_compute_splits(sched);
+}
+
+void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
+    for (int i = 0; i < sched->n_backends; i++) {
+        ggml_backend_synchronize(sched->backends[i]);
+    }
+}
+
+void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
+    sched->callback_eval = callback;
+    sched->callback_eval_user_data = user_data;
+}
+
+int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
+    return sched->n_splits;
+}
+
+int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) {
+    return sched->n_copies;
+}
+
+int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) {
+    return sched->n_backends;
+}
+
+ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) {
+    GGML_ASSERT(i >= 0 && i < sched->n_backends);
+    return sched->backends[i];
+}
+
+size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
+    int backend_index = ggml_backend_sched_backend_id(sched, backend);
+    GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
+
+    return ggml_gallocr_get_buffer_size(sched->galloc, backend_index);
+}
+
+void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
+    int backend_index = ggml_backend_sched_backend_id(sched, backend);
+    GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
+    tensor_backend_id(node) = backend_index;
+    SET_CAUSE(node, "usr");
+    sched->is_reset = false;
+}
+
+ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) {
+    int backend_index = tensor_backend_id(node);
+    if (backend_index == -1) {
+        return NULL;
+    }
+    return sched->backends[backend_index];
+}
+
+// utils
+
+void ggml_backend_view_init(struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor->buffer == NULL);
+    GGML_ASSERT(tensor->view_src != NULL);
+    GGML_ASSERT(tensor->view_src->buffer != NULL);
+    GGML_ASSERT(tensor->view_src->data != NULL);
+
+    tensor->buffer = tensor->view_src->buffer;
+    tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
+    ggml_backend_buffer_init_tensor(tensor->buffer, tensor);
+}
+
+void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
+    GGML_ASSERT(tensor->buffer == NULL);
+    GGML_ASSERT(tensor->data == NULL);
+    GGML_ASSERT(tensor->view_src == NULL);
+    GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer));
+    GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
+                (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));
+
+    tensor->buffer = buffer;
+    tensor->data = addr;
+    ggml_backend_buffer_init_tensor(buffer, tensor);
+}
+
+static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies,
+    struct ggml_context * ctx_allocated, struct ggml_context * ctx_unallocated, struct ggml_tensor * src) {
+
+    GGML_ASSERT(src != NULL);
+    GGML_ASSERT(src->data && "graph must be allocated");
+
+    size_t id = ggml_hash_insert(&hash_set, src);
+    if (id == GGML_HASHSET_ALREADY_EXISTS) {
+        return node_copies[ggml_hash_find(&hash_set, src)];
+    }
+
+    struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src);
+    if (src->view_src != NULL) {
+        dst->view_src = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src);
+        dst->view_offs = src->view_offs;
+    }
+    dst->op = src->op;
+    memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
+    ggml_set_name(dst, src->name);
+
+    // copy src
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        struct ggml_tensor * s = src->src[i];
+        if (s == NULL) {
+            continue;
+        }
+        dst->src[i] = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);
+    }
+
+    node_copies[id] = dst;
+    return dst;
+}
+
+static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) {
+    size_t id = ggml_hash_find(hash_set, src);
+    if (node_init[id]) {
+        return;
+    }
+    node_init[id] = true;
+
+    struct ggml_tensor * dst = node_copies[id];
+    if (dst->view_src != NULL) {
+        graph_copy_init_tensor(hash_set, node_copies, node_init, src->view_src);
+        ggml_backend_view_init(dst);
+    }
+    else {
+        ggml_backend_tensor_copy(src, dst);
+    }
+
+    // init src
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        struct ggml_tensor * s = src->src[i];
+        if (s == NULL) {
+            continue;
+        }
+        graph_copy_init_tensor(hash_set, node_copies, node_init, s);
+    }
+}
+
+struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
+    struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size);
+    struct ggml_tensor ** node_copies = calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT
+    bool * node_init = calloc(hash_set.size, sizeof(node_init[0]));
+
+    struct ggml_init_params params = {
+        /* .mem_size   = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false),
+        /* .mem_buffer = */ NULL,
+        /* .no_alloc   = */ true
+    };
+
+    struct ggml_context * ctx_allocated = ggml_init(params);
+    struct ggml_context * ctx_unallocated = ggml_init(params);
+
+    if (ctx_allocated == NULL || ctx_unallocated == NULL) {
+        fprintf(stderr, "failed to allocate context for graph copy\n");
+        ggml_hash_set_free(&hash_set);
+        free(node_copies);
+        free(node_init);
+        ggml_free(ctx_allocated);
+        ggml_free(ctx_unallocated);
+        return (struct ggml_backend_graph_copy) {
+            /* .buffer           = */ NULL,
+            /* .ctx_allocated    = */ NULL,
+            /* .ctx_unallocated  = */ NULL,
+            /* .graph            = */ NULL,
+        };
+    }
+
+    // dup nodes
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node);
+    }
+
+    // allocate nodes
+    ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);
+    if (buffer == NULL) {
+        fprintf(stderr, "failed to allocate buffer for graph copy\n");
+        ggml_hash_set_free(&hash_set);
+        free(node_copies);
+        free(node_init);
+        ggml_free(ctx_allocated);
+        ggml_free(ctx_unallocated);
+        return (struct ggml_backend_graph_copy) {
+            /* .buffer           = */ NULL,
+            /* .ctx_allocated    = */ NULL,
+            /* .ctx_unallocated  = */ NULL,
+            /* .graph            = */ NULL,
+        };
+    }
+
+    //printf("copy buffer size: %zu MB\n", ggml_backend_buffer_get_size(buffer) / 1024 / 1024);
+
+    // copy data and init views
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        graph_copy_init_tensor(&hash_set, node_copies, node_init, node);
+    }
+
+    // build graph copy
+    struct ggml_cgraph * graph_copy = ggml_new_graph_custom(ctx_allocated, graph->size, false);
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        struct ggml_tensor * node_copy = node_copies[ggml_hash_find(&hash_set, node)];
+        graph_copy->nodes[i] = node_copy;
+    }
+    graph_copy->n_nodes = graph->n_nodes;
+
+    ggml_hash_set_free(&hash_set);
+    free(node_copies);
+    free(node_init);
+
+    return (struct ggml_backend_graph_copy) {
+        /* .buffer           = */ buffer,
+        /* .ctx_allocated    = */ ctx_allocated,
+        /* .ctx_unallocated  = */ ctx_unallocated,
+        /* .graph            = */ graph_copy,
+    };
+}
+
+void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
+    ggml_backend_buffer_free(copy.buffer);
+    ggml_free(copy.ctx_allocated);
+    ggml_free(copy.ctx_unallocated);
+}
+
+bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
+    struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
+    if (copy.buffer == NULL) {
+        return false;
+    }
+
+    struct ggml_cgraph * g1 = graph;
+    struct ggml_cgraph * g2 = copy.graph;
+
+    assert(g1->n_nodes == g2->n_nodes);
+
+    for (int i = 0; i < g1->n_nodes; i++) {
+        //printf("eval %d/%d\n", i, g1->n_nodes);
+        struct ggml_tensor * t1 = g1->nodes[i];
+        struct ggml_tensor * t2 = g2->nodes[i];
+
+        assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
+
+        struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
+        struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
+
+        ggml_backend_graph_compute(backend1, &g1v);
+        ggml_backend_graph_compute(backend2, &g2v);
+
+        if (ggml_is_view_op(t1->op)) {
+            continue;
+        }
+
+        // compare results, calculate rms etc
+        if (!callback(i, t1, t2, user_data)) {
+            break;
+        }
+    }
+
+    ggml_backend_graph_copy_free(copy);
+
+    return true;
+}

+ 266 - 0
llama/ggml-backend.h

@@ -0,0 +1,266 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml.h"
+#include "ggml-alloc.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
+    typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
+    typedef struct ggml_backend_event * ggml_backend_event_t;
+    typedef struct ggml_backend * ggml_backend_t;
+    typedef void * ggml_backend_graph_plan_t;
+
+    //
+    // Backend buffer
+    //
+
+    // buffer type
+    GGML_API           const char *          ggml_backend_buft_name            (ggml_backend_buffer_type_t buft);
+    GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer    (ggml_backend_buffer_type_t buft, size_t size);
+    GGML_API           size_t                ggml_backend_buft_get_alignment   (ggml_backend_buffer_type_t buft);
+    GGML_API           size_t                ggml_backend_buft_get_max_size    (ggml_backend_buffer_type_t buft);
+    GGML_API GGML_CALL size_t                ggml_backend_buft_get_alloc_size  (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
+    GGML_API           bool                  ggml_backend_buft_is_host         (ggml_backend_buffer_type_t buft);
+
+    // buffer
+    enum ggml_backend_buffer_usage {
+        GGML_BACKEND_BUFFER_USAGE_ANY = 0,
+        GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1,
+        GGML_BACKEND_BUFFER_USAGE_COMPUTE = 2,
+    };
+
+    GGML_API           const char *                   ggml_backend_buffer_name          (ggml_backend_buffer_t buffer);
+    GGML_API           void                           ggml_backend_buffer_free          (ggml_backend_buffer_t buffer);
+    GGML_API           void *                         ggml_backend_buffer_get_base      (ggml_backend_buffer_t buffer);
+    GGML_API           size_t                         ggml_backend_buffer_get_size      (ggml_backend_buffer_t buffer);
+    GGML_API GGML_CALL void                           ggml_backend_buffer_init_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API           size_t                         ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
+    GGML_API           size_t                         ggml_backend_buffer_get_max_size  (ggml_backend_buffer_t buffer);
+    GGML_API           size_t                         ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API           void                           ggml_backend_buffer_clear         (ggml_backend_buffer_t buffer, uint8_t value);
+    GGML_API           bool                           ggml_backend_buffer_is_host       (ggml_backend_buffer_t buffer);
+    GGML_API           void                           ggml_backend_buffer_set_usage     (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
+    GGML_API           enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage     (ggml_backend_buffer_t buffer);
+    GGML_API           ggml_backend_buffer_type_t     ggml_backend_buffer_get_type      (ggml_backend_buffer_t buffer);
+    GGML_API           void                           ggml_backend_buffer_reset         (ggml_backend_buffer_t buffer);
+
+    //
+    // Backend
+    //
+
+    GGML_API ggml_guid_t  ggml_backend_guid(ggml_backend_t backend);
+    GGML_API const char * ggml_backend_name(ggml_backend_t backend);
+    GGML_API void         ggml_backend_free(ggml_backend_t backend);
+
+    GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
+    GGML_API ggml_backend_buffer_t      ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
+    GGML_API size_t                     ggml_backend_get_alignment(ggml_backend_t backend);
+    GGML_API size_t                     ggml_backend_get_max_size(ggml_backend_t backend);
+
+    GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+    GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+
+    // "offset" refers to the offset of the tensor data for setting/getting data
+    GGML_API GGML_CALL void ggml_backend_tensor_set(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+    GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+
+    GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
+
+    GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
+    GGML_API void                      ggml_backend_graph_plan_free  (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+
+    GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+    GGML_API enum ggml_status ggml_backend_graph_compute      (ggml_backend_t backend, struct ggml_cgraph * cgraph);
+    GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
+    GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
+    GGML_API bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
+    GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op);
+
+    // tensor copy between different backends
+    GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
+
+    // asynchronous copy
+    // the copy is performed after all the currently queued operations in backend_src
+    // backend_dst will wait for the copy to complete before performing other operations
+    // automatic fallback to sync copy if async is not supported
+    GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);
+
+    // events
+    GGML_API ggml_backend_event_t   ggml_backend_event_new        (ggml_backend_t backend);
+    GGML_API void                   ggml_backend_event_free       (ggml_backend_event_t event);
+    GGML_API void                   ggml_backend_event_record     (ggml_backend_event_t event);
+    GGML_API void                   ggml_backend_event_synchronize(ggml_backend_event_t event);
+    GGML_API void                   ggml_backend_event_wait       (ggml_backend_t backend, ggml_backend_event_t event);
+
+    //
+    // CPU backend
+    //
+
+    GGML_API ggml_backend_t ggml_backend_cpu_init(void);
+
+    GGML_API GGML_CALL bool ggml_backend_is_cpu                (ggml_backend_t backend);
+    GGML_API           void ggml_backend_cpu_set_n_threads     (ggml_backend_t backend_cpu, int n_threads);
+    GGML_API           void ggml_backend_cpu_set_threadpool    (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
+    GGML_API           void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
+
+    // Create a backend buffer from an existing pointer
+    GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
+
+    GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
+
+#ifdef GGML_USE_CPU_HBM
+    GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
+#endif
+
+    //
+    // Backend registry
+    //
+
+    // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
+
+    GGML_API size_t                     ggml_backend_reg_get_count(void);
+    GGML_API size_t                     ggml_backend_reg_find_by_name(const char * name);
+    GGML_API ggml_backend_t             ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional)
+    GGML_API const char *               ggml_backend_reg_get_name(size_t i);
+    GGML_API ggml_backend_t             ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
+    GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
+    GGML_API ggml_backend_buffer_t      ggml_backend_reg_alloc_buffer(size_t i, size_t size);
+
+    //
+    // Backend scheduler
+    //
+
+    // The backend scheduler allows for multiple backends to be used together
+    // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends
+    // The backends are selected based on:
+    // - the backend that supports the operation
+    // - the location of the pre-allocated tensors (e.g. the weights)
+    /*
+      Example usage:
+
+        // operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned
+        // preferrably to run on the same backend as the buffer
+        ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+
+        sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false);
+
+        // initialize buffers from a max size graph (optional)
+        reserve_graph = build_graph(sched, max_batch_size);
+
+        // manually assign nodes to a backend (optional, should not be needed in most cases)
+        struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
+        ggml_backend_sched_set_tensor_backend(sched, node, backend_gpu);
+
+        ggml_backend_sched_reserve(sched, reserve_graph);
+
+        // compute
+        graph = build_graph(sched);
+        ggml_backend_sched_graph_compute(sched, graph);
+
+        // if there are graph inputs:
+        ggml_backend_sched_reset(sched);
+        ggml_backend_sched_alloc_graph(sched, graph);
+        ggml_backend_tensor_set(input_tensor, ...);
+        ggml_backend_sched_graph_compute(sched, graph);
+    }
+    */
+
+    struct ggml_backend_sched;
+    typedef struct ggml_backend_sched * ggml_backend_sched_t;
+
+    // when ask == true, the scheduler wants to know if the user wants to observe this node
+    // this allows the scheduler to batch nodes together in order to evaluate them in a single call
+    //
+    // when ask == false, the scheduler is passing the node tensor to the user for observation
+    // if the user returns false, the scheduler will cancel the graph compute
+    //
+    typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
+
+    // Initialize a backend scheduler
+    GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
+    GGML_API void                 ggml_backend_sched_free(ggml_backend_sched_t sched);
+
+    // Initialize backend buffers from a measure graph
+    GGML_API bool                 ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
+
+    GGML_API int                  ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
+    GGML_API ggml_backend_t       ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);
+
+    // Get the number of splits of the last graph
+    GGML_API int                  ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
+    GGML_API int                  ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);
+
+    GGML_API size_t               ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
+
+    GGML_API void                 ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
+    GGML_API ggml_backend_t       ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
+
+    // Allocate and compute graph on the backend scheduler
+    GGML_API bool                 ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+    GGML_API enum ggml_status     ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+    GGML_API enum ggml_status     ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+    GGML_API void                 ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
+
+    // Reset all assignments and allocators - must be called before changing the node backends
+    GGML_API void                 ggml_backend_sched_reset(ggml_backend_sched_t sched);
+
+    // Set a callback to be called for each resulting node during graph compute
+    GGML_API void                 ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
+
+    //
+    // Utils
+    //
+
+    struct ggml_backend_graph_copy {
+        ggml_backend_buffer_t buffer;
+        struct ggml_context * ctx_allocated;
+        struct ggml_context * ctx_unallocated;
+        struct ggml_cgraph * graph;
+    };
+
+    // Copy a graph to a different backend
+    GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
+    GGML_API void                           ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
+
+    typedef bool (*GGML_CALL ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
+
+    // Compare the output of two backends
+    GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
+
+    // Tensor initialization
+    GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
+    GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
+
+
+#ifdef  __cplusplus
+}
+#endif

+ 397 - 0
llama/ggml-blas.cpp

@@ -0,0 +1,397 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifdef GGML_USE_BLAS
+
+#include "ggml-blas.h"
+#include "ggml-backend-impl.h"
+
+#include <future>
+#include <vector>
+
+#if defined(GGML_USE_ACCELERATE)
+#   include <Accelerate/Accelerate.h>
+#elif defined(GGML_BLAS_USE_MKL)
+#   include <mkl.h>
+#elif defined(GGML_BLAS_USE_BLIS)
+#   include <blis.h>
+#elif defined(GGML_BLAS_USE_NVPL)
+#   include <nvpl_blas.h>
+#else
+#   include <cblas.h>
+#endif
+
+struct ggml_backend_blas_context {
+    int n_threads = GGML_DEFAULT_N_THREADS;
+    std::unique_ptr<char[]> work_data;
+    size_t work_size = 0;
+#ifndef GGML_USE_OPENMP
+    std::vector<std::future<void>> tasks;
+#endif
+};
+
+// helper function to determine if it is better to use BLAS or not
+// for large matrices, BLAS is faster
+static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    const int64_t ne10 = src1->ne[0];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+
+    // TODO: find the optimal values for these
+    if (ggml_is_contiguous(src0) &&
+        ggml_is_contiguous(src1) &&
+        src1->type == GGML_TYPE_F32 &&
+        (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
+
+        /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
+        return true;
+    }
+
+    return false;
+}
+
+static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const enum ggml_type type = src0->type;
+
+    GGML_ASSERT(ne0 == ne01);
+    GGML_ASSERT(ne1 == ne11);
+    GGML_ASSERT(ne2 == ne12);
+    GGML_ASSERT(ne3 == ne13);
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == ggml_type_size(type));
+    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    // broadcast factors
+    const int64_t r2 = ne12/ne02;
+    const int64_t r3 = ne13/ne03;
+
+    const int64_t ne_plane      = ne01*ne00;
+    const size_t  desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
+
+    if (ctx->work_size < desired_wsize) {
+        ctx->work_data.reset(new char[desired_wsize]);
+        ctx->work_size = desired_wsize;
+    }
+    void * wdata = ctx->work_data.get();
+
+    // convert src0 to float
+    if (type != GGML_TYPE_F32) {
+        ggml_type_traits_t type_traits = ggml_internal_get_type_traits(type);
+        ggml_to_float_t const to_float = type_traits.to_float;
+
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                const void  *       x      = (char *)  src0->data + i02*nb02          + i03*nb03;
+                      float * const wplane = (float *) wdata      + i02*ne_plane      + i03*ne02*ne_plane;
+
+                const int min_cols_per_thread = 4096;
+                const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1);
+                const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1);
+
+#ifdef GGML_USE_OPENMP
+                #pragma omp parallel for num_threads(n_threads)
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
+                }
+#else
+                for (int i = 1; i < n_threads; i++) {
+                    const int64_t start =       i*ne01/n_threads;
+                    const int64_t end   = (i + 1)*ne01/n_threads;
+                    if (start < end) {
+                        ctx->tasks.push_back(std::async(std::launch::async, [=]() {
+                            for (int64_t i01 = start; i01 < end; i01++) {
+                                to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
+                            }
+                        }));
+                    }
+                }
+                {
+                    // reuse the current thread for the first task
+                    const int64_t start = 0;
+                    const int64_t end   = ne01/n_threads;
+                    for (int64_t i01 = start; i01 < end; i01++) {
+                        to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
+                    }
+                }
+#endif
+            }
+        }
+
+#ifndef GGML_USE_OPENMP
+        // wait for all tasks to finish
+        for (auto & task : ctx->tasks) {
+            task.get();
+        }
+        ctx->tasks.clear();
+#endif
+    }
+
+#if defined(OPENBLAS_VERSION)
+    openblas_set_num_threads(ctx->n_threads);
+#endif
+
+#if defined(GGML_BLAS_USE_BLIS)
+    bli_thread_set_num_threads(ctx->n_threads);
+#endif
+
+#if defined(GGML_BLAS_USE_NVPL)
+    nvpl_blas_set_num_threads(ctx->n_threads);
+#endif
+
+    for (int64_t i13 = 0; i13 < ne13; i13++) {
+        for (int64_t i12 = 0; i12 < ne12; i12++) {
+            const int64_t i03 = i13/r3;
+            const int64_t i02 = i12/r2;
+
+            const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
+            const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
+                  float * d = (float *) ((char *)  dst->data + i12*nb2  + i13*nb3);
+
+            if (type != GGML_TYPE_F32) {
+                x = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
+            }
+
+            cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                        ne1, ne01, ne10,
+                        1.0f,   y, ne10,
+                                x, ne00,
+                        0.0f,   d, ne01);
+        }
+    }
+}
+
+static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(ne0  == ne00);
+    GGML_ASSERT(ne1  == ne10);
+    GGML_ASSERT(ne2  == ne02);
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne3  == ne13);
+    GGML_ASSERT(ne03 == ne13);
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    // GGML_ASSERT(nb0 <= nb1);
+    // GGML_ASSERT(nb1 <= nb2);
+    // GGML_ASSERT(nb2 <= nb3);
+
+    // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
+    // src0: (k,n)
+    // src1: (k,m)
+    // dst:  (m,n)
+    //
+    // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
+    // Also expressed as (major,minor)
+    // a: (m,k): so src1 transposed
+    // b: (k,n): so src0
+    // c: (m,n)
+    //
+    // However, if ggml_is_transposed(src1) is true, then
+    // src1->data already contains a transposed version, so sgemm mustn't
+    // transpose it further.
+
+    int n = src0->ne[0];
+    int k = src0->ne[1];
+    int m = src1->ne[0];
+
+    CBLAS_TRANSPOSE transposeA;
+    int lda;
+
+    if (!ggml_is_transposed(src1)) {
+        transposeA = CblasTrans;
+        lda = m;
+    } else {
+        transposeA = CblasNoTrans;
+        lda = k;
+    }
+
+    float * a = (float *) ((char *) src1->data);
+    float * b = (float *) ((char *) src0->data);
+    float * c = (float *) ((char *) dst->data);
+
+    cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
+
+    GGML_UNUSED(ctx);
+}
+
+// backend interface
+
+GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) {
+    return "BLAS";
+
+    GGML_UNUSED(backend);
+}
+
+GGML_CALL static void ggml_backend_blas_free(ggml_backend_t backend) {
+    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
+    delete ctx;
+    delete backend;
+}
+
+GGML_CALL static ggml_backend_buffer_type_t ggml_backend_blas_get_default_buffer_type(ggml_backend_t backend) {
+    return ggml_backend_cpu_buffer_type();
+
+    GGML_UNUSED(backend);
+}
+
+GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
+
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * node = cgraph->nodes[i];
+
+        switch (node->op) {
+            case GGML_OP_MUL_MAT:
+                ggml_backend_blas_mul_mat(ctx, node);
+                break;
+
+            case GGML_OP_OUT_PROD:
+                ggml_backend_blas_out_prod(ctx, node);
+                break;
+
+            case GGML_OP_NONE:
+            case GGML_OP_RESHAPE:
+            case GGML_OP_VIEW:
+            case GGML_OP_PERMUTE:
+            case GGML_OP_TRANSPOSE:
+                break;
+
+            default:
+                GGML_ABORT("%s: unsupported op %s\n", __func__, ggml_op_desc(node));
+        }
+    }
+
+    return GGML_STATUS_SUCCESS;
+
+    GGML_UNUSED(backend);
+}
+
+GGML_CALL static bool ggml_backend_blas_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+
+    return (op->op == GGML_OP_MUL_MAT  && ggml_backend_blas_use_blas(op)) ||
+           (op->op == GGML_OP_OUT_PROD && op->src[0]->type == GGML_TYPE_F32 &&
+                                          op->src[1]->type == GGML_TYPE_F32 &&
+                                          ggml_is_matrix(src0) &&
+                                          ggml_is_matrix(src1) &&
+                                          ggml_is_contiguous(src0) &&
+                                          (ggml_is_contiguous(src1) || ggml_is_transposed(src1)));
+
+    GGML_UNUSED(backend);
+}
+
+GGML_CALL static bool ggml_backend_blas_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+    return ggml_backend_buft_is_host(buft);
+
+    GGML_UNUSED(backend);
+}
+
+static struct ggml_backend_i blas_backend_i = {
+    /* .get_name                = */ ggml_backend_blas_name,
+    /* .free                    = */ ggml_backend_blas_free,
+    /* .get_default_buffer_type = */ ggml_backend_blas_get_default_buffer_type,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_async        = */ NULL,
+    /* .synchronize             = */ NULL,
+    /* .graph_plan_create       = */ NULL,
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_blas_graph_compute,
+    /* .supports_op             = */ ggml_backend_blas_supports_op,
+    /* .supports_buft           = */ ggml_backend_blas_supports_buft,
+    /* .offload_op              = */ NULL,
+    /* .event_new               = */ NULL,
+    /* .event_free              = */ NULL,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+    /* .event_synchronize       = */ NULL,
+};
+
+static ggml_guid_t ggml_backend_blas_guid(void) {
+    static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d };
+    return &guid;
+}
+
+ggml_backend_t ggml_backend_blas_init(void) {
+    ggml_backend_blas_context * ctx = new ggml_backend_blas_context;
+
+    ggml_backend_t backend = new ggml_backend {
+        /* .guid      = */ ggml_backend_blas_guid(),
+        /* .interface = */ blas_backend_i,
+        /* .context   = */ ctx,
+    };
+
+#if !defined(NDEBUG) && defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
+    if (openblas_get_parallel() != OPENBLAS_OPENMP) {
+        fprintf(stderr, "%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
+    }
+#endif
+
+#if !defined(NDEBUG) && defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
+    fprintf(stderr, "%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
+#endif
+
+    return backend;
+}
+
+GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend) {
+    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid());
+}
+
+void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) {
+    GGML_ASSERT(ggml_backend_is_blas(backend_blas));
+
+    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;
+    ctx->n_threads = n_threads;
+}
+
+#endif

+ 49 - 0
llama/ggml-blas.h

@@ -0,0 +1,49 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+// backend API
+GGML_API GGML_CALL ggml_backend_t ggml_backend_blas_init(void);
+
+GGML_API GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend);
+
+// number of threads used for conversion to float
+// for openblas and blis, this will also set the number of threads used for blas operations
+GGML_API GGML_CALL void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
+
+
+#ifdef  __cplusplus
+}
+#endif

+ 1859 - 0
llama/ggml-common.h

@@ -0,0 +1,1859 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef GGML_COMMON_DECL
+
+#if defined(GGML_COMMON_DECL_C)
+#include <stdint.h>
+
+typedef uint16_t ggml_half;
+typedef uint32_t ggml_half2;
+
+#define GGML_COMMON_AGGR
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_METAL)
+#include <metal_stdlib>
+
+typedef half  ggml_half;
+typedef half2 ggml_half2;
+
+#define GGML_COMMON_AGGR
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_CUDA)
+#if defined(GGML_COMMON_DECL_MUSA)
+#include <musa_fp16.h>
+#else
+#include <cuda_fp16.h>
+#endif
+#include <cstdint>
+
+typedef half  ggml_half;
+typedef half2 ggml_half2;
+
+#define GGML_COMMON_AGGR data
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_HIP)
+#include <hip/hip_fp16.h>
+#include <cstdint>
+
+typedef half  ggml_half;
+typedef half2 ggml_half2;
+
+#define GGML_COMMON_AGGR data
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_SYCL)
+#include <sycl/half_type.hpp>
+#include <cstdint>
+
+typedef sycl::half  ggml_half;
+typedef sycl::half2 ggml_half2;
+
+#define GGML_COMMON_AGGR data
+
+#define GGML_COMMON_DECL
+#endif
+
+#if defined(GGML_COMMON_DECL)
+
+#ifndef __cplusplus
+#ifndef static_assert
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
+#define static_assert(cond, msg) _Static_assert(cond, msg)
+#else
+#define static_assert(cond, msg) struct global_scope_noop_trick
+#endif
+#endif
+#endif // __cplusplus
+
+// QK = number of values after dequantization
+// QK_K = super-block size
+
+#define QK_K 256
+#define K_SCALE_SIZE 12
+
+#if defined(GGML_COMMON_DECL_CUDA) || defined(GGML_COMMON_DECL_HIP) || defined(GGML_COMMON_DECL_SYCL)
+// QR = QK / number of values before dequantization
+// QI = number of 32 bit integers before dequantization
+
+#define QI4_0 (QK4_0 / (4 * QR4_0))
+#define QR4_0 2
+
+#define QI4_1 (QK4_1 / (4 * QR4_1))
+#define QR4_1 2
+
+#define QI5_0 (QK5_0 / (4 * QR5_0))
+#define QR5_0 2
+
+#define QI5_1 (QK5_1 / (4 * QR5_1))
+#define QR5_1 2
+
+#define QI8_0 (QK8_0 / (4 * QR8_0))
+#define QR8_0 1
+
+#define QI8_1 (QK8_1 / (4 * QR8_1))
+#define QR8_1 1
+
+#define QI2_K (QK_K / (4*QR2_K))
+#define QR2_K 4
+
+#define QI3_K (QK_K / (4*QR3_K))
+#define QR3_K 4
+
+#define QI4_K (QK_K / (4*QR4_K))
+#define QR4_K 2
+
+#define QI5_K (QK_K / (4*QR5_K))
+#define QR5_K 2
+
+#define QI6_K (QK_K / (4*QR6_K))
+#define QR6_K 2
+
+#define QI2_XXS (QK_K / (4*QR2_XXS))
+#define QR2_XXS 4
+
+#define QI2_XS (QK_K / (4*QR2_XS))
+#define QR2_XS 4
+
+#define QI2_S (QK_K / (4*QR2_S))
+#define QR2_S 4
+
+#define QI3_XXS (QK_K / (4*QR3_XXS))
+#define QR3_XXS 4
+
+#define QI3_XS (QK_K / (4*QR3_XS))
+#define QR3_XS 4
+
+#define QI1_S (QK_K / (4*QR1_S))
+#define QR1_S 8
+
+#define QI1_M (QK_K / (4*QR1_M))
+#define QR1_M 8
+
+#define QI4_NL (QK4_NL / (4*QR4_NL))
+#define QR4_NL 2
+
+#define QI4_XS (QK_K / (4*QR4_XS))
+#define QR4_XS 2
+
+#define QI3_S (QK_K / (4*QR3_S))
+#define QR3_S 4
+
+#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
+
+#define QK4_0 32
+typedef struct {
+    ggml_half d;           // delta
+    uint8_t qs[QK4_0 / 2]; // nibbles / quants
+} block_q4_0;
+static_assert(sizeof(block_q4_0) == sizeof(ggml_half) + QK4_0 / 2, "wrong q4_0 block size/padding");
+
+#define QK4_1 32
+typedef struct {
+    union {
+        struct {
+            ggml_half d; // delta
+            ggml_half m; // min
+        } GGML_COMMON_AGGR;
+        ggml_half2 dm;
+    };
+    uint8_t qs[QK4_1 / 2]; // nibbles / quants
+} block_q4_1;
+static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
+
+#define QK5_0 32
+typedef struct {
+    ggml_half d;           // delta
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_0 / 2]; // nibbles / quants
+} block_q5_0;
+static_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
+
+#define QK5_1 32
+typedef struct {
+    union {
+        struct {
+            ggml_half d; // delta
+            ggml_half m; // min
+        } GGML_COMMON_AGGR;
+        ggml_half2 dm;
+    };
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_1 / 2]; // nibbles / quants
+} block_q5_1;
+static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_half) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
+
+#define QK8_0 32
+typedef struct {
+    ggml_half d;       // delta
+    int8_t  qs[QK8_0]; // quants
+} block_q8_0;
+static_assert(sizeof(block_q8_0) == sizeof(ggml_half) + QK8_0, "wrong q8_0 block size/padding");
+
+#define QK8_1 32
+typedef struct {
+    union {
+        struct {
+            ggml_half d; // delta
+            ggml_half s; // d * sum(qs[i])
+        } GGML_COMMON_AGGR;
+        ggml_half2 ds;
+    };
+    int8_t qs[QK8_1]; // quants
+} block_q8_1;
+static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding");
+
+typedef struct {
+    ggml_half d[4];        // deltas for 4 q4_0 blocks
+    uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks
+} block_q4_0x4;
+static_assert(sizeof(block_q4_0x4) == 4 * sizeof(ggml_half) + QK4_0 * 2, "wrong q4_0x4 block size/padding");
+
+typedef struct {
+    ggml_half d[8];        // deltas for 8 q4_0 blocks
+    uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks
+} block_q4_0x8;
+static_assert(sizeof(block_q4_0x8) == 8 * sizeof(ggml_half) + QK4_0 * 4, "wrong q4_0x8 block size/padding");
+
+typedef struct {
+    ggml_half d[4];        // deltas for 4 q8_0 blocks
+    int8_t qs[QK8_0 * 4];  // quants for 4 q8_0 blocks
+} block_q8_0x4;
+static_assert(sizeof(block_q8_0x4) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong q8_0x4 block size/padding");
+
+typedef struct {
+    ggml_half d[8];        // deltas for 8 q8_0 blocks
+    int8_t qs[QK8_0 * 8];  // quants for 8 q8_0 blocks
+} block_q8_0x8;
+static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding");
+
+//
+// Super-block quantization structures
+//
+
+// 2-bit quantization
+// weight is represented as x = a * q + b
+// 16 blocks of 16 elements each
+// Effectively 2.625 bits per weight
+typedef struct {
+    uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
+    uint8_t qs[QK_K/4];      // quants
+    union {
+        struct {
+            ggml_half d;    // super-block scale for quantized scales
+            ggml_half dmin; // super-block scale for quantized mins
+        } GGML_COMMON_AGGR;
+        ggml_half2 dm;
+    };
+} block_q2_K;
+static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
+
+// 3-bit quantization
+// weight is represented as x = a * q
+// 16 blocks of 16 elements each
+// Effectively 3.4375 bits per weight
+typedef struct {
+    uint8_t hmask[QK_K/8]; // quants - high bit
+    uint8_t qs[QK_K/4];    // quants - low 2 bits
+    uint8_t scales[12];    // scales, quantized with 6 bits
+    ggml_half d;           // super-block scale
+} block_q3_K;
+static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
+
+// 4-bit quantization
+// 8 blocks of 32 elements each
+// weight is represented as x = a * q + b
+// Effectively 4.5 bits per weight
+typedef struct {
+    union {
+        struct {
+            ggml_half d;    // super-block scale for quantized scales
+            ggml_half dmin; // super-block scale for quantized mins
+        } GGML_COMMON_AGGR;
+        ggml_half2 dm;
+    };
+    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+    uint8_t qs[QK_K/2];           // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
+
+// 5-bit quantization
+// 8 blocks of 32 elements each
+// weight is represented as x = a * q + b
+// Effectively 5.5 bits per weight
+typedef struct {
+    union {
+        struct {
+            ggml_half d;    // super-block scale for quantized scales
+            ggml_half dmin; // super-block scale for quantized mins
+        } GGML_COMMON_AGGR;
+        ggml_half2 dm;
+    };
+    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+    uint8_t qh[QK_K/8];           // quants, high bit
+    uint8_t qs[QK_K/2];           // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+
+// 6-bit quantization
+// weight is represented as x = a * q
+// 16 blocks of 16 elements each
+// Effectively 6.5625 bits per weight
+typedef struct {
+    uint8_t ql[QK_K/2];      // quants, lower 4 bits
+    uint8_t qh[QK_K/4];      // quants, upper 2 bits
+    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
+    ggml_half d;             // super-block scale
+} block_q6_K;
+static_assert(sizeof(block_q6_K) == sizeof(ggml_half) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
+
+// This is only used for intermediate quantization and dot products
+typedef struct {
+    float   d;              // delta
+    int8_t  qs[QK_K];       // quants
+    int16_t bsums[QK_K/16]; // sum of quants in groups of 16
+} block_q8_K;
+static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
+
+// (Almost) "true" 2-bit quantization.
+// Due to the need to use blocks as per ggml design, it ends up using
+// 2.0625 bpw because of the 16-bit scale for each block of 256.
+typedef struct {
+    ggml_half d;
+    uint16_t qs[QK_K/8];
+} block_iq2_xxs;
+static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
+
+// 2.3125 bpw quants
+typedef struct {
+    ggml_half d;
+    uint16_t qs[QK_K/8];
+    uint8_t  scales[QK_K/32];
+} block_iq2_xs;
+static_assert(sizeof(block_iq2_xs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
+
+// 2.5625 bpw quants
+typedef struct {
+    ggml_half d;
+    uint8_t qs[QK_K/4];
+    uint8_t qh[QK_K/32];
+    uint8_t scales[QK_K/32];
+} block_iq2_s;
+static_assert(sizeof(block_iq2_s) == sizeof(ggml_half) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding");
+
+// (Almost) "true" 3-bit quantization.
+// Due to the need to use blocks as per ggml design, it ends up using
+// 3.0625 bpw because of the 16-bit scale for each block of 256.
+typedef struct {
+    ggml_half d;
+    uint8_t qs[3*QK_K/8];
+} block_iq3_xxs;
+static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_half) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
+
+// 3.4375 bpw
+#define IQ3S_N_SCALE QK_K/64
+typedef struct {
+    ggml_half d;
+    uint8_t qs[QK_K/4];
+    uint8_t qh[QK_K/32];
+    uint8_t signs[QK_K/8];
+    uint8_t scales[IQ3S_N_SCALE];
+} block_iq3_s;
+static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
+
+typedef struct {
+    ggml_half d;
+    uint8_t  qs[QK_K/8];
+    uint16_t qh[QK_K/32];
+} block_iq1_s;
+static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
+
+// 1.75 bpw
+typedef struct {
+    uint8_t  qs[QK_K/8];      // grid index, low 8 bits
+    uint8_t  qh[QK_K/16];     // grid index, high 3 bits + grid shift bit (for two groups of 8)
+    uint8_t  scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64)
+} block_iq1_m;
+static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");
+
+// Used by IQ1_M quants
+typedef union {
+    ggml_half f16;
+    uint16_t  u16;
+} iq1m_scale_t;
+
+// Non-linear quants
+#define QK4_NL 32
+typedef struct {
+    ggml_half d;
+    uint8_t qs[QK4_NL/2];
+} block_iq4_nl;
+static_assert(sizeof(block_iq4_nl) == sizeof(ggml_half) + QK4_NL/2, "wrong iq4_nl block size/padding");
+
+typedef struct {
+    ggml_half d;
+    uint16_t scales_h;
+    uint8_t  scales_l[QK_K/64];
+    uint8_t  qs[QK_K/2];
+} block_iq4_xs;
+static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
+
+#endif // GGML_COMMON_DECL
+#endif // GGML_COMMON_DECL
+
+////////////////////////////////////////////////////////////////////////////////
+
+#ifndef GGML_COMMON_IMPL
+
+#if defined(GGML_COMMON_IMPL_C)
+#include <stdint.h>
+
+#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#elif defined(GGML_COMMON_IMPL_METAL)
+#include <metal_stdlib>
+
+#define GGML_TABLE_BEGIN(type, name, size) static const constant type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)
+#include <cstdint>
+
+#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#elif defined(GGML_COMMON_IMPL_SYCL)
+
+#include <cstdint>
+
+#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#endif
+
+#if defined(GGML_COMMON_IMPL)
+
+GGML_TABLE_BEGIN(uint8_t, kmask_iq2xs, 8)
+    1, 2, 4, 8, 16, 32, 64, 128
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
+      0, 129, 130,   3, 132,   5,   6, 135, 136,   9,  10, 139,  12, 141, 142,  15,
+    144,  17,  18, 147,  20, 149, 150,  23,  24, 153, 154,  27, 156,  29,  30, 159,
+    160,  33,  34, 163,  36, 165, 166,  39,  40, 169, 170,  43, 172,  45,  46, 175,
+     48, 177, 178,  51, 180,  53,  54, 183, 184,  57,  58, 187,  60, 189, 190,  63,
+    192,  65,  66, 195,  68, 197, 198,  71,  72, 201, 202,  75, 204,  77,  78, 207,
+     80, 209, 210,  83, 212,  85,  86, 215, 216,  89,  90, 219,  92, 221, 222,  95,
+     96, 225, 226,  99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
+    240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
+GGML_TABLE_END()
+
+//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
+    0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
+    0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
+    0xff000000ff000000, 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff,
+    0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, 0x00000000ffffffff,
+    0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff,
+    0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff,
+    0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff,
+    0xff0000ffffff0000, 0x000000ffffff00ff, 0x000000ffffffff00, 0xff0000ffffffffff,
+    0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff,
+    0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff,
+    0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff,
+    0xff00ff00ffff0000, 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff,
+    0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, 0x0000ffff0000ffff,
+    0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00, 0xff00ffff00ffffff,
+    0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff,
+    0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff,
+    0xffff000000000000, 0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff,
+    0x00ff000000ff0000, 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff,
+    0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, 0x00ff0000ff00ffff,
+    0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff,
+    0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff,
+    0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff,
+    0xffff00ffff000000, 0x00ff00ffff0000ff, 0x00ff00ffff00ff00, 0xffff00ffff00ffff,
+    0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff,
+    0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff,
+    0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff,
+    0xffffff00ff000000, 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff,
+    0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, 0x00ffff00ffffffff,
+    0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00, 0xffffffff0000ffff,
+    0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff,
+    0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
+    0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
+GGML_TABLE_END()
+//#endif
+
+
+GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)
+    0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+    0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
+    0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
+    0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
+    0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
+    0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
+    0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
+    0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
+    0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
+    0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
+    0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
+    0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
+    0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
+    0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
+    0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
+    0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
+    0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
+    0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
+    0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
+    0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
+    0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
+    0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
+    0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
+    0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
+    0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
+    0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
+    0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
+    0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
+    0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
+    0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
+    0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
+    0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
+    0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
+    0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
+    0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
+    0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
+    0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
+    0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
+    0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
+    0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
+    0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
+    0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
+    0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
+    0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
+    0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
+    0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
+    0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
+    0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
+    0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
+    0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
+    0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
+    0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
+    0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
+    0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
+    0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
+    0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
+    0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
+    0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
+    0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
+    0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
+    0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
+    0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
+    0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
+    0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint64_t, iq2xs_grid, 512)
+    0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+    0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
+    0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
+    0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
+    0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
+    0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
+    0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
+    0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
+    0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
+    0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
+    0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
+    0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
+    0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
+    0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
+    0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
+    0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
+    0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
+    0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
+    0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
+    0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
+    0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
+    0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
+    0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
+    0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
+    0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
+    0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
+    0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
+    0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
+    0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
+    0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
+    0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
+    0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
+    0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
+    0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
+    0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
+    0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
+    0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
+    0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
+    0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
+    0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
+    0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
+    0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
+    0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
+    0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
+    0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
+    0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
+    0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
+    0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
+    0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
+    0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
+    0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
+    0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
+    0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
+    0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
+    0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
+    0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
+    0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
+    0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
+    0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
+    0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
+    0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
+    0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
+    0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
+    0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
+    0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
+    0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
+    0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
+    0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
+    0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
+    0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
+    0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
+    0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
+    0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
+    0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
+    0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
+    0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
+    0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
+    0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
+    0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
+    0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
+    0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
+    0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
+    0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
+    0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
+    0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
+    0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
+    0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
+    0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
+    0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
+    0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
+    0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
+    0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
+    0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
+    0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
+    0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
+    0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
+    0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
+    0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
+    0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
+    0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
+    0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
+    0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
+    0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
+    0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
+    0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
+    0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
+    0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
+    0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
+    0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
+    0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
+    0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
+    0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
+    0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
+    0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
+    0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
+    0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
+    0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
+    0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
+    0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
+    0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
+    0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
+    0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
+    0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
+    0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
+    0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
+    0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
+    0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
+    0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint64_t, iq2s_grid, 1024)
+    0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+    0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
+    0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
+    0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
+    0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
+    0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b,
+    0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919,
+    0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808,
+    0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908,
+    0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,
+    0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908,
+    0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08,
+    0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19,
+    0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819,
+    0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,
+    0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b,
+    0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908,
+    0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908,
+    0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919,
+    0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,
+    0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b,
+    0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919,
+    0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b,
+    0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919,
+    0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,
+    0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b,
+    0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b,
+    0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08,
+    0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b,
+    0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,
+    0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808,
+    0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908,
+    0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b,
+    0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908,
+    0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,
+    0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819,
+    0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808,
+    0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08,
+    0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819,
+    0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,
+    0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908,
+    0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919,
+    0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b,
+    0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919,
+    0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,
+    0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819,
+    0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919,
+    0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919,
+    0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808,
+    0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,
+    0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b,
+    0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908,
+    0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b,
+    0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908,
+    0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,
+    0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b,
+    0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919,
+    0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b,
+    0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819,
+    0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,
+    0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908,
+    0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b,
+    0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908,
+    0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b,
+    0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,
+    0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08,
+    0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908,
+    0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819,
+    0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819,
+    0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,
+    0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08,
+    0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19,
+    0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819,
+    0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808,
+    0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,
+    0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919,
+    0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808,
+    0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19,
+    0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08,
+    0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,
+    0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908,
+    0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808,
+    0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819,
+    0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908,
+    0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,
+    0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808,
+    0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808,
+    0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819,
+    0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908,
+    0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,
+    0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819,
+    0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b,
+    0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b,
+    0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08,
+    0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,
+    0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819,
+    0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919,
+    0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908,
+    0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808,
+    0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,
+    0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908,
+    0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808,
+    0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08,
+    0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08,
+    0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
+    0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919,
+    0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808,
+    0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819,
+    0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908,
+    0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,
+    0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819,
+    0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808,
+    0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808,
+    0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819,
+    0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,
+    0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908,
+    0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b,
+    0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,
+    0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808,
+    0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,
+    0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808,
+    0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b,
+    0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19,
+    0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819,
+    0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,
+    0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b,
+    0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908,
+    0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b,
+    0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b,
+    0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,
+    0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808,
+    0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819,
+    0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908,
+    0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08,
+    0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,
+    0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819,
+    0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919,
+    0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908,
+    0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b,
+    0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,
+    0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b,
+    0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908,
+    0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08,
+    0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819,
+    0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,
+    0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819,
+    0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919,
+    0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808,
+    0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808,
+    0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,
+    0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819,
+    0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919,
+    0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808,
+    0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819,
+    0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,
+    0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808,
+    0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b,
+    0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908,
+    0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808,
+    0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,
+    0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b,
+    0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908,
+    0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b,
+    0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908,
+    0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,
+    0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908,
+    0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08,
+    0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908,
+    0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b,
+    0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,
+    0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08,
+    0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819,
+    0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919,
+    0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808,
+    0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,
+    0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b,
+    0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919,
+    0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808,
+    0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819,
+    0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,
+    0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919,
+    0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808,
+    0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808,
+    0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b,
+    0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,
+    0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808,
+    0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b,
+    0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808,
+    0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919,
+    0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,
+    0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08,
+    0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919,
+    0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808,
+    0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b,
+    0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,
+    0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808,
+    0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808,
+    0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808,
+    0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908,
+    0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,
+    0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808,
+    0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b,
+    0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908,
+    0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808,
+    0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,
+    0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,
+    0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919,
+    0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b,
+    0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808,
+    0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,
+    0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b,
+    0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908,
+    0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08,
+    0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908,
+    0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,
+    0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819,
+    0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908,
+    0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b,
+    0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808,
+    0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,
+    0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908,
+    0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919,
+    0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808,
+    0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808,
+    0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,
+    0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919,
+    0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908,
+    0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908,
+    0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08,
+    0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,
+    0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b,
+    0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808,
+    0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819,
+    0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908,
+    0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,
+    0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808,
+    0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808,
+    0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b,
+    0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908,
+    0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,
+    0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908,
+    0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819,
+    0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819,
+    0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808,
+    0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,
+    0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b,
+    0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819,
+    0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b,
+    0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b,
+    0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,
+    0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819,
+    0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19,
+    0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819,
+    0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908,
+    0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,
+    0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint32_t, iq3xxs_grid, 256)
+    0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
+    0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
+    0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
+    0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
+    0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
+    0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
+    0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
+    0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
+    0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
+    0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
+    0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
+    0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
+    0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
+    0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
+    0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
+    0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
+    0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
+    0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
+    0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
+    0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
+    0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
+    0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
+    0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
+    0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
+    0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
+    0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
+    0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
+    0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
+    0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
+    0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
+    0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
+    0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
+    0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
+    0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
+    0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
+    0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
+    0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
+    0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
+    0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
+    0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
+    0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
+    0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
+    0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
+    0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
+    0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
+    0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
+    0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
+    0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
+    0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
+    0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
+    0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
+    0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
+    0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
+    0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
+    0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
+    0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
+    0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
+    0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
+    0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
+    0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
+    0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
+    0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
+    0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
+    0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
+    0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
+    0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
+    0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
+    0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
+    0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
+    0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
+    0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
+    0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
+    0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
+    0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
+    0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
+    0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
+    0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
+    0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
+    0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
+    0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
+    0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
+    0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
+    0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
+    0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
+    0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
+    0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
+    0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
+    0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
+    0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
+    0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
+    0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
+    0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
+    0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
+    0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
+    0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
+    0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
+GGML_TABLE_END()
+
+#define NGRID_IQ1S 2048
+#define IQ1S_DELTA 0.125f
+#define IQ1M_DELTA 0.125f
+#if defined(GGML_COMMON_IMPL_C)
+GGML_TABLE_BEGIN(uint64_t, iq1s_grid, NGRID_IQ1S)
+    0xffffffffffffffff, 0xffffffffffffff01, 0xffffffffffff0000, 0xffffffffffff01ff,
+    0xffffffffffff0101, 0xffffffffff00ff00, 0xffffffffff000000, 0xffffffffff01ffff,
+    0xffffffffff01ff01, 0xffffffffff0101ff, 0xffffffffff010101, 0xffffffff00ff0000,
+    0xffffffff0000ff00, 0xffffffff000000ff, 0xffffffff00000001, 0xffffffff00010000,
+    0xffffffff01ffffff, 0xffffffff01ffff01, 0xffffffff01ff01ff, 0xffffffff01ff0101,
+    0xffffffff01000000, 0xffffffff0101ffff, 0xffffffff0101ff01, 0xffffffff010101ff,
+    0xffffffff01010101, 0xffffff00ffff00ff, 0xffffff00ffff0000, 0xffffff00ff00ff00,
+    0xffffff00ff0000ff, 0xffffff00ff000001, 0xffffff00ff000100, 0xffffff00ff000101,
+    0xffffff00ff010000, 0xffffff0000ffff00, 0xffffff0000ff0001, 0xffffff0000ff0100,
+    0xffffff000000ff01, 0xffffff0000000000, 0xffffff0000000101, 0xffffff000001ff00,
+    0xffffff00000100ff, 0xffffff0000010001, 0xffffff00000101ff, 0xffffff0001ff0000,
+    0xffffff000100ff00, 0xffffff00010000ff, 0xffffff0001000001, 0xffffff0001010000,
+    0xffffff01ffffffff, 0xffffff01ffffff01, 0xffffff01ffff01ff, 0xffffff01ffff0101,
+    0xffffff01ff000000, 0xffffff01ff01ffff, 0xffffff01ff01ff01, 0xffffff01ff0101ff,
+    0xffffff01ff010101, 0xffffff0100ff0000, 0xffffff010000ff00, 0xffffff0100000100,
+    0xffffff01000100ff, 0xffffff0100010100, 0xffffff0101ffffff, 0xffffff0101ffff01,
+    0xffffff0101ff01ff, 0xffffff0101ff0101, 0xffffff010100ff00, 0xffffff0101000000,
+    0xffffff0101000100, 0xffffff010101ffff, 0xffffff010101ff01, 0xffffff01010101ff,
+    0xffffff0101010101, 0xffff00ffff00ff00, 0xffff00ffff0000ff, 0xffff00ffff000001,
+    0xffff00ffff010000, 0xffff00ff00ffff00, 0xffff00ff00ff0100, 0xffff00ff00000000,
+    0xffff00ff00000101, 0xffff00ff000100ff, 0xffff00ff00010000, 0xffff00ff0100ff00,
+    0xffff00ff01000100, 0xffff00ff01010000, 0xffff0000ffffff00, 0xffff0000ffff00ff,
+    0xffff0000ffff0000, 0xffff0000ffff0001, 0xffff0000ff000000, 0xffff0000ff0001ff,
+    0xffff0000ff000101, 0xffff0000ff010100, 0xffff000000ffffff, 0xffff000000ff0000,
+    0xffff000000ff0101, 0xffff00000000ffff, 0xffff00000000ff00, 0xffff0000000000ff,
+    0xffff000000000000, 0xffff000000000001, 0xffff000000000100, 0xffff00000001ffff,
+    0xffff00000001ff01, 0xffff000000010000, 0xffff0000000101ff, 0xffff000000010101,
+    0xffff000001ffff00, 0xffff00000100ff00, 0xffff000001000000, 0xffff0000010001ff,
+    0xffff000001000101, 0xffff00000101ff00, 0xffff0000010100ff, 0xffff000001010000,
+    0xffff000001010001, 0xffff000001010100, 0xffff0001ff0000ff, 0xffff0001ff000100,
+    0xffff000100ffff00, 0xffff000100ff00ff, 0xffff00010000ffff, 0xffff00010000ff01,
+    0xffff000100000000, 0xffff0001000001ff, 0xffff00010001ffff, 0xffff00010001ff00,
+    0xffff000100010001, 0xffff000100010100, 0xffff000101ff0000, 0xffff00010100ff00,
+    0xffff0001010000ff, 0xffff000101000100, 0xffff01ffffffffff, 0xffff01ffffffff01,
+    0xffff01ffffff01ff, 0xffff01ffffff0101, 0xffff01ffff000000, 0xffff01ffff01ffff,
+    0xffff01ffff01ff01, 0xffff01ffff0101ff, 0xffff01ffff010101, 0xffff01ff00ff0000,
+    0xffff01ff0000ff00, 0xffff01ff00000001, 0xffff01ff00010000, 0xffff01ff01ffffff,
+    0xffff01ff01ffff01, 0xffff01ff01ff01ff, 0xffff01ff01ff0101, 0xffff01ff01000000,
+    0xffff01ff0101ffff, 0xffff01ff0101ff01, 0xffff01ff010101ff, 0xffff01ff01010101,
+    0xffff0100ffff0000, 0xffff0100ff00ff00, 0xffff0100ff0000ff, 0xffff0100ff000100,
+    0xffff0100ff0100ff, 0xffff0100ff010000, 0xffff010000ffff00, 0xffff01000000ffff,
+    0xffff01000000ff00, 0xffff010000000000, 0xffff01000001ff00, 0xffff0100000100ff,
+    0xffff010000010100, 0xffff01000100ff00, 0xffff0100010000ff, 0xffff010001000001,
+    0xffff010001000100, 0xffff010001010000, 0xffff0101ffffffff, 0xffff0101ffffff01,
+    0xffff0101ffff01ff, 0xffff0101ffff0101, 0xffff0101ff000000, 0xffff0101ff01ffff,
+    0xffff0101ff01ff01, 0xffff0101ff0101ff, 0xffff0101ff010101, 0xffff010100ff0000,
+    0xffff01010000ff00, 0xffff010100000100, 0xffff01010001ff00, 0xffff010100010000,
+    0xffff010101ffffff, 0xffff010101ffff01, 0xffff010101ff0000, 0xffff010101ff01ff,
+    0xffff010101ff0101, 0xffff010101000000, 0xffff01010101ffff, 0xffff01010101ff01,
+    0xffff0101010101ff, 0xffff010101010101, 0xff00ffffff00ffff, 0xff00ffffff00ff00,
+    0xff00ffffff0000ff, 0xff00ffffff000100, 0xff00ffffff0100ff, 0xff00ffffff010000,
+    0xff00ffff00ffff00, 0xff00ffff00ff00ff, 0xff00ffff0000ffff, 0xff00ffff00000000,
+    0xff00ffff000001ff, 0xff00ffff0001ff00, 0xff00ffff000100ff, 0xff00ffff00010000,
+    0xff00ffff00010100, 0xff00ffff0100ff00, 0xff00ffff010000ff, 0xff00ffff01000001,
+    0xff00ffff0101ff00, 0xff00ffff01010000, 0xff00ff00ffffff00, 0xff00ff00ffff00ff,
+    0xff00ff00ffff0001, 0xff00ff00ffff0100, 0xff00ff00ff00ffff, 0xff00ff00ff00ff01,
+    0xff00ff00ff000000, 0xff00ff00ff0001ff, 0xff00ff00ff01ff00, 0xff00ff00ff0100ff,
+    0xff00ff00ff010100, 0xff00ff0000ff0000, 0xff00ff0000ff0101, 0xff00ff000000ffff,
+    0xff00ff000000ff00, 0xff00ff000000ff01, 0xff00ff00000000ff, 0xff00ff0000000000,
+    0xff00ff0000000001, 0xff00ff0000000100, 0xff00ff000001ffff, 0xff00ff0000010000,
+    0xff00ff0001ff00ff, 0xff00ff000100ff01, 0xff00ff0001000000, 0xff00ff000101ff00,
+    0xff00ff00010100ff, 0xff00ff01ff00ff00, 0xff00ff01ff0000ff, 0xff00ff01ff000001,
+    0xff00ff01ff010000, 0xff00ff0100ffffff, 0xff00ff0100ff0001, 0xff00ff0100ff0100,
+    0xff00ff010000ff01, 0xff00ff0100000000, 0xff00ff01000001ff, 0xff00ff0100000101,
+    0xff00ff01000100ff, 0xff00ff0100010001, 0xff00ff0101ff0000, 0xff00ff010100ff00,
+    0xff00ff01010000ff, 0xff00ff0101000001, 0xff00ff0101010000, 0xff0000ffffffff00,
+    0xff0000ffffff0001, 0xff0000ffffff0100, 0xff0000ffff0000ff, 0xff0000ffff000000,
+    0xff0000ffff0001ff, 0xff0000ffff000100, 0xff0000ffff01ff00, 0xff0000ffff010001,
+    0xff0000ff00ffff00, 0xff0000ff00ff0000, 0xff0000ff00ff0001, 0xff0000ff00ff01ff,
+    0xff0000ff00ff0101, 0xff0000ff0000ff00, 0xff0000ff000000ff, 0xff0000ff00000000,
+    0xff0000ff00000001, 0xff0000ff00000100, 0xff0000ff0001ff01, 0xff0000ff00010000,
+    0xff0000ff000101ff, 0xff0000ff01ff00ff, 0xff0000ff01ff0100, 0xff0000ff0100ffff,
+    0xff0000ff010000ff, 0xff0000ff01000000, 0xff0000ff010001ff, 0xff0000ff01000100,
+    0xff0000ff01000101, 0xff0000ff0101ff00, 0xff0000ff010100ff, 0xff0000ff01010000,
+    0xff0000ff01010100, 0xff000000ffffff01, 0xff000000ffff0000, 0xff000000ffff0101,
+    0xff000000ff00ff00, 0xff000000ff0000ff, 0xff000000ff000000, 0xff000000ff000001,
+    0xff000000ff000100, 0xff000000ff01ffff, 0xff000000ff01ff01, 0xff000000ff010000,
+    0xff000000ff0101ff, 0xff000000ff010101, 0xff00000000ffff00, 0xff00000000ff00ff,
+    0xff00000000ff0000, 0xff00000000ff0001, 0xff0000000000ff00, 0xff0000000000ff01,
+    0xff000000000000ff, 0xff00000000000000, 0xff00000000000001, 0xff00000000000100,
+    0xff00000000000101, 0xff0000000001ff00, 0xff000000000100ff, 0xff00000000010000,
+    0xff00000000010001, 0xff00000000010100, 0xff00000001ffffff, 0xff00000001ffff01,
+    0xff00000001ff00ff, 0xff00000001ff0000, 0xff00000001ff01ff, 0xff00000001ff0101,
+    0xff0000000100ffff, 0xff0000000100ff00, 0xff000000010000ff, 0xff00000001000000,
+    0xff00000001000001, 0xff00000001000100, 0xff00000001000101, 0xff0000000101ffff,
+    0xff0000000101ff01, 0xff00000001010000, 0xff000001ffffff00, 0xff000001ffff00ff,
+    0xff000001ffff0000, 0xff000001ffff0001, 0xff000001ff000000, 0xff000001ff000001,
+    0xff000001ff0001ff, 0xff000001ff000101, 0xff000001ff01ff00, 0xff000001ff010001,
+    0xff00000100ffffff, 0xff00000100ffff01, 0xff00000100ff00ff, 0xff00000100ff0000,
+    0xff00000100ff01ff, 0xff00000100ff0101, 0xff0000010000ff00, 0xff00000100000000,
+    0xff00000100000001, 0xff000001000001ff, 0xff00000100000100, 0xff0000010001ff00,
+    0xff000001000100ff, 0xff00000100010000, 0xff000001000101ff, 0xff00000100010100,
+    0xff00000100010101, 0xff00000101ff0001, 0xff00000101ff0101, 0xff0000010100ff01,
+    0xff00000101000000, 0xff000001010100ff, 0xff00000101010100, 0xff0001ffff00ff00,
+    0xff0001ffff000001, 0xff0001ffff010000, 0xff0001ff00ffff00, 0xff0001ff00ff00ff,
+    0xff0001ff00ff0001, 0xff0001ff00ff0100, 0xff0001ff0000ffff, 0xff0001ff00000000,
+    0xff0001ff000001ff, 0xff0001ff00000101, 0xff0001ff0001ffff, 0xff0001ff0001ff00,
+    0xff0001ff000100ff, 0xff0001ff00010001, 0xff0001ff00010100, 0xff0001ff01ff0000,
+    0xff0001ff0100ff00, 0xff0001ff010000ff, 0xff0001ff01010000, 0xff000100ff00ffff,
+    0xff000100ff00ff01, 0xff000100ff000000, 0xff000100ff000101, 0xff000100ff01ff00,
+    0xff000100ff010000, 0xff00010000ffff01, 0xff00010000ff00ff, 0xff00010000ff0000,
+    0xff00010000ff01ff, 0xff0001000000ff00, 0xff000100000000ff, 0xff00010000000000,
+    0xff00010000000001, 0xff00010000000100, 0xff00010000000101, 0xff0001000001ffff,
+    0xff00010000010000, 0xff00010000010101, 0xff00010001ff0100, 0xff0001000100ff00,
+    0xff0001000100ff01, 0xff00010001000000, 0xff000100010001ff, 0xff0001000101ff00,
+    0xff00010001010001, 0xff00010001010100, 0xff000101ffff0100, 0xff000101ff000001,
+    0xff000101ff0100ff, 0xff000101ff010001, 0xff00010100ff00ff, 0xff00010100ff0001,
+    0xff00010100ff0100, 0xff0001010000ffff, 0xff0001010000ff01, 0xff00010100000000,
+    0xff000101000001ff, 0xff0001010001ff00, 0xff00010100010001, 0xff00010100010100,
+    0xff00010101ff0000, 0xff0001010100ff00, 0xff00010101000001, 0xff00010101000101,
+    0xff01ffffffffffff, 0xff01ffffffffff01, 0xff01ffffffff01ff, 0xff01ffffffff0101,
+    0xff01ffffff000000, 0xff01ffffff01ffff, 0xff01ffffff01ff01, 0xff01ffffff010000,
+    0xff01ffffff0101ff, 0xff01ffffff010101, 0xff01ffff00ff0000, 0xff01ffff0000ff00,
+    0xff01ffff00000100, 0xff01ffff0001ff00, 0xff01ffff00010000, 0xff01ffff01ffffff,
+    0xff01ffff01ffff01, 0xff01ffff01ff01ff, 0xff01ffff01ff0101, 0xff01ffff01000000,
+    0xff01ffff0101ffff, 0xff01ffff0101ff01, 0xff01ffff01010000, 0xff01ffff010101ff,
+    0xff01ffff01010101, 0xff01ff00ffff0000, 0xff01ff00ff00ff00, 0xff01ff00ff0000ff,
+    0xff01ff00ff000100, 0xff01ff00ff010000, 0xff01ff0000ffff01, 0xff01ff0000ff00ff,
+    0xff01ff0000ff0100, 0xff01ff0000000000, 0xff01ff00000001ff, 0xff01ff0000000101,
+    0xff01ff000001ff00, 0xff01ff00000100ff, 0xff01ff0000010000, 0xff01ff0000010001,
+    0xff01ff0001ff0000, 0xff01ff000100ffff, 0xff01ff0001000001, 0xff01ff0001000100,
+    0xff01ff0001010000, 0xff01ff01ffffff00, 0xff01ff01ffff01ff, 0xff01ff01ffff0101,
+    0xff01ff01ff00ff00, 0xff01ff01ff000000, 0xff01ff01ff01ffff, 0xff01ff01ff01ff01,
+    0xff01ff01ff0101ff, 0xff01ff01ff010101, 0xff01ff0100ff0000, 0xff01ff010000ff00,
+    0xff01ff0100000001, 0xff01ff0100000100, 0xff01ff0100010000, 0xff01ff0101ffff00,
+    0xff01ff0101ff01ff, 0xff01ff0101ff0101, 0xff01ff010100ff00, 0xff01ff0101000000,
+    0xff01ff010101ffff, 0xff01ff010101ff01, 0xff01ff01010101ff, 0xff01ff0101010101,
+    0xff0100ffffff0000, 0xff0100ffff0000ff, 0xff0100ffff000001, 0xff0100ffff000100,
+    0xff0100ffff010000, 0xff0100ff00ff00ff, 0xff0100ff00ff0000, 0xff0100ff00ff0001,
+    0xff0100ff00ff0100, 0xff0100ff0000ff01, 0xff0100ff00000000, 0xff0100ff000001ff,
+    0xff0100ff00000101, 0xff0100ff00010001, 0xff0100ff01ff0000, 0xff0100ff0100ff00,
+    0xff0100ff010000ff, 0xff0100ff01000100, 0xff0100ff0101ff00, 0xff0100ff01010000,
+    0xff010000ffff0100, 0xff010000ff000000, 0xff010000ff01ff00, 0xff010000ff010100,
+    0xff01000000ffffff, 0xff01000000ff0000, 0xff01000000ff01ff, 0xff0100000000ff00,
+    0xff010000000000ff, 0xff01000000000000, 0xff01000000000100, 0xff0100000001ff01,
+    0xff01000000010000, 0xff010000000101ff, 0xff01000001ff0100, 0xff0100000100ffff,
+    0xff010000010000ff, 0xff01000001000000, 0xff010000010001ff, 0xff01000001000101,
+    0xff0100000101ff00, 0xff010000010100ff, 0xff01000001010001, 0xff01000001010100,
+    0xff010001ffff0000, 0xff010001ff00ffff, 0xff010001ff00ff01, 0xff010001ff000100,
+    0xff010001ff010000, 0xff01000100ffff00, 0xff01000100ff0100, 0xff01000100000000,
+    0xff0100010001ffff, 0xff0100010001ff00, 0xff01000100010100, 0xff01000101ff00ff,
+    0xff01000101ff0001, 0xff0100010100ffff, 0xff01000101000101, 0xff0101ffffffffff,
+    0xff0101ffffffff01, 0xff0101ffffff01ff, 0xff0101ffffff0101, 0xff0101ffff000000,
+    0xff0101ffff01ffff, 0xff0101ffff01ff01, 0xff0101ffff0101ff, 0xff0101ffff010101,
+    0xff0101ff00ff0000, 0xff0101ff0000ff00, 0xff0101ff000000ff, 0xff0101ff00010000,
+    0xff0101ff01ffffff, 0xff0101ff01ffff01, 0xff0101ff01ff01ff, 0xff0101ff01ff0101,
+    0xff0101ff0101ffff, 0xff0101ff0101ff01, 0xff0101ff010101ff, 0xff0101ff01010101,
+    0xff010100ffff0100, 0xff010100ff00ff00, 0xff010100ff0000ff, 0xff010100ff000100,
+    0xff010100ff010000, 0xff01010000ff0001, 0xff01010000ff0100, 0xff0101000000ff01,
+    0xff01010000000000, 0xff0101000001ff00, 0xff010100000100ff, 0xff01010000010001,
+    0xff01010000010100, 0xff01010001ff0000, 0xff0101000100ffff, 0xff01010001000001,
+    0xff01010001000100, 0xff010100010100ff, 0xff01010001010000, 0xff010101ffffffff,
+    0xff010101ffffff01, 0xff010101ffff01ff, 0xff010101ffff0101, 0xff010101ff01ffff,
+    0xff010101ff01ff01, 0xff010101ff0101ff, 0xff010101ff010101, 0xff01010100ff0000,
+    0xff0101010000ff00, 0xff01010100000001, 0xff01010100000100, 0xff01010100010000,
+    0xff01010101ffffff, 0xff01010101ffff01, 0xff01010101ff01ff, 0xff01010101ff0101,
+    0xff01010101000000, 0xff0101010101ffff, 0xff0101010101ff01, 0xff010101010101ff,
+    0xff01010101010101, 0x00ffffffffff0000, 0x00ffffffff00ff00, 0x00ffffffff000001,
+    0x00ffffffff010000, 0x00ffffff00ff0100, 0x00ffffff0000ff01, 0x00ffffff00000000,
+    0x00ffffff000001ff, 0x00ffffff00000101, 0x00ffffff0001ff00, 0x00ffffff000100ff,
+    0x00ffffff00010001, 0x00ffffff010000ff, 0x00ffffff01000100, 0x00ffffff0101ff00,
+    0x00ffffff01010001, 0x00ffff00ffffffff, 0x00ffff00ffffff00, 0x00ffff00ffff00ff,
+    0x00ffff00ffff0001, 0x00ffff00ffff0100, 0x00ffff00ff00ff01, 0x00ffff00ff000000,
+    0x00ffff00ff000001, 0x00ffff00ff0001ff, 0x00ffff00ff000101, 0x00ffff00ff01ff00,
+    0x00ffff00ff010001, 0x00ffff00ff010100, 0x00ffff0000ff0000, 0x00ffff0000ff01ff,
+    0x00ffff0000ff0101, 0x00ffff000000ff00, 0x00ffff00000000ff, 0x00ffff0000000000,
+    0x00ffff0000000001, 0x00ffff0000000100, 0x00ffff0000000101, 0x00ffff0000010000,
+    0x00ffff00000101ff, 0x00ffff0000010101, 0x00ffff0001ffff00, 0x00ffff0001ff00ff,
+    0x00ffff0001ff0001, 0x00ffff000100ffff, 0x00ffff000100ff01, 0x00ffff0001000000,
+    0x00ffff000101ffff, 0x00ffff000101ff00, 0x00ffff000101ff01, 0x00ffff01ffff0000,
+    0x00ffff01ff00ff00, 0x00ffff01ff0000ff, 0x00ffff01ff000001, 0x00ffff01ff010000,
+    0x00ffff0100ffff00, 0x00ffff010000ff01, 0x00ffff0100000000, 0x00ffff0100000101,
+    0x00ffff01000100ff, 0x00ffff0100010100, 0x00ffff0101ff0100, 0x00ffff01010000ff,
+    0x00ffff0101010000, 0x00ff00ffffffff00, 0x00ff00ffff000000, 0x00ff00ffff000100,
+    0x00ff00ffff010100, 0x00ff00ff00ff0000, 0x00ff00ff00ff01ff, 0x00ff00ff00ff0101,
+    0x00ff00ff0000ff00, 0x00ff00ff000000ff, 0x00ff00ff00000000, 0x00ff00ff00000001,
+    0x00ff00ff0001ff00, 0x00ff00ff0001ff01, 0x00ff00ff00010000, 0x00ff00ff000101ff,
+    0x00ff00ff00010101, 0x00ff00ff01ffff00, 0x00ff00ff01ff0001, 0x00ff00ff01ff0100,
+    0x00ff00ff0100ffff, 0x00ff00ff0100ff01, 0x00ff00ff01000000, 0x00ff00ff0101ffff,
+    0x00ff00ff0101ff00, 0x00ff00ff01010100, 0x00ff0000ffffff00, 0x00ff0000ffffff01,
+    0x00ff0000ffff0000, 0x00ff0000ffff0101, 0x00ff0000ff00ff00, 0x00ff0000ff0000ff,
+    0x00ff0000ff000000, 0x00ff0000ff000001, 0x00ff0000ff000100, 0x00ff0000ff01ffff,
+    0x00ff0000ff010000, 0x00ff0000ff010101, 0x00ff000000ffff00, 0x00ff000000ff00ff,
+    0x00ff000000ff0000, 0x00ff000000ff0001, 0x00ff000000ff0100, 0x00ff00000000ffff,
+    0x00ff00000000ff00, 0x00ff0000000000ff, 0x00ff000000000000, 0x00ff000000000001,
+    0x00ff0000000001ff, 0x00ff000000000100, 0x00ff00000001ff00, 0x00ff0000000100ff,
+    0x00ff000000010000, 0x00ff000000010001, 0x00ff000000010100, 0x00ff000001ffff01,
+    0x00ff000001ff00ff, 0x00ff000001ff0000, 0x00ff000001ff01ff, 0x00ff00000100ff00,
+    0x00ff0000010000ff, 0x00ff000001000000, 0x00ff000001000001, 0x00ff000001000100,
+    0x00ff000001000101, 0x00ff000001010000, 0x00ff0000010101ff, 0x00ff000001010101,
+    0x00ff0001ffffff00, 0x00ff0001ffff0000, 0x00ff0001ffff0100, 0x00ff0001ff0000ff,
+    0x00ff0001ff000000, 0x00ff0001ff0001ff, 0x00ff0001ff000101, 0x00ff0001ff01ff00,
+    0x00ff0001ff0100ff, 0x00ff0001ff010100, 0x00ff000100ffffff, 0x00ff000100ffff01,
+    0x00ff000100ff0000, 0x00ff000100ff01ff, 0x00ff00010000ffff, 0x00ff00010000ff00,
+    0x00ff00010000ff01, 0x00ff000100000000, 0x00ff000100000001, 0x00ff000100000100,
+    0x00ff00010001ff01, 0x00ff000100010000, 0x00ff0001000101ff, 0x00ff000101ffff00,
+    0x00ff000101ff0000, 0x00ff000101ff0101, 0x00ff0001010000ff, 0x00ff000101000000,
+    0x00ff00010101ff00, 0x00ff0001010100ff, 0x00ff000101010001, 0x00ff01ffffff0000,
+    0x00ff01ffff00ff00, 0x00ff01ffff000000, 0x00ff01ffff000101, 0x00ff01ffff010000,
+    0x00ff01ff00ffff01, 0x00ff01ff00ff0100, 0x00ff01ff0000ffff, 0x00ff01ff00000000,
+    0x00ff01ff000001ff, 0x00ff01ff0001ff00, 0x00ff01ff000100ff, 0x00ff01ff00010001,
+    0x00ff01ff00010100, 0x00ff01ff01ff0000, 0x00ff01ff0100ff00, 0x00ff01ff010000ff,
+    0x00ff01ff01000001, 0x00ff01ff01000100, 0x00ff01ff01010000, 0x00ff0100ffffff00,
+    0x00ff0100ffff0000, 0x00ff0100ffff0001, 0x00ff0100ffff0101, 0x00ff0100ff00ffff,
+    0x00ff0100ff0000ff, 0x00ff0100ff000000, 0x00ff0100ff0001ff, 0x00ff0100ff01ff00,
+    0x00ff0100ff0100ff, 0x00ff0100ff010001, 0x00ff010000ffffff, 0x00ff010000ff0000,
+    0x00ff010000ff0101, 0x00ff01000000ff00, 0x00ff01000000ff01, 0x00ff0100000000ff,
+    0x00ff010000000000, 0x00ff010000000001, 0x00ff010000000100, 0x00ff01000001ffff,
+    0x00ff01000001ff01, 0x00ff010000010000, 0x00ff010000010001, 0x00ff010000010101,
+    0x00ff010001ff0001, 0x00ff010001ff0100, 0x00ff01000100ff01, 0x00ff010001000000,
+    0x00ff010001000001, 0x00ff0100010001ff, 0x00ff01000101ff00, 0x00ff0100010100ff,
+    0x00ff010001010001, 0x00ff010001010100, 0x00ff0101ff000001, 0x00ff010100ff00ff,
+    0x00ff010100ff0001, 0x00ff010100ff0100, 0x00ff010100000000, 0x00ff0101000001ff,
+    0x00ff010100000101, 0x00ff0101000100ff, 0x00ff010100010100, 0x00ff0101010000ff,
+    0x00ff010101010000, 0x0000ffffffffff00, 0x0000ffffffff00ff, 0x0000ffffffff0000,
+    0x0000ffffffff0001, 0x0000ffffffff0100, 0x0000ffffff00ff01, 0x0000ffffff000000,
+    0x0000ffffff000101, 0x0000ffffff01ff00, 0x0000ffffff0100ff, 0x0000ffffff010100,
+    0x0000ffff00ffffff, 0x0000ffff00ff0000, 0x0000ffff00ff01ff, 0x0000ffff0000ff00,
+    0x0000ffff000000ff, 0x0000ffff00000000, 0x0000ffff00000001, 0x0000ffff00000100,
+    0x0000ffff00010000, 0x0000ffff000101ff, 0x0000ffff01ff0001, 0x0000ffff01ff0100,
+    0x0000ffff01000000, 0x0000ffff010001ff, 0x0000ffff0101ffff, 0x0000ffff0101ff00,
+    0x0000ffff01010001, 0x0000ffff01010100, 0x0000ff00ffff0000, 0x0000ff00ffff01ff,
+    0x0000ff00ffff0100, 0x0000ff00ffff0101, 0x0000ff00ff00ff00, 0x0000ff00ff0000ff,
+    0x0000ff00ff000000, 0x0000ff00ff000001, 0x0000ff00ff0001ff, 0x0000ff00ff000100,
+    0x0000ff00ff01ffff, 0x0000ff00ff010000, 0x0000ff00ff010001, 0x0000ff00ff0101ff,
+    0x0000ff00ff010101, 0x0000ff0000ffff00, 0x0000ff0000ff00ff, 0x0000ff0000ff0000,
+    0x0000ff0000ff0001, 0x0000ff0000ff0100, 0x0000ff000000ffff, 0x0000ff000000ff00,
+    0x0000ff000000ff01, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
+    0x0000ff00000001ff, 0x0000ff0000000100, 0x0000ff0000000101, 0x0000ff000001ff00,
+    0x0000ff00000100ff, 0x0000ff0000010000, 0x0000ff0000010001, 0x0000ff0000010100,
+    0x0000ff0001ffff01, 0x0000ff0001ff0000, 0x0000ff000100ff00, 0x0000ff00010000ff,
+    0x0000ff0001000000, 0x0000ff0001000001, 0x0000ff0001000100, 0x0000ff000101ffff,
+    0x0000ff0001010000, 0x0000ff0001010101, 0x0000ff01ffffff00, 0x0000ff01ffff0001,
+    0x0000ff01ff00ff01, 0x0000ff01ff000000, 0x0000ff01ff000101, 0x0000ff01ff01ff00,
+    0x0000ff01ff0100ff, 0x0000ff0100ffff01, 0x0000ff0100ff0000, 0x0000ff0100ff0101,
+    0x0000ff010000ff00, 0x0000ff01000000ff, 0x0000ff0100000000, 0x0000ff0100000001,
+    0x0000ff0100000100, 0x0000ff010001ff01, 0x0000ff0100010000, 0x0000ff0101ff0000,
+    0x0000ff010100ffff, 0x0000ff010100ff01, 0x0000ff0101000000, 0x0000ff0101000100,
+    0x0000ff0101000101, 0x0000ff01010100ff, 0x000000ffffff00ff, 0x000000ffffff0000,
+    0x000000ffff00ff00, 0x000000ffff0000ff, 0x000000ffff000000, 0x000000ffff000001,
+    0x000000ffff0001ff, 0x000000ffff000100, 0x000000ffff01ff00, 0x000000ffff010000,
+    0x000000ffff0101ff, 0x000000ffff010101, 0x000000ff00ffff00, 0x000000ff00ff00ff,
+    0x000000ff00ff0000, 0x000000ff00ff0001, 0x000000ff00ff0100, 0x000000ff00ff0101,
+    0x000000ff0000ffff, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
+    0x000000ff00000001, 0x000000ff000001ff, 0x000000ff00000100, 0x000000ff00000101,
+    0x000000ff0001ff00, 0x000000ff0001ff01, 0x000000ff000100ff, 0x000000ff00010000,
+    0x000000ff00010001, 0x000000ff00010100, 0x000000ff01ffffff, 0x000000ff01ff01ff,
+    0x000000ff01ff0101, 0x000000ff0100ff00, 0x000000ff010000ff, 0x000000ff01000000,
+    0x000000ff01000001, 0x000000ff01000100, 0x000000ff0101ff00, 0x000000ff010100ff,
+    0x000000ff01010000, 0x000000ff01010101, 0x00000000ffffff00, 0x00000000ffffff01,
+    0x00000000ffff00ff, 0x00000000ffff0000, 0x00000000ffff0001, 0x00000000ffff0100,
+    0x00000000ff00ffff, 0x00000000ff00ff00, 0x00000000ff00ff01, 0x00000000ff0000ff,
+    0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff000101,
+    0x00000000ff01ff00, 0x00000000ff0100ff, 0x00000000ff010000, 0x00000000ff010001,
+    0x00000000ff010100, 0x0000000000ffffff, 0x0000000000ffff00, 0x0000000000ffff01,
+    0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001, 0x0000000000ff01ff,
+    0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
+    0x00000000000000ff, 0x0000000000000000, 0x0000000000000001, 0x00000000000001ff,
+    0x0000000000000100, 0x0000000000000101, 0x000000000001ffff, 0x000000000001ff00,
+    0x00000000000100ff, 0x0000000000010000, 0x0000000000010001, 0x00000000000101ff,
+    0x0000000000010100, 0x0000000000010101, 0x0000000001ffff00, 0x0000000001ff00ff,
+    0x0000000001ff0000, 0x0000000001ff0100, 0x0000000001ff0101, 0x000000000100ffff,
+    0x000000000100ff00, 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001,
+    0x00000000010001ff, 0x0000000001000100, 0x000000000101ff00, 0x00000000010100ff,
+    0x0000000001010000, 0x0000000001010001, 0x0000000001010100, 0x00000001ffffffff,
+    0x00000001ffffff00, 0x00000001ffffff01, 0x00000001ffff00ff, 0x00000001ffff0001,
+    0x00000001ffff01ff, 0x00000001ffff0100, 0x00000001ff00ff00, 0x00000001ff0000ff,
+    0x00000001ff000000, 0x00000001ff0001ff, 0x00000001ff000100, 0x00000001ff01ffff,
+    0x00000001ff01ff00, 0x00000001ff01ff01, 0x00000001ff0100ff, 0x00000001ff010000,
+    0x00000001ff010001, 0x00000001ff0101ff, 0x00000001ff010100, 0x0000000100ffff00,
+    0x0000000100ff0000, 0x0000000100ff0001, 0x0000000100ff01ff, 0x0000000100ff0100,
+    0x0000000100ff0101, 0x000000010000ffff, 0x000000010000ff00, 0x000000010000ff01,
+    0x00000001000000ff, 0x0000000100000000, 0x0000000100000001, 0x00000001000001ff,
+    0x0000000100000100, 0x0000000100000101, 0x000000010001ff00, 0x00000001000100ff,
+    0x0000000100010000, 0x0000000100010100, 0x0000000101ffff01, 0x0000000101ff0000,
+    0x0000000101ff0001, 0x0000000101ff01ff, 0x0000000101ff0100, 0x0000000101ff0101,
+    0x000000010100ff00, 0x0000000101000000, 0x0000000101000101, 0x000000010101ff01,
+    0x0000000101010000, 0x0000000101010001, 0x00000001010101ff, 0x0000000101010100,
+    0x000001ffffff00ff, 0x000001ffffff0000, 0x000001ffffff0001, 0x000001ffffff0100,
+    0x000001ffff00ffff, 0x000001ffff000000, 0x000001ffff0001ff, 0x000001ffff01ff00,
+    0x000001ffff010101, 0x000001ff00ff0000, 0x000001ff00ff01ff, 0x000001ff00ff0101,
+    0x000001ff0000ff00, 0x000001ff000000ff, 0x000001ff00000000, 0x000001ff00000001,
+    0x000001ff000001ff, 0x000001ff00000100, 0x000001ff0001ffff, 0x000001ff0001ff01,
+    0x000001ff000100ff, 0x000001ff00010000, 0x000001ff01ffff01, 0x000001ff01ff0100,
+    0x000001ff0100ffff, 0x000001ff0100ff01, 0x000001ff01000000, 0x000001ff010001ff,
+    0x000001ff0101ff00, 0x000001ff01010100, 0x00000100ffffff00, 0x00000100ffffff01,
+    0x00000100ffff0000, 0x00000100ffff0101, 0x00000100ff00ff00, 0x00000100ff0000ff,
+    0x00000100ff000000, 0x00000100ff000001, 0x00000100ff000100, 0x00000100ff010000,
+    0x0000010000ffff00, 0x0000010000ff00ff, 0x0000010000ff0000, 0x0000010000ff0001,
+    0x0000010000ff0100, 0x000001000000ffff, 0x000001000000ff00, 0x000001000000ff01,
+    0x00000100000000ff, 0x0000010000000000, 0x0000010000000001, 0x00000100000001ff,
+    0x0000010000000100, 0x0000010000000101, 0x000001000001ff00, 0x00000100000100ff,
+    0x0000010000010000, 0x0000010000010001, 0x0000010000010100, 0x0000010001ffff00,
+    0x0000010001ff0000, 0x0000010001ff0100, 0x000001000100ff00, 0x00000100010000ff,
+    0x0000010001000000, 0x0000010001000001, 0x00000100010001ff, 0x0000010001000100,
+    0x0000010001010000, 0x00000101ffff00ff, 0x00000101ffff01ff, 0x00000101ff000000,
+    0x00000101ff000101, 0x00000101ff01ffff, 0x00000101ff010000, 0x00000101ff010001,
+    0x00000101ff010100, 0x0000010100ff0000, 0x0000010100ff01ff, 0x0000010100ff0100,
+    0x000001010000ff00, 0x0000010100000000, 0x0000010100000001, 0x00000101000001ff,
+    0x0000010100000100, 0x000001010001ff01, 0x0000010100010000, 0x00000101000101ff,
+    0x0000010100010101, 0x0000010101ffff00, 0x0000010101ff0101, 0x000001010100ff01,
+    0x0000010101000000, 0x0000010101000001, 0x00000101010001ff, 0x0000010101000101,
+    0x000001010101ff00, 0x0001ffffffff0000, 0x0001ffffff0000ff, 0x0001ffffff000001,
+    0x0001ffffff000100, 0x0001ffffff010000, 0x0001ffff00ff00ff, 0x0001ffff0000ffff,
+    0x0001ffff00000000, 0x0001ffff00000001, 0x0001ffff000001ff, 0x0001ffff00000101,
+    0x0001ffff0001ff00, 0x0001ffff000100ff, 0x0001ffff00010001, 0x0001ffff00010100,
+    0x0001ffff01ffff00, 0x0001ffff01000001, 0x0001ffff01010000, 0x0001ff00ffffff00,
+    0x0001ff00ffff00ff, 0x0001ff00ffff0001, 0x0001ff00ffff0100, 0x0001ff00ff00ff01,
+    0x0001ff00ff000000, 0x0001ff00ff01ff00, 0x0001ff00ff01ff01, 0x0001ff00ff010001,
+    0x0001ff00ff010100, 0x0001ff0000ff0000, 0x0001ff0000ff0100, 0x0001ff000000ff00,
+    0x0001ff0000000000, 0x0001ff0000000001, 0x0001ff0000000100, 0x0001ff0000010000,
+    0x0001ff0000010001, 0x0001ff0000010101, 0x0001ff0001ff00ff, 0x0001ff0001ff0101,
+    0x0001ff000100ff01, 0x0001ff0001000000, 0x0001ff000101ff00, 0x0001ff0001010001,
+    0x0001ff0001010100, 0x0001ff01ff00ff00, 0x0001ff01ff000001, 0x0001ff01ff000100,
+    0x0001ff0100ffffff, 0x0001ff0100ffff00, 0x0001ff0100ff0001, 0x0001ff0100000000,
+    0x0001ff0100000001, 0x0001ff01000001ff, 0x0001ff010001ffff, 0x0001ff0101ff0000,
+    0x0001ff010100ff00, 0x0001ff0101000001, 0x0001ff0101010000, 0x000100ffff00ff00,
+    0x000100ffff00ff01, 0x000100ffff000000, 0x000100ffff000001, 0x000100ffff000101,
+    0x000100ffff01ff00, 0x000100ffff010001, 0x000100ffff010100, 0x000100ff00ffffff,
+    0x000100ff00ffff01, 0x000100ff00ff0000, 0x000100ff00ff01ff, 0x000100ff00ff0101,
+    0x000100ff0000ff00, 0x000100ff000000ff, 0x000100ff00000000, 0x000100ff00000001,
+    0x000100ff00000100, 0x000100ff00000101, 0x000100ff0001ffff, 0x000100ff0001ff01,
+    0x000100ff00010000, 0x000100ff01ff00ff, 0x000100ff01ff0000, 0x000100ff01ff0100,
+    0x000100ff0100ffff, 0x000100ff0100ff01, 0x000100ff010000ff, 0x000100ff01000000,
+    0x000100ff01000001, 0x000100ff010001ff, 0x000100ff01000101, 0x000100ff0101ff00,
+    0x000100ff010100ff, 0x000100ff01010100, 0x00010000ffff0000, 0x00010000ffff01ff,
+    0x00010000ffff0101, 0x00010000ff00ff00, 0x00010000ff000000, 0x00010000ff000001,
+    0x00010000ff000100, 0x0001000000ff00ff, 0x0001000000ff0000, 0x0001000000ff0001,
+    0x0001000000ff0100, 0x000100000000ffff, 0x000100000000ff00, 0x00010000000000ff,
+    0x0001000000000000, 0x0001000000000001, 0x0001000000000100, 0x000100000001ff00,
+    0x00010000000100ff, 0x0001000000010000, 0x0001000000010001, 0x0001000000010100,
+    0x0001000001ff0001, 0x0001000001ff0100, 0x0001000001ff0101, 0x000100000100ff00,
+    0x0001000001000000, 0x0001000001000001, 0x0001000001000100, 0x0001000001000101,
+    0x000100000101ff01, 0x0001000001010000, 0x0001000001010001, 0x00010000010101ff,
+    0x00010001ffffff01, 0x00010001ffff0100, 0x00010001ff000000, 0x00010001ff01ffff,
+    0x00010001ff010001, 0x00010001ff0101ff, 0x00010001ff010100, 0x0001000100ffffff,
+    0x0001000100ff0000, 0x0001000100ff01ff, 0x0001000100ff0101, 0x000100010000ff00,
+    0x00010001000000ff, 0x0001000100000000, 0x0001000100000001, 0x00010001000001ff,
+    0x0001000100000101, 0x000100010001ffff, 0x0001000100010000, 0x00010001000101ff,
+    0x0001000101ffffff, 0x0001000101ffff01, 0x0001000101ff0000, 0x0001000101ff0101,
+    0x00010001010000ff, 0x0001000101000001, 0x00010001010001ff, 0x0001000101000100,
+    0x000100010101ffff, 0x00010001010100ff, 0x0001000101010001, 0x0001000101010101,
+    0x000101ffff000001, 0x000101ffff000100, 0x000101ffff010000, 0x000101ff00ffff00,
+    0x000101ff0000ff01, 0x000101ff00000000, 0x000101ff00000101, 0x000101ff0001ff00,
+    0x000101ff00010100, 0x000101ff01ff0000, 0x000101ff0100ff00, 0x000101ff010001ff,
+    0x000101ff01010001, 0x00010100ffffff00, 0x00010100ffff00ff, 0x00010100ff00ffff,
+    0x00010100ff000000, 0x00010100ff01ff00, 0x00010100ff0100ff, 0x00010100ff010001,
+    0x00010100ff010100, 0x0001010000ffffff, 0x0001010000ffff00, 0x0001010000ff0000,
+    0x0001010000ff0001, 0x0001010000ff01ff, 0x000101000000ff00, 0x00010100000000ff,
+    0x0001010000000000, 0x0001010000000001, 0x0001010000000100, 0x000101000001ffff,
+    0x0001010000010000, 0x0001010000010101, 0x0001010001ffff01, 0x0001010001ff00ff,
+    0x0001010001ff0101, 0x0001010001000000, 0x000101000101ff00, 0x00010100010100ff,
+    0x0001010001010000, 0x0001010001010100, 0x00010101ff00ff00, 0x00010101ff000001,
+    0x00010101ff0001ff, 0x0001010100ffff00, 0x0001010100ff00ff, 0x0001010100ff0100,
+    0x000101010000ffff, 0x0001010100000000, 0x00010101000001ff, 0x0001010100000101,
+    0x00010101000100ff, 0x0001010100010000, 0x0001010100010100, 0x0001010101ff0001,
+    0x00010101010000ff, 0x00010101010001ff, 0x0001010101000101, 0x0001010101010001,
+    0x01ffffffffffffff, 0x01ffffffffffff01, 0x01ffffffffff01ff, 0x01ffffffffff0101,
+    0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff, 0x01ffffffff010101,
+    0x01ffffff00ff0000, 0x01ffffff0000ffff, 0x01ffffff0000ff00, 0x01ffffff000000ff,
+    0x01ffffff00000001, 0x01ffffff00000100, 0x01ffffff00010000, 0x01ffffff01ffffff,
+    0x01ffffff01ffff01, 0x01ffffff01ff01ff, 0x01ffffff01ff0101, 0x01ffffff01000000,
+    0x01ffffff0101ffff, 0x01ffffff0101ff01, 0x01ffffff010101ff, 0x01ffffff01010101,
+    0x01ffff00ffff0000, 0x01ffff00ff00ff00, 0x01ffff00ff0000ff, 0x01ffff00ff000001,
+    0x01ffff00ff000100, 0x01ffff00ff010000, 0x01ffff0000ffff00, 0x01ffff0000ff00ff,
+    0x01ffff0000ff0100, 0x01ffff000000ffff, 0x01ffff000000ff01, 0x01ffff0000000000,
+    0x01ffff0000000001, 0x01ffff00000001ff, 0x01ffff0000000100, 0x01ffff00000100ff,
+    0x01ffff0000010001, 0x01ffff0000010100, 0x01ffff0001ff0000, 0x01ffff0001ff0100,
+    0x01ffff00010000ff, 0x01ffff0001000001, 0x01ffff0001000100, 0x01ffff0001010000,
+    0x01ffff01ffffffff, 0x01ffff01ffffff01, 0x01ffff01ffff01ff, 0x01ffff01ffff0101,
+    0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff01ff01, 0x01ffff01ff0101ff,
+    0x01ffff01ff010101, 0x01ffff010000ff00, 0x01ffff01000000ff, 0x01ffff0100000100,
+    0x01ffff0100010000, 0x01ffff0101ffffff, 0x01ffff0101ffff01, 0x01ffff0101ff01ff,
+    0x01ffff0101ff0101, 0x01ffff0101000000, 0x01ffff010101ffff, 0x01ffff010101ff01,
+    0x01ffff01010101ff, 0x01ffff0101010101, 0x01ff00ffff0000ff, 0x01ff00ffff000100,
+    0x01ff00ff00ffff00, 0x01ff00ff00ff00ff, 0x01ff00ff0000ff00, 0x01ff00ff00000000,
+    0x01ff00ff00000101, 0x01ff00ff0001ff00, 0x01ff00ff000100ff, 0x01ff00ff00010100,
+    0x01ff00ff010000ff, 0x01ff00ff01000100, 0x01ff0000ffffff00, 0x01ff0000ffff0100,
+    0x01ff0000ff00ff01, 0x01ff0000ff000000, 0x01ff0000ff000101, 0x01ff0000ff010001,
+    0x01ff0000ff010100, 0x01ff000000ffffff, 0x01ff000000ffff00, 0x01ff000000ff0000,
+    0x01ff000000ff01ff, 0x01ff00000000ff00, 0x01ff0000000000ff, 0x01ff000000000000,
+    0x01ff000000000001, 0x01ff000000000100, 0x01ff000000000101, 0x01ff000000010000,
+    0x01ff000000010001, 0x01ff0000000101ff, 0x01ff000000010101, 0x01ff000001ffff00,
+    0x01ff000001ff00ff, 0x01ff000001ff0001, 0x01ff000001ff0100, 0x01ff00000100ffff,
+    0x01ff00000100ff01, 0x01ff000001000000, 0x01ff0000010001ff, 0x01ff000001010001,
+    0x01ff0001ff00ff00, 0x01ff0001ff000001, 0x01ff0001ff000100, 0x01ff0001ff010000,
+    0x01ff000100ffff00, 0x01ff000100ff00ff, 0x01ff000100ff0100, 0x01ff000100ff0101,
+    0x01ff00010000ffff, 0x01ff000100000000, 0x01ff000100000100, 0x01ff000100000101,
+    0x01ff00010001ff00, 0x01ff000100010001, 0x01ff000100010101, 0x01ff000101ff0000,
+    0x01ff00010100ff00, 0x01ff000101000101, 0x01ff0001010100ff, 0x01ff01ffffffffff,
+    0x01ff01ffffffff01, 0x01ff01ffffff01ff, 0x01ff01ffffff0101, 0x01ff01ffff000000,
+    0x01ff01ffff01ffff, 0x01ff01ffff01ff01, 0x01ff01ffff0101ff, 0x01ff01ffff010101,
+    0x01ff01ff00ffff00, 0x01ff01ff00ff0000, 0x01ff01ff0000ff00, 0x01ff01ff000000ff,
+    0x01ff01ff00000100, 0x01ff01ff00010000, 0x01ff01ff00010100, 0x01ff01ff01ffffff,
+    0x01ff01ff01ffff01, 0x01ff01ff01ff01ff, 0x01ff01ff01ff0101, 0x01ff01ff01000000,
+    0x01ff01ff0101ffff, 0x01ff01ff0101ff01, 0x01ff01ff010101ff, 0x01ff01ff01010101,
+    0x01ff0100ffff0000, 0x01ff0100ffff0001, 0x01ff0100ff00ff00, 0x01ff0100ff0000ff,
+    0x01ff0100ff000001, 0x01ff0100ff010000, 0x01ff010000ffff00, 0x01ff010000ff00ff,
+    0x01ff010000ff0001, 0x01ff010000ff0100, 0x01ff01000000ffff, 0x01ff01000000ff01,
+    0x01ff010000000000, 0x01ff010000000101, 0x01ff01000001ff00, 0x01ff0100000100ff,
+    0x01ff010001ff0000, 0x01ff010001000001, 0x01ff010001000100, 0x01ff010001010000,
+    0x01ff0101ffffffff, 0x01ff0101ffffff01, 0x01ff0101ffff01ff, 0x01ff0101ffff0101,
+    0x01ff0101ff000000, 0x01ff0101ff01ffff, 0x01ff0101ff01ff01, 0x01ff0101ff0101ff,
+    0x01ff0101ff010101, 0x01ff010100ff0000, 0x01ff01010000ff00, 0x01ff0101000000ff,
+    0x01ff010100000001, 0x01ff010101ffffff, 0x01ff010101ffff01, 0x01ff010101ff01ff,
+    0x01ff010101ff0101, 0x01ff010101000000, 0x01ff01010101ffff, 0x01ff01010101ff01,
+    0x01ff0101010101ff, 0x01ff010101010101, 0x0100ffffffff0000, 0x0100ffffff00ff00,
+    0x0100ffffff000001, 0x0100ffffff0001ff, 0x0100ffffff000100, 0x0100ffffff010000,
+    0x0100ffff00ffff00, 0x0100ffff00ff0001, 0x0100ffff00ff0100, 0x0100ffff00000000,
+    0x0100ffff000001ff, 0x0100ffff00000101, 0x0100ffff00010100, 0x0100ffff00010101,
+    0x0100ffff01ff0000, 0x0100ffff0100ff00, 0x0100ffff010000ff, 0x0100ffff01000001,
+    0x0100ffff01000100, 0x0100ffff01010000, 0x0100ff00ffffff00, 0x0100ff00ffff00ff,
+    0x0100ff00ffff0001, 0x0100ff00ffff0100, 0x0100ff00ff00ffff, 0x0100ff00ff000000,
+    0x0100ff00ff0001ff, 0x0100ff00ff000101, 0x0100ff00ff01ff00, 0x0100ff00ff0100ff,
+    0x0100ff00ff010001, 0x0100ff00ff010100, 0x0100ff0000ffffff, 0x0100ff0000ff0000,
+    0x0100ff000000ffff, 0x0100ff000000ff00, 0x0100ff00000000ff, 0x0100ff0000000000,
+    0x0100ff0000000001, 0x0100ff0000000100, 0x0100ff000001ff01, 0x0100ff0000010000,
+    0x0100ff0001ff00ff, 0x0100ff0001ff0001, 0x0100ff000100ff01, 0x0100ff0001000000,
+    0x0100ff00010001ff, 0x0100ff000101ff00, 0x0100ff00010100ff, 0x0100ff0001010001,
+    0x0100ff0001010100, 0x0100ff01ffff0000, 0x0100ff01ff00ff00, 0x0100ff01ff0000ff,
+    0x0100ff01ff000100, 0x0100ff01ff010000, 0x0100ff0100ff00ff, 0x0100ff0100ff0001,
+    0x0100ff0100ff0100, 0x0100ff010000ffff, 0x0100ff010000ff01, 0x0100ff0100000000,
+    0x0100ff01000001ff, 0x0100ff0100010001, 0x0100ff0100010100, 0x0100ff0101ff0000,
+    0x0100ff01010000ff, 0x0100ff0101000001, 0x0100ff0101010100, 0x010000ffffffff00,
+    0x010000ffffff00ff, 0x010000ffffff0001, 0x010000ffff00ffff, 0x010000ffff000000,
+    0x010000ffff0001ff, 0x010000ffff010001, 0x010000ff00ffffff, 0x010000ff00ff0101,
+    0x010000ff0000ff00, 0x010000ff000000ff, 0x010000ff00000000, 0x010000ff00000001,
+    0x010000ff000001ff, 0x010000ff00000100, 0x010000ff0001ffff, 0x010000ff0001ff00,
+    0x010000ff0001ff01, 0x010000ff00010000, 0x010000ff01ff00ff, 0x010000ff01ff0001,
+    0x010000ff0100ff01, 0x010000ff010000ff, 0x010000ff01000000, 0x010000ff010001ff,
+    0x010000ff0101ff00, 0x010000ff01010100, 0x01000000ffffffff, 0x01000000ffff0000,
+    0x01000000ffff01ff, 0x01000000ffff0101, 0x01000000ff00ffff, 0x01000000ff00ff00,
+    0x01000000ff0000ff, 0x01000000ff000000, 0x01000000ff000001, 0x01000000ff000100,
+    0x01000000ff01ff00, 0x01000000ff010000, 0x01000000ff010100, 0x01000000ff010101,
+    0x0100000000ffff00, 0x0100000000ff00ff, 0x0100000000ff0000, 0x0100000000ff0001,
+    0x0100000000ff0100, 0x010000000000ffff, 0x010000000000ff00, 0x010000000000ff01,
+    0x01000000000000ff, 0x0100000000000000, 0x0100000000000001, 0x01000000000001ff,
+    0x0100000000000100, 0x0100000000000101, 0x010000000001ff00, 0x01000000000100ff,
+    0x0100000000010000, 0x0100000000010001, 0x0100000000010100, 0x0100000001ffff00,
+    0x0100000001ff0000, 0x0100000001ff01ff, 0x010000000100ff00, 0x010000000100ff01,
+    0x01000000010000ff, 0x0100000001000000, 0x0100000001000001, 0x0100000001000100,
+    0x0100000001000101, 0x010000000101ffff, 0x010000000101ff01, 0x0100000001010000,
+    0x01000000010101ff, 0x0100000001010101, 0x01000001ffffff00, 0x01000001ffff00ff,
+    0x01000001ff00ffff, 0x01000001ff000000, 0x01000001ff000100, 0x01000001ff01ffff,
+    0x01000001ff010001, 0x01000001ff010100, 0x0100000100ff0000, 0x0100000100ff01ff,
+    0x0100000100ff0100, 0x010000010000ff00, 0x010000010000ff01, 0x0100000100000000,
+    0x0100000100000001, 0x0100000100000100, 0x0100000100010000, 0x01000001000101ff,
+    0x0100000101ffff01, 0x0100000101ff00ff, 0x0100000101ff0100, 0x0100000101ff0101,
+    0x010000010100ff01, 0x01000001010000ff, 0x0100000101000000, 0x01000001010100ff,
+    0x0100000101010001, 0x0100000101010100, 0x010001ffffff0000, 0x010001ffff000001,
+    0x010001ffff000100, 0x010001ffff010000, 0x010001ff00ffff00, 0x010001ff00ff0001,
+    0x010001ff0000ffff, 0x010001ff0000ff01, 0x010001ff00000000, 0x010001ff00000001,
+    0x010001ff00000101, 0x010001ff000100ff, 0x010001ff00010000, 0x010001ff01ff0000,
+    0x010001ff0100ff00, 0x010001ff01000001, 0x010001ff01000100, 0x010001ff01010000,
+    0x01000100ffff00ff, 0x01000100ffff0001, 0x01000100ffff0100, 0x01000100ff00ffff,
+    0x01000100ff00ff01, 0x01000100ff000000, 0x01000100ff0001ff, 0x01000100ff000101,
+    0x01000100ff01ffff, 0x01000100ff01ff00, 0x01000100ff0100ff, 0x01000100ff010001,
+    0x0100010000ffffff, 0x0100010000ffff01, 0x0100010000ff0000, 0x0100010000ff01ff,
+    0x0100010000ff0101, 0x010001000000ff00, 0x01000100000000ff, 0x0100010000000000,
+    0x0100010000000001, 0x0100010000000100, 0x010001000001ff01, 0x0100010000010000,
+    0x0100010000010001, 0x0100010000010101, 0x0100010001ffff00, 0x0100010001ff00ff,
+    0x010001000100ffff, 0x010001000100ff01, 0x0100010001000000, 0x0100010001000101,
+    0x010001000101ff00, 0x0100010001010001, 0x01000101ffff0000, 0x01000101ff000000,
+    0x01000101ff010000, 0x0100010100ff00ff, 0x0100010100ff0001, 0x0100010100ff0100,
+    0x010001010000ffff, 0x0100010100000000, 0x01000101000001ff, 0x010001010001ff00,
+    0x0100010101ff0000, 0x010001010100ff00, 0x01000101010000ff, 0x0100010101000000,
+    0x0100010101000001, 0x0101ffffffffffff, 0x0101ffffffffff01, 0x0101ffffffff01ff,
+    0x0101ffffffff0101, 0x0101ffffff000000, 0x0101ffffff01ffff, 0x0101ffffff01ff01,
+    0x0101ffffff0101ff, 0x0101ffffff010101, 0x0101ffff00ff0000, 0x0101ffff0000ff00,
+    0x0101ffff000000ff, 0x0101ffff00000001, 0x0101ffff00000100, 0x0101ffff01ffffff,
+    0x0101ffff01ffff01, 0x0101ffff01ff01ff, 0x0101ffff01ff0101, 0x0101ffff01000000,
+    0x0101ffff0101ffff, 0x0101ffff0101ff01, 0x0101ffff010101ff, 0x0101ffff01010101,
+    0x0101ff00ffff0000, 0x0101ff00ffff0100, 0x0101ff00ff00ff00, 0x0101ff00ff0000ff,
+    0x0101ff00ff000001, 0x0101ff00ff000100, 0x0101ff00ff000101, 0x0101ff0000ff0001,
+    0x0101ff0000ff0100, 0x0101ff000000ff00, 0x0101ff0000000000, 0x0101ff00000001ff,
+    0x0101ff0000000101, 0x0101ff000001ff00, 0x0101ff00000100ff, 0x0101ff0001ff0000,
+    0x0101ff000100ffff, 0x0101ff000100ff01, 0x0101ff0001000001, 0x0101ff0001000100,
+    0x0101ff01ffffff01, 0x0101ff01ffff01ff, 0x0101ff01ffff0101, 0x0101ff01ff00ffff,
+    0x0101ff01ff000100, 0x0101ff01ff01ff01, 0x0101ff01ff0101ff, 0x0101ff01ff010101,
+    0x0101ff0100ff0000, 0x0101ff010000ff00, 0x0101ff0100000001, 0x0101ff0100000100,
+    0x0101ff0100010000, 0x0101ff0101ffffff, 0x0101ff0101ffff01, 0x0101ff0101ff01ff,
+    0x0101ff0101ff0101, 0x0101ff0101000000, 0x0101ff010101ffff, 0x0101ff010101ff01,
+    0x0101ff01010101ff, 0x0101ff0101010101, 0x010100ffff000100, 0x010100ffff010000,
+    0x010100ff00ffff00, 0x010100ff00ff00ff, 0x010100ff0000ffff, 0x010100ff000000ff,
+    0x010100ff00000000, 0x010100ff000001ff, 0x010100ff00000101, 0x010100ff0001ff00,
+    0x010100ff00010000, 0x010100ff00010001, 0x010100ff000101ff, 0x010100ff00010100,
+    0x010100ff01ff0000, 0x01010000ffff0001, 0x01010000ffff0100, 0x01010000ff00ffff,
+    0x01010000ff00ff01, 0x01010000ff000000, 0x01010000ff0001ff, 0x01010000ff010001,
+    0x01010000ff010100, 0x0101000000ffff01, 0x0101000000ff0000, 0x010100000000ff00,
+    0x01010000000000ff, 0x0101000000000000, 0x0101000000000001, 0x0101000000000100,
+    0x0101000000010000, 0x0101000000010101, 0x0101000001ffff00, 0x0101000001ff00ff,
+    0x0101000001ff0000, 0x0101000001ff0001, 0x0101000001ff0100, 0x010100000100ff01,
+    0x0101000001000000, 0x01010000010001ff, 0x01010001ffff0000, 0x01010001ff00ff00,
+    0x01010001ff000001, 0x01010001ff000101, 0x01010001ff01ff00, 0x01010001ff010000,
+    0x0101000100ff00ff, 0x0101000100ff0001, 0x0101000100ff0101, 0x010100010000ff01,
+    0x0101000100000000, 0x0101000100000001, 0x01010001000001ff, 0x010100010001ffff,
+    0x010100010001ff01, 0x0101000101ff0001, 0x010100010100ffff, 0x0101000101000000,
+    0x0101000101000001, 0x0101000101000100, 0x010100010101ff00, 0x01010001010100ff,
+    0x0101000101010001, 0x010101ffffffffff, 0x010101ffffffff01, 0x010101ffffff01ff,
+    0x010101ffffff0101, 0x010101ffff01ffff, 0x010101ffff01ff01, 0x010101ffff0101ff,
+    0x010101ffff010101, 0x010101ff0000ff00, 0x010101ff000000ff, 0x010101ff00000001,
+    0x010101ff00000100, 0x010101ff01ffffff, 0x010101ff01ffff01, 0x010101ff01ff01ff,
+    0x010101ff01ff0101, 0x010101ff01000000, 0x010101ff0101ffff, 0x010101ff0101ff01,
+    0x010101ff010101ff, 0x010101ff01010101, 0x01010100ffff0000, 0x01010100ff0000ff,
+    0x01010100ff000100, 0x01010100ff01ff00, 0x01010100ff010000, 0x0101010000ffff00,
+    0x010101000000ffff, 0x0101010000000000, 0x0101010000000101, 0x010101000001ff00,
+    0x0101010000010001, 0x0101010000010100, 0x010101000100ffff, 0x0101010001000001,
+    0x01010101ffffffff, 0x01010101ffffff01, 0x01010101ffff01ff, 0x01010101ffff0101,
+    0x01010101ff01ffff, 0x01010101ff01ff01, 0x01010101ff0101ff, 0x01010101ff010101,
+    0x010101010000ff00, 0x01010101000000ff, 0x0101010100000001, 0x0101010101ffffff,
+    0x0101010101ffff01, 0x0101010101ff01ff, 0x0101010101ff0101, 0x0101010101000000,
+    0x010101010101ffff, 0x010101010101ff01, 0x01010101010101ff, 0x0101010101010101,
+GGML_TABLE_END()
+#else
+GGML_TABLE_BEGIN(uint32_t, iq1s_grid_gpu, NGRID_IQ1S)
+    0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,
+    0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,
+    0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200,
+    0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212,
+    0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011,
+    0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111,
+    0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220,
+    0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022,
+    0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220,
+    0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101,
+    0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110,
+    0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111,
+    0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010,
+    0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210,
+    0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221,
+    0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021,
+    0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002,
+    0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101,
+    0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101,
+    0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211,
+    0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110,
+    0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022,
+    0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121,
+    0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220,
+    0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001,
+    0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101,
+    0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102,
+    0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012,
+    0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010,
+    0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111,
+    0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122,
+    0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222,
+    0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001,
+    0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102,
+    0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101,
+    0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000,
+    0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101,
+    0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112,
+    0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110,
+    0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211,
+    0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012,
+    0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111,
+    0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120,
+    0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122,
+    0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121,
+    0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221,
+    0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001,
+    0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101,
+    0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101,
+    0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011,
+    0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111,
+    0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011,
+    0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122,
+    0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121,
+    0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222,
+    0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101,
+    0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000,
+    0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200,
+    0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110,
+    0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112,
+    0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222,
+    0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021,
+    0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121,
+    0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201,
+    0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200,
+    0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101,
+    0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011,
+    0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010,
+    0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211,
+    0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121,
+    0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000,
+    0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202,
+    0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202,
+    0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211,
+    0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112,
+    0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020,
+    0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121,
+    0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222,
+    0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102,
+    0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100,
+    0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110,
+    0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011,
+    0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111,
+    0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110,
+    0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121,
+    0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222,
+    0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201,
+    0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102,
+    0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201,
+    0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012,
+    0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010,
+    0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010,
+    0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110,
+    0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011,
+    0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212,
+    0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021,
+    0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021,
+    0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021,
+    0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101,
+    0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101,
+    0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100,
+    0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010,
+    0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111,
+    0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010,
+    0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111,
+    0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120,
+    0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120,
+    0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101,
+    0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001,
+    0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201,
+    0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210,
+    0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211,
+    0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111,
+    0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112,
+    0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211,
+    0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010,
+    0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021,
+    0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122,
+    0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221,
+    0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102,
+    0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100,
+    0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101,
+    0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101,
+    0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101,
+    0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012,
+    0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110,
+    0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112,
+    0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210,
+    0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210,
+    0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210,
+    0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010,
+    0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110,
+    0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122,
+    0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020,
+    0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021,
+    0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022,
+    0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120,
+    0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222,
+    0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221,
+    0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001,
+    0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102,
+    0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201,
+    0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012,
+    0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111,
+    0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012,
+    0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110,
+    0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110,
+    0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121,
+    0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221,
+    0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220,
+    0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222,
+    0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000,
+    0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201,
+    0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012,
+    0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011,
+    0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212,
+    0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221,
+    0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121,
+    0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202,
+    0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202,
+    0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002,
+    0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101,
+    0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210,
+    0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112,
+    0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011,
+    0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011,
+    0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210,
+    0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020,
+    0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220,
+    0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222,
+    0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222,
+    0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001,
+    0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010,
+    0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111,
+    0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010,
+    0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110,
+    0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221,
+    0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122,
+    0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202,
+    0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100,
+    0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101,
+    0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112,
+    0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111,
+    0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211,
+    0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222,
+    0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221,
+    0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022,
+    0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101,
+    0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211,
+    0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111,
+    0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111,
+    0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010,
+    0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121,
+    0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222,
+    0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000,
+    0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202,
+    0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000,
+    0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202,
+    0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110,
+    0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110,
+    0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222,
+    0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120,
+    0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022,
+    0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101,
+    0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202,
+    0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110,
+    0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110,
+    0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111,
+    0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111,
+    0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120,
+    0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121,
+    0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001,
+    0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202,
+    0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001,
+    0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200,
+    0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011,
+    0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212,
+    0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012,
+    0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110,
+    0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012,
+    0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111,
+    0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020,
+    0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121,
+    0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222,
+    0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102,
+    0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102,
+    0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101,
+    0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212,
+    0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210,
+    0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111,
+    0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212,
+    0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221,
+    0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121,
+    0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002,
+    0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000,
+    0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202,
+    0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112,
+    0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111,
+    0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020,
+    0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221,
+    0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022,
+    0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100,
+    0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201,
+    0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112,
+    0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211,
+    0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012,
+    0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121,
+    0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020,
+    0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120,
+    0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200,
+    0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200,
+    0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110,
+    0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011,
+    0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222,
+    0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,
+    0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,
+GGML_TABLE_END()
+#endif
+
+#endif // GGML_COMMON_IMPL
+#endif // GGML_COMMON_IMPL

+ 3147 - 0
llama/ggml-cuda.cu

@@ -0,0 +1,3147 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ggml-cuda.h"
+#include "ggml.h"
+#include "ggml-backend-impl.h"
+
+#include "ggml-cuda/common.cuh"
+#include "ggml-cuda/acc.cuh"
+#include "ggml-cuda/arange.cuh"
+#include "ggml-cuda/argsort.cuh"
+#include "ggml-cuda/binbcast.cuh"
+#include "ggml-cuda/clamp.cuh"
+#include "ggml-cuda/concat.cuh"
+#include "ggml-cuda/conv-transpose-1d.cuh"
+#include "ggml-cuda/convert.cuh"
+#include "ggml-cuda/cpy.cuh"
+#include "ggml-cuda/cross-entropy-loss.cuh"
+#include "ggml-cuda/diagmask.cuh"
+#include "ggml-cuda/dmmv.cuh"
+#include "ggml-cuda/fattn.cuh"
+#include "ggml-cuda/getrows.cuh"
+#include "ggml-cuda/im2col.cuh"
+#include "ggml-cuda/mmq.cuh"
+#include "ggml-cuda/mmvq.cuh"
+#include "ggml-cuda/norm.cuh"
+#include "ggml-cuda/pad.cuh"
+#include "ggml-cuda/pool2d.cuh"
+#include "ggml-cuda/quantize.cuh"
+#include "ggml-cuda/rope.cuh"
+#include "ggml-cuda/scale.cuh"
+#include "ggml-cuda/softmax.cuh"
+#include "ggml-cuda/sumrows.cuh"
+#include "ggml-cuda/tsembd.cuh"
+#include "ggml-cuda/unary.cuh"
+#include "ggml-cuda/upscale.cuh"
+
+#include <algorithm>
+#include <array>
+#include <atomic>
+#include <cinttypes>
+#include <cstddef>
+#include <cstdint>
+#include <float.h>
+#include <limits>
+#include <map>
+#include <memory>
+#include <mutex>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdarg.h>
+#include <stdlib.h>
+#include <string>
+#include <vector>
+
+static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
+
+static void ggml_cuda_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
+    GGML_UNUSED(level);
+    GGML_UNUSED(user_data);
+    fprintf(stderr, "%s", msg);
+}
+
+ggml_log_callback ggml_cuda_log_callback = ggml_cuda_default_log_callback;
+void * ggml_cuda_log_user_data = NULL;
+
+GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data) {
+    ggml_cuda_log_callback = log_callback;
+    ggml_cuda_log_user_data = user_data;
+}
+
+#define GGML_CUDA_LOG_INFO(...) ggml_cuda_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
+#define GGML_CUDA_LOG_WARN(...) ggml_cuda_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
+#define GGML_CUDA_LOG_ERROR(...) ggml_cuda_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
+
+GGML_ATTRIBUTE_FORMAT(2, 3)
+static void ggml_cuda_log(enum ggml_log_level level, const char * format, ...) {
+    if (ggml_cuda_log_callback != NULL) {
+        va_list args;
+        va_start(args, format);
+        char buffer[128];
+        int len = vsnprintf(buffer, 128, format, args);
+        if (len < 128) {
+            ggml_cuda_log_callback(level, buffer, ggml_cuda_log_user_data);
+        } else {
+            std::vector<char> buffer2(len + 1);  // vsnprintf adds a null terminator
+            va_end(args);
+            va_start(args, format);
+            vsnprintf(&buffer2[0], buffer2.size(), format, args);
+            ggml_cuda_log_callback(level, buffer2.data(), ggml_cuda_log_user_data);
+        }
+        va_end(args);
+    }
+}
+
+[[noreturn]]
+void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
+    int id = -1; // in case cudaGetDevice fails
+    cudaGetDevice(&id);
+
+    GGML_CUDA_LOG_ERROR("CUDA error: %s\n", msg);
+    GGML_CUDA_LOG_ERROR("  current device: %d, in function %s at %s:%d\n", id, func, file, line);
+    GGML_CUDA_LOG_ERROR("  %s\n", stmt);
+    // abort with GGML_ASSERT to get a stack trace
+    GGML_ABORT("CUDA error");
+}
+
+// this is faster on Windows
+// probably because the Windows CUDA libraries forget to make this check before invoking the drivers
+void ggml_cuda_set_device(int device) {
+    int current_device;
+    CUDA_CHECK(cudaGetDevice(&current_device));
+
+    if (device == current_device) {
+        return;
+    }
+
+    CUDA_CHECK(cudaSetDevice(device));
+}
+
+int ggml_cuda_get_device() {
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+    return id;
+}
+
+static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
+    ggml_cuda_set_device(device);
+#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
+    auto res = hipMallocManaged(ptr, size);
+    if (res == hipSuccess) {
+        // if error we "need" to know why...
+        CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
+    }
+    return res;
+#else
+
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+    cudaError_t err;
+    if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
+    {
+        err = cudaMallocManaged(ptr, size);
+    }
+    else
+    {
+        err = cudaMalloc(ptr, size);
+    }
+    return err;
+#else
+    return cudaMalloc(ptr, size);
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+
+#endif
+}
+
+static ggml_cuda_device_info ggml_cuda_init() {
+#ifdef __HIP_PLATFORM_AMD__
+    // Workaround for a rocBLAS bug when using multiple graphics cards:
+    // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
+    rocblas_initialize();
+    CUDA_CHECK(cudaDeviceSynchronize());
+#endif
+
+    ggml_cuda_device_info info = {};
+
+    cudaError_t err = cudaGetDeviceCount(&info.device_count);
+    if (err != cudaSuccess) {
+        GGML_CUDA_LOG_ERROR("%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err));
+        return info;
+    }
+
+    GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
+
+    int64_t total_vram = 0;
+#ifdef GGML_CUDA_FORCE_MMQ
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:    yes\n", __func__);
+#else
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:    no\n", __func__);
+#endif // GGML_CUDA_FORCE_MMQ
+#ifdef GGML_CUDA_FORCE_CUBLAS
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
+#else
+    GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
+#endif // GGML_CUDA_FORCE_CUBLAS
+    GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
+    for (int id = 0; id < info.device_count; ++id) {
+        int device_vmm = 0;
+
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+        CUdevice device;
+        CU_CHECK(cuDeviceGet(&device, id));
+        CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
+
+        if (device_vmm) {
+            CUmemAllocationProp alloc_prop = {};
+            alloc_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
+            alloc_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+            alloc_prop.location.id = id;
+            CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
+        }
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+        info.devices[id].vmm = !!device_vmm;
+
+        cudaDeviceProp prop;
+        CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
+        GGML_CUDA_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
+
+        info.default_tensor_split[id] = total_vram;
+        total_vram += prop.totalGlobalMem;
+
+        info.devices[id].nsm   = prop.multiProcessorCount;
+        info.devices[id].smpb  = prop.sharedMemPerBlock;
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+        info.devices[id].smpbo = prop.sharedMemPerBlock;
+        info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
+#else
+        info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
+        info.devices[id].cc = 100*prop.major + 10*prop.minor;
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    }
+
+    for (int id = 0; id < info.device_count; ++id) {
+        info.default_tensor_split[id] /= total_vram;
+    }
+
+    // configure logging to stdout
+    // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
+
+    return info;
+}
+
+const ggml_cuda_device_info & ggml_cuda_info() {
+    static ggml_cuda_device_info info = ggml_cuda_init();
+    return info;
+}
+
+// #define DEBUG_CUDA_MALLOC
+
+// buffer pool for cuda (legacy)
+struct ggml_cuda_pool_leg : public ggml_cuda_pool {
+    static const int MAX_BUFFERS = 256;
+
+    int device;
+    struct ggml_cuda_buffer {
+        void * ptr = nullptr;
+        size_t size = 0;
+    };
+
+    ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
+    size_t pool_size = 0;
+
+    explicit ggml_cuda_pool_leg(int device) :
+        device(device) {
+    }
+
+    ~ggml_cuda_pool_leg() {
+        ggml_cuda_set_device(device);
+        for (int i = 0; i < MAX_BUFFERS; ++i) {
+            ggml_cuda_buffer & b = buffer_pool[i];
+            if (b.ptr != nullptr) {
+                CUDA_CHECK(cudaFree(b.ptr));
+                pool_size -= b.size;
+            }
+        }
+        GGML_ASSERT(pool_size == 0);
+    }
+
+    void * alloc(size_t size, size_t * actual_size) override {
+#ifdef DEBUG_CUDA_MALLOC
+        int nnz = 0;
+        size_t max_size = 0;
+#endif
+        size_t best_diff = 1ull << 36;
+        int ibest = -1;
+        for (int i = 0; i < MAX_BUFFERS; ++i) {
+            ggml_cuda_buffer& b = buffer_pool[i];
+            if (b.ptr != nullptr) {
+#ifdef DEBUG_CUDA_MALLOC
+                ++nnz;
+                if (b.size > max_size) max_size = b.size;
+#endif
+                if (b.size >= size) {
+                    size_t diff = b.size - size;
+                    if (diff < best_diff) {
+                        best_diff = diff;
+                        ibest = i;
+                        if (!best_diff) {
+                            void * ptr = b.ptr;
+                            *actual_size = b.size;
+                            b.ptr = nullptr;
+                            b.size = 0;
+                            return ptr;
+                        }
+                    }
+                }
+            }
+        }
+        if (ibest >= 0) {
+            ggml_cuda_buffer& b = buffer_pool[ibest];
+            void * ptr = b.ptr;
+            *actual_size = b.size;
+            b.ptr = nullptr;
+            b.size = 0;
+            return ptr;
+        }
+        void * ptr;
+        size_t look_ahead_size = (size_t) (1.05 * size);
+        look_ahead_size = 256 * ((look_ahead_size + 255)/256);
+        ggml_cuda_set_device(device);
+        CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
+        *actual_size = look_ahead_size;
+        pool_size += look_ahead_size;
+#ifdef DEBUG_CUDA_MALLOC
+        GGML_CUDA_LOG_INFO("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz,
+                           (uint32_t)(max_size / 1024 / 1024), (uint32_t)(pool_size / 1024 / 1024), (uint32_t)(size / 1024 / 1024));
+#endif
+        return ptr;
+    }
+
+    void free(void * ptr, size_t size) override {
+        for (int i = 0; i < MAX_BUFFERS; ++i) {
+            ggml_cuda_buffer& b = buffer_pool[i];
+            if (b.ptr == nullptr) {
+                b.ptr = ptr;
+                b.size = size;
+                return;
+            }
+        }
+        GGML_CUDA_LOG_WARN("Cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
+        ggml_cuda_set_device(device);
+        CUDA_CHECK(cudaFree(ptr));
+        pool_size -= size;
+    }
+};
+
+// pool with virtual memory
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
+    static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
+
+    int device;
+    CUdeviceptr pool_addr = 0;
+    size_t pool_used = 0;
+    size_t pool_size = 0;
+    size_t granularity;
+
+    explicit ggml_cuda_pool_vmm(int device) :
+        device(device),
+        granularity(ggml_cuda_info().devices[device].vmm_granularity) {
+    }
+
+    ~ggml_cuda_pool_vmm() {
+        if (pool_addr != 0) {
+            CU_CHECK(cuMemUnmap(pool_addr, pool_size));
+            CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE));
+        }
+    }
+
+    void * alloc(size_t size, size_t * actual_size) override {
+        // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
+        const size_t alignment = 128;
+        size = alignment * ((size + alignment - 1) / alignment);
+
+        size_t avail = pool_size - pool_used;
+
+        if (size > avail) {
+            // round up to the next multiple of the granularity
+            size_t reserve_size = size - avail;
+            reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
+
+            GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
+
+            // allocate more physical memory
+            CUmemAllocationProp prop = {};
+            prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
+            prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+            prop.location.id = device;
+            CUmemGenericAllocationHandle handle;
+            CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
+
+            // reserve virtual address space (if not already reserved)
+            if (pool_addr == 0) {
+                CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
+            }
+
+            // map at the end of the pool
+            CU_CHECK(cuMemMap(pool_addr + pool_size, reserve_size, 0, handle, 0));
+
+            // the memory allocation handle is no longer needed after mapping
+            CU_CHECK(cuMemRelease(handle));
+
+            // set access
+            CUmemAccessDesc access = {};
+            access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+            access.location.id = device;
+            access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
+            CU_CHECK(cuMemSetAccess(pool_addr + pool_size, reserve_size, &access, 1));
+
+            // add to the pool
+            pool_size += reserve_size;
+
+            //printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
+            //       device, (unsigned long long) (pool_size/1024/1024),
+            //       (unsigned long long) (reserve_size/1024/1024));
+        }
+
+        GGML_ASSERT(pool_addr != 0);
+
+        void * ptr = (void *) (pool_addr + pool_used);
+        *actual_size = size;
+        pool_used += size;
+
+#ifdef DEBUG_CUDA_MALLOC
+        printf("cuda pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
+#endif
+
+        return ptr;
+    }
+
+    void free(void * ptr, size_t size) override {
+#ifdef DEBUG_CUDA_MALLOC
+        printf("cuda pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
+#endif
+
+        pool_used -= size;
+
+        // all deallocations must be in reverse order of the allocations
+        GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
+    }
+};
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+
+std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+    if (ggml_cuda_info().devices[device].vmm) {
+        return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
+    }
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+    return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
+}
+
+// cuda buffer
+
+struct ggml_backend_cuda_buffer_context {
+    int device;
+    void * dev_ptr = nullptr;
+    std::string name;
+
+    ggml_backend_cuda_buffer_context(int device, void * dev_ptr) :
+        device(device), dev_ptr(dev_ptr),
+        name(GGML_CUDA_NAME + std::to_string(device)) {
+    }
+
+    ~ggml_backend_cuda_buffer_context() {
+        CUDA_CHECK(cudaFree(dev_ptr));
+    }
+};
+
+GGML_CALL static const char * ggml_backend_cuda_buffer_get_name(ggml_backend_buffer_t buffer) {
+    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+    return ctx->name.c_str();
+}
+
+GGML_CALL static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
+    return buffer->iface.get_name == ggml_backend_cuda_buffer_get_name;
+}
+
+GGML_CALL static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+    delete ctx;
+
+    // TODO: this needs to be freed in cuda and hipblas backends because
+    // the cuda backend implementation compiled with msvc
+    free(buffer);
+}
+
+GGML_CALL static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
+    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+    return ctx->dev_ptr;
+}
+
+GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+    if (tensor->view_src != NULL) {
+        assert(tensor->view_src->buffer->buft == buffer->buft);
+        return;
+    }
+
+    if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
+        // initialize padding to 0 to avoid possible NaN values
+        size_t original_size = ggml_nbytes(tensor);
+        size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
+
+        if (padded_size > original_size) {
+            ggml_cuda_set_device(ctx->device);
+            CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size));
+        }
+    }
+}
+
+GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+    ggml_cuda_set_device(ctx->device);
+    CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
+    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+}
+
+GGML_CALL static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+    ggml_cuda_set_device(ctx->device);
+    CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
+    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+}
+
+GGML_CALL static bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
+    if (ggml_backend_buffer_is_cuda(src->buffer)) {
+        ggml_backend_cuda_buffer_context * src_ctx = (ggml_backend_cuda_buffer_context *)src->buffer->context;
+        ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *)dst->buffer->context;
+        if (src_ctx->device == dst_ctx->device) {
+            CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(src), cudaMemcpyDeviceToDevice, cudaStreamPerThread));
+        } else {
+#ifdef GGML_CUDA_NO_PEER_COPY
+            return false;
+#else
+            CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, dst_ctx->device, src->data, src_ctx->device, ggml_nbytes(src), cudaStreamPerThread));
+#endif
+        }
+        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+        return true;
+    }
+    return false;
+
+    GGML_UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+    ggml_cuda_set_device(ctx->device);
+    CUDA_CHECK(cudaDeviceSynchronize());
+    CUDA_CHECK(cudaMemset(ctx->dev_ptr, value, buffer->size));
+    CUDA_CHECK(cudaDeviceSynchronize());
+}
+
+static ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
+    /* .get_name        = */ ggml_backend_cuda_buffer_get_name,
+    /* .free_buffer     = */ ggml_backend_cuda_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_cuda_buffer_get_base,
+    /* .init_tensor     = */ ggml_backend_cuda_buffer_init_tensor,
+    /* .set_tensor      = */ ggml_backend_cuda_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cuda_buffer_get_tensor,
+    /* .cpy_tensor      = */ ggml_backend_cuda_buffer_cpy_tensor,
+    /* .clear           = */ ggml_backend_cuda_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+// cuda buffer type
+struct ggml_backend_cuda_buffer_type_context {
+    int device;
+    std::string name;
+};
+
+GGML_CALL static const char * ggml_backend_cuda_buffer_type_name(ggml_backend_buffer_type_t buft) {
+    ggml_backend_cuda_buffer_type_context * ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
+
+    return ctx->name.c_str();
+}
+
+static bool ggml_backend_buft_is_cuda(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_cuda_buffer_type_name;
+}
+
+GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
+
+    ggml_cuda_set_device(buft_ctx->device);
+
+    size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
+
+    void * dev_ptr;
+    cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
+    if (err != cudaSuccess) {
+        // clear the error
+        cudaGetLastError();
+        GGML_CUDA_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
+        return nullptr;
+    }
+
+    ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);
+
+    return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
+}
+
+GGML_CALL static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    return 128;
+
+    GGML_UNUSED(buft);
+}
+
+GGML_CALL static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+    size_t size = ggml_nbytes(tensor);
+    int64_t ne0 = tensor->ne[0];
+
+    if (ggml_is_quantized(tensor->type)) {
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+        }
+    }
+
+    return size;
+
+    GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
+    /* .get_name         = */ ggml_backend_cuda_buffer_type_name,
+    /* .alloc_buffer     = */ ggml_backend_cuda_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_cuda_buffer_type_get_alignment,
+    /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
+    /* .get_alloc_size   = */ ggml_backend_cuda_buffer_type_get_alloc_size,
+    /* .is_host          = */ NULL,
+};
+
+GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
+    static std::mutex mutex;
+    std::lock_guard<std::mutex> lock(mutex);
+
+    if (device >= ggml_backend_cuda_get_device_count()) {
+        return nullptr;
+    }
+
+    static ggml_backend_buffer_type ggml_backend_cuda_buffer_types[GGML_CUDA_MAX_DEVICES];
+
+    static bool ggml_backend_cuda_buffer_type_initialized = false;
+
+    if (!ggml_backend_cuda_buffer_type_initialized) {
+        for (int i = 0; i < GGML_CUDA_MAX_DEVICES; i++) {
+            ggml_backend_cuda_buffer_types[i] = {
+                /* .iface    = */ ggml_backend_cuda_buffer_type_interface,
+                /* .context  = */ new ggml_backend_cuda_buffer_type_context{i, GGML_CUDA_NAME + std::to_string(i)},
+            };
+        }
+        ggml_backend_cuda_buffer_type_initialized = true;
+    }
+
+    return &ggml_backend_cuda_buffer_types[device];
+}
+
+// cuda split buffer
+
+static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
+    int64_t row_rounding = 0;
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
+            continue;
+        }
+
+        const int cc = ggml_cuda_info().devices[id].cc;
+        row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));
+    }
+    return row_rounding;
+}
+
+static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) {
+    const int64_t nrows = ggml_nrows(tensor);
+    const int64_t rounding = get_row_rounding(tensor_split);
+
+    *row_low = id == 0 ? 0 : nrows*tensor_split[id];
+    *row_low -= *row_low % rounding;
+
+    if (id == ggml_backend_cuda_get_device_count() - 1) {
+        *row_high = nrows;
+    } else {
+        *row_high = nrows*tensor_split[id + 1];
+        *row_high -= *row_high % rounding;
+    }
+}
+
+static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
+}
+
+struct ggml_backend_cuda_split_buffer_type_context {
+    std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
+};
+
+struct ggml_backend_cuda_split_buffer_context {
+    ~ggml_backend_cuda_split_buffer_context() {
+        for (ggml_tensor_extra_gpu * extra : tensor_extras) {
+            for (int id = 0; id < GGML_CUDA_MAX_DEVICES; ++id) {
+                for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) {
+                    if (extra->events[id][is] != nullptr) {
+                        CUDA_CHECK(cudaEventDestroy(extra->events[id][is]));
+                    }
+                }
+                if (extra->data_device[id] != nullptr) {
+                    CUDA_CHECK(cudaFree(extra->data_device[id]));
+                }
+            }
+            delete extra;
+        }
+    }
+
+    std::vector<ggml_tensor_extra_gpu *> tensor_extras;
+};
+
+GGML_CALL static const char * ggml_backend_cuda_split_buffer_get_name(ggml_backend_buffer_t buffer) {
+    return GGML_CUDA_NAME "_Split";
+
+    GGML_UNUSED(buffer);
+}
+
+static bool ggml_backend_buffer_is_cuda_split(ggml_backend_buffer_t buffer) {
+    return buffer->iface.get_name == ggml_backend_cuda_split_buffer_get_name;
+    GGML_UNUSED(ggml_backend_buffer_is_cuda_split); // only used in debug builds currently, avoid unused function warning in release builds
+}
+
+GGML_CALL static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
+    delete ctx;
+}
+
+GGML_CALL static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {
+    // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced
+    return (void *)0x1000;
+
+    GGML_UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+    GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
+
+    ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
+    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
+
+    const int64_t ne0 = tensor->ne[0];
+
+    ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
+    ctx->tensor_extras.push_back(extra);
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        int64_t row_low, row_high;
+        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
+
+        int64_t nrows_split = row_high - row_low;
+        if (nrows_split == 0) {
+            continue;
+        }
+
+        size_t size = ggml_nbytes_split(tensor, nrows_split);
+        const size_t original_size = size;
+
+        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+        }
+
+        // FIXME: do not crash if cudaMalloc fails
+        // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
+        ggml_cuda_set_device(id);
+        char * buf;
+        CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
+
+        // set padding to 0 to avoid possible NaN values
+        if (size > original_size) {
+            CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size));
+        }
+
+        extra->data_device[id] = buf;
+
+        for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) {
+            CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming));
+        }
+    }
+    tensor->extra = extra;
+}
+
+GGML_CALL static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    // split tensors must always be set in their entirety at once
+    GGML_ASSERT(offset == 0);
+    GGML_ASSERT(size == ggml_nbytes(tensor));
+
+    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
+
+    const int64_t ne0 = tensor->ne[0];
+    const size_t nb1 = tensor->nb[1];
+    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        int64_t row_low, row_high;
+        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
+
+        int64_t nrows_split = row_high - row_low;
+        if (nrows_split == 0) {
+            continue;
+        }
+
+        const size_t offset_split = row_low*nb1;
+        size_t size = ggml_nbytes_split(tensor, nrows_split);
+        const size_t original_size = size;
+
+        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+        }
+
+        const char * buf_host = (const char *)data + offset_split;
+        CUDA_CHECK(cudaMemcpyAsync(extra->data_device[id], buf_host, original_size, cudaMemcpyHostToDevice, cudaStreamPerThread));
+    }
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+    }
+}
+
+GGML_CALL static void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    // split tensors must always be set in their entirety at once
+    GGML_ASSERT(offset == 0);
+    GGML_ASSERT(size == ggml_nbytes(tensor));
+
+    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
+
+    const int64_t ne0 = tensor->ne[0];
+    const size_t nb1 = tensor->nb[1];
+    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        int64_t row_low, row_high;
+        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
+
+        int64_t nrows_split = row_high - row_low;
+        if (nrows_split == 0) {
+            continue;
+        }
+
+        const size_t offset_split = row_low*nb1;
+        size_t size = ggml_nbytes_split(tensor, nrows_split);
+        const size_t original_size = size;
+
+        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+        }
+
+        char * buf_host = (char *)data + offset_split;
+        CUDA_CHECK(cudaMemcpyAsync(buf_host, extra->data_device[id], original_size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
+    }
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+    }
+}
+
+GGML_CALL static void ggml_backend_cuda_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    GGML_UNUSED(buffer);
+    GGML_UNUSED(value);
+}
+
+static struct ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
+    /* .get_name        = */ ggml_backend_cuda_split_buffer_get_name,
+    /* .free_buffer     = */ ggml_backend_cuda_split_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_cuda_split_buffer_get_base,
+    /* .init_tensor     = */ ggml_backend_cuda_split_buffer_init_tensor,
+    /* .set_tensor      = */ ggml_backend_cuda_split_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cuda_split_buffer_get_tensor,
+    /* .cpy_tensor      = */ NULL,
+    /* .clear           = */ ggml_backend_cuda_split_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+// cuda split buffer type
+
+GGML_CALL static const char * ggml_backend_cuda_split_buffer_type_name(ggml_backend_buffer_type_t buft) {
+    return GGML_CUDA_NAME "_Split";
+
+    GGML_UNUSED(buft);
+}
+
+static bool ggml_backend_buft_is_cuda_split(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_cuda_split_buffer_type_name;
+}
+
+GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
+    // instead, we allocate them for each tensor separately in init_tensor
+    // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
+    // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
+    ggml_backend_cuda_split_buffer_context * ctx = new ggml_backend_cuda_split_buffer_context();
+
+    return ggml_backend_buffer_init(buft, ggml_backend_cuda_split_buffer_interface, ctx, size);
+}
+
+GGML_CALL static size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    return 128;
+
+    GGML_UNUSED(buft);
+}
+
+GGML_CALL static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+    ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
+
+    size_t total_size = 0;
+
+    const int64_t ne0 = tensor->ne[0];
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        int64_t row_low, row_high;
+        get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, id);
+
+        int64_t nrows_split = row_high - row_low;
+        if (nrows_split == 0) {
+            continue;
+        }
+
+        total_size += ggml_nbytes_split(tensor, nrows_split);
+
+        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+        }
+    }
+
+    return total_size;
+}
+
+GGML_CALL static bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+    return false;
+
+    GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface = {
+    /* .get_name         = */ ggml_backend_cuda_split_buffer_type_name,
+    /* .alloc_buffer     = */ ggml_backend_cuda_split_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_cuda_split_buffer_type_get_alignment,
+    /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
+    /* .get_alloc_size   = */ ggml_backend_cuda_split_buffer_type_get_alloc_size,
+    /* .is_host          = */ ggml_backend_cuda_split_buffer_type_is_host,
+};
+
+GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split) {
+    static std::mutex mutex;
+    std::lock_guard<std::mutex> lock(mutex);
+
+    static std::map<std::array<float, GGML_CUDA_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
+
+    std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split_arr = {};
+
+    bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_CUDA_MAX_DEVICES, [](float x) { return x == 0.0f; });
+    if (all_zero) {
+        tensor_split_arr = ggml_cuda_info().default_tensor_split;
+    } else {
+        float split_sum = 0.0f;
+        for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
+            tensor_split_arr[i] = split_sum;
+            split_sum += tensor_split[i];
+        }
+        for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
+            tensor_split_arr[i] /= split_sum;
+        }
+    }
+
+    auto it = buft_map.find(tensor_split_arr);
+    if (it != buft_map.end()) {
+        return &it->second;
+    }
+
+    struct ggml_backend_buffer_type buft {
+        /* .iface   = */ ggml_backend_cuda_split_buffer_type_interface,
+        /* .context = */ new ggml_backend_cuda_split_buffer_type_context{tensor_split_arr},
+    };
+
+    auto result = buft_map.emplace(tensor_split_arr, buft);
+    return &result.first->second;
+}
+
+// host buffer type
+
+GGML_CALL static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
+    return GGML_CUDA_NAME "_Host";
+
+    GGML_UNUSED(buft);
+}
+
+GGML_CALL static const char * ggml_backend_cuda_host_buffer_name(ggml_backend_buffer_t buffer) {
+    return GGML_CUDA_NAME "_Host";
+
+    GGML_UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    CUDA_CHECK(cudaFreeHost(buffer->context));
+}
+
+static void * ggml_cuda_host_malloc(size_t size) {
+    if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
+        return nullptr;
+    }
+
+    void * ptr = nullptr;
+    cudaError_t err = cudaMallocHost((void **) &ptr, size);
+    if (err != cudaSuccess) {
+        // clear the error
+        cudaGetLastError();
+        GGML_CUDA_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
+                           size / 1024.0 / 1024.0, cudaGetErrorString(err));
+        return nullptr;
+    }
+
+    return ptr;
+}
+
+GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    void * ptr = ggml_cuda_host_malloc(size);
+
+    if (ptr == nullptr) {
+        // fallback to cpu buffer
+        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
+    }
+
+    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
+    buffer->buft = buft;
+    buffer->iface.get_name = ggml_backend_cuda_host_buffer_name;
+    buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
+
+    return buffer;
+}
+
+GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() {
+    static struct ggml_backend_buffer_type ggml_backend_cuda_buffer_type_host = {
+        /* .iface    = */ {
+            /* .get_name         = */ ggml_backend_cuda_host_buffer_type_name,
+            /* .alloc_buffer     = */ ggml_backend_cuda_host_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
+            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
+            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
+            /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,
+        },
+        /* .context  = */ nullptr,
+    };
+
+    return &ggml_backend_cuda_buffer_type_host;
+}
+
+//static bool ggml_backend_buffer_is_cuda_host(ggml_backend_buffer_t buffer) {
+//    return buffer->buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;
+//}
+
+/// kernels
+
+typedef void (*ggml_cuda_op_mul_mat_t)(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream);
+
+#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE
+#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
+#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE
+
+#define MUL_MAT_SRC1_COL_STRIDE 128
+
+static __global__ void mul_mat_p021_f16_f32(
+    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) {
+
+    const half * x = (const half *) vx;
+
+    const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
+    const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+    const int channel_x = channel / (nchannels_y / nchannels_x);
+
+    const int nrows_y = ncols_x;
+    const int nrows_dst = nrows_x;
+    const int row_dst = row_x;
+
+    float tmp = 0.0f;
+
+    for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
+        const int col_x = col_x0 + threadIdx.x;
+
+        if (col_x >= ncols_x) {
+            break;
+        }
+
+        // x is transposed and permuted
+        const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
+        const float xi = __half2float(x[ix]);
+
+        const int row_y = col_x;
+
+        // y is not transposed but permuted
+        const int iy = channel*nrows_y + row_y;
+
+        tmp += xi * y[iy];
+    }
+
+    // dst is not transposed and not permuted
+    const int idst = channel*nrows_dst + row_dst;
+
+    // sum up partial sums and write back result
+    tmp = warp_reduce_sum(tmp);
+
+    if (threadIdx.x == 0) {
+        dst[idst] = tmp;
+    }
+}
+
+static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
+    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
+    const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) {
+
+    const half * x = (const half *) vx;
+
+    const int row_x     = blockDim.y*blockIdx.y + threadIdx.y;
+    const int channel   = blockDim.z*blockIdx.z + threadIdx.z;
+    const int channel_x = channel / channel_x_divisor;
+
+    const int nrows_y   = ncols_x;
+    const int nrows_dst = nrows_x;
+    const int row_dst   = row_x;
+
+    const int idst = channel*nrows_dst + row_dst;
+
+    float tmp = 0.0f;
+
+    for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
+        const int col_x = col_x0 + threadIdx.x;
+
+        if (col_x >= ncols_x) {
+            break;
+        }
+
+        const int row_y = col_x;
+
+        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
+        const int iy = channel*nrows_y + row_y;
+
+        const float xi = __half2float(x[ix]);
+
+        tmp += xi * y[iy];
+    }
+
+    // sum up partial sums and write back result
+    tmp = warp_reduce_sum(tmp);
+
+    if (threadIdx.x == 0) {
+        dst[idst] = tmp;
+    }
+}
+
+static void ggml_mul_mat_p021_f16_f32_cuda(
+    const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
+    const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
+
+    const dim3 block_nums(1, nrows_x, nchannels_y);
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y);
+}
+
+static void ggml_mul_mat_vec_nc_f16_f32_cuda(
+    const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
+    const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) {
+
+    const dim3 block_nums(1, nrows_x, nchannels_y);
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
+        (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
+}
+
+static cudaError_t ggml_cuda_cpy_tensor_2d(
+    void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
+
+    GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
+    char * src_ptr = (char *) src->data;
+    char * dst_ptr = (char *) dst;
+
+    const int64_t ne0 = src->ne[0];
+    const int64_t nb0 = src->nb[0];
+    const int64_t nb1 = src->nb[1];
+    const int64_t nb2 = src->nb[2];
+    const int64_t nb3 = src->nb[3];
+    const enum ggml_type type = src->type;
+    const int64_t ts = ggml_type_size(type);
+    const int64_t bs = ggml_blck_size(type);
+    int64_t i1_diff = i1_high - i1_low;
+
+    const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
+    if (nb0 == ts && nb1 == ts*ne0/bs) {
+        return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyDeviceToDevice, stream);
+    } else if (nb0 == ts) {
+        return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyDeviceToDevice, stream);
+    } else {
+        for (int64_t i1 = 0; i1 < i1_diff; i1++) {
+            const void * rx = (const void *) ((const char *) x + i1*nb1);
+            void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
+            // pretend the row is a matrix with cols=1
+            cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyDeviceToDevice, stream);
+            if (r != cudaSuccess) {
+                return r;
+            }
+        }
+        return cudaSuccess;
+    }
+}
+
+static void ggml_cuda_op_mul_mat_cublas(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+    GGML_ASSERT(src0_dd_i  != nullptr);
+    GGML_ASSERT(src1_ddf_i != nullptr);
+    GGML_ASSERT(dst_dd_i   != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne10 = src1->ne[0];
+
+    const int64_t ne0 = dst->ne[0];
+
+    const int64_t row_diff = row_high - row_low;
+
+    int id = ggml_cuda_get_device();
+
+    // the main device has a larger memory buffer to hold the results from all GPUs
+    // ldc == nrows of the matrix that cuBLAS writes into
+    int64_t ldc = id == ctx.device ? ne0 : row_diff;
+
+    const int compute_capability = ggml_cuda_info().devices[id].cc;
+
+    if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
+        // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
+        ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
+        if (src0->type != GGML_TYPE_F16) {
+            const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
+            GGML_ASSERT(to_fp16_cuda != nullptr);
+            size_t ne = row_diff*ne00;
+            src0_as_f16.alloc(ne);
+            to_fp16_cuda(src0_dd_i, src0_as_f16.get(), ne, stream);
+        }
+        const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
+
+        ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));
+        if (src1->type != GGML_TYPE_F16) {
+            const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+            GGML_ASSERT(to_fp16_cuda != nullptr);
+            size_t ne = src1_ncols*ne10;
+            src1_as_f16.alloc(ne);
+            to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
+        }
+        const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
+        ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
+
+        const half alpha_f16 = 1.0f;
+        const half beta_f16 = 0.0f;
+
+        CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
+        CUBLAS_CHECK(
+            cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+                    row_diff, src1_ncols, ne10,
+                    &alpha_f16, src0_ptr,       CUDA_R_16F, ne00,
+                                src1_ptr,       CUDA_R_16F, ne10,
+                    &beta_f16,   dst_f16.get(), CUDA_R_16F, ldc,
+                    CUBLAS_COMPUTE_16F,
+                    CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+
+        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+        to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+    } else {
+        ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));
+        ggml_cuda_pool_alloc<float> src1_ddq_as_f32(ctx.pool(id));
+
+        if (src0->type != GGML_TYPE_F32) {
+            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
+            GGML_ASSERT(to_fp32_cuda != nullptr);
+            src0_ddq_as_f32.alloc(row_diff*ne00);
+            to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
+        }
+        if (src1->type != GGML_TYPE_F32) {
+            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
+            GGML_ASSERT(to_fp32_cuda != nullptr);
+            src1_ddq_as_f32.alloc(src1_ncols*ne10);
+            to_fp32_cuda(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
+        }
+
+        const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
+        const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
+
+        const float alpha = 1.0f;
+        const float beta = 0.0f;
+
+        CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
+        CUBLAS_CHECK(
+            cublasSgemm(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+                    row_diff, src1_ncols, ne10,
+                    &alpha, src0_ddf_i,  ne00,
+                            src1_ddf1_i, ne10,
+                    &beta,  dst_dd_i,    ldc));
+    }
+
+    GGML_UNUSED(dst);
+    GGML_UNUSED(src1_ddq_i);
+    GGML_UNUSED(src1_padded_row_size);
+}
+
+static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
+    static bool peer_access_enabled = false;
+
+    const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
+
+    if (peer_access_enabled == enable_peer_access) {
+        return;
+    }
+
+#ifdef NDEBUG
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        ggml_cuda_set_device(id);
+        CUDA_CHECK(cudaDeviceSynchronize());
+    }
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        ggml_cuda_set_device(id);
+
+        for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) {
+            if (id == id_other) {
+                continue;
+            }
+            if (id != main_device && id_other != main_device) {
+                continue;
+            }
+
+            int can_access_peer;
+            CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
+            if (can_access_peer) {
+                if (enable_peer_access) {
+                    cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0);
+                    if (err != cudaErrorPeerAccessAlreadyEnabled) {
+                        CUDA_CHECK(err);
+                    }
+                } else {
+                    cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
+                    if (err != cudaErrorPeerAccessNotEnabled) {
+                        CUDA_CHECK(err);
+                    }
+                }
+            }
+        }
+    }
+
+    ggml_cuda_set_device(main_device);
+#endif // NDEBUG
+
+    peer_access_enabled = enable_peer_access;
+
+    GGML_UNUSED(main_device);
+}
+
+static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
+    void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
+
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+    // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
+    cudaMemcpy3DPeerParms p = {};
+    p.dstDevice = dstDevice;
+    p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height);
+    p.srcDevice = srcDevice;
+    p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height);
+    p.extent = make_cudaExtent(width, height, 1);
+    return cudaMemcpy3DPeerAsync(&p, stream);
+#else
+    // HIP does not support cudaMemcpy3DPeerAsync or vmm pools
+    GGML_UNUSED(dstDevice);
+    GGML_UNUSED(srcDevice);
+    return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+}
+
+static void ggml_cuda_op_mul_mat(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
+    quantize_cuda_t quantize_src1) {
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+    const int64_t nrows1 = ggml_nrows(src1);
+
+    GGML_ASSERT(ne03 == ne13);
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+
+    const int64_t nb2 = dst->nb[2];
+    const int64_t nb3 = dst->nb[3];
+
+    GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
+    GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
+    ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
+    ggml_backend_cuda_buffer_context * dst_ctx  = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
+
+    GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
+
+    const int64_t i02_divisor = ne12 / ne02;
+
+    const size_t src0_ts = ggml_type_size(src0->type);
+    const size_t src0_bs = ggml_blck_size(src0->type);
+    const size_t q8_1_ts = sizeof(block_q8_1);
+    const size_t q8_1_bs = QK8_1;
+
+    const bool src0_is_contiguous = ggml_is_contiguous(src0);
+    const bool src1_is_contiguous = ggml_is_contiguous(src1);
+
+    const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
+
+    const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
+    GGML_ASSERT(!(split && ne02 > 1));
+    GGML_ASSERT(!(split && ne03 > 1));
+    GGML_ASSERT(!(split && ne02 < ne12));
+
+    ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr;
+
+
+    std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
+    if (split) {
+        ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
+        tensor_split = buft_ctx->tensor_split;
+    }
+
+    struct dev_data {
+        int cc;
+
+        ggml_cuda_pool_alloc<char>   src0_dd_alloc;
+        ggml_cuda_pool_alloc<float> src1_ddf_alloc;
+        ggml_cuda_pool_alloc<char>  src1_ddq_alloc;
+        ggml_cuda_pool_alloc<float>   dst_dd_alloc;
+
+        char  *  src0_dd = nullptr;
+        float * src1_ddf = nullptr; // float
+        char  * src1_ddq = nullptr; // q8_1
+        float *   dst_dd = nullptr;
+
+        int64_t  row_low;
+        int64_t row_high;
+    };
+
+    dev_data dev[GGML_CUDA_MAX_DEVICES];
+
+    int used_devices = 0;
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        dev[id].cc = ggml_cuda_info().devices[id].cc;
+
+        // by default, use all rows
+        dev[id].row_low  = 0;
+        dev[id].row_high = ne01;
+
+        // for multi GPU, get the row boundaries from tensor split
+        // and round to mul_mat_q tile sizes
+        if (split) {
+            const int64_t rounding = get_row_rounding(tensor_split);
+
+            if (id != 0) {
+                dev[id].row_low  = ne01*tensor_split[id];
+                if (dev[id].row_low < ne01) {
+                    dev[id].row_low -= dev[id].row_low % rounding;
+                }
+            }
+
+            if (id != ggml_backend_cuda_get_device_count() - 1) {
+                dev[id].row_high  = ne01*tensor_split[id + 1];
+                if (dev[id].row_high < ne01) {
+                    dev[id].row_high -= dev[id].row_high % rounding;
+                }
+            }
+        }
+    }
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
+            continue;
+        }
+
+        used_devices++;
+
+        const bool src1_on_device = id == src1_ctx->device;
+        const bool  dst_on_device = id == dst_ctx->device;
+
+        ggml_cuda_set_device(id);
+        cudaStream_t stream = ctx.stream(id, 0);
+
+        if (src0_is_contiguous) {
+            dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
+        } else {
+            dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0));
+        }
+
+        // If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared:
+        if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
+            const int64_t nbytes_data    = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
+            const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
+            CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
+        }
+
+        if (src1_on_device && src1_is_contiguous) {
+            dev[id].src1_ddf = (float *) src1->data;
+        } else {
+            dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1));
+        }
+
+        if (quantize_src1) {
+            size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
+            if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+                src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq);
+            }
+            dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
+
+            if (src1_on_device && src1_is_contiguous) {
+                quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
+                CUDA_CHECK(cudaGetLastError());
+            }
+        }
+
+        if (dst_on_device) {
+            dev[id].dst_dd = (float *) dst->data;
+        } else {
+            const size_t size_dst_ddf = split ? (dev[id].row_high - dev[id].row_low)*ne1 : ggml_nelements(dst);
+            dev[id].dst_dd = dev[id].dst_dd_alloc.alloc(ctx.pool(id), size_dst_ddf);
+        }
+    }
+
+    // if multiple devices are used they need to wait for the main device
+    // here an event is recorded that signals that the main device has finished calculating the input data
+    if (split && used_devices > 1) {
+        ggml_cuda_set_device(ctx.device);
+        CUDA_CHECK(cudaEventRecord(src0_extra->events[ctx.device][0], ctx.stream()));
+    }
+
+    const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
+    for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
+        const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_CUDA_MAX_STREAMS : 0;
+        const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
+
+        for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+            if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
+                continue;
+            }
+
+            const bool src1_on_device = id == src1_ctx->device;
+            const bool  dst_on_device = id == dst_ctx->device;
+            const int64_t row_diff = dev[id].row_high - dev[id].row_low;
+
+            ggml_cuda_set_device(id);
+            cudaStream_t stream = ctx.stream(id, is);
+
+            // wait for main GPU data if necessary
+            if (split && (id != ctx.device || is != 0)) {
+                CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[ctx.device][0], 0));
+            }
+
+            for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
+                const int64_t i03 = i0 / ne12;
+                const int64_t i02 = i0 % ne12;
+
+                size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs;
+                if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+                    src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq);
+                } else {
+                    src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs;
+                }
+
+                // for split tensors the data begins at i0 == i0_offset_low
+                char  *  src0_dd_i =  dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
+                float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
+                char  * src1_ddq_i = dev[id].src1_ddq +  src1_ddq_i_offset;
+                float *   dst_dd_i =   dev[id].dst_dd + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
+
+                // the main device memory buffer can be on VRAM scratch, with space for all partial results
+                // in that case an offset on dst_ddf_i is needed
+                if (id == ctx.device) {
+                    dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split
+                }
+
+                // copy src0, src1 to device if necessary
+                if (src1_is_contiguous) {
+                    if (id != ctx.device) {
+                        if (quantize_src1) {
+                            char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
+                            if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+                                const size_t pitch = ne11*sizeof(block_q8_1_mmq);
+                                const size_t width = src1_ncols*sizeof(block_q8_1_mmq);
+                                const size_t height = src1_padded_col_size/(4*QK8_1);
+                                CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream));
+                            } else {
+                                CUDA_CHECK(cudaMemcpyPeerAsync(
+                                    src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
+                            }
+                        } else {
+                            float * src1_ddf_i_source = (float *) src1->data;
+                            src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
+                            CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddf_i, id, src1_ddf_i_source, ctx.device,
+                                                            src1_ncols*ne10*sizeof(float), stream));
+                        }
+                    }
+                } else if (src1_on_device && !src1_is_contiguous) {
+                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
+                                src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
+                } else {
+                    GGML_ABORT("fatal error");
+                }
+
+                if (quantize_src1 && !src1_is_contiguous) {
+                    quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
+                    CUDA_CHECK(cudaGetLastError());
+                }
+
+                if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
+                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
+                }
+
+                // do the computation
+                op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
+                    dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, stream);
+                CUDA_CHECK(cudaGetLastError());
+
+                // copy dst to host or other device if necessary
+                if (!dst_on_device) {
+                    void * dst_off_device = dst->data;
+                    if (split) {
+                        // src0 = weight matrix is saved as a transposed matrix for better memory layout.
+                        // dst is NOT transposed.
+                        // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
+                        // Instead they need to be copied to the correct slice in ne0 = dst row index.
+                        // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
+                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+                        dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
+                        CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(
+                            dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream));
+                    } else {
+                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+                        dhf_dst_i += src1_col_0*ne0;
+                        CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), cudaMemcpyDeviceToDevice, stream));
+                    }
+                }
+
+                // add event for the main device to wait on until other device is done
+                if (split && (id != ctx.device || is != 0)) {
+                    CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream));
+                }
+            }
+        }
+    }
+
+    // main device waits for all other devices to be finished
+    if (split && ggml_backend_cuda_get_device_count() > 1) {
+        int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
+        is_max = is_max <= GGML_CUDA_MAX_STREAMS ? is_max : GGML_CUDA_MAX_STREAMS;
+
+        ggml_cuda_set_device(ctx.device);
+        for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+            if (dev[id].row_low == dev[id].row_high) {
+                continue;
+            }
+            for (int64_t is = 0; is < is_max; ++is) {
+                CUDA_CHECK(cudaStreamWaitEvent(ctx.stream(), src0_extra->events[id][is], 0));
+            }
+        }
+    }
+}
+
+static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
+    GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
+    GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
+    GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+
+    const int64_t ne12 = src1->ne[2];
+
+    cudaStream_t main_stream = ctx.stream();
+
+    void  * src0_ddq = src0->data;
+    float * src1_ddf = (float *) src1->data;
+    float * dst_ddf  = (float *) dst->data;
+
+    ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
+}
+
+static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(!ggml_is_transposed(src0));
+    GGML_ASSERT(!ggml_is_transposed(src1));
+    GGML_ASSERT(!ggml_is_permuted(src0));
+    GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+
+    const int64_t nb01 = src0->nb[1];
+    const int64_t nb02 = src0->nb[2];
+
+    const int64_t ne12 = src1->ne[2];
+
+    cudaStream_t main_stream = ctx.stream();
+
+    void  * src0_ddq = src0->data;
+    float * src1_ddf = (float *) src1->data;
+    float * dst_ddf  = (float *) dst->data;
+
+    const int64_t row_stride_x = nb01 / sizeof(half);
+    const int64_t channel_stride_x = nb02 / sizeof(half);
+
+    ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
+}
+
+static __global__ void k_compute_batched_ptrs(
+        const half * src0_as_f16, const half * src1_as_f16, char * dst,
+        const void ** ptrs_src, void ** ptrs_dst,
+        int64_t ne12, int64_t ne13,
+        int64_t ne23,
+        size_t  nb02, size_t  nb03,
+        size_t  nb12, size_t  nb13,
+        size_t  nbd2, size_t  nbd3,
+        int64_t r2,   int64_t r3) {
+    int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
+    int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
+
+    if (i13 >= ne13 || i12 >= ne12) {
+        return;
+    }
+
+    int64_t i03 = i13 / r3;
+    int64_t i02 = i12 / r2;
+
+    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
+    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
+    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2 + i13*nbd3;
+}
+
+static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(!ggml_is_transposed(src0));
+    GGML_ASSERT(!ggml_is_transposed(src1));
+
+    GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t ne_dst = ggml_nelements(dst);
+
+    cudaStream_t main_stream = ctx.stream();
+
+    CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
+
+    void * src0_ddq = src0->data;
+    half * src0_f16 = (half *) src0_ddq;
+    float * src1_ddf = (float *) src1->data;
+    float * dst_ddf  = (float *) dst->data;
+
+    // convert src1 to fp16
+    ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
+    if (src1->type != GGML_TYPE_F16) {
+        const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+        const int64_t ne_src1 = ggml_nelements(src1);
+        src1_f16_alloc.alloc(ne_src1);
+        GGML_ASSERT(to_fp16_cuda != nullptr);
+        to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
+    }
+    half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get();
+
+    ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
+    char * dst_t;
+
+    cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
+    cudaDataType_t      cu_data_type    = CUDA_R_16F;
+
+    // dst strides
+    size_t nbd2 = dst->nb[2];
+    size_t nbd3 = dst->nb[3];
+
+    const half  alpha_f16 = 1.0f;
+    const half  beta_f16  = 0.0f;
+
+    const float alpha_f32 = 1.0f;
+    const float beta_f32  = 0.0f;
+
+    const void * alpha = &alpha_f16;
+    const void * beta  = &beta_f16;
+
+    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+        dst_t = (char *) dst_f16.alloc(ne_dst);
+
+        nbd2 /= sizeof(float) / sizeof(half);
+        nbd3 /= sizeof(float) / sizeof(half);
+    } else {
+        dst_t = (char *) dst_ddf;
+
+        cu_compute_type = CUBLAS_COMPUTE_32F;
+        cu_data_type    = CUDA_R_32F;
+
+        alpha = &alpha_f32;
+        beta  = &beta_f32;
+    }
+
+    GGML_ASSERT(ne12 % ne02 == 0);
+    GGML_ASSERT(ne13 % ne03 == 0);
+
+    // broadcast factors
+    const int64_t r2 = ne12/ne02;
+    const int64_t r3 = ne13/ne03;
+
+#if 0
+    // use cublasGemmEx
+    {
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                int i03 = i13 / r3;
+                int i02 = i12 / r2;
+
+                CUBLAS_CHECK(
+                        cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
+                            ne01, ne11, ne10,
+                            alpha, (const char *) src0_as_f16 + i02*src0->nb[2]   + i03*src0->nb[3]  , CUDA_R_16F,   nb01/sizeof(half),
+                                   (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F,   nb11/sizeof(float),
+                            beta,  (      char *)       dst_t + i12*nbd2          + i13*nbd3,          cu_data_type, ne01,
+                            cu_compute_type,
+                            CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+            }
+        }
+    }
+#else
+#ifdef GGML_USE_MUSA
+    GGML_ASSERT(false);
+#else // !GGML_USE_MUSA
+    if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
+        // there is no broadcast and src0, src1 are contiguous across dims 2, 3
+        // use cublasGemmStridedBatchedEx
+        CUBLAS_CHECK(
+        cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
+                ne01, ne11, ne10,
+                alpha, (const char *) src0_f16, CUDA_R_16F,   nb01/nb00, nb02/nb00,  // strideA
+                       (const char *) src1_f16, CUDA_R_16F,   nb11/nb10, nb12/nb10,  // strideB
+                beta,  (      char *)    dst_t, cu_data_type, ne01,       nb2/nb0,   // strideC
+                ne12*ne13,
+                cu_compute_type,
+                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+    } else {
+        // use cublasGemmBatchedEx
+        const int ne23 = ne12*ne13;
+
+        ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
+        ggml_cuda_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23);
+
+        dim3 block_dims(ne13, ne12);
+        k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
+                src0_f16, src1_f16, dst_t,
+                ptrs_src.get(), ptrs_dst.get(),
+                ne12, ne13,
+                ne23,
+                nb02, nb03,
+                src1->type == GGML_TYPE_F16 ? nb12 : nb12/2,
+                src1->type == GGML_TYPE_F16 ? nb13 : nb13/2,
+                nbd2, nbd3,
+                r2, r3);
+        CUDA_CHECK(cudaGetLastError());
+
+        CUBLAS_CHECK(
+        cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
+                ne01, ne11, ne10,
+                alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F,   nb01/nb00,
+                       (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F,   nb11/nb10,
+                beta,  (      void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
+                ne23,
+                cu_compute_type,
+                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+    }
+#endif // GGML_USE_MUSA
+#endif
+
+    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+        to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
+    }
+}
+
+static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
+
+    bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+        && src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
+    bool          use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
+    bool              use_mul_mat_q =  ggml_is_quantized(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+
+    // if mmvq is available it's a better choice than dmmv:
+#ifndef GGML_CUDA_FORCE_DMMV
+    use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
+#endif // GGML_CUDA_FORCE_DMMV
+
+    bool any_gpus_with_slow_fp16 = false;
+
+    if (split) {
+        ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
+        auto & tensor_split = buft_ctx->tensor_split;
+        for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+            // skip devices that are not going to do any work:
+            if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
+                continue;
+            }
+
+            const int cc            = ggml_cuda_info().devices[id].cc;
+            use_mul_mat_q           = use_mul_mat_q           && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+            any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
+        }
+    } else {
+        const int cc            = ggml_cuda_info().devices[ctx.device].cc;
+        use_mul_mat_q           = use_mul_mat_q           && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+        any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
+    }
+
+    // debug helpers
+    //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
+    //printf("      %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
+    //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
+    //printf("      %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
+    //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
+    //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
+
+    if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+        // FP32 precision KQ single-batch for batch size 1 without FlashAttention
+        ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
+    } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
+        // FP32 precision KQV single-batch for batch size 1 without FlashAttention
+        ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
+    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
+               && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+        // KQ + KQV multi-batch without FlashAttention
+        ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
+    } else if (use_dequantize_mul_mat_vec) {
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
+    } else if (use_mul_mat_vec_q) {
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
+    } else if (use_mul_mat_q) {
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
+    } else {
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
+    }
+}
+
+struct mmid_row_mapping {
+    int32_t i1;
+    int32_t i2;
+};
+
+static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
+                                                 int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
+                                                 const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
+                                                 int64_t ne11, int64_t ne10,
+                                                 size_t nb11, size_t nb12) {
+    int32_t iid1 = blockIdx.x;
+    int32_t id = blockIdx.y;
+
+    const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
+
+    if (row_id_i != i02) {
+        return;
+    }
+
+    const int64_t i11 = id % ne11;
+    const int64_t i12 = iid1;
+
+    __shared__ int src1_row;
+    if (threadIdx.x == 0) {
+        src1_row = atomicAdd(cur_src1_row, 1);
+        row_mapping[src1_row] = {id, iid1};
+    }
+    __syncthreads();
+
+    const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
+    float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
+
+    for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
+        src1_row_contiguous[i] = src1_row_original[i];
+    }
+}
+
+static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
+                                                  const mmid_row_mapping * __restrict__ row_mapping,
+                                                  int64_t ne0,
+                                                  size_t nb1, size_t nb2) {
+    int32_t i = blockIdx.x;
+
+    const int32_t i1 = row_mapping[i].i1;
+    const int32_t i2 = row_mapping[i].i2;
+
+    const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
+    float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
+
+    for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
+        dst_row_original[j] = dst_row_contiguous[j];
+    }
+}
+
+static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * ids  = dst->src[2];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
+
+    cudaStream_t stream = ctx.stream();
+
+    const int64_t n_as = ne02;
+    const int64_t n_ids = ids->ne[0];
+
+    std::vector<char> ids_host(ggml_nbytes(ids));
+    const char * ids_dev = (const char *) ids->data;
+    CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
+    CUDA_CHECK(cudaStreamSynchronize(stream));
+
+    ggml_tensor src0_row = *src0;
+    ggml_tensor src1_row = *src1;
+    ggml_tensor dst_row  = *dst;
+
+    char * src0_original = (char *) src0->data;
+    char * src1_original = (char *) src1->data;
+    char * dst_original  = (char *)  dst->data;
+
+    src0_row.ne[2] = 1;
+    src0_row.ne[3] = 1;
+    src0_row.nb[3] = nb02;
+
+    src1_row.ne[1] = 1;
+    src1_row.ne[2] = 1;
+    src1_row.ne[3] = 1;
+    src1_row.nb[2] = nb11;
+    src1_row.nb[3] = nb11;
+
+    dst_row.ne[1] = 1;
+    dst_row.ne[2] = 1;
+    dst_row.ne[3] = 1;
+    dst_row.nb[2] = nb1;
+    dst_row.nb[3] = nb1;
+
+    if (ne12 == 1) {
+        for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+            for (int64_t id = 0; id < n_ids; id++) {
+                const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
+
+                GGML_ASSERT(i02 >= 0 && i02 < n_as);
+
+                const int64_t i11 = id % ne11;
+                const int64_t i12 = iid1;
+
+                const int64_t i1 = id;
+                const int64_t i2 = i12;
+
+                src0_row.data = src0_original + i02*nb02;
+                src1_row.data = src1_original + i11*nb11 + i12*nb12;
+                dst_row.data  =  dst_original + i1*nb1   + i2*nb2;
+
+                ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+            }
+        }
+    } else {
+        ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
+        ggml_cuda_pool_alloc<char>  dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
+
+        src1_row.data = src1_contiguous.get();
+        dst_row.data  =  dst_contiguous.get();
+
+        for (int64_t i02 = 0; i02 < n_as; i02++) {
+            int64_t num_src1_rows = 0;
+
+            for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+                for (int64_t id = 0; id < n_ids; id++) {
+                    const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
+
+                    GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
+
+                    if (row_id_i != i02) {
+                        continue;
+                    }
+
+                    num_src1_rows++;
+                }
+            }
+
+            if (num_src1_rows == 0) {
+                continue;
+            }
+
+            ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
+            ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
+            CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
+
+            {
+                dim3 block_dims(std::min((unsigned int)ne10, 768u));
+                dim3 grid_dims(ids->ne[1], n_ids);
+                k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
+                        src1_original, src1_contiguous.get(),
+                        dev_cur_src1_row.get(), dev_row_mapping.get(),
+                        ids_dev, i02, ids->nb[1], ids->nb[0],
+                        ne11, ne10,
+                        nb11, nb12);
+                CUDA_CHECK(cudaGetLastError());
+            }
+
+            src0_row.data = src0_original + i02*nb02;
+
+            GGML_ASSERT(nb11 == sizeof(float)*ne10);
+            GGML_ASSERT(nb1 == sizeof(float)*ne0);
+
+            src1_row.ne[1] = num_src1_rows;
+            src1_row.nb[1] = nb11;
+            src1_row.nb[2] = num_src1_rows*nb11;
+            src1_row.nb[3] = num_src1_rows*nb11;
+
+            dst_row.ne[1] = num_src1_rows;
+            dst_row.nb[1] = nb1;
+            dst_row.nb[2] = num_src1_rows*nb1;
+            dst_row.nb[3] = num_src1_rows*nb1;
+
+            ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+
+            {
+                dim3 block_dims(std::min((unsigned int)ne0, 768u));
+                dim3 grid_dims(num_src1_rows);
+                k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
+                        dst_original, dst_contiguous.get(),
+                        dev_row_mapping.get(),
+                        ne0,
+                        nb1, nb2);
+                CUDA_CHECK(cudaGetLastError());
+            }
+        }
+    }
+}
+
+static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
+    // why is this here instead of mul_mat?
+    if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) {
+        ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
+    }
+
+    switch (dst->op) {
+        case GGML_OP_REPEAT:
+            ggml_cuda_op_repeat(ctx, dst);
+            break;
+        case GGML_OP_GET_ROWS:
+            ggml_cuda_op_get_rows(ctx, dst);
+            break;
+        case GGML_OP_DUP:
+            ggml_cuda_dup(ctx, dst);
+            break;
+        case GGML_OP_CPY:
+            ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]);
+            break;
+        case GGML_OP_CONT:
+            ggml_cuda_dup(ctx, dst);
+            break;
+        case GGML_OP_ADD:
+            ggml_cuda_op_add(ctx, dst);
+            break;
+        case GGML_OP_SUB:
+            ggml_cuda_op_sub(ctx, dst);
+            break;
+        case GGML_OP_ACC:
+            ggml_cuda_op_acc(ctx, dst);
+            break;
+        case GGML_OP_MUL:
+            ggml_cuda_op_mul(ctx, dst);
+            break;
+        case GGML_OP_DIV:
+            ggml_cuda_op_div(ctx, dst);
+            break;
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(dst)) {
+                case GGML_UNARY_OP_GELU:
+                    ggml_cuda_op_gelu(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_SILU:
+                    ggml_cuda_op_silu(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_GELU_QUICK:
+                    ggml_cuda_op_gelu_quick(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_TANH:
+                    ggml_cuda_op_tanh(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_RELU:
+                    ggml_cuda_op_relu(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_SIGMOID:
+                    ggml_cuda_op_sigmoid(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_HARDSIGMOID:
+                    ggml_cuda_op_hardsigmoid(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_HARDSWISH:
+                    ggml_cuda_op_hardswish(ctx, dst);
+                    break;
+                default:
+                    return false;
+            }
+            break;
+        case GGML_OP_NORM:
+            ggml_cuda_op_norm(ctx, dst);
+            break;
+        case GGML_OP_GROUP_NORM:
+            ggml_cuda_op_group_norm(ctx, dst);
+            break;
+        case GGML_OP_CONCAT:
+            ggml_cuda_op_concat(ctx, dst);
+            break;
+        case GGML_OP_UPSCALE:
+            ggml_cuda_op_upscale(ctx, dst);
+            break;
+        case GGML_OP_PAD:
+            ggml_cuda_op_pad(ctx, dst);
+            break;
+        case GGML_OP_ARANGE:
+            ggml_cuda_op_arange(ctx, dst);
+            break;
+        case GGML_OP_TIMESTEP_EMBEDDING:
+            ggml_cuda_op_timestep_embedding(ctx, dst);
+            break;
+        case GGML_OP_LEAKY_RELU:
+            ggml_cuda_op_leaky_relu(ctx, dst);
+            break;
+        case GGML_OP_RMS_NORM:
+            ggml_cuda_op_rms_norm(ctx, dst);
+            break;
+        case GGML_OP_MUL_MAT:
+            if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
+                GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
+                return false;
+            } else {
+                ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
+            }
+            break;
+        case GGML_OP_MUL_MAT_ID:
+            ggml_cuda_mul_mat_id(ctx, dst);
+            break;
+        case GGML_OP_SCALE:
+            ggml_cuda_op_scale(ctx, dst);
+            break;
+        case GGML_OP_SQR:
+            ggml_cuda_op_sqr(ctx, dst);
+            break;
+        case GGML_OP_SQRT:
+            ggml_cuda_op_sqrt(ctx, dst);
+            break;
+        case GGML_OP_SIN:
+            ggml_cuda_op_sin(ctx, dst);
+            break;
+        case GGML_OP_COS:
+            ggml_cuda_op_cos(ctx, dst);
+            break;
+        case GGML_OP_CLAMP:
+            ggml_cuda_op_clamp(ctx, dst);
+            break;
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+                break;
+        case GGML_OP_DIAG_MASK_INF:
+            ggml_cuda_op_diag_mask_inf(ctx, dst);
+            break;
+        case GGML_OP_SOFT_MAX:
+            ggml_cuda_op_soft_max(ctx, dst);
+            break;
+        case GGML_OP_ROPE:
+            ggml_cuda_op_rope(ctx, dst);
+            break;
+        case GGML_OP_IM2COL:
+            ggml_cuda_op_im2col(ctx, dst);
+            break;
+        case GGML_OP_CONV_TRANSPOSE_1D:
+            ggml_cuda_op_conv_transpose_1d(ctx,dst);
+            break;
+        case GGML_OP_POOL_2D:
+            ggml_cuda_op_pool2d(ctx, dst);
+            break;
+        case GGML_OP_SUM_ROWS:
+            ggml_cuda_op_sum_rows(ctx, dst);
+            break;
+        case GGML_OP_ARGSORT:
+            ggml_cuda_op_argsort(ctx, dst);
+            break;
+#if !defined(GGML_DISABLE_FLASH_ATTN)
+        case GGML_OP_FLASH_ATTN_EXT:
+            ggml_cuda_flash_attn_ext(ctx, dst);
+            break;
+#endif
+        case GGML_OP_CROSS_ENTROPY_LOSS:
+            ggml_cuda_cross_entropy_loss(ctx, dst);
+            break;
+        default:
+            return false;
+    }
+
+    cudaError_t err = cudaGetLastError();
+    if (err != cudaSuccess) {
+        GGML_CUDA_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst));
+        CUDA_CHECK(err);
+    }
+
+    return true;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// backend
+
+GGML_CALL static const char * ggml_backend_cuda_name(ggml_backend_t backend) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    return cuda_ctx->name.c_str();
+}
+
+GGML_CALL static void ggml_backend_cuda_free(ggml_backend_t backend) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    delete cuda_ctx;
+    delete backend;
+}
+
+GGML_CALL static ggml_backend_buffer_type_t ggml_backend_cuda_get_default_buffer_type(ggml_backend_t backend) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    return ggml_backend_cuda_buffer_type(cuda_ctx->device);
+}
+
+GGML_CALL static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+    GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
+
+    CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream()));
+}
+
+GGML_CALL static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+    GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
+
+    CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream()));
+}
+
+GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
+    ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
+    ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
+
+    if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) {
+        return false;
+    }
+
+    if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) {
+        return false;
+    }
+
+    // device -> device copy
+    ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context;
+    ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context;
+
+    ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context;
+    ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context;
+
+    if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
+#ifndef NDEBUG
+        GGML_CUDA_LOG_WARN("%s: backend and buffer devices do not match\n", __func__);
+#endif
+        return false;
+    }
+
+    if (backend_src != backend_dst) {
+        // copy on src stream
+        if (cuda_ctx_src->device == cuda_ctx_dst->device) {
+            CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
+        } else {
+#ifdef GGML_CUDA_NO_PEER_COPY
+            return false;
+#else
+            CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream()));
+#endif
+        }
+
+        // record event on src stream after the copy
+        if (!cuda_ctx_src->copy_event) {
+            ggml_cuda_set_device(cuda_ctx_src->device);
+            CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));
+        }
+
+        CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream()));
+
+        // wait on dst stream for the copy to complete
+        CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx_dst->stream(), cuda_ctx_src->copy_event, 0));
+    } else {
+        // src and dst are on the same backend
+        CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
+    }
+    return true;
+}
+
+GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    CUDA_CHECK(cudaStreamSynchronize(cuda_ctx->stream()));
+
+    GGML_UNUSED(backend);
+}
+
+static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
+    graph_node_properties->node_address = node->data;
+    graph_node_properties->node_op = node->op;
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        graph_node_properties->ne[i] = node->ne[i];
+        graph_node_properties->nb[i] = node->nb[i];
+    }
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
+    }
+}
+
+static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
+    if (node->data != graph_node_properties->node_address &&
+          node->op != GGML_OP_CPY &&
+          node->op != GGML_OP_VIEW) {
+        return false;
+    }
+
+    if (node->op != graph_node_properties->node_op) {
+        return false;
+    }
+
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        if (node->ne[i] != graph_node_properties->ne[i]) {
+            return false;
+        }
+        if (node->nb[i] != graph_node_properties->nb[i]) {
+            return false;
+        }
+    }
+
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        if (node->src[i] &&
+            node->src[i]->data != graph_node_properties->src_address[i] &&
+            node->op != GGML_OP_CPY &&
+            node->op != GGML_OP_VIEW
+        ) {
+            return false;
+        }
+    }
+    return true;
+}
+
+GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    ggml_cuda_set_device(cuda_ctx->device);
+
+#ifdef USE_CUDA_GRAPH
+    static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
+
+    // Objects required for CUDA Graph
+    if (cuda_ctx->cuda_graph == nullptr) {
+        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
+    }
+
+    bool use_cuda_graph = true;
+    bool cuda_graph_update_required = false;
+    // vector of pointers to CUDA cpy kernels, which are required to identify
+    // kernel parameters which need updated in the graph for each token
+    std::vector<void *> ggml_cuda_cpy_fn_ptrs;
+
+    if (cuda_ctx->cuda_graph->graph == nullptr) {
+        if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
+            cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
+#ifndef NDEBUG
+            GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
+#endif
+        }
+    }
+
+    // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
+    // or previous graph capture failure.
+    // Also disable for multi-gpu for now. TO DO investigate
+    if (disable_cuda_graphs_due_to_env
+        || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
+        || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
+        || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
+        use_cuda_graph = false;
+    }
+
+    if (use_cuda_graph) {
+        if (cuda_ctx->cuda_graph->instance == nullptr) {
+            cuda_graph_update_required = true;
+        }
+
+        // Check if the graph size has changed
+        if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
+            cuda_graph_update_required = true;
+            cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
+        }
+
+        // Loop over nodes in GGML graph to determine if CUDA graph update is required
+        // and store properties to allow this comparison for the next token
+        for (int i = 0; i < cgraph->n_nodes; i++) {
+            bool has_matching_properties = true;
+            if (!cuda_graph_update_required) {
+                has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+            }
+            if (!has_matching_properties) {
+                cuda_graph_update_required = true;
+            }
+            set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+        }
+
+        // Loop over nodes in GGML graph to obtain info needed for CUDA graph
+        cuda_ctx->cuda_graph->updated_kernel_arg.clear();
+        for (int i = 0; i < cgraph->n_nodes; i++) {
+            ggml_tensor * node = cgraph->nodes[i];
+
+            if (node->src[0] && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) {
+                use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
+#ifndef NDEBUG
+                GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to split buffer\n", __func__);
+#endif
+            }
+
+            if (node->op == GGML_OP_MUL_MAT_ID) {
+                use_cuda_graph = false; // This node type is not supported by CUDA graph capture
+#ifndef NDEBUG
+                GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
+#endif
+            }
+
+            if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
+                // disable CUDA graphs for batch size > 1 for now.
+                // Changes in batch size or context size can cause changes to the grid size of some kernels.
+                use_cuda_graph = false;
+#ifndef NDEBUG
+                GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
+#endif
+            }
+
+            if (node->op == GGML_OP_CPY) {
+                // store the copy op parameter which changes with each token.
+                cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
+                // store a pointer to each copy op CUDA kernel to identify it later
+                void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
+                if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
+                    ggml_cuda_cpy_fn_ptrs.push_back(ptr);
+                }
+            }
+
+            if (!use_cuda_graph) {
+                break;
+            }
+        }
+
+        // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
+        if (use_cuda_graph && cuda_graph_update_required) {
+            cuda_ctx->cuda_graph->number_consecutive_updates++;
+        } else {
+            cuda_ctx->cuda_graph->number_consecutive_updates = 0;
+        }
+
+        if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
+            cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
+#ifndef NDEBUG
+            GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
+#endif
+        }
+    }
+
+    if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
+        CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
+    }
+
+#else
+    bool use_cuda_graph = false;
+    bool cuda_graph_update_required = false;
+#endif // USE_CUDA_GRAPH
+
+    bool graph_evaluated_or_captured = false;
+
+    while (!graph_evaluated_or_captured) {
+        // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
+        // With the use of CUDA graphs, the execution will be performed by the graph launch.
+        if (!use_cuda_graph || cuda_graph_update_required) {
+            for (int i = 0; i < cgraph->n_nodes; i++) {
+                ggml_tensor * node = cgraph->nodes[i];
+
+                if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+                    continue;
+                }
+
+#ifndef NDEBUG
+                assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
+                for (int j = 0; j < GGML_MAX_SRC; j++) {
+                    if (node->src[j] != nullptr) {
+                        assert(node->src[j]->buffer);
+                        assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
+                    }
+                }
+#endif
+
+                bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
+                if (!ok) {
+                    GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+                }
+                GGML_ASSERT(ok);
+            }
+        }
+
+#ifdef USE_CUDA_GRAPH
+        if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
+            if (cuda_ctx->cuda_graph->graph != nullptr) {
+                CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
+                cuda_ctx->cuda_graph->graph = nullptr;
+            }
+            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
+
+#if 0
+            if (disable_cuda_graphs_due_to_failed_capture) {
+                use_cuda_graph = false;
+                cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
+#ifndef NDEBUG
+                GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
+#endif
+            } else {
+                graph_evaluated_or_captured = true; // CUDA graph has been captured
+            }
+#endif
+            graph_evaluated_or_captured = true; // CUDA graph has been captured
+        } else {
+            graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
+        }
+    }
+
+    if (use_cuda_graph) {
+        if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
+            CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+        }
+
+        // Perform update to graph (if required for this token), and change copy parameter (required for every token)
+
+        if (cuda_graph_update_required) {
+            // Extract nodes from graph
+            // First call with null argument gets number of nodes in graph
+            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
+            // Subsequent call with non-null argument gets nodes
+            cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
+            cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
+            if (cuda_ctx->cuda_graph->num_nodes > 0) {
+                CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
+
+                // Loop over nodes, and extract kernel parameters from each node
+                for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+                    cudaGraphNodeType node_type;
+                    CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
+                    if (node_type == cudaGraphNodeTypeKernel) {
+                        cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
+                        if (stat == cudaErrorInvalidDeviceFunction) {
+                            // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
+                            // We don't need to update blas nodes, so clear error and move on.
+                            cudaGetLastError();
+                        } else {
+                            GGML_ASSERT(stat == cudaSuccess);
+                        }
+                    }
+                }
+            }
+        }
+
+        // One of the arguments to the copy kernel is updated for each token, hence we need to
+        // replace that argument with the updated value in the CUDA graph
+        if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
+            int k = 0;
+            for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+                if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
+                    char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
+                    cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
+                    CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
+                }
+            }
+        }
+
+        // Update graph executable
+        cudaGraphExecUpdateResultInfo result_info;
+        cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
+        if (stat == cudaErrorGraphExecUpdateFailure) {
+#ifndef NDEBUG
+            GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__);
+#endif
+            // The pre-existing graph exec cannot be updated due to violated constraints
+            // so instead clear error and re-instantiate
+            cudaGetLastError();
+            CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
+            cuda_ctx->cuda_graph->instance = nullptr;
+            CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+        } else {
+            GGML_ASSERT(stat == cudaSuccess);
+        }
+        // Launch graph
+        CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
+#else
+        graph_evaluated_or_captured = true;
+#endif // USE_CUDA_GRAPH
+    }
+
+    return GGML_STATUS_SUCCESS;
+}
+
+GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
+    switch (op->op) {
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_SIGMOID:
+                case GGML_UNARY_OP_HARDSIGMOID:
+                case GGML_UNARY_OP_HARDSWISH:
+                case GGML_UNARY_OP_GELU_QUICK:
+                case GGML_UNARY_OP_TANH:
+                    return ggml_is_contiguous(op->src[0]);
+                default:
+                    return false;
+            }
+            break;
+        case GGML_OP_MUL_MAT:
+        case GGML_OP_MUL_MAT_ID:
+            {
+                struct ggml_tensor * a = op->src[0];
+                struct ggml_tensor * b = op->src[1];
+                if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
+                    return false;
+                }
+                if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
+                    return false;
+                }
+                switch (a->type) {
+                    case GGML_TYPE_F32:
+                    case GGML_TYPE_F16:
+                    case GGML_TYPE_Q4_0:
+                    case GGML_TYPE_Q4_1:
+                    case GGML_TYPE_Q5_0:
+                    case GGML_TYPE_Q5_1:
+                    case GGML_TYPE_Q8_0:
+                    case GGML_TYPE_Q2_K:
+                    case GGML_TYPE_Q3_K:
+                    case GGML_TYPE_Q4_K:
+                    case GGML_TYPE_Q5_K:
+                    case GGML_TYPE_Q6_K:
+                    case GGML_TYPE_Q8_K:
+                    case GGML_TYPE_IQ1_M:
+                    case GGML_TYPE_IQ1_S:
+                    case GGML_TYPE_IQ2_S:
+                    case GGML_TYPE_IQ2_XS:
+                    case GGML_TYPE_IQ2_XXS:
+                    case GGML_TYPE_IQ3_S:
+                    case GGML_TYPE_IQ3_XXS:
+                    case GGML_TYPE_IQ4_NL:
+                    case GGML_TYPE_IQ4_XS:
+                        return true;
+                    default:
+                        return false;
+                }
+            } break;
+        case GGML_OP_GET_ROWS:
+            {
+                switch (op->src[0]->type) {
+                    case GGML_TYPE_F16:
+                    case GGML_TYPE_F32:
+                    case GGML_TYPE_Q4_0:
+                    case GGML_TYPE_Q4_1:
+                    case GGML_TYPE_Q5_0:
+                    case GGML_TYPE_Q5_1:
+                    case GGML_TYPE_Q8_0:
+                        return true;
+                    default:
+                        return false;
+                }
+            } break;
+        case GGML_OP_CPY:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                ggml_type src1_type = op->src[1]->type;
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
+                return false;
+            } break;
+        case GGML_OP_DUP:
+        case GGML_OP_REPEAT:
+        case GGML_OP_CONCAT:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+            } break;
+        case GGML_OP_CONV_TRANSPOSE_1D:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                ggml_type src1_type = op->src[1]->type;
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
+                return false;
+            } break;
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_NORM:
+        case GGML_OP_ADD:
+        case GGML_OP_SUB:
+        case GGML_OP_MUL:
+        case GGML_OP_DIV:
+        case GGML_OP_RMS_NORM:
+        case GGML_OP_SCALE:
+        case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
+        case GGML_OP_CLAMP:
+        case GGML_OP_CONT:
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_SOFT_MAX:
+            return true;
+        case GGML_OP_ROPE:
+            return ggml_is_contiguous(op->src[0]);
+        case GGML_OP_IM2COL:
+        case GGML_OP_POOL_2D:
+        case GGML_OP_SUM_ROWS:
+        case GGML_OP_ARGSORT:
+        case GGML_OP_ACC:
+        case GGML_OP_GROUP_NORM:
+        case GGML_OP_UPSCALE:
+        case GGML_OP_PAD:
+        case GGML_OP_ARANGE:
+        case GGML_OP_TIMESTEP_EMBEDDING:
+        case GGML_OP_LEAKY_RELU:
+            return true;
+        case GGML_OP_FLASH_ATTN_EXT:
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+            return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
+#else
+            if (op->src[0]->ne[0] == 128) {
+                return true;
+            }
+            if (op->src[0]->ne[0] ==  64 && op->src[1]->type == GGML_TYPE_F16) {
+                return true;
+            }
+            return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
+                op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
+        case GGML_OP_CROSS_ENTROPY_LOSS:
+            return true;
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+        default:
+            return false;
+    }
+
+    GGML_UNUSED(backend);
+}
+
+GGML_CALL static bool ggml_backend_cuda_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+    if (ggml_backend_buft_is_cuda_split(buft)) {
+        return true;
+    }
+
+    if (ggml_backend_buft_is_cuda(buft)) {
+        ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+        ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
+        return buft_ctx->device == cuda_ctx->device;
+    }
+
+    return false;
+}
+
+GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
+    const int min_batch_size = 32;
+
+    return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
+           (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
+
+    GGML_UNUSED(backend);
+}
+
+static ggml_backend_event_t ggml_backend_cuda_event_new(ggml_backend_t backend) {
+#ifdef GGML_CUDA_NO_PEER_COPY
+    return nullptr;
+#else
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    ggml_cuda_set_device(cuda_ctx->device);
+
+    cudaEvent_t event;
+    CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
+
+    return new ggml_backend_event {
+        /* .backend = */ backend,
+        /* .context = */ event,
+    };
+#endif
+}
+
+static void ggml_backend_cuda_event_free(ggml_backend_event_t event) {
+    CUDA_CHECK(cudaEventDestroy((cudaEvent_t)event->context));
+
+    delete event;
+}
+
+static void ggml_backend_cuda_event_record(ggml_backend_event_t event) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)event->backend->context;
+
+    CUDA_CHECK(cudaEventRecord((cudaEvent_t)event->context, cuda_ctx->stream()));
+}
+
+static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    if (ggml_backend_is_cuda(event->backend)) {
+        CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), (cudaEvent_t)event->context, 0));
+    } else {
+#if 0
+        // untested
+        auto wait_fn = [](void * user_data) {
+            ggml_backend_event_t event = (ggml_backend_event_t)user_data;
+            ggml_backend_event_synchronize(event);
+        };
+
+        CUDA_CHECK(cudaLaunchHostFunc(cuda_ctx->stream(), wait_fn, event));
+#endif
+        GGML_ABORT("fatal error");
+    }
+}
+
+static void ggml_backend_cuda_event_synchronize(ggml_backend_event_t event) {
+    CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
+}
+
+static ggml_backend_i ggml_backend_cuda_interface = {
+    /* .get_name                = */ ggml_backend_cuda_name,
+    /* .free                    = */ ggml_backend_cuda_free,
+    /* .get_default_buffer_type = */ ggml_backend_cuda_get_default_buffer_type,
+    /* .set_tensor_async        = */ ggml_backend_cuda_set_tensor_async,
+    /* .get_tensor_async        = */ ggml_backend_cuda_get_tensor_async,
+    /* .cpy_tensor_async        = */ ggml_backend_cuda_cpy_tensor_async,
+    /* .synchronize             = */ ggml_backend_cuda_synchronize,
+    /* .graph_plan_create       = */ NULL,
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_cuda_graph_compute,
+    /* .supports_op             = */ ggml_backend_cuda_supports_op,
+    /* .supports_buft           = */ ggml_backend_cuda_supports_buft,
+    /* .offload_op              = */ ggml_backend_cuda_offload_op,
+    /* .event_new               = */ ggml_backend_cuda_event_new,
+    /* .event_free              = */ ggml_backend_cuda_event_free,
+    /* .event_record            = */ ggml_backend_cuda_event_record,
+    /* .event_wait              = */ ggml_backend_cuda_event_wait,
+    /* .event_synchronize       = */ ggml_backend_cuda_event_synchronize,
+};
+
+static ggml_guid_t ggml_backend_cuda_guid() {
+    static ggml_guid guid = { 0x2c, 0xdd, 0xe8, 0x1c, 0x65, 0xb3, 0x65, 0x73, 0x6a, 0x12, 0x88, 0x61, 0x1c, 0xc9, 0xdc, 0x25 };
+    return &guid;
+}
+
+GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) {
+    if (device < 0 || device >= ggml_backend_cuda_get_device_count()) {
+        GGML_CUDA_LOG_ERROR("%s: invalid device %d\n", __func__, device);
+        return nullptr;
+    }
+
+    ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context(device);
+    if (ctx == nullptr) {
+        GGML_CUDA_LOG_ERROR("%s: failed to allocate context\n", __func__);
+        return nullptr;
+    }
+
+    ggml_backend_t cuda_backend = new ggml_backend {
+        /* .guid      = */ ggml_backend_cuda_guid(),
+        /* .interface = */ ggml_backend_cuda_interface,
+        /* .context   = */ ctx
+    };
+
+    return cuda_backend;
+}
+
+GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend) {
+    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cuda_guid());
+}
+
+GGML_CALL int ggml_backend_cuda_get_device_count() {
+    return ggml_cuda_info().device_count;
+}
+
+GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size) {
+    cudaDeviceProp prop;
+    CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
+    snprintf(description, description_size, "%s", prop.name);
+}
+
+GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total) {
+    ggml_cuda_set_device(device);
+
+    CUDA_CHECK(cudaMemGetInfo(free, total));
+}
+
+GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
+    if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
+        return false;
+    }
+
+#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
+    cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
+    if (err != cudaSuccess) {
+        // clear the error
+        cudaGetLastError();
+
+        GGML_CUDA_LOG_WARN("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
+                           size / 1024.0 / 1024.0, cudaGetErrorString(err));
+        return false;
+    }
+    return true;
+#else
+    return false;
+#endif
+}
+
+GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
+    if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
+        return;
+    }
+
+    cudaError_t err = cudaHostUnregister(buffer);
+    if (err != cudaSuccess) {
+        // clear the error
+        cudaGetLastError();
+    }
+}
+
+// backend registry
+GGML_CALL static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * user_data) {
+    ggml_backend_t cuda_backend = ggml_backend_cuda_init((int) (intptr_t) user_data);
+    return cuda_backend;
+
+    GGML_UNUSED(params);
+}
+
+GGML_CALL int ggml_backend_cuda_reg_devices() {
+    int device_count = ggml_backend_cuda_get_device_count();
+    //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
+    for (int i = 0; i < device_count; i++) {
+        char name[128];
+        snprintf(name, sizeof(name), "%s%d", GGML_CUDA_NAME, i);
+        ggml_backend_register(name, ggml_backend_reg_cuda_init, ggml_backend_cuda_buffer_type(i), (void *) (intptr_t) i);
+    }
+    return device_count;
+}

+ 75 - 0
llama/ggml-cuda.h

@@ -0,0 +1,75 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#ifdef GGML_USE_HIPBLAS
+#define GGML_CUDA_NAME "ROCm"
+#define GGML_CUBLAS_NAME "hipBLAS"
+#elif defined(GGML_USE_MUSA)
+#define GGML_CUDA_NAME "MUSA"
+#define GGML_CUBLAS_NAME "muBLAS"
+#else
+#define GGML_CUDA_NAME "CUDA"
+#define GGML_CUBLAS_NAME "cuBLAS"
+#endif
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+#define GGML_CUDA_MAX_DEVICES       16
+
+// backend API
+GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device);
+
+GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend);
+
+// device buffer
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
+
+// split tensor buffer that splits matrices by rows across multiple devices
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split);
+
+// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
+
+GGML_API GGML_CALL int ggml_backend_cuda_reg_devices();
+
+GGML_API GGML_CALL int  ggml_backend_cuda_get_device_count(void);
+GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
+GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
+
+GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
+GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);
+
+GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data);
+#ifdef  __cplusplus
+}
+#endif

+ 73 - 0
llama/ggml-cuda/acc.cu

@@ -0,0 +1,73 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "acc.cuh"
+
+static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
+    const int ne10, const int ne11, const int ne12,
+    const int nb1, const int nb2, int offset) {
+    const int i = blockDim.x * blockIdx.x + threadIdx.x;
+    if (i >= ne) {
+        return;
+    }
+    int src1_idx = i - offset;
+    int oz = src1_idx / nb2;
+    int oy = (src1_idx - (oz * nb2)) / nb1;
+    int ox = src1_idx % nb1;
+    if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
+        dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
+    } else {
+        dst[i] = x[i];
+    }
+}
+
+static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
+    const int ne10, const int ne11, const int ne12,
+    const int nb1, const int nb2, const int offset, cudaStream_t stream) {
+    int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
+    acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
+}
+
+void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const float * src0_d = (const float *)src0->data;
+    const float * src1_d = (const float *)src1->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
+
+    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
+    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
+    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
+    int offset = dst->op_params[3] / 4; // offset in bytes
+
+    acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, stream);
+}

+ 31 - 0
llama/ggml-cuda/acc.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_ACC_BLOCK_SIZE 256
+
+void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 60 - 0
llama/ggml-cuda/arange.cu

@@ -0,0 +1,60 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arange.cuh"
+
+static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {
+    // blockIDx.x: idx of ne0 / BLOCK_SIZE
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+    dst[nidx] = start + step * nidx;
+}
+
+static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE;
+    arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0, stream>>>(dst, ne0, start,  step);
+}
+
+void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    float start;
+    float stop;
+    float step;
+    memcpy(&start, (float *)dst->op_params + 0, sizeof(float));
+    memcpy(&stop,  (float *)dst->op_params + 1, sizeof(float));
+    memcpy(&step,  (float *)dst->op_params + 2, sizeof(float));
+
+    int64_t steps = (int64_t)ceil((stop - start) / step);
+    GGML_ASSERT(ggml_nelements(dst) == steps);
+
+    arange_f32_cuda(dst_d, dst->ne[0], start, step, stream);
+}

+ 31 - 0
llama/ggml-cuda/arange.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_ARANGE_BLOCK_SIZE 256
+
+void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 130 - 0
llama/ggml-cuda/argsort.cu

@@ -0,0 +1,130 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "argsort.cuh"
+
+template<typename T>
+static inline __device__ void ggml_cuda_swap(T & a, T & b) {
+    T tmp = a;
+    a = b;
+    b = tmp;
+}
+
+template<ggml_sort_order order>
+static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
+    // bitonic sort
+    int col = threadIdx.x;
+    int row = blockIdx.y;
+
+    if (col >= ncols_pad) {
+        return;
+    }
+
+    const float * x_row = x + row * ncols;
+    extern __shared__ int dst_row[];
+
+    // initialize indices
+    dst_row[col] = col;
+
+    __syncthreads();
+
+    for (int k = 2; k <= ncols_pad; k *= 2) {
+        for (int j = k / 2; j > 0; j /= 2) {
+            int ixj = col ^ j;
+            if (ixj > col) {
+                if ((col & k) == 0) {
+                    if (dst_row[col] >= ncols ||
+                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+                    ) {
+                        ggml_cuda_swap(dst_row[col], dst_row[ixj]);
+                    }
+                } else {
+                    if (dst_row[ixj] >= ncols ||
+                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+                    ) {
+                        ggml_cuda_swap(dst_row[col], dst_row[ixj]);
+                    }
+                }
+            }
+            __syncthreads();
+        }
+    }
+
+    // copy the result to dst without the padding
+    if (col < ncols) {
+        dst[row * ncols + col] = dst_row[col];
+    }
+}
+
+static int next_power_of_2(int x) {
+    int n = 1;
+    while (n < x) {
+        n *= 2;
+    }
+    return n;
+}
+
+static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
+    // bitonic sort requires ncols to be power of 2
+    const int ncols_pad = next_power_of_2(ncols);
+
+    const dim3 block_dims(ncols_pad, 1, 1);
+    const dim3 block_nums(1, nrows, 1);
+    const size_t shared_mem = ncols_pad * sizeof(int);
+
+    // FIXME: this limit could be raised by ~2-4x on Ampere or newer
+    GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
+
+    if (order == GGML_SORT_ORDER_ASC) {
+        k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
+    } else if (order == GGML_SORT_ORDER_DESC) {
+        k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
+    } else {
+        GGML_ABORT("fatal error");
+    }
+}
+
+void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+
+    argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
+}

+ 29 - 0
llama/ggml-cuda/argsort.cuh

@@ -0,0 +1,29 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 314 - 0
llama/ggml-cuda/binbcast.cu

@@ -0,0 +1,314 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "binbcast.cuh"
+
+static __device__ __forceinline__ float op_repeat(const float a, const float b) {
+    return b;
+    GGML_UNUSED(a);
+}
+
+static __device__ __forceinline__ float op_add(const float a, const float b) {
+    return a + b;
+}
+
+static __device__ __forceinline__ float op_sub(const float a, const float b) {
+    return a - b;
+}
+
+static __device__ __forceinline__ float op_mul(const float a, const float b) {
+    return a * b;
+}
+
+static __device__ __forceinline__ float op_div(const float a, const float b) {
+    return a / b;
+}
+
+template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
+static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
+        int ne0, int ne1, int ne2, int ne3,
+        int ne10, int ne11, int ne12, int ne13,
+        /*int s0, */ int s1,  int s2,  int s3,
+        /*int s00,*/ int s01, int s02, int s03,
+        /*int s10,*/ int s11, int s12, int s13) {
+    const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
+    const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
+    const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
+    const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
+
+    if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+        return;
+    }
+
+    const int i11 = i1 % ne11;
+    const int i12 = i2 % ne12;
+    const int i13 = i3 % ne13;
+
+    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;
+    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;
+
+    const src0_t * src0_row = src0 + i_src0;
+    const src1_t * src1_row = src1 + i_src1;
+    dst_t * dst_row = dst + i_dst;
+
+    for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
+        const int i10 = i0 % ne10;
+        dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+    }
+}
+
+template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
+static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
+        int ne0, int ne1, int ne2, int ne3,
+        int ne10, int ne11, int ne12, int ne13,
+        /*int s0, */ int s1,  int s2,  int s3,
+        /*int s00,*/ int s01, int s02, int s03,
+        /*int s10,*/ int s11, int s12, int s13) {
+
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    const int i3 = i/(ne2*ne1*ne0);
+    const int i2 = (i/(ne1*ne0)) % ne2;
+    const int i1 = (i/ne0) % ne1;
+    const int i0 = i % ne0;
+
+    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+        return;
+    }
+
+    const int i11 = i1 % ne11;
+    const int i12 = i2 % ne12;
+    const int i13 = i3 % ne13;
+
+    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;
+    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;
+
+    const src0_t * src0_row = src0 + i_src0;
+    const src1_t * src1_row = src1 + i_src1;
+    dst_t * dst_row = dst + i_dst;
+
+    const int i10 = i0 % ne10;
+    dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+}
+
+template<float (*bin_op)(const float, const float)>
+struct bin_bcast_cuda {
+    template<typename src0_t, typename src1_t, typename dst_t>
+    void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
+            const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
+            cudaStream_t stream) {
+
+        GGML_TENSOR_BINARY_OP_LOCALS
+
+        int nr0 = ne10/ne0;
+        int nr1 = ne11/ne1;
+        int nr2 = ne12/ne2;
+        int nr3 = ne13/ne3;
+
+        int nr[4] = { nr0, nr1, nr2, nr3 };
+
+        // collapse dimensions until first broadcast dimension
+        int64_t cne[] = {ne0, ne1, ne2, ne3};
+        int64_t cne0[] = {ne00, ne01, ne02, ne03};
+        int64_t cne1[] = {ne10, ne11, ne12, ne13};
+
+        size_t cnb[] = {nb0, nb1, nb2, nb3};
+        size_t cnb0[] = {nb00, nb01, nb02, nb03};
+        size_t cnb1[] = {nb10, nb11, nb12, nb13};
+
+        auto collapse = [](int64_t cne[]) {
+            cne[0] *= cne[1];
+            cne[1] = cne[2];
+            cne[2] = cne[3];
+            cne[3] = 1;
+        };
+
+        auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
+            cnb[1] *= cne[1];
+            cnb[2] *= cne[2];
+            cnb[3] *= cne[3];
+        };
+
+        if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+            for (int i = 0; i < 4; i++) {
+                if (nr[i] != 1) {
+                    break;
+                }
+                if (i > 0) {
+                    collapse_nb(cnb, cne);
+                    collapse_nb(cnb0, cne0);
+                    collapse_nb(cnb1, cne1);
+                    collapse(cne);
+                    collapse(cne0);
+                    collapse(cne1);
+                }
+            }
+        }
+
+        {
+            int64_t ne0 = cne[0];
+            int64_t ne1 = cne[1];
+            int64_t ne2 = cne[2];
+            int64_t ne3 = cne[3];
+
+            //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
+            //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
+            //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
+            //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
+
+            int64_t ne10 = cne1[0];
+            int64_t ne11 = cne1[1];
+            int64_t ne12 = cne1[2];
+            int64_t ne13 = cne1[3];
+
+            size_t nb0 = cnb[0];
+            size_t nb1 = cnb[1];
+            size_t nb2 = cnb[2];
+            size_t nb3 = cnb[3];
+
+            size_t nb00 = cnb0[0];
+            size_t nb01 = cnb0[1];
+            size_t nb02 = cnb0[2];
+            size_t nb03 = cnb0[3];
+
+            size_t nb10 = cnb1[0];
+            size_t nb11 = cnb1[1];
+            size_t nb12 = cnb1[2];
+            size_t nb13 = cnb1[3];
+
+            size_t s0 = nb0 / sizeof(dst_t);
+            size_t s1 = nb1 / sizeof(dst_t);
+            size_t s2 = nb2 / sizeof(dst_t);
+            size_t s3 = nb3 / sizeof(dst_t);
+
+            size_t s10 = nb10 / sizeof(src1_t);
+            size_t s11 = nb11 / sizeof(src1_t);
+            size_t s12 = nb12 / sizeof(src1_t);
+            size_t s13 = nb13 / sizeof(src1_t);
+
+            size_t s00 = nb00 / sizeof(src0_t);
+            size_t s01 = nb01 / sizeof(src0_t);
+            size_t s02 = nb02 / sizeof(src0_t);
+            size_t s03 = nb03 / sizeof(src0_t);
+
+            GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
+            GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
+            GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
+            GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
+
+            GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
+            GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
+            GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
+            GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
+
+            GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
+            GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
+            GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
+            GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
+
+            GGML_ASSERT(s0 == 1);
+            GGML_ASSERT(s00 == 1);
+            GGML_ASSERT(s10 == 1);
+
+            const int block_size = 128;
+
+            int64_t hne0 = std::max(ne0/2LL, 1LL);
+
+            dim3 block_dims;
+            block_dims.x = std::min<unsigned int>(hne0, block_size);
+            block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
+            block_dims.z = std::min(std::min<unsigned int>(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
+
+            dim3 block_nums(
+                (hne0 + block_dims.x - 1) / block_dims.x,
+                (ne1 + block_dims.y - 1) / block_dims.y,
+                (ne2*ne3 + block_dims.z - 1) / block_dims.z
+            );
+
+            if (block_nums.z > 65535) {
+                // this is the maximum number of blocks in z dimension, fallback to 1D grid kernel
+                int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
+                k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
+                    src0_dd, src1_dd, dst_dd,
+                    ne0, ne1, ne2, ne3,
+                    ne10, ne11, ne12, ne13,
+                    /* s0, */ s1, s2, s3,
+                    /* s00, */ s01, s02, s03,
+                    /* s10, */ s11, s12, s13);
+            } else {
+                k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
+                    src0_dd, src1_dd, dst_dd,
+                    ne0, ne1, ne2, ne3,
+                    ne10, ne11, ne12, ne13,
+                    /* s0, */ s1, s2, s3,
+                    /* s00, */ s01, s02, s03,
+                    /* s10, */ s11, s12, s13);
+            }
+        }
+    }
+};
+
+template<class op>
+static void ggml_cuda_op_bin_bcast(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+        op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
+    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+        op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
+    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+        op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
+    } else {
+        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
+            ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+        GGML_ABORT("fatal error");
+    }
+}
+
+void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
+}
+
+void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
+
+void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
+
+void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
+
+void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}

+ 33 - 0
llama/ggml-cuda/binbcast.cuh

@@ -0,0 +1,33 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 60 - 0
llama/ggml-cuda/clamp.cu

@@ -0,0 +1,60 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "clamp.cuh"
+
+static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
+}
+
+static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
+    clamp_f32<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
+}
+
+
+void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    float min;
+    float max;
+    memcpy(&min, dst->op_params, sizeof(float));
+    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
+
+    clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream);
+}

+ 31 - 0
llama/ggml-cuda/clamp.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_CLAMP_BLOCK_SIZE 256
+
+void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 695 - 0
llama/ggml-cuda/common.cuh

@@ -0,0 +1,695 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml.h"
+#include "ggml-cuda.h"
+
+#include <cstdint>
+#include <memory>
+
+#if defined(GGML_USE_HIPBLAS)
+#define GGML_COMMON_DECL_HIP
+#define GGML_COMMON_IMPL_HIP
+#else
+#define GGML_COMMON_DECL_CUDA
+#define GGML_COMMON_IMPL_CUDA
+#if defined(GGML_USE_MUSA)
+#define GGML_COMMON_DECL_MUSA
+#define GGML_COMMON_IMPL_MUSA
+#endif
+#endif
+#include "ggml-common.h"
+
+#include <cstdio>
+#include <array>
+#include <cassert>
+#include <cfloat>
+#include <string>
+#include <vector>
+
+#if defined(GGML_USE_HIPBLAS)
+#include "vendors/hip.h"
+#elif defined(GGML_USE_MUSA)
+#include "vendors/musa.h"
+#else
+#include "vendors/cuda.h"
+#endif // defined(GGML_USE_HIPBLAS)
+
+#define STRINGIZE_IMPL(...) #__VA_ARGS__
+#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
+
+#define WARP_SIZE 32
+#define CUDART_HMAX   11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
+#define CUDART_HMASK  12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
+
+#define CC_PASCAL     600
+#define MIN_CC_DP4A   610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#define CC_VOLTA      700
+#define CC_TURING     750
+#define CC_AMPERE     800
+#define CC_OFFSET_AMD 1000000
+#define CC_RDNA1      (CC_OFFSET_AMD + 1010)
+#define CC_RDNA2      (CC_OFFSET_AMD + 1030)
+#define CC_RDNA3      (CC_OFFSET_AMD + 1100)
+
+#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+#define GGML_CUDA_MAX_STREAMS 8
+
+[[noreturn]]
+void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
+
+#define CUDA_CHECK_GEN(err, success, error_fn)                                      \
+     do {                                                                           \
+        auto err_ = (err);                                                          \
+        if (err_ != (success)) {                                                    \
+            ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_));    \
+        }                                                                           \
+    } while (0)
+
+#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
+
+#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
+    static const char * cublas_get_error_str(const cublasStatus_t err) {
+        return cublasGetStatusString(err);
+    }
+#else
+    static const char * cublas_get_error_str(const cublasStatus_t err) {
+        switch (err) {
+            case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
+            case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
+            case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
+            case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
+            case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
+            case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
+            case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
+            case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
+            case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
+            default: return "unknown error";
+        }
+    }
+#endif // CUDART_VERSION >= 12000
+
+#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
+
+#if !defined(GGML_USE_HIPBLAS)
+static const char * cu_get_error_str(CUresult err) {
+    const char * err_str;
+    cuGetErrorString(err, &err_str);
+    return err_str;
+}
+#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
+#endif
+
+#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
+#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
+#else
+#define GGML_CUDA_ASSUME(x)
+#endif // CUDART_VERSION >= 11100
+
+#ifdef GGML_CUDA_F16
+typedef half dfloat; // dequantize float
+typedef half2 dfloat2;
+#else
+typedef float dfloat; // dequantize float
+typedef float2 dfloat2;
+#endif // GGML_CUDA_F16
+
+#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
+#define FP16_AVAILABLE
+#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
+
+#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
+#define FAST_FP16_AVAILABLE
+#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
+
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+#define FP16_MMA_AVAILABLE
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
+#define INT8_MMA_AVAILABLE
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
+
+static constexpr bool fast_fp16_available(const int cc) {
+    return cc >= CC_PASCAL && cc != 610;
+}
+
+static constexpr bool fp16_mma_available(const int cc) {
+    return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
+}
+
+static constexpr bool int8_mma_available(const int cc) {
+    return cc < CC_OFFSET_AMD && cc >= CC_TURING;
+}
+
+[[noreturn]]
+static __device__ void no_device_code(
+    const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
+
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
+           file_name, line, function_name, arch);
+    GGML_UNUSED(arch_list);
+#else
+    printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
+           file_name, line, function_name, arch, arch_list);
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    __trap();
+
+    GGML_UNUSED(no_device_code); // suppress unused function warning
+}
+
+#ifdef __CUDA_ARCH__
+#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
+#else
+#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
+#endif // __CUDA_ARCH__
+
+static __device__ __forceinline__ float warp_reduce_sum(float x) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
+    }
+    return x;
+}
+
+static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
+        a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
+    }
+    return a;
+}
+
+static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
+#ifdef FP16_AVAILABLE
+
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
+        reinterpret_cast<half&>(a.x) +=  __low2half(a_other);
+        reinterpret_cast<half&>(a.y) += __high2half(a_other);
+    }
+    return a;
+#else
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
+    }
+    return a;
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+
+#else
+    NO_DEVICE_CODE;
+    return a;
+#endif // FP16_AVAILABLE
+}
+
+static __device__ __forceinline__ float warp_reduce_max(float x) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+    }
+    return x;
+}
+
+static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
+#ifdef FP16_AVAILABLE
+
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+    return __float2half(fmaxf(__half2float(a), __half2float(b)));
+#else
+    return __hmax(a, b);
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+
+#else
+   NO_DEVICE_CODE;
+   GGML_UNUSED(b);
+   return a;
+#endif // FP16_AVAILABLE
+}
+
+static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+
+#if CUDART_VERSION >= CUDART_HMAX
+    return __hmax2(a, b);
+#else
+    half2 ret;
+    reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a),  __low2float(b)));
+    reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
+    return ret;
+#endif // CUDART_VERSION >= CUDART_HMAX
+
+#else
+    GGML_UNUSED(a);
+    GGML_UNUSED(b);
+    NO_DEVICE_CODE;
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+}
+
+static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
+#pragma unroll
+   for (int mask = 16; mask > 0; mask >>= 1) {
+       x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+   }
+   return x;
+#else
+   GGML_UNUSED(x);
+   NO_DEVICE_CODE;
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
+}
+
+#if CUDART_VERSION < CUDART_HMASK
+static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
+    const uint32_t mask_low  = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
+    const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
+    return mask_low | mask_high;
+}
+#endif // CUDART_VERSION < CUDART_HMASK
+
+static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
+    c = __builtin_amdgcn_sdot4(a, b, c, false);
+#elif defined(RDNA3)
+    c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
+#elif defined(__gfx1010__) || defined(__gfx900__)
+    int tmp1;
+    int tmp2;
+    asm("\n \
+        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
+        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
+        v_add3_u32 %0, %1, %2, %0 \n \
+        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
+        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
+        v_add3_u32 %0, %1, %2, %0 \n \
+        "
+        : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
+        : "v"(a), "v"(b)
+    );
+#else
+    const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
+    const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+    c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
+#endif
+    return c;
+
+#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+    return __dp4a(a, b, c);
+#else // __CUDA_ARCH__ >= MIN_CC_DP4A
+    const int8_t * a8 = (const int8_t *) &a;
+    const int8_t * b8 = (const int8_t *) &b;
+    return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+}
+
+// TODO: move to ggml-common.h
+static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+
+typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
+
+static __device__ __forceinline__ float get_alibi_slope(
+    const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
+) {
+    if (max_bias <= 0.0f) {
+        return 1.0f;
+    }
+    const float base = h < n_head_log2 ? m0 : m1;
+    const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+    return powf(base, exph);
+}
+
+template <ggml_type type>
+struct ggml_cuda_type_traits;
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_F16> {
+    static constexpr int qk = 1;
+    static constexpr int qr = 1;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
+    static constexpr int qk = QK4_0;
+    static constexpr int qr = QR4_0;
+    static constexpr int qi = QI4_0;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
+    static constexpr int qk = QK4_1;
+    static constexpr int qr = QR4_1;
+    static constexpr int qi = QI4_1;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
+    static constexpr int qk = QK5_0;
+    static constexpr int qr = QR5_0;
+    static constexpr int qi = QI5_0;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
+    static constexpr int qk = QK5_1;
+    static constexpr int qr = QR5_1;
+    static constexpr int qi = QI5_1;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
+    static constexpr int qk = QK8_0;
+    static constexpr int qr = QR8_0;
+    static constexpr int qi = QI8_0;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR2_K;
+    static constexpr int qi = QI2_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR3_K;
+    static constexpr int qi = QI3_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR4_K;
+    static constexpr int qi = QI4_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR5_K;
+    static constexpr int qi = QI5_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR6_K;
+    static constexpr int qi = QI6_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR2_XXS;
+    static constexpr int qi = QI2_XXS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR2_XS;
+    static constexpr int qi = QI2_XS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR2_S;
+    static constexpr int qi = QI2_S;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR3_XXS;
+    static constexpr int qi = QI3_XXS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR1_S;
+    static constexpr int qi = QI1_S;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR1_M;
+    static constexpr int qi = QI1_M;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
+    static constexpr int qk = QK4_NL;
+    static constexpr int qr = QR4_NL;
+    static constexpr int qi = QI4_NL;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR4_XS;
+    static constexpr int qi = QI4_XS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR3_S;
+    static constexpr int qi = QI3_S;
+};
+
+//////////////////////
+
+struct ggml_cuda_device_info {
+    int device_count;
+
+    struct cuda_device_info {
+        int     cc;                 // compute capability
+        int     nsm;                // number of streaming multiprocessors
+        size_t  smpb;               // max. shared memory per block
+        size_t  smpbo;              // max. shared memory per block (with opt-in)
+        bool    vmm;                // virtual memory support
+        size_t  vmm_granularity;    // granularity of virtual memory
+        size_t  total_vram;
+    };
+
+    cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
+
+    std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
+};
+
+const ggml_cuda_device_info & ggml_cuda_info();
+
+void ggml_cuda_set_device(int device);
+int ggml_cuda_get_device();
+
+struct ggml_cuda_pool {
+    virtual ~ggml_cuda_pool() = default;
+
+    virtual void * alloc(size_t size, size_t * actual_size) = 0;
+    virtual void free(void * ptr, size_t size) = 0;
+};
+
+template<typename T>
+struct ggml_cuda_pool_alloc {
+    ggml_cuda_pool * pool = nullptr;
+    T * ptr = nullptr;
+    size_t actual_size = 0;
+
+    ggml_cuda_pool_alloc() = default;
+
+    explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) {
+    }
+
+    ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) {
+        alloc(size);
+    }
+
+    ~ggml_cuda_pool_alloc() {
+        if (ptr != nullptr) {
+            pool->free(ptr, actual_size);
+        }
+    }
+
+    // size is in number of elements
+    T * alloc(size_t size) {
+        GGML_ASSERT(pool != nullptr);
+        GGML_ASSERT(ptr == nullptr);
+        ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
+        return ptr;
+    }
+
+    T * alloc(ggml_cuda_pool & pool, size_t size) {
+        this->pool = &pool;
+        return alloc(size);
+    }
+
+    T * get() {
+        return ptr;
+    }
+
+    ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete;
+    ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete;
+    ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete;
+    ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
+};
+
+
+// backend interface
+
+struct ggml_tensor_extra_gpu {
+    void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
+    cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
+};
+
+
+#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
+#define USE_CUDA_GRAPH
+#endif
+
+struct ggml_graph_node_properties {
+    void * node_address;
+    ggml_op node_op;
+    int64_t ne[GGML_MAX_DIMS];
+    size_t nb[GGML_MAX_DIMS];
+    void * src_address[GGML_MAX_SRC];
+};
+
+struct ggml_cuda_graph {
+#ifdef USE_CUDA_GRAPH
+    ~ggml_cuda_graph() {
+        if (instance != nullptr) {
+            CUDA_CHECK(cudaGraphExecDestroy(instance));
+        }
+        if (graph != nullptr) {
+            CUDA_CHECK(cudaGraphDestroy(graph));
+        }
+    }
+    cudaGraph_t graph = nullptr;
+    cudaGraphExec_t instance = nullptr;
+    size_t num_nodes = 0;
+    std::vector<cudaGraphNode_t> nodes;
+    std::vector<cudaKernelNodeParams> params;
+    bool disable_due_to_gpu_arch = false;
+    bool disable_due_to_too_many_updates = false;
+    bool disable_due_to_failed_graph_capture = false;
+    int number_consecutive_updates = 0;
+    std::vector<ggml_graph_node_properties> ggml_graph_properties;
+    std::vector<char **> updated_kernel_arg;
+#endif
+};
+
+struct ggml_backend_cuda_context {
+    int device;
+    std::string name;
+    cudaEvent_t copy_event = nullptr;
+
+    cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
+    cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
+
+    std::unique_ptr<ggml_cuda_graph> cuda_graph;
+
+    explicit ggml_backend_cuda_context(int device) :
+        device(device),
+        name(GGML_CUDA_NAME + std::to_string(device)) {
+    }
+
+    ~ggml_backend_cuda_context() {
+        if (copy_event != nullptr) {
+            CUDA_CHECK(cudaEventDestroy(copy_event));
+        }
+        for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
+            for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
+                if (streams[i][j] != nullptr) {
+                    CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
+                }
+            }
+            if (cublas_handles[i] != nullptr) {
+                CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
+            }
+        }
+    }
+
+    cudaStream_t stream(int device, int stream) {
+        if (streams[device][stream] == nullptr) {
+            ggml_cuda_set_device(device);
+            CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
+        }
+        return streams[device][stream];
+    }
+
+    cudaStream_t stream() {
+        return stream(device, 0);
+    }
+
+    cublasHandle_t cublas_handle(int device) {
+        if (cublas_handles[device] == nullptr) {
+            ggml_cuda_set_device(device);
+            CUBLAS_CHECK(cublasCreate(&cublas_handles[device]));
+            CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH));
+        }
+        return cublas_handles[device];
+    }
+
+    cublasHandle_t cublas_handle() {
+        return cublas_handle(device);
+    }
+
+    // pool
+    std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
+
+    static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
+
+    ggml_cuda_pool & pool(int device) {
+        if (pools[device] == nullptr) {
+            pools[device] = new_pool_for_device(device);
+        }
+        return *pools[device];
+    }
+
+    ggml_cuda_pool & pool() {
+        return pool(device);
+    }
+};

+ 222 - 0
llama/ggml-cuda/concat.cu

@@ -0,0 +1,222 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "concat.cuh"
+
+// contiguous kernels
+static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+
+    if (nidx < ne00) { // src0
+        int offset_src =
+            nidx +
+            blockIdx.y * ne00 +
+            blockIdx.z * ne00 * gridDim.y;
+        dst[offset_dst] = x[offset_src];
+    } else {
+        int offset_src =
+            (nidx - ne00) +
+            blockIdx.y * (ne0 - ne00) +
+            blockIdx.z * (ne0 - ne00) * gridDim.y;
+        dst[offset_dst] = y[offset_src];
+    }
+}
+
+static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+
+    if (blockIdx.y < ne01) { // src0
+        int offset_src =
+            nidx +
+            blockIdx.y * ne0 +
+            blockIdx.z * ne0 * ne01;
+        dst[offset_dst] = x[offset_src];
+    } else {
+        int offset_src =
+            nidx +
+            (blockIdx.y - ne01) * ne0 +
+            blockIdx.z * ne0 * (gridDim.y - ne01);
+        dst[offset_dst] = y[offset_src];
+    }
+}
+
+static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+
+    if (blockIdx.z < ne02) { // src0
+        int offset_src =
+            nidx +
+            blockIdx.y * ne0 +
+            blockIdx.z * ne0 * gridDim.y;
+        dst[offset_dst] = x[offset_src];
+    } else {
+        int offset_src =
+            nidx +
+            blockIdx.y * ne0 +
+            (blockIdx.z - ne02) * ne0 *  gridDim.y;
+        dst[offset_dst] = y[offset_src];
+    }
+}
+
+static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, ne1, ne2);
+    if (dim == 0) {
+        concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
+        return;
+    }
+    if (dim == 1) {
+        concat_f32_dim1<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne01);
+        return;
+    }
+    concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
+}
+
+// non-contiguous kernel (slow)
+static __global__ void concat_f32_non_cont(
+        const char * src0,
+        const char * src1,
+              char * dst,
+           int64_t   ne00,
+           int64_t   ne01,
+           int64_t   ne02,
+           int64_t   ne03,
+          uint64_t   nb00,
+          uint64_t   nb01,
+          uint64_t   nb02,
+          uint64_t   nb03,
+           int64_t /*ne10*/,
+           int64_t /*ne11*/,
+           int64_t /*ne12*/,
+           int64_t /*ne13*/,
+          uint64_t   nb10,
+          uint64_t   nb11,
+          uint64_t   nb12,
+          uint64_t   nb13,
+           int64_t   ne0,
+           int64_t /*ne1*/,
+           int64_t /*ne2*/,
+           int64_t /*ne3*/,
+          uint64_t   nb0,
+          uint64_t   nb1,
+          uint64_t   nb2,
+          uint64_t   nb3,
+          int32_t   dim) {
+    const int64_t i3 = blockIdx.z;
+    const int64_t i2 = blockIdx.y;
+    const int64_t i1 = blockIdx.x;
+
+    int64_t o[4] = {0, 0, 0, 0};
+    o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
+
+    const float * x;
+
+    for (int i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
+        if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+            x = (const float *)(src0 + (i3       )*nb03 + (i2       )*nb02 + (i1       )*nb01 + (i0       )*nb00);
+        } else {
+            x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
+        }
+
+        float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+        *y = *x;
+    }
+}
+
+
+void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    cudaStream_t stream = ctx.stream();
+
+    const int32_t dim = ((int32_t *) dst->op_params)[0];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
+        const float * src0_d = (const float *)src0->data;
+        const float * src1_d = (const float *)src1->data;
+
+        float * dst_d = (float *)dst->data;
+
+        if (dim != 3) {
+            for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+                concat_f32_cuda(
+                        src0_d + i3 * (src0->nb[3] / 4),
+                        src1_d + i3 * (src1->nb[3] / 4),
+                        dst_d + i3 * ( dst->nb[3] / 4),
+                        src0->ne[0], src0->ne[1], src0->ne[2],
+                        dst->ne[0],  dst->ne[1],  dst->ne[2], dim, stream);
+            }
+        } else {
+            const size_t size0 = ggml_nbytes(src0);
+            const size_t size1 = ggml_nbytes(src1);
+
+            CUDA_CHECK(cudaMemcpyAsync(dst_d,           src0_d, size0, cudaMemcpyDeviceToDevice, stream));
+            CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
+        }
+    } else {
+        dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
+        concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
+                (const char *)src0->data,
+                (const char *)src1->data,
+                (      char *)dst->data,
+                src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+                src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+                src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+                src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
+                dst->ne[0],  dst->ne[1],  dst->ne[2],  dst->ne[3],
+                dst->nb[0],  dst->nb[1],  dst->nb[2],  dst->nb[3], dim);
+    }
+}

+ 31 - 0
llama/ggml-cuda/concat.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_CONCAT_BLOCK_SIZE 256
+
+void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 113 - 0
llama/ggml-cuda/conv-transpose-1d.cu

@@ -0,0 +1,113 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "conv-transpose-1d.cuh"
+
+static  __global__ void conv_transpose_1d_kernel(
+        const int s0, const int p0, const int d0, const int output_size,
+        const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
+        const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
+        const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
+        const float * src0, const float * src1,  float * dst) {
+    int global_index = threadIdx.x + blockIdx.x * blockDim.x;
+    if (global_index >= output_size) {
+        return;
+    }
+
+    int out_index = global_index / dst_ne0;
+
+    float accumulator = 0;
+
+    for (int c = 0; c < src0_ne2; c++) {
+        int idx = global_index % dst_ne0;
+
+        int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
+        int input_offset = src1_ne0 * c;
+
+        for (int i = 0; i < src1_ne0; i++) {
+            if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
+                continue;
+            }
+            int weight_idx = idx - i*s0;
+
+            float kernel_weight = src0[kernel_offset + weight_idx];
+            float input_value =  src1[input_offset+i];
+
+            accumulator += kernel_weight * input_value;
+        }
+    }
+    dst[global_index] = accumulator;
+}
+
+static void conv_transpose_1d_f32_f32_cuda(
+        const int s0, const int p0, const int d0, const int output_size,
+        const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
+        const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
+        const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
+        const float * src0, const float * src1,  float * dst,
+        cudaStream_t stream) {
+
+    const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE;
+    conv_transpose_1d_kernel<<<num_blocks,CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE, 0, stream>>>(
+        s0,p0,d0,output_size,
+        src0_ne0, src0_ne1,  src0_ne2, src0_ne3,
+        src1_ne0, src1_ne1,  src1_ne2, src1_ne3,
+        dst_ne0,  dst_ne1,   dst_ne2,  dst_ne3,
+        src0,src1, dst);
+}
+
+void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+
+    const ggml_tensor * src1 = dst->src[1];
+    const float * src1_d = (const float *)src1->data;
+
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+
+    const int s0 = opts[0];
+    const int p0 = 0;//opts[3];
+    const int d0 = 1;//opts[4];
+
+    const int64_t kernel_size = ggml_nelements(src0);
+    const int64_t input_size = ggml_nelements(src1);
+    const int64_t output_size = ggml_nelements(dst);
+
+    conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
+        src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+        src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+        dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+        src0_d, src1_d, dst_d, stream);
+}

+ 31 - 0
llama/ggml-cuda/conv-transpose-1d.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256
+
+void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 712 - 0
llama/ggml-cuda/convert.cu

@@ -0,0 +1,712 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "convert.cuh"
+#include "dequantize.cuh"
+
+#define CUDA_Q8_0_NE_ALIGN 2048
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
+    const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
+
+    if (i >= k) {
+        return;
+    }
+
+    const int64_t ib = i/qk; // block index
+    const int64_t iqs = (i%qk)/qr; // quant index
+    const int64_t iybs = i - i%qk; // y block start index
+    const int64_t y_offset = qr == 1 ? 1 : qk/2;
+
+    // dequantize
+    dfloat2 v;
+    dequantize_kernel(vx, ib, iqs, v);
+
+    y[iybs + iqs + 0]        = v.x;
+    y[iybs + iqs + y_offset] = v.y;
+}
+
+template <bool need_check>
+static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
+#if __CUDA_ARCH__ >= CC_PASCAL
+    constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
+
+    const int64_t   i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
+    const int * x0 = ((int *) vx) + blockIdx.x * nint;
+    half2 * y2 = (half2 *) (y + i0);
+
+    __shared__ int vals[nint];
+
+#pragma unroll
+    for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {
+        if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {
+            break;
+        }
+
+        const int ix = ix0 + threadIdx.x;
+        vals[ix] = x0[ix];
+    }
+
+    __syncthreads();
+
+#pragma unroll
+    for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
+        if (need_check && i0 + iy + 2*threadIdx.x >= k) {
+            return;
+        }
+
+        const half * b0 = ((const half  *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);
+        const half    d = *b0;
+        const char2  qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];
+
+        y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
+    }
+#else
+    GGML_UNUSED(vx);
+    GGML_UNUSED(y);
+    GGML_UNUSED(k);
+    NO_DEVICE_CODE;
+#endif // __CUDA_ARCH__ >= CC_PASCAL
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
+
+    const int64_t i = blockIdx.x;
+
+    // assume 32 threads
+    const int64_t tid = threadIdx.x;
+    const int64_t il  = tid/8;
+    const int64_t ir  = tid%8;
+    const int64_t ib = 8*i + ir;
+    if (ib >= nb32) {
+        return;
+    }
+
+    dst_t * y = yy + 256*i + 32*ir + 4*il;
+
+    const block_q4_0 * x = (const block_q4_0 *)vx + ib;
+    const float d = __half2float(x->d);
+    const float dm = -8*d;
+
+    const uint8_t * q = x->qs + 4*il;
+
+    for (int l = 0; l < 4; ++l) {
+        y[l+ 0] = d * (q[l] & 0xF) + dm;
+        y[l+16] = d * (q[l] >>  4) + dm;
+    }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
+
+    const int64_t i = blockIdx.x;
+
+    // assume 32 threads
+    const int64_t tid = threadIdx.x;
+    const int64_t il  = tid/8;
+    const int64_t ir  = tid%8;
+    const int64_t ib = 8*i + ir;
+    if (ib >= nb32) {
+        return;
+    }
+
+    dst_t * y = yy + 256*i + 32*ir + 4*il;
+
+    const block_q4_1 * x = (const block_q4_1 *)vx + ib;
+    const float2 d = __half22float2(x->dm);
+
+    const uint8_t * q = x->qs + 4*il;
+
+    for (int l = 0; l < 4; ++l) {
+        y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
+        y[l+16] = d.x * (q[l] >>  4) + d.y;
+    }
+}
+
+//================================== k-quants
+
+template<typename dst_t>
+static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_q2_K * x = (const block_q2_K *) vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t n   = tid/32;
+    const int64_t l   = tid - 32*n;
+    const int64_t is  = 8*n + l/16;
+
+    const uint8_t q = x[i].qs[32*n + l];
+    dst_t * y = yy + i*QK_K + 128*n;
+
+    float dall = __low2half(x[i].dm);
+    float dmin = __high2half(x[i].dm);
+    y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
+    y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
+    y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
+    y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i = blockIdx.x;
+    const block_q3_K * x = (const block_q3_K *) vx;
+
+    const int64_t r = threadIdx.x/4;
+    const int64_t tid = r/2;
+    const int64_t is0 = r%2;
+    const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);
+    const int64_t n = tid / 4;
+    const int64_t j = tid - 4*n;
+
+    uint8_t m = 1 << (4*n + j);
+    int64_t is = 8*n + 2*j + is0;
+    int shift = 2*j;
+
+    int8_t us = is <  4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
+                is <  8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
+                is < 12 ? (x[i].scales[is-8] >>  4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
+                          (x[i].scales[is-8] >>  4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
+    float d_all = x[i].d;
+    float dl = d_all * (us - 32);
+
+    dst_t * y = yy + i*QK_K + 128*n + 32*j;
+    const uint8_t * q = x[i].qs + 32*n;
+    const uint8_t * hm = x[i].hmask;
+
+    for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
+}
+
+static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
+    if (j < 4) {
+        d = q[j] & 63; m = q[j + 4] & 63;
+    } else {
+        d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+        m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
+    }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+    const block_q4_K * x = (const block_q4_K *) vx;
+
+    const int64_t i = blockIdx.x;
+
+    // assume 32 threads
+    const int64_t tid = threadIdx.x;
+    const int64_t il  = tid/8;
+    const int64_t ir  = tid%8;
+    const int64_t is  = 2*il;
+    const int64_t n   = 4;
+
+    dst_t * y = yy + i*QK_K + 64*il + n*ir;
+
+    const float dall = __low2half(x[i].dm);
+    const float dmin = __high2half(x[i].dm);
+
+    const uint8_t * q = x[i].qs + 32*il + n*ir;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[i].scales, sc, m);
+    const float d1 = dall * sc; const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[i].scales, sc, m);
+    const float d2 = dall * sc; const float m2 = dmin * m;
+    for (int l = 0; l < n; ++l) {
+        y[l + 0] = d1 * (q[l] & 0xF) - m1;
+        y[l +32] = d2 * (q[l] >>  4) - m2;
+    }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+    const block_q5_K * x = (const block_q5_K *) vx;
+
+    const int64_t i = blockIdx.x;
+
+    // assume 64 threads - this is very slightly better than the one below
+    const int64_t tid = threadIdx.x;
+    const int64_t il  = tid/16;   // il is in 0...3
+    const int64_t ir  = tid%16;   // ir is in 0...15
+    const int64_t is  = 2*il;     // is is in 0...6
+
+    dst_t * y = yy + i*QK_K + 64*il + 2*ir;
+
+    const float dall = __low2half(x[i].dm);
+    const float dmin = __high2half(x[i].dm);
+
+    const uint8_t * ql = x[i].qs + 32*il + 2*ir;
+    const uint8_t * qh = x[i].qh + 2*ir;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[i].scales, sc, m);
+    const float d1 = dall * sc; const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[i].scales, sc, m);
+    const float d2 = dall * sc; const float m2 = dmin * m;
+
+    uint8_t   hm  = 1 << (2*il);
+    y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
+    y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
+    hm <<= 1;
+    y[32] = d2 * ((ql[ 0] >>  4) + (qh[ 0] & hm ? 16 : 0)) - m2;
+    y[33] = d2 * ((ql[ 1] >>  4) + (qh[ 1] & hm ? 16 : 0)) - m2;
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+    const block_q6_K * x = (const block_q6_K *) vx;
+
+    const int64_t i = blockIdx.x;
+
+    // assume 64 threads - this is very slightly better than the one below
+    const int64_t tid = threadIdx.x;
+    const int64_t ip  = tid/32;   // ip is 0 or 1
+    const int64_t il  = tid - 32*ip; // 0...32
+    const int64_t is  = 8*ip + il/16;
+
+    dst_t * y = yy + i*QK_K + 128*ip + il;
+
+    const float d = x[i].d;
+
+    const uint8_t * ql = x[i].ql + 64*ip + il;
+    const uint8_t   qh = x[i].qh[32*ip + il];
+    const int8_t  * sc = x[i].scales + is;
+
+    y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+    y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
+    y[64] = d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+    y[96] = d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh >> 6) & 3) << 4)) - 32);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_iq2_xxs * x = (const block_iq2_xxs  *) vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint16_t * q2 = x[i].qs + 4*ib;
+    const uint8_t  * aux8 = (const uint8_t *)q2;
+    const uint8_t  * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
+    const uint32_t aux32 = q2[2] | (q2[3] << 16);
+    const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
+    const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
+    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_iq2_xs * x = (const block_iq2_xs *) vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint16_t * q2 = x[i].qs + 4*ib;
+    const uint8_t  * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
+    const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
+    const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
+    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_iq2_s * x = (const block_iq2_s *) vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
+    const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
+    const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
+    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_iq3_xxs * x = (const block_iq3_xxs  *) vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint8_t  * q3 = x[i].qs + 8*ib;
+    const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
+    const uint8_t  * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
+    const uint8_t  * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
+    const uint32_t aux32 = gas[0] | (gas[1] << 16);
+    const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
+    const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
+    for (int j = 0; j < 4; ++j) {
+        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+    }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_iq3_s * x = (const block_iq3_s *) vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint8_t * qs = x[i].qs + 8*ib;
+    const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
+    const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
+    const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
+    const uint8_t signs = x[i].signs[4*ib + il];
+    for (int j = 0; j < 4; ++j) {
+        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+    }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_iq1_s * x = (const block_iq1_s  *) vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
+    const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
+    uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
+    grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
+    grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
+    grid32[0] &= 0x0f0f0f0f;
+    for (int j = 0; j < 8; ++j) {
+        y[j] = d * (q[j] + delta);
+    }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_iq1_m * x = (const block_iq1_m  *) vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint16_t * sc = (const uint16_t *)x[i].scales;
+    iq1m_scale_t scale;
+    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+    const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
+    const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
+    const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
+    uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
+    grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
+    grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
+    grid32[0] &= 0x0f0f0f0f;
+    for (int j = 0; j < 8; ++j) {
+        y[j] = d * (q[j] + delta);
+    }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+    const uint8_t  * q4 = x[ib].qs + 4*il;
+    const float d = (float)x[ib].d;
+    for (int j = 0; j < 4; ++j) {
+        y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
+        y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];
+    }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+    const int64_t i   = blockIdx.x;
+    const block_iq4_xs * x = (const block_iq4_xs *)vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+    const uint8_t  * q4 = x[i].qs + 16*ib + 4*il;
+    const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
+    for (int j = 0; j < 4; ++j) {
+        y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
+        y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];
+    }
+}
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
+    const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
+    dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
+    if (k % CUDA_Q8_0_NE_ALIGN == 0) {
+        const bool need_check = false;
+        dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
+    } else {
+        const bool need_check = true;
+        dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
+    }
+}
+
+template<typename dst_t>
+static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb32 = k / 32;
+    const int nb = (k + 255) / 256;
+    dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
+}
+
+template<typename dst_t>
+static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb32 = k / 32;
+    const int nb = (k + 255) / 256;
+    dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
+}
+
+template<typename dst_t>
+static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = (k + QK_K - 1) / QK_K;
+    dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = (k + QK_K - 1) / QK_K;
+    dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template <typename src_t, typename dst_t>
+static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
+    const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    const src_t * x = (src_t *) vx;
+
+    y[i] = x[i];
+}
+
+template <typename src_t, typename dst_t>
+static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_row_q4_0_cuda;
+        case GGML_TYPE_Q4_1:
+            return dequantize_row_q4_1_cuda;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+        case GGML_TYPE_Q8_0:
+            if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) {
+                return dequantize_block_q8_0_f16_cuda;
+            }
+            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+        case GGML_TYPE_Q2_K:
+            return dequantize_row_q2_K_cuda;
+        case GGML_TYPE_Q3_K:
+            return dequantize_row_q3_K_cuda;
+        case GGML_TYPE_Q4_K:
+            return dequantize_row_q4_K_cuda;
+        case GGML_TYPE_Q5_K:
+            return dequantize_row_q5_K_cuda;
+        case GGML_TYPE_Q6_K:
+            return dequantize_row_q6_K_cuda;
+        case GGML_TYPE_IQ2_XXS:
+            return dequantize_row_iq2_xxs_cuda;
+        case GGML_TYPE_IQ2_XS:
+            return dequantize_row_iq2_xs_cuda;
+        case GGML_TYPE_IQ2_S:
+            return dequantize_row_iq2_s_cuda;
+        case GGML_TYPE_IQ3_XXS:
+            return dequantize_row_iq3_xxs_cuda;
+        case GGML_TYPE_IQ1_S:
+            return dequantize_row_iq1_s_cuda;
+        case GGML_TYPE_IQ1_M:
+            return dequantize_row_iq1_m_cuda;
+        case GGML_TYPE_IQ4_NL:
+            return dequantize_row_iq4_nl_cuda;
+        case GGML_TYPE_IQ4_XS:
+            return dequantize_row_iq4_xs_cuda;
+        case GGML_TYPE_IQ3_S:
+            return dequantize_row_iq3_s_cuda;
+        case GGML_TYPE_F32:
+            return convert_unary_cuda<float>;
+        default:
+            return nullptr;
+    }
+}
+
+to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_row_q4_0_cuda;
+        case GGML_TYPE_Q4_1:
+            return dequantize_row_q4_1_cuda;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+        case GGML_TYPE_Q8_0:
+            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+        case GGML_TYPE_Q2_K:
+            return dequantize_row_q2_K_cuda;
+        case GGML_TYPE_Q3_K:
+            return dequantize_row_q3_K_cuda;
+        case GGML_TYPE_Q4_K:
+            return dequantize_row_q4_K_cuda;
+        case GGML_TYPE_Q5_K:
+            return dequantize_row_q5_K_cuda;
+        case GGML_TYPE_Q6_K:
+            return dequantize_row_q6_K_cuda;
+        case GGML_TYPE_IQ2_XXS:
+            return dequantize_row_iq2_xxs_cuda;
+        case GGML_TYPE_IQ2_XS:
+            return dequantize_row_iq2_xs_cuda;
+        case GGML_TYPE_IQ2_S:
+            return dequantize_row_iq2_s_cuda;
+        case GGML_TYPE_IQ3_XXS:
+            return dequantize_row_iq3_xxs_cuda;
+        case GGML_TYPE_IQ1_S:
+            return dequantize_row_iq1_s_cuda;
+        case GGML_TYPE_IQ1_M:
+            return dequantize_row_iq1_m_cuda;
+        case GGML_TYPE_IQ4_NL:
+            return dequantize_row_iq4_nl_cuda;
+        case GGML_TYPE_IQ4_XS:
+            return dequantize_row_iq4_xs_cuda;
+        case GGML_TYPE_IQ3_S:
+            return dequantize_row_iq3_s_cuda;
+        case GGML_TYPE_F16:
+            return convert_unary_cuda<half>;
+        default:
+            return nullptr;
+    }
+}

+ 39 - 0
llama/ggml-cuda/convert.cuh

@@ -0,0 +1,39 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
+
+template<typename T>
+using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream);
+
+typedef to_t_cuda_t<float> to_fp32_cuda_t;
+typedef to_t_cuda_t<half> to_fp16_cuda_t;
+
+to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type);
+
+to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);

+ 515 - 0
llama/ggml-cuda/cpy.cu

@@ -0,0 +1,515 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "cpy.cuh"
+
+typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
+
+static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    float * dsti = (float *) cdsti;
+
+    *dsti = *xi;
+}
+
+static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    half * dsti = (half *) cdsti;
+
+    *dsti = __float2half(*xi);
+}
+
+static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
+    const half * xi = (const half *) cxi;
+    half * dsti = (half *) cdsti;
+
+    *dsti = *xi;
+}
+
+static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
+    const half * xi = (const half *) cxi;
+    float * dsti = (float *) cdsti;
+
+    *dsti = *xi;
+}
+
+template <cpy_kernel_t cpy_1>
+static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
+                                   const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                                   const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+                                   const int nb12, const int nb13) {
+    const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= ne) {
+        return;
+    }
+
+    // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
+    // then combine those indices with the corresponding byte offsets to get the total offsets
+    const int64_t i03 = i/(ne00 * ne01 * ne02);
+    const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+    const int64_t i13 = i/(ne10 * ne11 * ne12);
+    const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
+
+    cpy_1(cx + x_offset, cdst + dst_offset);
+}
+
+static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q8_0 * dsti = (block_q8_0 *) cdsti;
+
+    float amax = 0.0f; // absolute max
+
+    for (int j = 0; j < QK8_0; j++) {
+        const float v = xi[j];
+        amax = fmaxf(amax, fabsf(v));
+    }
+
+    const float d = amax / ((1 << 7) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->d = d;
+
+    for (int j = 0; j < QK8_0; ++j) {
+        const float x0 = xi[j]*id;
+
+        dsti->qs[j] = roundf(x0);
+    }
+}
+
+static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q4_0 * dsti = (block_q4_0 *) cdsti;
+
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK4_0; ++j) {
+        const float v = xi[j];
+        if (amax < fabsf(v)) {
+            amax = fabsf(v);
+            vmax = v;
+        }
+    }
+
+    const float d  = vmax / -8;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->d = d;
+
+    for (int j = 0; j < QK4_0/2; ++j) {
+        const float x0 = xi[0       + j]*id;
+        const float x1 = xi[QK4_0/2 + j]*id;
+
+        const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
+        const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
+
+        dsti->qs[j]  = xi0;
+        dsti->qs[j] |= xi1 << 4;
+    }
+}
+
+static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q4_1 * dsti = (block_q4_1 *) cdsti;
+
+    float vmin = FLT_MAX;
+    float vmax = -FLT_MAX;
+
+    for (int j = 0; j < QK4_1; ++j) {
+        const float v = xi[j];
+
+        if (v < vmin) vmin = v;
+        if (v > vmax) vmax = v;
+    }
+
+    const float d  = (vmax - vmin) / ((1 << 4) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->dm.x = d;
+    dsti->dm.y = vmin;
+
+    for (int j = 0; j < QK4_1/2; ++j) {
+        const float x0 = (xi[0       + j] - vmin)*id;
+        const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
+
+        const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
+        const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
+
+        dsti->qs[j]  = xi0;
+        dsti->qs[j] |= xi1 << 4;
+    }
+}
+
+static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q5_0 * dsti = (block_q5_0 *) cdsti;
+
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK5_0; ++j) {
+        const float v = xi[j];
+        if (amax < fabsf(v)) {
+            amax = fabsf(v);
+            vmax = v;
+        }
+    }
+
+    const float d  = vmax / -16;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->d = d;
+
+    uint32_t qh = 0;
+    for (int j = 0; j < QK5_0/2; ++j) {
+        const float x0 = xi[0       + j]*id;
+        const float x1 = xi[QK5_0/2 + j]*id;
+
+        const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
+        const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
+
+        dsti->qs[j]  = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+    }
+    memcpy(dsti->qh, &qh, sizeof(qh));
+}
+
+static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q5_1 * dsti = (block_q5_1 *) cdsti;
+
+    float min = xi[0];
+    float max = xi[0];
+
+    for (int j = 1; j < QK5_1; ++j) {
+        const float v = xi[j];
+        min = v < min ? v : min;
+        max = v > max ? v : max;
+    }
+
+    const float d  = (max - min) / 31;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->dm.x = d;
+    dsti->dm.y = min;
+
+    uint32_t qh = 0;
+    for (int j = 0; j < QK5_1/2; ++j) {
+        const float x0 = (xi[0       + j] - min)*id;
+        const float x1 = (xi[QK5_1/2 + j] - min)*id;
+
+        const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+        const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+        dsti->qs[j]  = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+    }
+    memcpy(dsti->qh, &qh, sizeof(qh));
+}
+
+
+static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
+    if (x <= val[0]) return 0;
+    if (x >= val[n-1]) return n-1;
+    int ml = 0, mu = n-1;
+    while (mu-ml > 1) {
+        int mav = (ml+mu)/2;
+        if (x < val[mav]) mu = mav; else ml = mav;
+    }
+    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
+
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK4_NL; ++j) {
+        const float v = xi[j];
+        if (amax < fabsf(v)) {
+            amax = fabsf(v);
+            vmax = v;
+        }
+    }
+
+    float d = vmax / kvalues_iq4nl[0];
+    const float id = d ? 1.0f/d : 0.0f;
+
+    float sumqx = 0, sumq2 = 0;
+    for (int j = 0; j < QK4_NL/2; ++j) {
+        const float x0 = xi[0        + j]*id;
+        const float x1 = xi[QK4_NL/2 + j]*id;
+        const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
+        const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
+        dsti->qs[j] = xi0 | (xi1 << 4);
+        const float v0 = kvalues_iq4nl[xi0];
+        const float v1 = kvalues_iq4nl[xi1];
+        const float w0 = xi[0        + j]*xi[0        + j];
+        const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
+        sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
+        sumq2 += w0*v0*v0 + w1*v1*v1;
+    }
+
+    dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
+}
+
+template <cpy_kernel_t cpy_blck, int qk>
+static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
+                                 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                                 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+                                 const int nb12, const int nb13) {
+    const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+
+    if (i >= ne) {
+        return;
+    }
+
+    const int i03 = i/(ne00 * ne01 * ne02);
+    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+    const int i13 = i/(ne10 * ne11 * ne12);
+    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+
+    cpy_blck(cx + x_offset, cdst + dst_offset);
+}
+
+static void ggml_cpy_f16_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_f16_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q8_0_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK8_0 == 0);
+    const int num_blocks = ne / QK8_0;
+    cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q4_0_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK4_0 == 0);
+    const int num_blocks = ne / QK4_0;
+    cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q4_1_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK4_1 == 0);
+    const int num_blocks = ne / QK4_1;
+    cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q5_0_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK5_0 == 0);
+    const int num_blocks = ne / QK5_0;
+    cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q5_1_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK5_1 == 0);
+    const int num_blocks = ne / QK5_1;
+    cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_iq4_nl_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK4_NL == 0);
+    const int num_blocks = ne / QK4_NL;
+    cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f16_f16_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
+    const int64_t ne = ggml_nelements(src0);
+    GGML_ASSERT(ne == ggml_nelements(src1));
+
+    GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
+    GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+
+    //GGML_ASSERT(src0->ne[3] == 1);
+
+    const int64_t nb00 = src0->nb[0];
+    const int64_t nb01 = src0->nb[1];
+    const int64_t nb02 = src0->nb[2];
+    const int64_t nb03 = src0->nb[3];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+
+    //GGML_ASSERT(src1->ne[3] == 1);
+
+    const int64_t nb10 = src1->nb[0];
+    const int64_t nb11 = src1->nb[1];
+    const int64_t nb12 = src1->nb[2];
+    const int64_t nb13 = src1->nb[3];
+
+    cudaStream_t main_stream = ctx.stream();
+
+    char * src0_ddc = (char *) src0->data;
+    char * src1_ddc = (char *) src1->data;
+
+    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
+        ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
+        ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
+        ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
+        ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+        ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else {
+        fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
+                ggml_type_name(src0->type), ggml_type_name(src1->type));
+        GGML_ABORT("fatal error");
+    }
+}
+
+void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    ggml_cuda_cpy(ctx, src0, dst);
+}
+
+void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
+    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+            return (void*) cpy_f32_f16<cpy_1_f32_f32>;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
+            return (void*) cpy_f32_f16<cpy_1_f32_f16>;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+            return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+            return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+            return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
+            return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
+            return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
+            return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+            return (void*) cpy_f32_f16<cpy_1_f32_f16>;
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+            return (void*) cpy_f32_f16<cpy_1_f16_f32>;
+    } else {
+        fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
+                ggml_type_name(src0->type), ggml_type_name(src1->type));
+        GGML_ABORT("fatal error");
+    }
+}

+ 35 - 0
llama/ggml-cuda/cpy.cuh

@@ -0,0 +1,35 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_CPY_BLOCK_SIZE 32
+
+void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
+
+void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);

+ 132 - 0
llama/ggml-cuda/cross-entropy-loss.cu

@@ -0,0 +1,132 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "cross-entropy-loss.cuh"
+#include "sumrows.cuh"
+
+#include <cmath>
+#include <cstdint>
+
+static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
+    const int warp_id = threadIdx.x / WARP_SIZE;
+    const int lane_id = threadIdx.x % WARP_SIZE;
+    const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;
+
+    const int ne_tmp = WARP_SIZE*nclasses;
+
+    extern __shared__ float tmp_all[];
+    float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
+    float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
+
+    // Each warp first loads ne_tmp logits/labels into shared memory:
+    for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
+        const int ig = i0*nclasses + i; // ig == i global
+
+        tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
+        tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
+    }
+
+    // Each thread in the warp then calculates the cross entropy loss for a single row.
+    // TODO: pad in order to avoid shared memory bank conflicts.
+
+    // Find maximum for softmax:
+    float max = -INFINITY;
+    for (int i = 0; i < nclasses; ++i) {
+        max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
+    }
+
+    // Calculate log(softmax(logits)) which is just logits - max:
+    float sum = 0.0f;
+    for (int i = 0; i < nclasses; ++i) {
+        float val = tmp_logits[lane_id*nclasses + i] - max;
+        sum += expf(val);
+        tmp_logits[lane_id*nclasses + i] = val;
+    }
+    sum = logf(sum);
+
+    // log(exp(logits - max) / sum) = (logits - max) - log(sum)
+    float loss = 0.0f;
+    for (int i = 0; i < nclasses; ++i) {
+        loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
+    }
+    loss = -warp_reduce_sum(loss) / (float)k;
+
+    __syncthreads();
+
+    if (lane_id == 0) {
+        tmp_all[warp_id] = loss;
+    }
+
+    __syncthreads();
+
+    if (warp_id != 0) {
+        return;
+    }
+
+    loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
+    loss = warp_reduce_sum(loss);
+
+    if (lane_id != 0) {
+        return;
+    }
+
+    dst[blockIdx.x] = loss;
+}
+
+void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    const int64_t ne00  = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       * dst_d  = (float       *) dst->data;
+
+    ggml_cuda_pool & pool = ctx.pool();
+    cudaStream_t stream = ctx.stream();
+
+    const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
+    const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
+    const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);
+
+    ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
+
+    cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+
+    // Combine results from individual blocks:
+    sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream);
+}

+ 31 - 0
llama/ggml-cuda/cross-entropy-loss.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
+
+void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 129 - 0
llama/ggml-cuda/dequantize.cuh

@@ -0,0 +1,129 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+    const block_q4_0 * x = (const block_q4_0 *) vx;
+
+    const dfloat d = x[ib].d;
+
+    const int vui = x[ib].qs[iqs];
+
+    v.x = vui & 0xF;
+    v.y = vui >> 4;
+
+#ifdef GGML_CUDA_F16
+    v = __hsub2(v, {8.0f, 8.0f});
+    v = __hmul2(v, {d, d});
+#else
+    v.x = (v.x - 8.0f) * d;
+    v.y = (v.y - 8.0f) * d;
+#endif // GGML_CUDA_F16
+}
+
+static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+    const block_q4_1 * x = (const block_q4_1 *) vx;
+
+    const dfloat d = __low2half(x[ib].dm);
+    const dfloat m = __high2half(x[ib].dm);
+
+    const int vui = x[ib].qs[iqs];
+
+    v.x = vui & 0xF;
+    v.y = vui >> 4;
+
+#ifdef GGML_CUDA_F16
+    v = __hmul2(v, {d, d});
+    v = __hadd2(v, {m, m});
+#else
+    v.x = (v.x * d) + m;
+    v.y = (v.y * d) + m;
+#endif // GGML_CUDA_F16
+}
+
+static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+    const block_q5_0 * x = (const block_q5_0 *) vx;
+
+    const dfloat d = x[ib].d;
+
+    uint32_t qh;
+    memcpy(&qh, x[ib].qh, sizeof(qh));
+
+    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
+    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+
+    v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
+    v.y = ((x[ib].qs[iqs] >>  4) | xh_1);
+
+#ifdef GGML_CUDA_F16
+    v = __hsub2(v, {16.0f, 16.0f});
+    v = __hmul2(v, {d, d});
+#else
+    v.x = (v.x - 16.0f) * d;
+    v.y = (v.y - 16.0f) * d;
+#endif // GGML_CUDA_F16
+}
+
+static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+    const block_q5_1 * x = (const block_q5_1 *) vx;
+
+    const dfloat d = __low2half(x[ib].dm);
+    const dfloat m = __high2half(x[ib].dm);
+
+    uint32_t qh;
+    memcpy(&qh, x[ib].qh, sizeof(qh));
+
+    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
+    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+
+    v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
+    v.y = ((x[ib].qs[iqs] >>  4) | xh_1);
+
+#ifdef GGML_CUDA_F16
+    v = __hmul2(v, {d, d});
+    v = __hadd2(v, {m, m});
+#else
+    v.x = (v.x * d) + m;
+    v.y = (v.y * d) + m;
+#endif // GGML_CUDA_F16
+}
+
+static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+    const block_q8_0 * x = (const block_q8_0 *) vx;
+
+    const dfloat d = x[ib].d;
+
+    v.x = x[ib].qs[iqs + 0];
+    v.y = x[ib].qs[iqs + 1];
+
+#ifdef GGML_CUDA_F16
+    v = __hmul2(v, {d, d});
+#else
+    v.x *= d;
+    v.y *= d;
+#endif // GGML_CUDA_F16
+}

+ 66 - 0
llama/ggml-cuda/diagmask.cu

@@ -0,0 +1,66 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "diagmask.cuh"
+
+static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
+    const int col = blockDim.y*blockIdx.y + threadIdx.y;
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (col >= ncols) {
+        return;
+    }
+
+    const int i = row*ncols + col;
+    //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
+    //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
+    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
+}
+
+static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
+    const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
+    const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
+    const dim3 block_nums(nrows_x, block_num_x, 1);
+    diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
+}
+
+void ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int nrows0 = ggml_nrows(src0);
+
+    const int n_past = ((int32_t *) dst->op_params)[0];
+
+    diag_mask_inf_f32_cuda(src0_d, dst_d, ne00, nrows0, ne01, n_past, stream);
+}

+ 31 - 0
llama/ggml-cuda/diagmask.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
+
+void ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 709 - 0
llama/ggml-cuda/dmmv.cu

@@ -0,0 +1,709 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "dmmv.cuh"
+#include "dequantize.cuh"
+#include "convert.cuh"
+
+#ifndef K_QUANTS_PER_ITERATION
+#define K_QUANTS_PER_ITERATION 2
+#else
+static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
+#endif
+
+static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
+
+    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
+
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q2_K * x = (const block_q2_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...15
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
+
+    const int step = 16/K_QUANTS_PER_ITERATION;
+
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
+
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15 or 0...14 in steps of 2
+    const int q_offset = 32*im + l0;
+    const int s_offset = 8*im;
+    const int y_offset = 128*im + l0;
+
+    uint32_t aux[4];
+    const uint8_t * d = (const uint8_t *)aux;
+    const uint8_t * m = (const uint8_t *)(aux + 2);
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const float   * y = yy + i * QK_K + y_offset;
+        const uint8_t * q = x[i].qs + q_offset;
+
+        const float dall = __low2half(x[i].dm);
+        const float dmin = __high2half(x[i].dm);
+
+        const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
+        aux[0] = a[0] & 0x0f0f0f0f;
+        aux[1] = a[1] & 0x0f0f0f0f;
+        aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
+        aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
+
+        float sum1 = 0, sum2 = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+                  + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+                  + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
+                  + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
+                  + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
+                  + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
+                  + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
+                  +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
+            sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
+                  + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
+
+        }
+        tmp += dall * sum1 - dmin * sum2;
+
+    }
+
+    // sum up partial sums and write back result
+    tmp = warp_reduce_sum(tmp);
+
+    if (threadIdx.x == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
+
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q3_K * x = (const block_q3_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
+
+    const int n  = K_QUANTS_PER_ITERATION;               // iterations in the inner loop
+    const int step = 16/K_QUANTS_PER_ITERATION;
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0....15 or 0...7
+
+    const uint8_t m = 1 << (4*im);
+
+    const int l0 = n*in;                                 // 0...15 or 0...14 in steps of 2
+    const int q_offset =  32*im + l0;
+    const int y_offset = 128*im + l0;
+
+    uint16_t utmp[4];
+    const int8_t * s = (const int8_t *)utmp;
+
+    const uint16_t s_shift = 4*im;
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const float   * y  = yy + i * QK_K + y_offset;
+        const uint8_t * q = x[i].qs + q_offset;
+        const uint8_t * h = x[i].hmask + l0;
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
+        utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
+
+        const float d = x[i].d;
+
+        float sum = 0;
+        for (int l = 0; l < n; ++l) {
+            sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
+                 + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
+                 + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
+                 + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
+            sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
+                 + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
+                 + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
+                + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
+        }
+        tmp += d * sum;
+
+    }
+
+    // sum up partial sums and write back result
+    tmp = warp_reduce_sum(tmp);
+
+    if (threadIdx.x == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
+
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q4_K * x = (const block_q4_K *)vx + ib0;
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
+
+    const int step = 8/K_QUANTS_PER_ITERATION;           // 8 or 4
+
+    const int il  = tid/step;                            // 0...3
+    const int ir  = tid - step*il;                       // 0...7 or 0...3
+    const int n   = 2 * K_QUANTS_PER_ITERATION;          // 2 or 4
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+#if K_QUANTS_PER_ITERATION == 2
+    uint32_t q32[4];
+    const uint8_t * q4 = (const uint8_t *)q32;
+#else
+    uint16_t q16[4];
+    const uint8_t * q4 = (const uint8_t *)q16;
+#endif
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const float   * y1 = yy + i*QK_K + y_offset;
+        const float   * y2 = y1 + 128;
+
+        const float dall = __low2half(x[i].dm);
+        const float dmin = __high2half(x[i].dm);
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+#if K_QUANTS_PER_ITERATION == 2
+        const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
+        const uint32_t * q2 = q1 + 16;
+
+        q32[0] = q1[0] & 0x0f0f0f0f;
+        q32[1] = q1[0] & 0xf0f0f0f0;
+        q32[2] = q2[0] & 0x0f0f0f0f;
+        q32[3] = q2[0] & 0xf0f0f0f0;
+
+        float4 s = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < 4; ++l) {
+            s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4];
+            s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12];
+            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+        }
+        tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
+#else
+        const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
+        const uint16_t * q2 = q1 + 32;
+
+        q16[0] = q1[0] & 0x0f0f;
+        q16[1] = q1[0] & 0xf0f0;
+        q16[2] = q2[0] & 0x0f0f;
+        q16[3] = q2[0] & 0xf0f0;
+
+        float4 s = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < 2; ++l) {
+            s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
+            s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
+            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+        }
+        tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
+#endif
+
+    }
+
+    // sum up partial sums and write back result
+    tmp = warp_reduce_sum(tmp);
+
+    if (tid == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) {
+
+    const int row = blockIdx.x;
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q5_K * x = (const block_q5_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int tid = threadIdx.x/2;  // 0...15
+    const int ix  = threadIdx.x%2;
+
+    const int il  = tid/4;     // 0...3
+    const int ir  = tid - 4*il;// 0...3
+    const int n   = 2;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    const uint8_t hm1  = 1 << (2*im);
+    const uint8_t hm2  = hm1 << 4;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+    uint16_t q16[8];
+    const uint8_t * q4 = (const uint8_t *)q16;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2) {
+
+        const uint8_t * ql1 = x[i].qs + q_offset;
+        const uint8_t * qh  = x[i].qh + l0;
+        const float   * y1  = yy + i*QK_K + y_offset;
+        const float   * y2  = y1 + 128;
+
+        const float dall = __low2half(x[i].dm);
+        const float dmin = __high2half(x[i].dm);
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+        float4 sum = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        const uint16_t * q1 = (const uint16_t *)ql1;
+        const uint16_t * q2 = q1 + 32;
+        q16[0] = q1[0] & 0x0f0f;
+        q16[1] = q1[8] & 0x0f0f;
+        q16[2] = (q1[0] >> 4) & 0x0f0f;
+        q16[3] = (q1[8] >> 4) & 0x0f0f;
+        q16[4] = q2[0] & 0x0f0f;
+        q16[5] = q2[8] & 0x0f0f;
+        q16[6] = (q2[0] >> 4) & 0x0f0f;
+        q16[7] = (q2[8] >> 4) & 0x0f0f;
+        for (int l = 0; l < n; ++l) {
+            sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+                   + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0));
+            sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+                   + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0));
+            sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+                   + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0));
+            sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+                   + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0));
+            smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+                  + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
+        }
+        tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
+    }
+
+    // sum up partial sums and write back result
+    tmp = warp_reduce_sum(tmp);
+
+    if (threadIdx.x == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
+
+    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
+
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q6_K * x = (const block_q6_K *)vx + ib0;
+
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
+
+    const int step = 16/K_QUANTS_PER_ITERATION;          // 16 or 8
+
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
+
+#if K_QUANTS_PER_ITERATION == 1
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15
+    const int is = 0;
+#else
+    const int l0 = 4 * in;                               // 0, 4, 8, ..., 28
+    const int is = in / 4;
+#endif
+    const int ql_offset = 64*im + l0;
+    const int qh_offset = 32*im + l0;
+    const int s_offset  =  8*im + is;
+    const int y_offset = 128*im + l0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const float   * y  = yy + i * QK_K + y_offset;
+        const uint8_t * ql = x[i].ql + ql_offset;
+        const uint8_t * qh = x[i].qh + qh_offset;
+        const int8_t  * s  = x[i].scales + s_offset;
+
+        const float d = x[i].d;
+
+#if K_QUANTS_PER_ITERATION == 1
+        float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+                  + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+                  + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+                  + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+                  + y[64] * s[4] * d * ((int8_t)((ql[ 0]  >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+                  + y[80] * s[5] * d * ((int8_t)((ql[16]  >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+                  + y[96] * s[6] * d * ((int8_t)((ql[32]  >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+                  +y[112] * s[7] * d * ((int8_t)((ql[48]  >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
+        tmp += sum;
+#else
+        float sum = 0;
+        for (int l = 0; l < 4; ++l) {
+            sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+                 + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+                 + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+                 + y[l+96] * s[6] * d * ((int8_t)((ql[l+32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
+        }
+        tmp += sum;
+#endif
+
+    }
+
+    // sum up partial sums and write back result
+    tmp = warp_reduce_sum(tmp);
+
+    if (tid == 0) {
+        dst[row] = tmp;
+    }
+}
+
+static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+    const half * x = (const half *) vx;
+
+    // automatic half -> float type cast if dfloat == float
+    v.x = x[ib + iqs + 0];
+    v.y = x[ib + iqs + 1];
+}
+
+static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
+    return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 :
+        type == GGML_TYPE_Q4_1 ? dequantize_q4_1 :
+        type == GGML_TYPE_Q5_0 ? dequantize_q5_0 :
+        type == GGML_TYPE_Q5_1 ? dequantize_q5_1 :
+        type == GGML_TYPE_Q8_0 ? dequantize_q8_0 :
+        type == GGML_TYPE_F16 ? convert_f16 :
+        nullptr;
+}
+
+template <ggml_type type>
+static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
+    constexpr int qk = ggml_cuda_type_traits<type>::qk; // quantized weights per x block
+    constexpr int qr = ggml_cuda_type_traits<type>::qr; // number of quantized weights per data value in x block
+    constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type);
+
+    const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int tid = threadIdx.x;
+
+    const int iter_stride = 2*GGML_CUDA_DMMV_X;
+    const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+// partial sum for each thread
+#ifdef GGML_CUDA_F16
+    half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
+#else
+    float tmp = 0.0f;
+#endif // GGML_CUDA_F16
+
+    for (int i = 0; i < ncols; i += iter_stride) {
+        const int col = i + vals_per_iter*tid;
+        const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index
+        const int iqs = (col%qk)/qr; // x quant index
+        const int iybs = col - col%qk; // y block start index
+
+// processing >2 values per i iter is faster for fast GPUs
+#pragma unroll
+        for (int j = 0; j < vals_per_iter; j += 2) {
+            // process 2 vals per j iter
+
+            // dequantize
+            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
+            dfloat2 v;
+            dequantize_kernel(vx, ib, iqs + j/qr, v);
+
+            // matrix multiplication
+            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
+#ifdef GGML_CUDA_F16
+            tmp += __hmul2(v, {
+                y[iybs + iqs + j/qr + 0],
+                y[iybs + iqs + j/qr + y_offset]
+            });
+#else
+            tmp += v.x * y[iybs + iqs + j/qr + 0];
+            tmp += v.y * y[iybs + iqs + j/qr + y_offset];
+#endif // GGML_CUDA_F16
+        }
+    }
+
+    // sum up partial sums and write back result
+    tmp = warp_reduce_sum(tmp);
+
+    if (tid == 0) {
+#ifdef GGML_CUDA_F16
+        dst[row] = tmp.x + tmp.y;
+#else
+        dst[row] = tmp;
+#endif // GGML_CUDA_F16
+    }
+}
+
+static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    dequantize_mul_mat_vec<GGML_TYPE_Q4_0>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    dequantize_mul_mat_vec<GGML_TYPE_Q4_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    dequantize_mul_mat_vec<GGML_TYPE_Q5_0>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    dequantize_mul_mat_vec<GGML_TYPE_Q5_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    dequantize_mul_mat_vec<GGML_TYPE_Q8_0>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const dim3 block_dims(32, 1, 1);
+    dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    dequantize_mul_mat_vec<GGML_TYPE_F16>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+void ggml_cuda_op_dequantize_mul_mat_vec(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
+    GGML_UNUSED(ctx);
+    const int64_t ne00 = src0->ne[0];
+    const int64_t row_diff = row_high - row_low;
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
+#ifdef GGML_CUDA_F16
+    ggml_cuda_pool_alloc<half> src1_dfloat_a(ctx.pool());
+    half * src1_dfloat = nullptr; // dfloat == half
+
+    bool src1_convert_f16 =
+        src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
+        src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
+        src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
+
+    if (src1_convert_f16) {
+        src1_dfloat = src1_dfloat_a.alloc(ne00);
+        const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+        GGML_ASSERT(to_fp16_cuda != nullptr);
+        to_fp16_cuda(src1_ddf_i, src1_dfloat, ne00, stream);
+    }
+#else
+    const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
+#endif // GGML_CUDA_F16
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q2_K:
+            dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q3_K:
+            dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q4_K:
+            dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q5_K:
+            dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_Q6_K:
+            dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            break;
+        case GGML_TYPE_F16:
+            convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+
+    GGML_UNUSED(src1);
+    GGML_UNUSED(dst);
+    GGML_UNUSED(src1_ddq_i);
+    GGML_UNUSED(src1_ncols);
+    GGML_UNUSED(src1_padded_row_size);
+}
+
+bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) {
+    return src0_type == GGML_TYPE_Q4_0 || src0_type == GGML_TYPE_Q4_1 ||
+        src0_type == GGML_TYPE_Q5_0 || src0_type == GGML_TYPE_Q5_1 ||
+        src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K ||
+        src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K ||
+        src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K ||
+        src0_type == GGML_TYPE_F16;
+}

+ 46 - 0
llama/ggml-cuda/dmmv.cuh

@@ -0,0 +1,46 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+// dmmv = dequantize_mul_mat_vec
+
+// TODO: remove this?
+#ifndef GGML_CUDA_DMMV_X
+#define GGML_CUDA_DMMV_X 32
+#endif
+
+#ifndef GGML_CUDA_MMV_Y
+#define GGML_CUDA_MMV_Y 1
+#endif
+
+void ggml_cuda_op_dequantize_mul_mat_vec(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream);
+
+bool ggml_cuda_dmmv_type_supported(ggml_type src0_type);

+ 734 - 0
llama/ggml-cuda/fattn-common.cuh

@@ -0,0 +1,734 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "common.cuh"
+#include "convert.cuh"
+#include "vecdotq.cuh"
+
+#include <cstdint>
+
+#define FATTN_KQ_STRIDE       256
+#define HALF_MAX_HALF         __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
+#define SOFTMAX_FTZ_THRESHOLD -20.0f                   // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
+
+typedef void (* fattn_kernel_t)(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3);
+
+typedef half (*vec_dot_KQ_f16_t)(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
+typedef float (*vec_dot_KQ_f32_t)(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
+
+template<typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+    const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    T sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI4_0;
+        const int shift = k_KQ & (QI8_1/2);
+
+        const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+        const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+#ifdef FP16_AVAILABLE
+        if (std::is_same<T, half>::value) {
+            const half2  * Q_ds = (const half2  *) Q_ds_v;
+
+            const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
+            sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
+        } else
+#endif // FP16_AVAILABLE
+        {
+            const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+            sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
+        }
+    }
+
+    return sum;
+}
+
+template<typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+    const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    T sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI4_1;
+        const int shift = k_KQ & (QI8_1/2);
+
+        const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+        const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+#ifdef FP16_AVAILABLE
+        if (std::is_same<T, half>::value) {
+            const half2  * Q_ds = (const half2  *) Q_ds_v;
+
+            const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
+            const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
+            sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
+        } else
+#endif // FP16_AVAILABLE
+        {
+            const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+            const float sumid4d8   =  __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
+            const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
+
+            sum += (T) (sumid4d8 + m4s8scaled);
+        }
+    }
+
+    return sum;
+}
+
+template<typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+    const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    T sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI5_0;
+        const int iqs8  = k_KQ %  QI8_1;
+        const int shift = k_KQ & (QI8_1/2);
+
+        int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
+        v |= (vh <<  4) & 0x00000010; // 0 ->  4
+        v |= (vh << 11) & 0x00001000; // 1 -> 12
+        v |= (vh << 18) & 0x00100000; // 2 -> 20
+        v |= (vh << 25) & 0x10000000; // 3 -> 28
+
+        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+        const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+#ifdef FP16_AVAILABLE
+        if (std::is_same<T, half>::value) {
+            const half2  * Q_ds = (const half2  *) Q_ds_v;
+
+            const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
+            sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
+        } else
+#endif // FP16_AVAILABLE
+        {
+            const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+            sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
+        }
+    }
+
+    return sum;
+}
+
+template<typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+    const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    T sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI5_1;
+        const int iqs8  = k_KQ %  QI8_1;
+        const int shift = k_KQ & (QI8_1/2);
+
+        int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
+        v |= (vh <<  4) & 0x00000010; // 0 ->  4
+        v |= (vh << 11) & 0x00001000; // 1 -> 12
+        v |= (vh << 18) & 0x00100000; // 2 -> 20
+        v |= (vh << 25) & 0x10000000; // 3 -> 28
+
+        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+        const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+#ifdef FP16_AVAILABLE
+        if (std::is_same<T, half>::value) {
+            const half2  * Q_ds = (const half2  *) Q_ds_v;
+
+            const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
+            const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
+            sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
+        } else
+#endif // FP16_AVAILABLE
+        {
+            const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+            const float sumid5d8   =  __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
+            const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
+
+            sum += (T) (sumid5d8 + m5s8scaled);
+        }
+    }
+
+    return sum;
+}
+
+template <typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+    const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    T sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const int ib  = k_KQ / QI8_0;
+        const int iqs = k_KQ % QI8_0;
+
+        const int v = get_int_b2(K_q8_0[ib].qs, iqs);
+
+        T Q_d;
+        if (std::is_same<T, half>::value) {
+            const half2  * Q_ds = (const half2  *) Q_ds_v;
+            Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]);
+        } else {
+            const float2 * Q_ds = (const float2 *) Q_ds_v;
+            Q_d = Q_ds[k_KQ_0/WARP_SIZE].x;
+        }
+
+        sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d);
+    }
+
+    return sum;
+}
+
+template <typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
+
+    const half2 * K_h2 = (const half2 *) K_c;
+    GGML_UNUSED(Q_q8);
+    GGML_UNUSED(Q_ds_v);
+
+#ifdef FP16_AVAILABLE
+    if (std::is_same<T, half>::value) {
+        const half2 * Q_h2 = (const half2 *) Q_v;
+
+        half2 sum2 = make_half2(0.0f, 0.0f);
+
+#pragma unroll
+        for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+            const int k_KQ = k_KQ_0 + threadIdx.x;
+
+            const half2 K_ik = K_h2[k_KQ];
+            sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
+        }
+
+        return __low2half(sum2) + __high2half(sum2);
+    }
+#endif // FP16_AVAILABLE
+
+    const float2 * Q_f2 = (const float2 *) Q_v;
+
+    float sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const half2 K_ik = K_h2[k_KQ];
+        sum +=  __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x;
+        sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y;
+    }
+
+    return sum;
+}
+
+template <typename Tds>
+static __device__ __forceinline__ void quantize_q8_1_to_shared(
+    const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
+
+    float vals[sizeof(int)] = {0.0f};
+#pragma unroll
+    for (int l = 0; l < sizeof(int); ++l) {
+        vals[l] = scale * x[4*threadIdx.x + l];
+    }
+
+    float amax = fabsf(vals[0]);
+    float sum  = vals[0];
+#pragma unroll
+    for (int l = 1; l < sizeof(int); ++l) {
+        amax = fmaxf(amax, fabsf(vals[l]));
+        sum += vals[l];
+    }
+#pragma unroll
+    for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
+        amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32));
+        sum +=             __shfl_xor_sync(0xFFFFFFFF, sum,  mask, 32);
+    }
+
+    const float d = amax / 127;
+    int q32 = 0;
+    int8_t * q8 = (int8_t *) &q32;
+
+    if (d != 0.0f) {
+#pragma unroll
+        for (int l = 0; l < sizeof(int); ++l) {
+            q8[l] = roundf(vals[l] / d);
+        }
+    }
+
+    yq32[threadIdx.x] = q32;
+    if (threadIdx.x % QI8_1 == 0) {
+        if (std::is_same<Tds, half2>::value) {
+            ((half2  *) yds)[threadIdx.x/QI8_1] =  make_half2(d, sum);
+        } else {
+            ((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum);
+        }
+    }
+}
+
+typedef half  (*dequantize_1_f16_t)(const void *, const int64_t);
+typedef float (*dequantize_1_f32_t)(const void *, const int64_t);
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) {
+    const block_q4_0 * x = (const block_q4_0 *) vx;
+
+    const int64_t ib    =  i          /  QK4_0;
+    const int     iqs   =  i          % (QK4_0/2);
+    const int     shift = (i % QK4_0) / (QK4_0/2);
+
+    const T   d  = x[ib].d;
+    const int q0 = x[ib].qs[iqs];
+    const int q  = ((q0 >> (4*shift)) & 0x0F) - 8;
+
+#ifdef FP16_AVAILABLE
+    if (std::is_same<T, half>::value) {
+        return ((half) d)*((half) q);
+    }
+#endif // FP16_AVAILABLE
+
+    return ((float) d)*((float) q);
+}
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) {
+    const block_q4_1 * x = (const block_q4_1 *) vx;
+
+    const int64_t ib    =  i          /  QK4_1;
+    const int     iqs   =  i          % (QK4_1/2);
+    const int     shift = (i % QK4_1) / (QK4_1/2);
+
+    const half2 dm = x[ib].dm;
+    const int   q0 = x[ib].qs[iqs];
+    const int   q  = ((q0 >> (4*shift)) & 0x0F);
+
+#ifdef FP16_AVAILABLE
+    if (std::is_same<T, half>::value) {
+        return __low2half(dm)*((half) q) + __high2half(dm);
+    }
+#endif // FP16_AVAILABLE
+
+    return __low2float(dm)*((float) q) + __high2float(dm);
+}
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) {
+    const block_q5_0 * x = (const block_q5_0 *) vx;
+
+    const int64_t ib    =  i          /  QK5_0;
+    const int     idq   =  i          %  QK5_0;
+    const int     iqs   =  i          % (QK5_0/2);
+    const int     shift = (i % QK5_0) / (QK5_0/2);
+
+    const T   d   = x[ib].d;
+    const int ql0 = x[ib].qs[iqs];
+    const int qh0 = get_int_b2(x[ib].qh, 0);
+    const int ql  = ((ql0 >> (4*shift)) & 0x0F);
+    const int qh  = ((qh0 >> idq) << 4) & 0x10;
+    const int q   = (ql | qh) - 16;
+
+#ifdef FP16_AVAILABLE
+    if (std::is_same<T, half>::value) {
+        return ((half) d)*((half) q);
+    }
+#endif // FP16_AVAILABLE
+
+    return ((float) d)*((float) q);
+}
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) {
+    const block_q5_1 * x = (const block_q5_1 *) vx;
+
+    const int64_t ib    =  i          /  QK5_1;
+    const int     idq   =  i          %  QK5_1;
+    const int     iqs   =  i          % (QK5_1/2);
+    const int     shift = (i % QK5_1) / (QK5_1/2);
+
+    const half2 dm  = x[ib].dm;
+    const int   ql0 = x[ib].qs[iqs];
+    const int   qh0 = get_int_b4(x[ib].qh, 0);
+    const int   ql  = ((ql0 >> (4*shift)) & 0x0F);
+    const int   qh  = ((qh0 >> idq) << 4) & 0x10;
+    const int   q   = (ql | qh);
+
+#ifdef FP16_AVAILABLE
+    if (std::is_same<T, half>::value) {
+        return __low2half(dm)*((half) q) + __high2half(dm);
+    }
+#endif // FP16_AVAILABLE
+
+    return __low2float(dm)*((float) q) + __high2float(dm);
+}
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) {
+    const block_q8_0 * x = (const block_q8_0 *) vx;
+
+    const int64_t ib  = i / QK8_0;
+    const int     iqs = i % QK8_0;
+
+    const T   d = x[ib].d;
+    const int q = x[ib].qs[iqs];
+
+#ifdef FP16_AVAILABLE
+    if (std::is_same<T, half>::value) {
+        return ((half) d)*((half) q);
+    }
+#endif // FP16_AVAILABLE
+
+    return ((float) d)*((float) q);
+}
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) {
+    const half * x = (const half *) vx;
+
+    return x[i];
+}
+
+template <int D>
+constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
+    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
+        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
+        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
+        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
+        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
+        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
+        nullptr;
+}
+
+template <int D>
+constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
+    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
+        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
+        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
+        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
+        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
+        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
+        nullptr;
+}
+
+constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) {
+    return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
+        type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
+        type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
+        type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
+        type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
+        type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
+        nullptr;
+}
+
+constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
+    return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float> :
+        type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float> :
+        type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float> :
+        type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float> :
+        type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float> :
+        type_V == GGML_TYPE_F16 ? dequantize_1_f16<float> :
+        nullptr;
+}
+
+template<int D, int parallel_blocks> // D == head size
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(D, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_combine_results(
+        const float  * __restrict__ VKQ_parts,
+        const float2 * __restrict__ VKQ_meta,
+        float * __restrict__ dst) {
+    VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
+    VKQ_meta  += parallel_blocks   * gridDim.y*blockIdx.x;
+    dst       +=                 D * gridDim.y*blockIdx.x;
+
+    const int tid = threadIdx.x;
+    __builtin_assume(tid < D);
+
+    __shared__ float2 meta[parallel_blocks];
+    if (tid < 2*parallel_blocks) {
+        ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
+    }
+
+    __syncthreads();
+
+    float kqmax = meta[0].x;
+#pragma unroll
+    for (int l = 1; l < parallel_blocks; ++l) {
+        kqmax = max(kqmax, meta[l].x);
+    }
+
+    float VKQ_numerator   = 0.0f;
+    float VKQ_denominator = 0.0f;
+#pragma unroll
+    for (int l = 0; l < parallel_blocks; ++l) {
+        const float diff = meta[l].x - kqmax;
+        const float KQ_max_scale = expf(diff);
+        const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
+        *((uint32_t *) &KQ_max_scale) &= ftz_mask;
+
+        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
+        VKQ_denominator += KQ_max_scale * meta[l].y;
+    }
+
+    dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
+}
+
+static void on_no_fattn_vec_case(const int D) {
+    if (D == 64) {
+        fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
+        fprintf(stderr, "By default only f16 KV cache is supported.\n");
+        fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
+        GGML_ABORT("fatal error");
+    } else if (D == 128) {
+        fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
+        fprintf(stderr, "Supported combinations:\n");
+        fprintf(stderr, "  - K == q4_0, V == q4_0,  4.50 BPV\n");
+        fprintf(stderr, "  - K == q8_0, V == q8_0,  8.50 BPV\n");
+        fprintf(stderr, "  - K == f16,  V == f16,  16.00 BPV\n");
+        fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
+        GGML_ABORT("fatal error");
+    } else {
+        fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
+        fprintf(stderr, "Only f16 is supported.\n");
+        GGML_ABORT("fatal error");
+    }
+}
+
+template <int D, int parallel_blocks>
+void launch_fattn(
+    ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
+    const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
+) {
+    const ggml_tensor * Q = dst->src[0];
+    const ggml_tensor * K = dst->src[1];
+    const ggml_tensor * V = dst->src[2];
+
+    const ggml_tensor * mask = dst->src[3];
+
+    ggml_tensor * KQV = dst;
+
+    GGML_ASSERT(Q->type == GGML_TYPE_F32);
+    GGML_ASSERT(KQV->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
+    GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
+                                "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
+
+    GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
+
+    ggml_cuda_pool & pool = ctx.pool();
+    cudaStream_t main_stream = ctx.stream();
+
+    ggml_cuda_pool_alloc<half>   K_f16(pool);
+    ggml_cuda_pool_alloc<half>   V_f16(pool);
+    ggml_cuda_pool_alloc<float>  dst_tmp(pool);
+    ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
+
+    char * K_data = (char *) K->data;
+    size_t nb11 = K->nb[1];
+    size_t nb12 = K->nb[2];
+    size_t nb13 = K->nb[3];
+
+    char * V_data = (char *) V->data;
+    size_t nb21 = V->nb[1];
+    size_t nb22 = V->nb[2];
+    size_t nb23 = V->nb[3];
+
+    if (need_f16_K && K->type != GGML_TYPE_F16) {
+        K_f16.alloc(ggml_nelements(K));
+        to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
+        to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
+        K_data = (char *) K_f16.ptr;
+
+        const size_t bs = ggml_blck_size(K->type);
+        const size_t ts = ggml_type_size(K->type);
+
+        nb11 = nb11*bs*sizeof(half)/ts;
+        nb12 = nb12*bs*sizeof(half)/ts;
+        nb13 = nb13*bs*sizeof(half)/ts;
+    }
+
+    if (need_f16_V && V->type != GGML_TYPE_F16) {
+        V_f16.alloc(ggml_nelements(V));
+        to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
+        to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
+        V_data = (char *) V_f16.ptr;
+
+        const size_t bs = ggml_blck_size(V->type);
+        const size_t ts = ggml_type_size(V->type);
+
+        nb21 = nb21*bs*sizeof(half)/ts;
+        nb22 = nb22*bs*sizeof(half)/ts;
+        nb23 = nb23*bs*sizeof(half)/ts;
+    }
+
+    if (parallel_blocks > 1) {
+        dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+        dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+    }
+
+    const dim3 block_dim(WARP_SIZE, nwarps, 1);
+    const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
+    const int  shmem = 0;
+
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
+
+    memcpy(&scale,         (float *) KQV->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (float *) KQV->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0.0f) {
+        scale /= logit_softcap;
+    }
+
+    const uint32_t n_head      = Q->ne[2];
+    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
+        (const char *) Q->data,
+        K_data,
+        V_data,
+        mask ? ((const char *) mask->data) : nullptr,
+        (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
+        scale, max_bias, m0, m1, n_head_log2, logit_softcap,
+        Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
+        K->ne[0], K->ne[1], K->ne[2], K->ne[3],
+        mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
+        Q->nb[1], Q->nb[2], Q->nb[3],
+        nb11, nb12, nb13,
+        nb21, nb22, nb23,
+        KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
+    );
+    CUDA_CHECK(cudaGetLastError());
+
+    if ((parallel_blocks) == 1) {
+        return;
+    }
+
+    const dim3 block_dim_combine(D, 1, 1);
+    const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
+    const int  shmem_combine = 0;
+
+    flash_attn_combine_results<D, parallel_blocks>
+        <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
+        (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
+    CUDA_CHECK(cudaGetLastError());
+}

+ 379 - 0
llama/ggml-cuda/fattn-tile-f16.cu

@@ -0,0 +1,379 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-tile-f16.cuh"
+
+#define FATTN_KQ_STRIDE_TILE_F16 64
+
+template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_tile_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#ifdef FP16_AVAILABLE
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
+    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.y              + nb01*ic0);
+    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.y / gqa_ratio));
+    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const half   * maskh = (const half   *)  mask + ne11*ic0;
+
+    const int stride_KV2 = nb11 / sizeof(half2);
+
+    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const half  slopeh = __float2half(slopef);
+
+    static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
+
+    __shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16];
+    half2 * KQ2 = (half2 *) KQ;
+
+    __shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts.
+
+    half kqmax[ncols/nwarps];
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        kqmax[j0/nwarps] = -HALF_MAX_HALF;
+    }
+    half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}};
+
+    half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
+
+    // Convert Q to half2 and store in registers:
+    __shared__ half2 Q_h2[ncols][D/2];
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
+            Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
+        }
+    }
+
+    __syncthreads();
+
+    const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
+    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
+        // Calculate KQ tile and keep track of new maximum KQ values:
+
+        half kqmax_new[ncols/nwarps];
+#pragma unroll
+        for (int j = 0; j < ncols/nwarps; ++j) {
+            kqmax_new[j] = kqmax[j];
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) {
+            const int i_KQ = i_KQ_0 + threadIdx.y;
+
+#pragma unroll
+            for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+                const int k_KQ = k_KQ_0 + threadIdx.x;
+
+                KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
+            }
+        }
+
+        __syncthreads();
+
+        half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}};
+
+#pragma unroll
+        for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) {
+            half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE];
+            half2 Q_k[ncols/nwarps];
+
+#pragma unroll
+            for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
+                const int i_KQ = i_KQ_0 + threadIdx.x;
+
+                K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
+            }
+#pragma unroll
+            for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+                const int j_KQ = j_KQ_0 + threadIdx.y;
+
+                Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ];
+            }
+
+#pragma unroll
+            for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
+#pragma unroll
+                for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+                    sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps];
+                }
+            }
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
+            const int i_KQ = i_KQ_0 + threadIdx.x;
+
+#pragma unroll
+            for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+                const int j_KQ = j_KQ_0 + threadIdx.y;
+
+                half sum;
+                if (use_logit_softcap) {
+                    const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+                    sum = logit_softcap * tanhf(tmp.x + tmp.y);
+                } else {
+                    sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+                }
+                sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
+
+                kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
+
+                KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum;
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
+            const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]));
+            kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
+
+#pragma unroll
+            for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]);
+                const half2 val = h2exp(diff);
+                kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val;
+                KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val;
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) {
+            const int k = k0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) {
+            half2  V_k[(D/2)/WARP_SIZE][2];
+            half2 KQ_k[ncols/nwarps];
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i];
+                V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i];
+            }
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+                const int j = j0 + threadIdx.y;
+
+                KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2];
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+#pragma unroll
+                for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+                    VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]);
+                    VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]);
+                }
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
+        const int j_VKQ = j_VKQ_0 + threadIdx.y;
+
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+
+        half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
+        kqsum_j = warp_reduce_sum(kqsum_j);
+
+#pragma unroll
+        for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
+            const int i0 = i00 + 2*threadIdx.x;
+
+            half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
+            if (parallel_blocks == 1) {
+                dst_val /= __half2half2(kqsum_j);
+            }
+            const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+            dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] =  __low2float(dst_val);
+            dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
+        }
+
+        if (parallel_blocks != 1 && threadIdx.x == 0) {
+            dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+        }
+    }
+#else
+   NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+}
+
+template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
+void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+    switch (Q->ne[0]) {
+        case  64: {
+            constexpr int      D = 64;
+            constexpr int nwarps = 8;
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+        } break;
+        case 128: {
+            constexpr int      D = 128;
+            constexpr int nwarps = 8;
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+        } break;
+        default: {
+            GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
+        } break;
+    }
+}
+
+void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    const int32_t precision = KQV->op_params[3];
+    GGML_ASSERT(precision == GGML_PREC_DEFAULT);
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (Q->ne[1] <= 16) {
+        constexpr int cols_per_block = 16;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 32) {
+        constexpr int cols_per_block = 32;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        }
+        return;
+    }
+
+    constexpr int cols_per_block = 32;
+    constexpr int parallel_blocks = 1;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+    }
+}

+ 29 - 0
llama/ggml-cuda/fattn-tile-f16.cuh

@@ -0,0 +1,29 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 371 - 0
llama/ggml-cuda/fattn-tile-f32.cu

@@ -0,0 +1,371 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-tile-f32.cuh"
+
+#define FATTN_KQ_STRIDE_TILE_F32 32
+
+template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_tile_ext_f32(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
+    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.y              + nb01*ic0);
+    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.y / gqa_ratio));
+    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const half   * maskh = (const half   *)  mask + ne11*ic0;
+
+    const int stride_KV2 = nb11 / sizeof(half2);
+
+    const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+
+    static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
+
+    __shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32];
+
+    __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts.
+    float2 * KV_tmp2 = (float2 *) KV_tmp;
+
+    float kqmax[ncols/nwarps];
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        kqmax[j0/nwarps] = -FLT_MAX/2.0f;
+    }
+    float kqsum[ncols/nwarps] = {0.0f};
+
+    float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
+
+    // Convert Q to half2 and store in registers:
+    __shared__ float Q_f[ncols][D];
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
+            float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f);
+            Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
+            Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
+        }
+    }
+
+    __syncthreads();
+
+    const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32;
+    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
+        // Calculate KQ tile and keep track of new maximum KQ values:
+
+        float kqmax_new[ncols/nwarps];
+#pragma unroll
+        for (int j = 0; j < ncols/nwarps; ++j) {
+            kqmax_new[j] = kqmax[j];
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += nwarps) {
+            const int i_KQ = i_KQ_0 + threadIdx.y;
+
+#pragma unroll
+            for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
+                const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
+                KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] =  __low2float(tmp);
+                KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
+            }
+        }
+
+        __syncthreads();
+
+        float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}};
+
+#pragma unroll
+        for (int k_KQ = 0; k_KQ < D; ++k_KQ) {
+            float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE];
+            float Q_k[ncols/nwarps];
+
+#pragma unroll
+            for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
+                const int i_KQ = i_KQ_0 + threadIdx.x;
+
+                K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
+            }
+#pragma unroll
+            for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+                const int j_KQ = j_KQ_0 + threadIdx.y;
+
+                Q_k[j_KQ_0/nwarps] = Q_f[j_KQ][k_KQ];
+            }
+
+#pragma unroll
+            for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
+#pragma unroll
+                for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+                    sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE] * Q_k[j_KQ_0/nwarps];
+                }
+            }
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
+            const int i_KQ = i_KQ_0 + threadIdx.x;
+
+#pragma unroll
+            for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+                const int j_KQ = j_KQ_0 + threadIdx.y;
+
+                if (use_logit_softcap) {
+                    sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+                }
+
+                sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
+
+                kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+
+                KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F32 + i_KQ] = sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps];
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
+            const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
+            kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
+
+            float kqsum_add = 0.0f;
+#pragma unroll
+            for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F32; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const float diff = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] - kqmax[j0/nwarps];
+                const float val = expf(diff);
+                kqsum_add += val;
+                KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] = val;
+            }
+            kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
+                VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F32; k0 += nwarps) {
+            const int k = k0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                KV_tmp2[k*(D/2) + i].x =  __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
+                KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) {
+            float2 V_k[(D/2)/WARP_SIZE];
+            float  KQ_k[ncols/nwarps];
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i];
+            }
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+                const int j = j0 + threadIdx.y;
+
+                KQ_k[j0/nwarps] = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + k];
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+#pragma unroll
+                for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+                    VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps];
+                    VKQ[j0/nwarps][i0/WARP_SIZE].y += V_k[i0/WARP_SIZE].y*KQ_k[j0/nwarps];
+                }
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
+        const int j_VKQ = j_VKQ_0 + threadIdx.y;
+
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+
+        float kqsum_j = kqsum[j_VKQ_0/nwarps];
+        kqsum_j = warp_reduce_sum(kqsum_j);
+
+#pragma unroll
+        for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
+            const int i0 = i00 + 2*threadIdx.x;
+
+            float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
+            if (parallel_blocks == 1) {
+                dst_val.x /= kqsum_j;
+                dst_val.y /= kqsum_j;
+            }
+            const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+            dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
+            dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
+        }
+
+        if (parallel_blocks != 1 && threadIdx.x == 0) {
+            dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+        }
+    }
+}
+
+template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
+void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+    switch (Q->ne[0]) {
+        case  64: {
+            constexpr int      D = 64;
+            constexpr int nwarps = 8;
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+        } break;
+        case 128: {
+            constexpr int      D = 128;
+            constexpr int nwarps = 8;
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+        } break;
+        default: {
+            GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
+        } break;
+    }
+}
+
+void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q = dst->src[0];
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (Q->ne[1] <= 16) {
+        constexpr int cols_per_block = 16;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 32) {
+        constexpr int cols_per_block = 32;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        }
+        return;
+    }
+
+    constexpr int cols_per_block = 32;
+    constexpr int parallel_blocks = 1;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+    }
+}

+ 29 - 0
llama/ggml-cuda/fattn-tile-f32.cuh

@@ -0,0 +1,29 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 468 - 0
llama/ggml-cuda/fattn-vec-f16.cuh

@@ -0,0 +1,468 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "fattn-common.cuh"
+
+template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(D, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_vec_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#ifdef FP16_AVAILABLE
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
+    constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
+    constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
+
+    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
+    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    Q += nb02* blockIdx.y              + nb01*ic0;
+    K += nb12*(blockIdx.y / gqa_ratio);
+    V += nb22*(blockIdx.y / gqa_ratio);
+
+    const half * maskh = (const half   *)  mask + ne11*ic0;
+
+    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const half  slopeh = __float2half(slopef);
+
+    static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
+    constexpr int nwarps = D / WARP_SIZE;
+    const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+    __builtin_assume(tid < D);
+
+    __shared__ half KQ[ncols*D];
+    half2 * KQ2 = (half2 *) KQ;
+
+    half kqmax[ncols];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        kqmax[j] = -HALF_MAX_HALF;
+    }
+    half kqsum[ncols] = {0.0f};
+
+    __shared__ half kqmax_shared[ncols][WARP_SIZE];
+    __shared__ half kqsum_shared[ncols][WARP_SIZE];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        if (threadIdx.y == 0) {
+            kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF;
+            kqsum_shared[j][threadIdx.x] = 0.0f;
+        }
+    }
+    __syncthreads();
+
+    // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
+    half2  Q_h2[ncols][D/(2*WARP_SIZE)];
+    int   Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)];
+    half2  Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
+    if (Q_q8_1) {
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (j0 + nwarps > ncols && j >= ncols) {
+                break;
+            }
+
+            // Reuse KQ as temporary storage for converting Q to q8_1:
+            int   * tmp_q_i32 = (int   *) &KQ[j*D];
+            half2 * tmp_q_ds  = (half2 *) (tmp_q_i32 + D/sizeof(int));
+
+            // Set memory to zero if out of bounds:
+            if (ncols > 2 && ic0 + j >= ne01) {
+#pragma unroll
+                for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                    const int i = i0 + threadIdx.x;
+
+                    tmp_q_i32[i] = 0;
+                }
+                if (threadIdx.x < D/QK8_1) {
+                    tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f);
+                }
+                continue;
+            }
+
+            const float * Q_f = (const float *) (Q + j*nb01);
+#pragma unroll
+            for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                quantize_q8_1_to_shared<half2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            int   * tmp_q_i32 = (int   *) &KQ[j*D];
+            half2 * tmp_q_ds  = (half2 *) (tmp_q_i32 + D/sizeof(int));
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
+                Q_ds[j][i0/WARP_SIZE]  = tmp_q_ds[i/QI8_1];
+            }
+        }
+
+        __syncthreads();
+    } else {
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
+                Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
+            }
+        }
+    }
+
+
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        KQ[j*D + tid] = -HALF_MAX_HALF;
+    }
+
+    half2 VKQ[ncols] = {{0.0f, 0.0f}};
+
+    const int k_start = parallel_blocks == 1 ? 0 : ip*D;
+    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
+        // Calculate KQ tile and keep track of new maximum KQ values:
+
+        // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
+        // see https://github.com/ggerganov/llama.cpp/pull/7061 .
+        // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
+        half kqmax_new = kqmax[0];
+        half kqmax_new_arr[ncols];
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            kqmax_new_arr[j] = kqmax[j];
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
+            const int i_KQ = i_KQ_0 + threadIdx.y;
+
+            if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
+                break;
+            }
+
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
+                sum = warp_reduce_sum(sum);
+
+                if (use_logit_softcap) {
+                    sum = logit_softcap*tanhf(sum);
+                }
+
+                sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
+
+                if (ncols == 1) {
+                    kqmax_new        = ggml_cuda_hmax(kqmax_new,        sum);
+                } else {
+                    kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
+                }
+
+                if (threadIdx.x == 0) {
+                    KQ[j*D + i_KQ] = sum;
+                }
+            }
+        }
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
+
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+            if (threadIdx.x == 0) {
+                kqmax_shared[j][threadIdx.y] = kqmax_new_j;
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            half kqmax_new_j = kqmax_shared[j][threadIdx.x];
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+            const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
+            kqmax[j] = kqmax_new_j;
+
+            const half val = hexp(KQ[j*D + tid] - kqmax[j]);
+            kqsum[j] = kqsum[j]*KQ_max_scale + val;
+            KQ[j*D + tid] = val;
+
+            VKQ[j] *= __half2half2(KQ_max_scale);
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int k0 = 0; k0 < D; k0 += 2) {
+            if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
+                break;
+            }
+
+            half2 V_k;
+            reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid);
+            reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        kqsum[j] = warp_reduce_sum(kqsum[j]);
+        if (threadIdx.x == 0) {
+            kqsum_shared[j][threadIdx.y] = kqsum[j];
+        }
+    }
+
+    __syncthreads();
+
+#pragma unroll
+    for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+        if (ncols > 2 && ic0 + j_VKQ >= ne01) {
+            break;
+        }
+
+        kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
+        kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
+
+        half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
+        if (parallel_blocks == 1) {
+            dst_val /= kqsum[j_VKQ];
+        }
+        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+        dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
+    }
+
+    if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
+        dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
+    }
+#else
+   NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+}
+
+template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
+void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    constexpr int nwarps = D/WARP_SIZE;
+    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
+    constexpr bool need_f16_K = D != 128;
+    constexpr bool need_f16_V = D != 128 && D != 64;
+    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
+}
+
+template <int D, ggml_type type_K, ggml_type type_V>
+void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+    const ggml_tensor * K   = dst->src[1];
+    const ggml_tensor * V   = dst->src[2];
+
+    const int32_t precision = KQV->op_params[3];
+    GGML_ASSERT(precision == GGML_PREC_DEFAULT);
+
+    GGML_ASSERT(K->type == type_K);
+    GGML_ASSERT(V->type == type_V);
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (Q->ne[1] == 1) {
+        constexpr int cols_per_block  = 1;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] == 2) {
+        constexpr int cols_per_block  = 2;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 4) {
+        constexpr int cols_per_block  = 4;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 8) {
+        constexpr int cols_per_block  = 8;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
+        return;
+    }
+
+    constexpr int cols_per_block  = 8;
+    constexpr int parallel_blocks = 1;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+    }
+}
+
+#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V)                          \
+    template void ggml_cuda_flash_attn_ext_vec_f16_case                     \
+    <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_0);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_1);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_0);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_1);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q8_0);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_F16);
+
+extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);

+ 446 - 0
llama/ggml-cuda/fattn-vec-f32.cuh

@@ -0,0 +1,446 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "fattn-common.cuh"
+
+template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(D, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_vec_ext_f32(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
+    constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
+    constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
+
+    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
+    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    Q += nb02* blockIdx.y              + nb01*ic0;
+    K += nb12*(blockIdx.y / gqa_ratio);
+    V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
+    const half * maskh = (const half   *)  mask + ne11*ic0;
+
+    const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+
+    static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
+    constexpr int nwarps = D / WARP_SIZE;
+    const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+    __builtin_assume(tid < D);
+
+    __shared__ float KQ[ncols*D];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        KQ[j*D + tid] = -FLT_MAX/2.0f;
+    }
+
+    float kqmax[ncols];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        kqmax[j] = -FLT_MAX/2.0f;
+    }
+    float kqsum[ncols] = {0.0f};
+
+    __shared__ float kqmax_shared[ncols][WARP_SIZE];
+    __shared__ float kqsum_shared[ncols][WARP_SIZE];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        if (threadIdx.y == 0) {
+            kqmax_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
+            kqsum_shared[j][threadIdx.x] = 0.0f;
+        }
+    }
+    __syncthreads();
+
+    // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
+    float2  Q_f2[ncols][D/(2*WARP_SIZE)];
+    int    Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)];
+    float2  Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
+    if (Q_q8_1) {
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (j0 + nwarps > ncols && j >= ncols) {
+                break;
+            }
+
+            // Reuse KQ as temporary storage for converting Q to q8_1:
+            int    * tmp_q_i32 = (int    *) &KQ[j*D];
+            float2 * tmp_q_ds  = (float2 *) (tmp_q_i32 + D/sizeof(int));
+
+            // Set memory to zero if out of bounds:
+            if (ncols > 2 && ic0 + j >= ne01) {
+#pragma unroll
+                for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                    const int i = i0 + threadIdx.x;
+
+                    tmp_q_i32[i] = 0;
+                }
+                if (threadIdx.x < D/QK8_1) {
+                    tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
+                }
+                continue;
+            }
+
+            const float * Q_f = (const float *) (Q + j*nb01);
+#pragma unroll
+            for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            int    * tmp_q_i32 = (int    *) &KQ[j*D];
+            float2 * tmp_q_ds  = (float2 *) (tmp_q_i32 + D/sizeof(int));
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
+                Q_ds[j][i0/WARP_SIZE]  = tmp_q_ds[i/QI8_1];
+            }
+        }
+
+        __syncthreads();
+    } else {
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                Q_f2[j][i0/WARP_SIZE]    = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
+                Q_f2[j][i0/WARP_SIZE].x *= scale;
+                Q_f2[j][i0/WARP_SIZE].y *= scale;
+            }
+        }
+    }
+
+    float VKQ[ncols] = {0.0f};
+
+    const int k_start = parallel_blocks == 1 ? 0 : ip*D;
+    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
+        // Calculate KQ tile and keep track of new maximum KQ values:
+
+        float kqmax_new_arr[ncols];
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            kqmax_new_arr[j] = kqmax[j];
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
+            const int i_KQ = i_KQ_0 + threadIdx.y;
+
+            if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
+                break;
+            }
+
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
+                sum = warp_reduce_sum(sum);
+
+                if (use_logit_softcap) {
+                    sum = logit_softcap*tanhf(sum);
+                }
+
+                sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
+
+                kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
+
+                if (threadIdx.x == 0) {
+                    KQ[j*D + i_KQ] = sum;
+                }
+            }
+        }
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            float kqmax_new_j = kqmax_new_arr[j];
+
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+            if (threadIdx.x == 0) {
+                kqmax_shared[j][threadIdx.y] = kqmax_new_j;
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            float kqmax_new_j = kqmax_shared[j][threadIdx.x];
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+            const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
+            kqmax[j] = kqmax_new_j;
+
+            const float val = expf(KQ[j*D + tid] - kqmax[j]);
+            kqsum[j] = kqsum[j]*KQ_max_scale + val;
+            KQ[j*D + tid] = val;
+
+            VKQ[j] *= KQ_max_scale;
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int k = 0; k < D; ++k) {
+            if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) {
+                break;
+            }
+
+            const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                VKQ[j] += V_ki*KQ[j*D + k];
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        kqsum[j] = warp_reduce_sum(kqsum[j]);
+        if (threadIdx.x == 0) {
+            kqsum_shared[j][threadIdx.y] = kqsum[j];
+        }
+    }
+
+    __syncthreads();
+
+#pragma unroll
+    for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+        if (ncols > 2 && ic0 + j_VKQ >= ne01) {
+            break;
+        }
+
+        kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
+        kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
+
+        float dst_val = VKQ[j_VKQ];
+        if (parallel_blocks == 1) {
+            dst_val /= kqsum[j_VKQ];
+        }
+        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+        dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
+    }
+
+    if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
+        dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
+    }
+}
+
+template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
+void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    constexpr int nwarps = D/WARP_SIZE;
+    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
+    constexpr bool need_f16_K = D != 128;
+    constexpr bool need_f16_V = D != 128 && D != 64;
+    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
+}
+
+template <int D, ggml_type type_K, ggml_type type_V>
+void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+    const ggml_tensor * K   = dst->src[1];
+    const ggml_tensor * V   = dst->src[2];
+
+    GGML_ASSERT(K->type == type_K);
+    GGML_ASSERT(V->type == type_V);
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (Q->ne[1] == 1) {
+        constexpr int cols_per_block  = 1;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] == 2) {
+        constexpr int cols_per_block  = 2;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 4) {
+        constexpr int cols_per_block  = 4;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 8) {
+        constexpr int cols_per_block  = 8;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
+        return;
+    }
+
+    constexpr int cols_per_block  = 8;
+    constexpr int parallel_blocks = 1;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+    }
+}
+
+#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V)                          \
+    template void ggml_cuda_flash_attn_ext_vec_f32_case                     \
+    <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_0);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_1);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_0);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_1);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q8_0);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_F16);
+
+extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);

+ 569 - 0
llama/ggml-cuda/fattn-wmma-f16.cuh

@@ -0,0 +1,569 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "fattn-common.cuh"
+
+#ifdef FP16_MMA_AVAILABLE
+#include <mma.h>
+#endif // FP16_MMA_AVAILABLE
+
+// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
+template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#ifdef FP16_MMA_AVAILABLE
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
+    const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel.
+
+    static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
+    static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
+    constexpr int frag_m = ncols == 8 ? 32 : 16;
+    constexpr int frag_n = ncols == 8 ?  8 : 16;
+    static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t>                      frag_c_KQ;
+    typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half>                          frag_c_VKQ;
+
+    constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
+    constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
+    static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
+
+    // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
+    constexpr int D_padded = D + 8;
+    constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
+    constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float * Q_f   = (const float *) (Q + nb02* blockIdx.y              + nb01*ic0);
+    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.y / gqa_ratio));
+    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0;
+    const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2);
+
+    const int stride_Q  = nb01 / sizeof(float);
+    const int stride_KV = nb11 / sizeof(half);
+
+    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const half  slopeh = __float2half(slopef);
+    const half2 slope2 = make_half2(slopef, slopef);
+
+    const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
+
+    frag_b Q_b[D/16][ncols/frag_n];
+
+    // A single buffer for temporarily holding tiles of KQ and VKQ parts:
+    constexpr int mem_KQ = ncols*kqs_padded*kqar;
+    constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
+    __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
+    float * KQ_f = (float *) KQ;
+    half2 * KQ2 = (half2 *) KQ;
+
+    float    KQ_rowsum_f[ncols/nwarps] = {0.0f};
+    float       KQ_max_f[ncols/nwarps];
+    float KQ_max_scale_f[ncols/nwarps] = {0.0f};
+
+#pragma unroll
+    for (int j = 0; j < ncols/nwarps; ++j) {
+        KQ_max_f[j] = -FLT_MAX/2.0f;
+    }
+
+    half2    KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+    half2       KQ_max_h2[ncols/nwarps];
+    half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+
+#pragma unroll
+    for (int j = 0; j < ncols/nwarps; ++j) {
+        KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
+    }
+
+    __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
+    half2 * VKQ2 = (half2 *) VKQ;
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+#pragma unroll
+        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                break;
+            }
+            VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
+        }
+    }
+
+    // Convert Q to half and apply scale, temporarily store in KQ:
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D && i >= D) {
+                break;
+            }
+            KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
+        }
+    }
+
+    __syncthreads();
+
+    // Load Q into tensor core fragments/registers since it will be used frequently:
+#pragma unroll
+    for (int i0 = 0; i0 < D; i0 += 16) {
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+            nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+        }
+    }
+
+    __syncthreads();
+
+    // Iterate over ne11 == previous tokens:
+    for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
+        // Calculate tile of KQ:
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
+            frag_c_KQ KQ_c[ncols/frag_n];
+#pragma unroll
+            for (int j = 0; j < ncols/frag_n; ++j) {
+                nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
+            }
+#pragma unroll
+            for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
+                frag_a_K K_a;
+                nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+#pragma unroll
+                for (int j = 0; j < ncols/frag_n; ++j) {
+                    nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
+                }
+            }
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+                nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
+            }
+        }
+
+        __syncthreads();
+
+        // Calculate softmax for each KQ column using the current max. value.
+        // The divisor is stored in KQ_rowsum and will be applied at the end.
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (std::is_same<KQ_acc_t, float>::value) {
+                float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+
+                    if (use_logit_softcap) {
+                        KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
+                    }
+                }
+
+                float KQ_max_new = KQ_max_f[j0/nwarps];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
+                }
+                KQ_max_new = warp_reduce_max(KQ_max_new);
+
+                const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
+                KQ_max_scale_f[j0/nwarps] = expf(diff);
+                if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                    KQ_max_scale_f[j0/nwarps] = 0.0f;
+                }
+                KQ_max_f[j0/nwarps] = KQ_max_new;
+
+                float KQ_rowsum_add = 0.0f;
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
+                    KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
+                    if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                        KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
+                    }
+                    KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
+                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
+                }
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+                KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
+            } else {
+                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+
+                    if (use_logit_softcap) {
+                        // There is no dedicated tangens hyperbolicus function for half2.
+                        KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
+                        KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
+                                               /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
+
+                        KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
+                    }
+                }
+
+                half2 KQ_max_new = KQ_max_h2[j0/nwarps];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
+                }
+                KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
+                const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
+                KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
+                const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+                *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
+                KQ_max_h2[j0/nwarps] = KQ_max_new;
+
+                half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
+                    KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
+                    const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+                    *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
+                    KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
+                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
+                }
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+                KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
+            }
+        }
+
+        __syncthreads();
+
+        frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+#pragma unroll
+            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+                nvcuda::wmma::load_matrix_sync(
+                    KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
+                    KQ + j0*(kqar*kqs_padded) + k,
+                    kqar*kqs_padded);
+            }
+        }
+
+        frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
+#pragma unroll
+            for (int j = 0; j < ncols/frag_n; ++j) {
+                nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
+            }
+
+#pragma unroll
+            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+
+                frag_a_V v_a;
+                nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+#pragma unroll
+                for (int j = 0; j < ncols/frag_n; ++j) {
+                    nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
+                }
+            }
+        }
+
+        __syncthreads();
+
+        const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+                nvcuda::wmma::store_matrix_sync(
+                    KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
+                    VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
+                    D_padded, nvcuda::wmma::mem_col_major);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            half2 VKQ_scale;
+            if (std::is_same<KQ_acc_t, float>::value) {
+                VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
+            } else {
+                VKQ_scale = KQ_max_scale_h2[j0/nwarps];
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+                if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                    break;
+                }
+
+                half2 VKQ_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+                for (int l = 0; l < VKQ_ratio; ++l) {
+                    VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
+                }
+                VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j_VKQ = j0 + threadIdx.y;
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+
+        float KQ_rowsum_j;
+        if (std::is_same<KQ_acc_t, float>::value) {
+            KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
+        } else {
+            KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D && i >= D) {
+                break;
+            }
+            float dst_val = VKQ[j_VKQ*D_padded + i];
+            if (parallel_blocks == 1) {
+                dst_val /= KQ_rowsum_j;
+            }
+            dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
+        }
+
+        if (parallel_blocks == 1 || threadIdx.x != 0) {
+            continue;
+        }
+
+        float2 dst_meta_val;
+        if (std::is_same<KQ_acc_t, float>::value) {
+            dst_meta_val.x = KQ_max_f[j0/nwarps];
+        } else {
+            dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
+        }
+        dst_meta_val.y = KQ_rowsum_j;
+        dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
+    }
+#else
+   NO_DEVICE_CODE;
+#endif // FP16_MMA_AVAILABLE
+}
+
+constexpr int get_max_power_of_2(int x) {
+    return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
+}
+
+static_assert(get_max_power_of_2(1) == 1, "Test failed.");
+static_assert(get_max_power_of_2(2) == 2, "Test failed.");
+static_assert(get_max_power_of_2(4) == 4, "Test failed.");
+static_assert(get_max_power_of_2(6) == 2, "Test failed.");
+
+// Number of VKQ rows calculated in parallel:
+constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
+    return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
+}
+
+static_assert(get_VKQ_stride(128, 1, 32) ==  32, "Test failed.");
+static_assert(get_VKQ_stride(128, 2, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
+static_assert(get_VKQ_stride( 64, 1, 32) ==  32, "Test failed.");
+static_assert(get_VKQ_stride( 64, 2, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride( 64, 4, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
+
+template <int D, int cols_per_block, typename KQ_acc_t>
+void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    constexpr int nwarps = 4;
+
+    constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
+    const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
+    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (4*blocks_num_pb1 < 2*nsm) {
+        constexpr int parallel_blocks = 4;
+        fattn_kernel_t fattn_kernel;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        } else {
+            constexpr bool use_logit_softcap = true;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        }
+        launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+        return;
+    }
+    if (2*blocks_num_pb1 < 2*nsm) {
+        constexpr int parallel_blocks = 2;
+        fattn_kernel_t fattn_kernel;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        } else {
+            constexpr bool use_logit_softcap = true;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        }
+        launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+        return;
+    }
+    constexpr int parallel_blocks = 1;
+    fattn_kernel_t fattn_kernel;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        fattn_kernel = flash_attn_ext_f16<
+            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+    } else {
+        constexpr bool use_logit_softcap = true;
+        fattn_kernel = flash_attn_ext_f16<
+            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+    }
+    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+}
+
+#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t)                         \
+    template void ggml_cuda_flash_attn_ext_wmma_f16_case                              \
+    <D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE(112, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE(128, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float);
+extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float);
+extern DECL_FATTN_WMMA_F16_CASE(112, 32, float);
+extern DECL_FATTN_WMMA_F16_CASE(128, 32, float);
+// extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
+
+extern DECL_FATTN_WMMA_F16_CASE( 64,  8, half);
+extern DECL_FATTN_WMMA_F16_CASE( 96,  8, half);
+extern DECL_FATTN_WMMA_F16_CASE(128,  8, half);
+extern DECL_FATTN_WMMA_F16_CASE(256,  8, half);
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE(112, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE(128, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE(112, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE(128, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);

+ 371 - 0
llama/ggml-cuda/fattn.cu

@@ -0,0 +1,371 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-tile-f16.cuh"
+#include "fattn-tile-f32.cuh"
+#include "fattn-vec-f16.cuh"
+#include "fattn-vec-f32.cuh"
+#include "fattn-wmma-f16.cuh"
+#include "fattn.cuh"
+
+#include <cstdint>
+
+static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    const int32_t precision = KQV->op_params[3];
+
+    if (precision != GGML_PREC_DEFAULT) {
+        if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
+            constexpr int cols_per_block = 16;
+            switch (Q->ne[0]) {
+                case 64:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
+                    break;
+                case 80:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
+                    break;
+                case 96:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
+                    break;
+                case 112:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
+                    break;
+                case 128:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+                    break;
+                case 256:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
+                    break;
+                default:
+                    GGML_ABORT("fatal error");
+                    break;
+            }
+        } else {
+            constexpr int cols_per_block = 32;
+            switch (Q->ne[0]) {
+                case 64:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
+                    break;
+                case 80:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
+                    break;
+                case 96:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
+                    break;
+                case 112:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
+                    break;
+                case 128:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+                    break;
+                // case 256:
+                //     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+                //     break;
+                default:
+                    GGML_ABORT("fatal error");
+                    break;
+            }
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
+        constexpr int cols_per_block = 8;
+        switch (Q->ne[0]) {
+            case 64:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+                break;
+            case 96:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+                break;
+            case 128:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+                break;
+            case 256:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+                break;
+            default:
+                GGML_ABORT("fatal error");
+                break;
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 32) {
+        constexpr int cols_per_block = 16;
+        switch (Q->ne[0]) {
+            case 64:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+                break;
+            case 80:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+                break;
+            case 96:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+                break;
+            case 112:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+                break;
+            case 128:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+                break;
+            case 256:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+                break;
+            default:
+                GGML_ABORT("fatal error");
+                break;
+        }
+        return;
+    }
+
+    constexpr int cols_per_block = 32;
+    switch (Q->ne[0]) {
+        case 64:
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+            break;
+        case 80:
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+            break;
+        case 96:
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+            break;
+        case 112:
+            ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+            break;
+        case 128:
+            ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+            break;
+        case 256:
+            ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+#define FATTN_VEC_F16_CASE(D, type_K, type_V)                               \
+    if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) {    \
+        ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
+        return;                                                             \
+    }                                                                       \
+
+static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * Q = dst->src[1];
+    ggml_tensor * K = dst->src[1];
+    ggml_tensor * V = dst->src[2];
+
+#ifdef GGML_CUDA_FA_ALL_QUANTS
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 )
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_0)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_1)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_0)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_1)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q8_0)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_F16)
+
+    FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+#else
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+#endif // GGML_CUDA_FA_ALL_QUANTS
+
+    on_no_fattn_vec_case(Q->ne[0]);
+}
+
+#define FATTN_VEC_F32_CASE(D, type_K, type_V)                               \
+    if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) {    \
+        ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \
+        return;                                                             \
+    }                                                                       \
+
+static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * Q = dst->src[1];
+    ggml_tensor * K = dst->src[1];
+    ggml_tensor * V = dst->src[2];
+
+#ifdef GGML_CUDA_FA_ALL_QUANTS
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_0)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_1)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_0)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_1)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q8_0)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_F16)
+
+    FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+#else
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+#endif // GGML_CUDA_FA_ALL_QUANTS
+
+    on_no_fattn_vec_case(Q->ne[0]);
+}
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    ggml_cuda_set_device(ctx.device);
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const int32_t precision = KQV->op_params[3];
+
+    // On AMD the tile kernels perform poorly, use the vec kernel instead:
+    if (cc >= CC_OFFSET_AMD) {
+        if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
+            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+        } else {
+            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+        }
+        return;
+    }
+
+    if (!fast_fp16_available(cc)) {
+        if (Q->ne[1] <= 8) {
+            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+        } else {
+            ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
+        }
+        return;
+    }
+
+    if (!fp16_mma_available(cc)) {
+        if (Q->ne[1] <= 8) {
+            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+        } else {
+            ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
+        if (precision == GGML_PREC_DEFAULT) {
+            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+            return;
+        } else if(Q->ne[0] <= 128) {
+            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+            return;
+        }
+    }
+
+    ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+}

+ 29 - 0
llama/ggml-cuda/fattn.cuh

@@ -0,0 +1,29 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 203 - 0
llama/ggml-cuda/getrows.cu

@@ -0,0 +1,203 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "getrows.cuh"
+#include "dequantize.cuh"
+
+template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static __global__ void k_get_rows(
+            const void * src0, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+    const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
+    const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+    const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+    if (i00 >= ne00) {
+        return;
+    }
+
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+    const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
+
+    const int ib = i00/qk; // block index
+    const int iqs = (i00%qk)/qr; // quant index
+    const int iybs = i00 - i00%qk; // dst block start index
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+    // dequantize
+    dfloat2 v;
+    dequantize_kernel(src0_row, ib, iqs, v);
+
+    dst_row[iybs + iqs + 0]        = v.x;
+    dst_row[iybs + iqs + y_offset] = v.y;
+}
+
+template<typename src0_t, typename dst_t>
+static __global__ void k_get_rows_float(
+            const src0_t * src0, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+    const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
+    const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+    const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+    if (i00 >= ne00) {
+        return;
+    }
+
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+    const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+
+    dst_row[i00] = src0_row[i00];
+}
+
+template<int qk, int qr, dequantize_kernel_t dq>
+static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+                            const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
+    const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
+    const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
+
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
+
+    GGML_ASSERT(ne00 % 2 == 0);
+
+    k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
+            src0_dd, src1_dd, dst_dd,
+            ne00, /*ne01, ne02, ne03,*/
+            /*ne10, ne11,*/ ne12, /*ne13,*/
+            /* s0,*/ s1, s2, s3,
+            /* nb00,*/ nb01, nb02, nb03,
+            s10, s11, s12/*, s13*/);
+
+    GGML_UNUSED(dst);
+}
+
+template<typename src0_t>
+static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+                                const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
+    const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
+    const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
+
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
+
+    k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
+            src0_dd, src1_dd, dst_dd,
+            ne00, /*ne01, ne02, ne03,*/
+            /*ne10, ne11,*/ ne12, /*ne13,*/
+            /* s0,*/ s1, s2, s3,
+            /* nb00,*/ nb01, nb02, nb03,
+            s10, s11, s12/*, s13*/);
+
+    GGML_UNUSED(dst);
+}
+
+void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const float * src0_d = (const float *)src0->data;
+    const float * src1_d = (const float *)src1->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+    GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
+
+    const int32_t * src1_i32 = (const int32_t *) src1_d;
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_F32:
+            get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q4_0:
+            get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        default:
+            // TODO: k-quants
+            GGML_ABORT("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
+            break;
+    }
+}

+ 31 - 0
llama/ggml-cuda/getrows.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_GET_ROWS_BLOCK_SIZE 256
+
+void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 130 - 0
llama/ggml-cuda/im2col.cu

@@ -0,0 +1,130 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "im2col.cuh"
+
+template <typename T>
+static  __global__ void im2col_kernel(
+        const float * x, T * dst, int64_t batch_offset,
+        int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
+        int s0, int s1, int p0, int p1, int d0, int d1) {
+    const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
+    if (i >= pelements) {
+        return;
+    }
+
+    const int64_t  ksize = OW * (KH > 1 ? KW : 1);
+    const int64_t  kx = i / ksize;
+    const int64_t  kd = kx * ksize;
+    const int64_t  ky = (i - kd) / OW;
+    const int64_t  ix = i % OW;
+
+    const int64_t  oh = blockIdx.y;
+    const int64_t  batch = blockIdx.z / IC;
+    const int64_t  ic = blockIdx.z % IC;
+
+    const int64_t iiw = ix * s0 + kx * d0 - p0;
+    const int64_t iih = oh * s1 + ky * d1 - p1;
+
+    const int64_t offset_dst =
+        ((batch * OH + oh) * OW + ix) * CHW +
+        (ic * (KW * KH) + ky * KW + kx);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        dst[offset_dst] = 0.0f;
+    } else {
+        const int64_t offset_src = ic * offset_delta + batch * batch_offset;
+        dst[offset_dst] = x[offset_src + iih * IW + iiw];
+    }
+}
+
+template <typename T>
+static void im2col_cuda(const float * x, T* dst,
+    int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
+    int64_t batch, int64_t batch_offset, int64_t offset_delta,
+    int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+    const int parallel_elements = OW * KW * KH;
+    const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
+    dim3 block_nums(num_blocks, OH, batch * IC);
+    im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
+}
+
+static void im2col_cuda_f16(const float * x, half * dst,
+    int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
+    int64_t batch, int64_t batch_offset, int64_t offset_delta,
+    int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+
+    im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
+}
+
+static void im2col_cuda_f32(const float * x, float * dst,
+    int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
+    int64_t batch, int64_t batch_offset, int64_t offset_delta,
+    int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+
+    im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
+}
+
+void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const float * src1_d = (const float *)src1->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
+
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
+
+    const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
+
+    const int64_t IC = src1->ne[is_2D ? 2 : 1];
+    const int64_t IH = is_2D ? src1->ne[1] : 1;
+    const int64_t IW =         src1->ne[0];
+
+    const int64_t KH = is_2D ? src0->ne[1] : 1;
+    const int64_t KW =         src0->ne[0];
+
+    const int64_t OH = is_2D ? dst->ne[2] : 1;
+    const int64_t OW =         dst->ne[1];
+
+    const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+    const int64_t batch = src1->ne[3];
+    const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
+
+    if(dst->type == GGML_TYPE_F16) {
+        im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
+    } else {
+        im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
+    }
+}

+ 31 - 0
llama/ggml-cuda/im2col.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_IM2COL_BLOCK_SIZE 256
+
+void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 247 - 0
llama/ggml-cuda/mma.cuh

@@ -0,0 +1,247 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+struct mma_int_A_I16K4 {
+    static constexpr int I  = 16;
+    static constexpr int K  = 4;
+    static constexpr int ne = 2;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_i(const int l) {
+        const int ret = (l%2) * (I/2) + threadIdx.x / K;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  I);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_k(const int /* l */) {
+        const int ret = threadIdx.x % K;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  K);
+        return ret;
+    }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE)
+        const int * xs = xs0 + (threadIdx.x%I)*stride;
+        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "+r"(x[0]), "+r"(x[1])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_i(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
+};
+
+struct mma_int_A_I16K8 {
+    static constexpr int I  = 16;
+    static constexpr int K  = 8;
+    static constexpr int ne = 4;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_i(const int l) {
+        const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  I);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_k(const int l) {
+        const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  K);
+        return ret;
+    }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE)
+        const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
+        asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
+            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_i(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
+
+    __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
+        ((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
+    }
+};
+
+struct mma_int_B_J8K4 {
+    static constexpr int J  = 8;
+    static constexpr int K  = 4;
+    static constexpr int ne = 1;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_j(const int /* l */) {
+        const int ret = threadIdx.x / K;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  J);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_k(const int /* l */) {
+        const int ret = threadIdx.x % K;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  K);
+        return ret;
+    }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
+        const int * xs = xs0 + (threadIdx.x%J)*stride;
+        asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
+            : "+r"(x[0])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_j(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
+};
+
+struct mma_int_B_J8K8 {
+    static constexpr int J  = 8;
+    static constexpr int K  = 8;
+    static constexpr int ne = 2;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_j(const int /* l */) {
+        const int ret = threadIdx.x / (K/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  J);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_k(const int l) {
+        const int ret = l * (K/2) + threadIdx.x % (K/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  K);
+        return ret;
+    }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
+        const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
+        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "+r"(x[0]), "+r"(x[1])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_j(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
+};
+
+struct mma_int_C_I16J8 {
+    static constexpr int I  = 16;
+    static constexpr int J  = 8;
+    static constexpr int ne = 4;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_i(const int l) {
+        const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  I);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_j(const int l) {
+        const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  J);
+        return ret;
+    }
+
+    __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
+#ifdef INT8_MMA_AVAILABLE
+#if __CUDA_ARCH__ >= CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+#else
+        // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[0]), "+r"(x[1])
+            : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+#endif // __CUDA_ARCH__ >= CC_AMPERE
+#else
+        GGML_UNUSED(mma_A);
+        GGML_UNUSED(mma_B);
+        NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
+#ifdef INT8_MMA_AVAILABLE
+#if __CUDA_ARCH__ >= CC_AMPERE
+        asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
+#else
+        // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[0]), "+r"(x[1])
+            : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[0]), "+r"(x[1])
+            : "r"(mma_A.x[2]), "r"(mma_B.x[1]));
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[3]), "r"(mma_B.x[1]));
+#endif // __CUDA_ARCH__ >= CC_AMPERE
+#else
+        GGML_UNUSED(mma_A);
+        GGML_UNUSED(mma_B);
+        NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+    }
+};

+ 176 - 0
llama/ggml-cuda/mmq.cu

@@ -0,0 +1,176 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "mmq.cuh"
+
+void ggml_cuda_op_mul_mat_q(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+    const int64_t ne00 = src0->ne[0];
+
+    const int64_t nb01 = src0->nb[1];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    GGML_ASSERT(ne10 % QK8_1 == 0);
+
+    const int64_t ne0 = dst->ne[0];
+
+    const int64_t row_diff = row_high - row_low;
+    const int64_t stride00 = nb01 / ggml_type_size(src0->type);
+
+    int id = ggml_cuda_get_device();
+    const int compute_capability = ggml_cuda_info().devices[id].cc;
+
+    // the main device has a larger memory buffer to hold the results from all GPUs
+    // nrows_dst == nrows of the matrix that the kernel writes into
+    const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
+
+    const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst};
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q2_K:
+            mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q3_K:
+            mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q4_K:
+            mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q5_K:
+            mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q6_K:
+            mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ2_XXS:
+            mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ2_XS:
+            mul_mat_q_case<GGML_TYPE_IQ2_XS>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ2_S:
+            mul_mat_q_case<GGML_TYPE_IQ2_S>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ3_XXS:
+            mul_mat_q_case<GGML_TYPE_IQ3_XXS>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ3_S:
+            mul_mat_q_case<GGML_TYPE_IQ3_S>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ1_S:
+            mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ4_XS:
+            mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ4_NL:
+            mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+
+    GGML_UNUSED(src1);
+    GGML_UNUSED(dst);
+    GGML_UNUSED(src1_ddf_i);
+}
+
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
+#ifdef GGML_CUDA_FORCE_CUBLAS
+    return false;
+#endif // GGML_CUDA_FORCE_CUBLAS
+
+    bool mmq_supported;
+
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ4_NL:
+            mmq_supported = true;
+            break;
+        default:
+            mmq_supported = false;
+            break;
+    }
+
+    if (!mmq_supported) {
+        return false;
+    }
+
+    if (int8_mma_available(cc)) {
+        return true;
+    }
+
+    if (cc < MIN_CC_DP4A) {
+        return false;
+    }
+
+#ifdef GGML_CUDA_FORCE_MMQ
+    return true;
+#endif //GGML_CUDA_FORCE_MMQ
+
+    if (cc < CC_OFFSET_AMD) {
+        return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+    }
+
+    return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+}

+ 2962 - 0
llama/ggml-cuda/mmq.cuh

@@ -0,0 +1,2962 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "common.cuh"
+#include "vecdotq.cuh"
+#include "mma.cuh"
+
+#include <climits>
+#include <cstdint>
+
+#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
+#define MMQ_ITER_K 256
+#define MMQ_NWARPS 8
+
+typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
+typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00);
+typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
+
+enum mmq_q8_1_ds_layout {
+    MMQ_Q8_1_DS_LAYOUT_D4,
+    MMQ_Q8_1_DS_LAYOUT_DS4,
+    MMQ_Q8_1_DS_LAYOUT_D2S6,
+};
+
+struct block_q8_1_mmq {
+    // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block.
+    // The y float data is first grouped as blocks of 128 values.
+    // These blocks are then treated as individual data values and transposed.
+    //
+    // To avoid shared memory bank conflicts each block is padded with 16 bytes.
+    // This padding is also used to store block scales/partial sums.
+    // The scales multiplied with the quantized data are equal to the unquantized values.
+    // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization)
+    //     and are only needed for performance reasons.
+    //
+    // The exact data stored depends on the x data type.
+    union {
+        float d4[4];    // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3
+        half2 ds4[4];   // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3
+        half  d2s6[8];  // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values,
+                        //     stored as d0,d1,s1,s2,s3,s4,s5
+    };
+    int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
+};
+static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
+static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1),      "Unexpected block_q8_1_mmq size");
+
+static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
+    switch (type_x) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+            return MMQ_Q8_1_DS_LAYOUT_DS4;
+        case GGML_TYPE_Q5_0:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
+        case GGML_TYPE_Q5_1:
+            return MMQ_Q8_1_DS_LAYOUT_DS4;
+        case GGML_TYPE_Q8_0:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
+        case GGML_TYPE_Q2_K:
+            return MMQ_Q8_1_DS_LAYOUT_D2S6;
+        case GGML_TYPE_Q3_K:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+            return MMQ_Q8_1_DS_LAYOUT_DS4;
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ3_S:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
+        case GGML_TYPE_IQ1_S:
+            return MMQ_Q8_1_DS_LAYOUT_DS4;
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ4_NL:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
+struct tile_x_sizes {
+    int qs;
+    int dm;
+    int sc;
+};
+
+static constexpr int get_mmq_x_max_host(const int cc) {
+    return int8_mma_available(cc) ? 128 :
+#ifdef GGML_CUDA_FORCE_MMQ
+        cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128                     : 64;
+#else
+        cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
+#endif // GGML_CUDA_FORCE_MMQ
+}
+
+static constexpr __device__ int get_mmq_x_max_device() {
+#ifdef INT8_MMA_AVAILABLE
+    return 128;
+#else // INT8_MMA_AVAILABLE
+
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+    return 128;
+#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+
+#if __CUDA_ARCH__ >= CC_VOLTA
+#ifdef GGML_CUDA_FORCE_MMQ
+    return MMQ_DP4A_MAX_BATCH_SIZE;
+#else // GGML_CUDA_FORCE_MMQ
+    return 128;
+#endif // GGML_CUDA_FORCE_MMQ
+#else // __CUDA_ARCH__ >= CC_VOLTA
+
+    return 64;
+#endif // __CUDA_ARCH__ >= CC_VOLTA
+
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#endif // INT8_MMA_AVAILABLE
+}
+
+static constexpr int get_mmq_y_host(const int cc) {
+    return cc >= CC_OFFSET_AMD ? (cc == CC_RDNA1 ? 64 : 128) : (cc >= CC_VOLTA ? 128 : 64);
+}
+
+static constexpr __device__ int get_mmq_y_device() {
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA1)
+    return 64;
+#else
+    return 128;
+#endif // defined RDNA1
+#else
+#if __CUDA_ARCH__ >= CC_VOLTA
+    return 128;
+#else
+    return 64;
+#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+}
+
+#define MMQ_DP4A_TXS_Q4_0    tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0   + mmq_y/QI4_0,     0}
+#define MMQ_DP4A_TXS_Q4_1    tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_1   + mmq_y/QI4_1,     0}
+#define MMQ_DP4A_TXS_Q8_0    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
+#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0}
+#define MMQ_DP4A_TXS_Q8_1    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
+#define MMQ_DP4A_TXS_Q2_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE         + mmq_y,           0}
+#define MMQ_DP4A_TXS_Q3_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y,                                     mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q4_K    tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_K,                     mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q5_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K   + mmq_y/QI5_K,     mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q6_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K   + mmq_y/QI6_K,     mmq_y*WARP_SIZE/8 + mmq_y/8}
+
+static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
+    return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
+        type == GGML_TYPE_Q4_1    ? MMQ_DP4A_TXS_Q4_1 :
+        type == GGML_TYPE_Q5_0    ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_Q5_1    ? MMQ_DP4A_TXS_Q8_1 :
+        type == GGML_TYPE_Q8_0    ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_Q2_K    ? MMQ_DP4A_TXS_Q2_K :
+        type == GGML_TYPE_Q3_K    ? MMQ_DP4A_TXS_Q3_K :
+        type == GGML_TYPE_Q4_K    ? MMQ_DP4A_TXS_Q4_K :
+        type == GGML_TYPE_Q5_K    ? MMQ_DP4A_TXS_Q5_K :
+        type == GGML_TYPE_Q6_K    ? MMQ_DP4A_TXS_Q6_K :
+        type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_IQ2_XS  ? MMQ_DP4A_TXS_Q8_0_16 :
+        type == GGML_TYPE_IQ2_S   ? MMQ_DP4A_TXS_Q8_0_16 :
+        type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_IQ3_S   ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_IQ1_S   ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_IQ4_XS  ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_IQ4_NL  ? MMQ_DP4A_TXS_Q8_0 :
+        tile_x_sizes{0, 0, 0};
+}
+
+#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0                 + 4)
+#define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0                 + 4)
+#define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE                         + 4)
+#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2                       + 4)
+#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K     + WARP_SIZE/8 + 7)
+
+static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
+
+static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
+    return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_Q4_1    ? MMQ_MMA_TILE_X_K_Q8_1 :
+        type == GGML_TYPE_Q5_0    ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_Q5_1    ? MMQ_MMA_TILE_X_K_Q8_1 :
+        type == GGML_TYPE_Q8_0    ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_Q2_K    ? MMQ_MMA_TILE_X_K_Q2_K :
+        type == GGML_TYPE_Q3_K    ? MMQ_MMA_TILE_X_K_Q3_K :
+        type == GGML_TYPE_Q4_K    ? MMQ_MMA_TILE_X_K_Q8_1 :
+        type == GGML_TYPE_Q5_K    ? MMQ_MMA_TILE_X_K_Q8_1 :
+        type == GGML_TYPE_Q6_K    ? MMQ_MMA_TILE_X_K_Q6_K :
+        type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_IQ2_XS  ? MMQ_MMA_TILE_X_K_Q3_K :
+        type == GGML_TYPE_IQ2_S   ? MMQ_MMA_TILE_X_K_Q3_K :
+        type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_IQ3_S   ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_IQ1_S   ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_IQ4_XS  ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_IQ4_NL  ? MMQ_MMA_TILE_X_K_Q8_0 :
+        0;
+}
+
+#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
+
+static int mmq_get_granularity_host(const int mmq_x, const int cc) {
+    return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
+}
+
+#ifdef INT8_MMA_AVAILABLE
+static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
+    return mmq_x >= 48 ? 16 : 8;
+}
+#else
+static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
+    return 8;
+}
+#endif // INT8_MMA_AVAILABLE
+
+// ------------------------------------------------------------
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + 2*WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = threadIdx.x / QI4_0;
+    const int kqsx = threadIdx.x % QI4_0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
+        const int qs0 = get_int_b2(bxi->qs, kqsx);
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0]     = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
+#else
+        x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
+        int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0       + kbxd] = bxi->d;
+#else
+        x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
+
+                int u[2*VDR_Q4_0_Q8_1_MMQ];
+
+#pragma unroll
+                for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
+                    u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs +  l];
+                    u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
+                }
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
+                    (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u,
+                     x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = threadIdx.x / QI4_1;
+    const int kqsx = threadIdx.x % QI4_1;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
+        const int qs0 = get_int_b4(bxi->qs, kqsx);
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0]     = (qs0 >> 0) & 0x0F0F0F0F;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
+#else
+        x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
+        int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_dm[i*MMQ_MMA_TILE_X_K_Q8_1       + kbxd] = bxi->dm;
+#else
+        x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
+
+                int u[2*VDR_Q4_1_Q8_1_MMQ];
+
+#pragma unroll
+                for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
+                    u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs +  l];
+                    u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
+                }
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
+                    (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u,
+                     x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = threadIdx.x / QI5_0;
+    const int kqsx = threadIdx.x % QI5_0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
+
+        const int ql = get_int_b2(bxi->qs, kqsx);
+        const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
+
+        int qs0 = (ql >>  0)   & 0x0F0F0F0F;
+        qs0    |= (qh <<  4)   & 0x00000010;  // 0 ->  4
+        qs0    |= (qh << 11)   & 0x00001000;  // 1 -> 12
+        qs0    |= (qh << 18)   & 0x00100000;  // 2 -> 20
+        qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28
+        qs0     = __vsubss4(qs0, 0x10101010); // subtract 16
+
+        int qs1 = (ql >>  4)   & 0x0F0F0F0F;
+        qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4
+        qs1    |= (qh >>  5)   & 0x00001000;  // 17 -> 12
+        qs1    |= (qh <<  2)   & 0x00100000;  // 18 -> 20
+        qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
+        qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
+        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
+        int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0       + kbxd] = bxi->d;
+#else
+        x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = threadIdx.x / QI5_1;
+    const int kqsx = threadIdx.x % QI5_1;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
+
+        const int ql = get_int_b4(bxi->qs, kqsx);
+        const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
+
+        int qs0 = (ql >>  0) & 0x0F0F0F0F;
+        qs0    |= (qh <<  4) & 0x00000010; // 0 ->  4
+        qs0    |= (qh << 11) & 0x00001000; // 1 -> 12
+        qs0    |= (qh << 18) & 0x00100000; // 2 -> 20
+        qs0    |= (qh << 25) & 0x10000000; // 3 -> 28
+
+        int qs1 = (ql >>  4) & 0x0F0F0F0F;
+        qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4
+        qs1    |= (qh >>  5) & 0x00001000; // 17 -> 12
+        qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
+        qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
+        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
+        int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_dm[i*MMQ_MMA_TILE_X_K_Q8_1       + kbxd] = bxi->dm;
+#else
+        x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_tile + 2*WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = threadIdx.x / QI8_0;
+    const int kqsx = threadIdx.x % QI8_0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0         + threadIdx.x] = get_int_b2(bxi[0].qs,               kqsx);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + 0         + threadIdx.x] = get_int_b2(bxi[0].qs,               kqsx);
+        x_qs[i*(2*WARP_SIZE + 1)     + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) {
+        int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0             + kbxd] = bxi->d;
+#else
+        x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
+                    (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE],
+                     x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+            }
+        }
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    typedef mma_int_A_I16K8 mma_A;
+    typedef mma_int_B_J8K8  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+    const half2 * y_ds = (const half2 *) y;
+
+    mma_A A[ntx][WARP_SIZE/QI8_0];
+    float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
+
+    const int i0 = (threadIdx.y/ntx)*rows_per_warp;
+
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+            const int k0 = k00 + k01;
+
+            A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
+        }
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
+
+#pragma unroll
+            for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+                const int k0 = k00 + k01;
+
+                dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
+            }
+        }
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+            mma_B  B;
+            float dB[mma_C::ne/2];
+
+            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
+
+                if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
+                    dB[l] =             y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+                } else {
+                    dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+                }
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C;
+                C.mma_K8(A[n][k01/QI8_0], B);
+
+#pragma unroll
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
+                }
+            }
+        }
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
+                    (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+                    x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    typedef mma_int_A_I16K8 mma_A;
+    typedef mma_int_B_J8K8  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_dm = (const half2 *) y;
+
+    mma_A    A[ntx][WARP_SIZE/QI8_1];
+    float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
+
+    const int i0 = (threadIdx.y/ntx)*rows_per_warp;
+
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+            const int k0 = k00 + k01;
+
+            A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
+        }
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
+
+#pragma unroll
+            for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+                const int k0 = k00 + k01;
+
+                dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
+            }
+        }
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+            mma_B    B;
+            float2 dsB[mma_C::ne/2];
+
+            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
+
+                dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C;
+                C.mma_K8(A[n][k01/QI8_1], B);
+
+#pragma unroll
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
+                }
+            }
+        }
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
+                    &x_qs[i*(2*WARP_SIZE + 1) + k0],
+                    &y_qs[j*MMQ_TILE_Y_K + k01],
+                    &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
+                    y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+#ifdef INT8_MMA_AVAILABLE
+
+    typedef mma_int_A_I16K4 mma_A;
+    typedef mma_int_A_I16K8 mma_A_K8;
+    typedef mma_int_B_J8K4  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + WARP_SIZE*2;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+    mma_A   A[ntx][8];
+    float  dA[ntx][mma_C::ne/2][8];
+
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+            const int k0 = k00 + k01;
+
+            ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+        }
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+#pragma unroll
+            for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
+                const int k0 = k00 + k01;
+
+                dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
+            }
+        }
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+            mma_B B[2];
+            float dB[mma_C::ne/2];
+
+            B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
+            B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
+
+                dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C[2];
+                C[0].mma_K4(A[n][k01/4 + 0], B[0]);
+                C[1].mma_K4(A[n][k01/4 + 1], B[1]);
+
+#pragma unroll
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
+                }
+            }
+        }
+    }
+#else
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+    NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % QI2_K;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) {
+        int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride;
+
+        const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
+
+#pragma unroll
+        for (int l = 0; l < QR2_K; ++l) {
+            const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
+
+            const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + k] = x_qs_k;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int sc_m = bxi->scales[kqsx];
+#ifdef FAST_FP16_AVAILABLE
+        const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));
+#else
+        const float2 bxi_dmf = __half22float2(bxi->dm);
+        const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
+#endif // FAST_FP16_AVAILABLE
+
+#ifdef INT8_MMA_AVAILABLE
+        x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
+#else
+        x_dm[i*(WARP_SIZE + 1)       + kqsx] = x_dm_ik;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+    float2 y_df[mmq_x/nwarps];
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+        y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
+    }
+
+#pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                if (k01 < WARP_SIZE/2) {
+                    constexpr int ns = 2;
+                    sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
+                        &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+                        &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+                        &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+                } else {
+                    constexpr int ns = 1;
+                    sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
+                        &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+                        &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+                        &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+                }
+            }
+        }
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+#ifdef INT8_MMA_AVAILABLE
+
+    typedef mma_int_A_I16K4 mma_A;
+    typedef mma_int_A_I16K8 mma_A_K8;
+    typedef mma_int_B_J8K4  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+    mma_A   A[ntx][8];
+    float  dA[ntx][mma_C::ne/2][8];
+    float  mA[ntx][mma_C::ne/2][8];
+
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+            const int k0 = k00 + k01;
+
+            ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
+        }
+    }
+
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+#pragma unroll
+            for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
+                const int k0 = k00 + k01;
+
+                const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
+
+                dA[n][l][k01/(QI8_1/2)] = dm.x;
+                mA[n][l][k01/(QI8_1/2)] = dm.y;
+            }
+        }
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+        float2 dB[mma_C::ne/2];
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int j = j0 + mma_C::get_j(l);
+
+            dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
+        }
+
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+            mma_B B[2];
+
+            B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
+            B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+
+            mma_C Cm[2];
+            if (k01 >= WARP_SIZE * 3/4) {
+                mma_A A1;
+                A1.x[0] = 0x01010101;
+                A1.x[1] = 0x01010101;
+                Cm[0].mma_K4(A1, B[0]);
+                Cm[1].mma_K4(A1, B[1]);
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                mma_C Cd[2];
+
+                Cd[0].mma_K4(A[n][k01/4 + 0], B[0]);
+                Cd[1].mma_K4(A[n][k01/4 + 1], B[1]);
+
+#pragma unroll
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
+                    if (k01 >= WARP_SIZE * 3/4) {
+                        tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
+                    }
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
+                }
+            }
+        }
+
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
+            float2 sB[mma_C::ne/2];
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
+
+                sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
+                }
+            }
+        }
+    }
+#else
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+    NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+    int   * x_sc = (int   *) (x_df + txs.dm);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % QI3_K;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) {
+        int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+
+        const int x_ql_0 = get_int_b2(bxi->qs,    kqsx);
+        const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
+
+#pragma unroll
+        for (int l = 0; l < QR3_K; ++l) {
+            const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
+
+            const int x_ql_k =  (x_ql_0 >> (2*l))       & 0x03030303;
+            const int x_qh_k = ((x_qh_0 >>    l)  << 2) & 0x04040404;
+
+            const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + k] = x_qs_k;
+#endif // INT8_MMA_AVAILABLE
+        }
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
+        int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+
+        const int ksc = threadIdx.x % (WARP_SIZE/8);
+
+        const int ksc_low = ksc % (QI3_K/8);
+        const int shift_low = 4 * (ksc / (QI3_K/8));
+        const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
+
+        const int ksc_high = QI3_K/8;
+        const int shift_high = 2 * ksc;
+        const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
+
+        const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
+
+#ifdef INT8_MMA_AVAILABLE
+        const int8_t * sc8 = (const int8_t *) &sc;
+        const float d = bxi->d;
+
+#pragma unroll
+        for (int l = 0; l < sizeof(int); ++l) {
+            x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
+        }
+#else
+        x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+#ifndef INT8_MMA_AVAILABLE
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
+        int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+
+        x_df[i] = bxi->d;
+    }
+#endif // INT8_MMA_AVAILABLE
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * x_sc = (const int   *) x_df + txs.dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4;
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
+                    &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
+                    x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) {
+    // scale arrangement after the following two lines:
+    //   - ksc == 0: sc0, sc1, sc2, sc3
+    //   - ksc == 1: sc4, sc5, sc6, sc7
+    //   - ksc == 2:  m0,  m1,  m2,  m3
+    //   - ksc == 3:  m4,  m5,  m6,  m7
+    return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits
+           ((scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030);  // upper 2 bits
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+    int   * x_sc = (int   *) (x_dm + txs.dm);
+#endif // INT8_MMA_AVAILABLE
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+        const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
+#else
+        x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+#ifdef INT8_MMA_AVAILABLE
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
+        int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+
+        const int * scales = (const int *) bxi->scales;
+        const int ksc = threadIdx.x % (WARP_SIZE/16);
+
+        const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
+        const int  m32 = unpack_scales_q45_K(scales, ksc + 2);
+
+        const uint8_t * sc8 = (const uint8_t *) &sc32;
+        const uint8_t *  m8 = (const uint8_t *)  &m32;
+
+        const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
+
+#pragma unroll
+        for (int l = 0; l < sizeof(int); ++l) {
+            x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+        }
+    }
+
+#else
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) {
+        int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+
+        x_dm[i] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
+
+        const int * scales = (const int *) bxi->scales;
+
+        const int ksc = threadIdx.x % (WARP_SIZE/8);
+        const int scales8 = unpack_scales_q45_K(scales, ksc);
+
+        x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
+    }
+#endif // INT8_MMA_AVAILABLE
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const int   * x_sc = (const int   *) x_dm + txs.dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16);
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
+                    &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+                    x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+    int   * x_sc = (int   *) (x_dm + txs.dm);
+#endif // INT8_MMA_AVAILABLE
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+        const int ky = QR5_K*threadIdx.x;
+
+        const int ql = get_int_b4(bxi->qs, threadIdx.x);
+        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+        const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4));
+        const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010;
+        const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010;
+
+        const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
+        const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + kq0] = ql0 | qh0;
+        x_qs[i*(2*WARP_SIZE + 1)     + kq1] = ql1 | qh1;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+#ifdef INT8_MMA_AVAILABLE
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
+        int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+
+        const int * scales = (const int *) bxi->scales;
+        const int ksc = threadIdx.x % (WARP_SIZE/16);
+
+        const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
+        const int  m32 = unpack_scales_q45_K(scales, ksc + 2);
+
+        const uint8_t * sc8 = (const uint8_t *) &sc32;
+        const uint8_t *  m8 = (const uint8_t *)  &m32;
+
+        const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
+
+#pragma unroll
+        for (int l = 0; l < sizeof(int); ++l) {
+            x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+        }
+    }
+
+#else
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) {
+        int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+
+        x_dm[i] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
+        int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+
+        const int * scales = (const int *) bxi->scales;
+
+        const int ksc = threadIdx.x % (WARP_SIZE/8);
+        const int scales8 = unpack_scales_q45_K(scales, ksc);
+
+        x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
+    }
+#endif // INT8_MMA_AVAILABLE
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const int   * x_sc = (const int   *) x_dm + txs.dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
+                    &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+                    x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+    int   * x_sc = (int   *) (x_df + WARP_SIZE/QI6_K);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+    int   * x_sc = (int   *) (x_df + txs.dm);
+#endif // INT8_MMA_AVAILABLE
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
+
+        const int ql = get_int_b2(bxi->ql, threadIdx.x);
+        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+        const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4));
+        const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030;
+        const int qh1 =  (qh >> ((threadIdx.x & 0x08) >> 2))       & 0x30303030;
+
+        const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
+        const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+        x_qs[i*(2*WARP_SIZE + 1)     + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI6_K;  // == 1 if QK_K == 256
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
+        int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q6_K       + kbxd] = bxi->d;
+#else
+        x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
+#else
+        x_sc[i*(WARP_SIZE/8) + i/8   + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * x_sc = (const int   *) x_df + txs.dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]);
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
+                    &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
+                    x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+#ifdef INT8_MMA_AVAILABLE
+
+    typedef mma_int_A_I16K4 mma_A;
+    typedef mma_int_B_J8K4  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + WARP_SIZE*2;
+    const int   * x_sc = (const int   *) x_df + WARP_SIZE/QI6_K;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+    mma_A   A[ntx][8];
+    int   scA[ntx][mma_C::ne/2][8];
+    float  dA[ntx][mma_C::ne/2];
+
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+            const int k0 = k00 + k01;
+
+            A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0),        MMQ_MMA_TILE_X_K_Q6_K);
+            A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
+        }
+
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
+            const int k0 = k00 + k01;
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+                const int      sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
+                const int8_t * sc        = (const int8_t *) &sc_packed;
+
+#pragma unroll
+                for (int ksc = 0; ksc < sizeof(int); ++ksc) {
+                    scA[n][l][k01/4 + ksc] = sc[ksc];
+                }
+            }
+        }
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
+        }
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+        float tmp[ntx][mma_C::ne] = {{0.0f}};
+
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+            mma_B B[2];
+            float dB[mma_C::ne/2];
+
+            B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0        + k01, MMQ_TILE_Y_K);
+            B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
+
+                dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C[2];
+                C[0].mma_K4(A[n][k01/4 + 0], B[0]);
+                C[1].mma_K4(A[n][k01/4 + 1], B[1]);
+
+#pragma unroll
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
+                }
+            }
+        }
+
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
+            }
+        }
+    }
+#else
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+    NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = threadIdx.x / QI4_NL;
+    const int kqsx = threadIdx.x % QI4_NL;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
+
+        const int aux_q4 = get_int_b2(bxi->qs, kqsx);
+        const int2 v = get_int_from_table_16(aux_q4);
+        const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + k0 + 0] = v.x;
+        x_qs[i*(2*WARP_SIZE + 1)     + k0 + 4] = v.y;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
+        int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
+#else
+        x_df[i*(WARP_SIZE/4) + i/4   + kbxd] = __half2float(bxi->d);
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % (QI2_XXS/2);
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) {
+        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride;
+
+        const int q2 = get_int_b2(bxi->qs, 2*kqsx+0);
+        const uint8_t * aux8 = (const uint8_t *) &q2;
+        const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1);
+
+#pragma unroll
+        for (int l = 0; l < QR2_XXS; ++l) {
+            const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
+            const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
+
+            const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
+            const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
+
+            const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
+            const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid0;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid1;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int ls = aux32 >> 28;
+        const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
+#else
+        x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = (ls*d + d/2)/4;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % (QI2_XS/2);
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) {
+        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride;
+
+        const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+        const uint16_t * q2 = (const uint16_t *) &q2_packed;
+
+    #pragma unroll
+        for (int l = 0; l < QR2_XS; ++l) {
+            const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
+            const uint32_t * signs    = (const uint32_t *)(ksigns64   + (q2[l] >> 9));
+
+            const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
+            const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int ls = bxi->scales[kqsx];
+        const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
+#else
+        x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
+        x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % (QI2_S/2);
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) {
+        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride;
+
+        const int       qs_packed = get_int_b2(bxi->qs, kqsx);
+        const uint8_t * qs        = (const uint8_t *) &qs_packed;
+
+        const int qh = bxi->qh[kqsx];
+
+        const int       signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx);
+        const uint8_t * signs_packed_8  = (const uint8_t *) &signs_packed_32;
+
+#pragma unroll
+        for (int l = 0; l < QR2_S; ++l) {
+            const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300)));
+
+            const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
+            const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
+
+            const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
+            const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int ls = bxi->scales[kqsx];
+        const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
+#else
+        x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
+        x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % (QI3_XXS/2);
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) {
+        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride;
+
+        const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+        const uint8_t * q3 = (const uint8_t *) &q3_packed;
+        const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx);
+
+#pragma unroll
+        for (int l = 0; l < QR3_XXS; ++l) {
+            const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
+
+            const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
+
+            const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
+            const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int ls = aux32 >> 28;
+        const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
+#else
+        x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = (ls*d + d/2)/2;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % (QI3_S/2);
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) {
+        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride;
+
+        const int2      qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+        const uint8_t * qs        = (const uint8_t *) &qs_packed;
+
+        const int qh = bxi->qh[kqsx];
+
+        const int       signs_packed_32 = get_int_b2(bxi->signs, kqsx);
+        const uint8_t * signs_packed_8  = (const uint8_t *) &signs_packed_32;
+
+#pragma unroll
+        for (int l = 0; l < QR3_S; ++l) {
+            const int2 grid_pos = make_int2(
+                iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)],
+                iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]);
+
+            const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
+            const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
+
+            const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
+            const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+0)] = grid_l;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
+        const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
+#else
+        x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = ls*d;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_ds = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % QI1_S;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) {
+        int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride;
+
+        const int       qs_packed = get_int_b2(bxi->qs, kqsx);
+        const uint8_t * qs        = (const uint8_t *) &qs_packed;
+
+        const int qh = bxi->qh[kqsx];
+
+    #pragma unroll
+        for (int l = 0; l < QR1_S/2; ++l) {
+            const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)];
+
+            const int grid0 = (grid >> 0) & 0x0F0F0F0F;
+            const int grid1 = (grid >> 4) & 0x0F0F0F0F;
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+0)] = grid0;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+1)] = grid1;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const float  d1q   = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
+        const float  delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
+
+#ifdef INT8_MMA_AVAILABLE
+        x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
+#else
+        x_ds[i*(WARP_SIZE/4) + i/4   + kqsx] = make_half2(d1q, d1q*delta);
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = 0;           // threadIdx.x / QI4_XS
+    const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx;
+
+        const int aux_q4 = get_int_b4(bxi->qs, kqsx);
+        const int2 v = get_int_from_table_16(aux_q4);
+        const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + k0 + 0] = v.x;
+        x_qs[i*(2*WARP_SIZE + 1)     + k0 + 4] = v.y;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+        int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
+
+        const float d = __half2float(bxi->d);
+
+        const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
+            | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
+
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
+#else
+        x_df[i*(WARP_SIZE/4) + i/4   + threadIdx.x % 8] = d * (ls - 32);
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template<int mmq_x, int mmq_y, int nwarps, bool need_check>
+static __device__ __forceinline__ void mmq_write_back_dp4a(
+    const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+        if (j > j_max) {
+            return;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            if (need_check && i > i_max) {
+                continue;
+            }
+
+            dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+        }
+    }
+}
+
+template<int mmq_x, int mmq_y, int nwarps, bool need_check>
+static __device__ __forceinline__ void mmq_write_back_mma(
+    const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
+
+    typedef mma_int_C_I16J8 mma_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
+#ifdef INT8_MMA_AVAILABLE
+    static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
+#endif // INT8_MMA_AVAILABLE
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+            for (int l = 0; l < mma_C::ne; ++l) {
+                const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
+
+                if (j > j_max) {
+                    continue;
+                }
+
+                const int i = i0 + n*mma_C::I + mma_C::get_i(l);
+
+                if (need_check && i > i_max) {
+                    continue;
+                }
+
+                dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
+            }
+        }
+    }
+}
+
+// -------------------------------------------------------------------------------------------------------------------------------------
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
+struct mmq_type_traits;
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
+    static constexpr int              vdr          = VDR_Q4_0_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_0<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_DS4>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
+    static constexpr int              vdr          = VDR_Q4_1_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_1<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
+    static constexpr int              vdr          = VDR_Q5_0_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_0<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
+    static constexpr int              vdr          = VDR_Q5_1_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_1<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
+    static constexpr int              vdr          = VDR_Q8_0_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q8_0<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
+    static constexpr int              vdr          = VDR_Q2_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q2_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
+    static constexpr int              vdr          = VDR_Q3_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q3_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
+    static constexpr int              vdr          = VDR_Q4_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
+    static constexpr int              vdr          = VDR_Q5_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
+    static constexpr int              vdr          = VDR_Q6_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q6_K<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
+    static constexpr int              vdr          = VDR_IQ2_XXS_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_xxs<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XS> {
+    static constexpr int              vdr          = VDR_IQ2_XS_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_xs<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_S> {
+    static constexpr int              vdr          = VDR_IQ2_S_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_s<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_XXS> {
+    static constexpr int              vdr          = VDR_IQ3_XXS_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq3_xxs<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_S> {
+    static constexpr int              vdr          = VDR_IQ3_S_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq3_s<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S> {
+    static constexpr int              vdr          = VDR_IQ1_S_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq1_s<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
+    static constexpr int              vdr          = VDR_IQ4_NL_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
+    static constexpr int              vdr          = VDR_IQ4_XS_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
+static __device__ void mul_mat_q_process_tile(
+    const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
+    const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0,
+    const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
+
+    constexpr int              qk         = ggml_cuda_type_traits<type>::qk;
+    constexpr int              mmq_y      = get_mmq_y_device();
+    constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
+
+    extern __shared__ char data_mul_mat_q[];
+    int * tile_y = (int *) data_mul_mat_q;
+    int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
+
+#ifdef INT8_MMA_AVAILABLE
+    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
+    constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
+    constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
+
+    constexpr int blocks_per_iter = MMQ_ITER_K / qk;
+
+    float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
+
+    const int tile_x_max_i = ne01 - it*mmq_y - 1;
+    const int tile_y_max_j = ne11 - jt*mmq_x - 1;
+
+    const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
+
+    for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
+        load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
+
+        {
+            const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
+#pragma unroll
+            for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
+                int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+
+                tile_y[l] = by0[l];
+            }
+        }
+
+        __syncthreads();
+
+        vec_dot(tile_x, tile_y, sum, 0);
+
+        __syncthreads();
+
+        {
+            const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
+#pragma unroll
+            for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
+                int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+
+                tile_y[l] = by0[l];
+            }
+        }
+
+        __syncthreads();
+
+        vec_dot(tile_x, tile_y, sum, WARP_SIZE);
+
+        __syncthreads();
+    }
+
+    if (fixup) {
+        write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
+    } else {
+        write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
+    }
+}
+
+
+// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
+
+template <ggml_type type, int mmq_x, int nwarps, bool need_check>
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+    __launch_bounds__(WARP_SIZE*nwarps, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#else
+#if __CUDA_ARCH__ >= CC_VOLTA
+    __launch_bounds__(WARP_SIZE*nwarps, 1)
+#else
+    __launch_bounds__(WARP_SIZE*nwarps, 2)
+#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+static __global__ void mul_mat_q(
+    const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
+    const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
+
+    // Skip unused template specializations for faster compilation:
+    if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    constexpr int qk    = ggml_cuda_type_traits<type>::qk;
+    constexpr int mmq_y = get_mmq_y_device();
+
+    // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
+#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
+    {
+        constexpr bool fixup = false;
+        mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+            (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+                blockIdx.x, blockIdx.y, 0, ne00/qk);
+        return;
+    }
+#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
+
+    const     int64_t blocks_per_ne00 = ne00 / qk;
+    constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
+
+    const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
+    const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
+
+    // kbc == k block continuous, current index in continuous ijk space.
+    int64_t kbc      = (int64_t) blockIdx.x     *blocks_per_ne00*ntx*nty / gridDim.x;
+    int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x;
+
+    kbc      -= (kbc      % blocks_per_ne00) % blocks_per_iter;
+    kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
+
+    // kb0 == k index when doing the matrix multiplication for an output tile.
+    int kb0_start = kbc % blocks_per_ne00;
+    int kb0_stop  = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
+    while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
+        const int jt =  kbc /    (blocks_per_ne00*nty);                    // j index of current tile.
+        const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile.
+
+        constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
+        mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+            (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+             it, jt, kb0_start, kb0_stop);
+
+        kbc += blocks_per_ne00;
+        kbc -= kbc % blocks_per_ne00;
+
+        kb0_start = 0;
+        kb0_stop  = min(blocks_per_ne00, kbc_stop - kbc);
+    }
+
+    if (kbc >= kbc_stop) {
+        return;
+    }
+
+    const int jt =  kbc /    (blocks_per_ne00*nty);
+    const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
+
+    constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
+    mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+        (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+            it, jt, kb0_start, kb0_stop);
+}
+
+
+template <ggml_type type, int mmq_x, int nwarps, bool need_check>
+static __global__ void mul_mat_q_stream_k_fixup(
+    float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) {
+
+    constexpr int     mmq_y           = get_mmq_y_device();
+    constexpr int     qk              = ggml_cuda_type_traits<type>::qk;
+    constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
+    const     int64_t blocks_per_ne00 = ne00 / qk;
+
+    float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
+
+    const int ntx = (ne11 + mmq_x - 1) / mmq_x;
+    const int nty = (ne01 + mmq_y - 1) / mmq_y;
+
+    bool any_fixup = false;
+
+    const int bidx_start = ((blockIdx.y*nty + blockIdx.x)     * block_num_mmq)                           / (gridDim.y*gridDim.x);
+    const int bidx_stop  = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x);
+
+    int64_t kbc_0;
+    int64_t kbc_stop_0 = (int64_t) bidx_start*blocks_per_ne00*ntx*nty / block_num_mmq;
+
+    for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
+        kbc_0 = kbc_stop_0;
+        kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
+
+        const int64_t kbc      = kbc_0      - (kbc_0      % blocks_per_ne00) % blocks_per_iter;
+        const int64_t kbc_stop = kbc_stop_0 - (kbc_stop_0 % blocks_per_ne00) % blocks_per_iter;
+
+        // Skip fixup tile if the MMQ CUDA block never wrote anything to it:
+        if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
+            continue;
+        }
+
+        const int jt =  kbc_stop /    (blocks_per_ne00*nty);
+        const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
+
+        // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
+        if (it != blockIdx.x || jt != blockIdx.y) {
+            continue;
+        }
+
+        any_fixup = true;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
+            }
+        }
+    }
+
+    if (!any_fixup) {
+        return;
+    }
+
+    dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y;
+
+    const int i_max = ne01 - blockIdx.x*mmq_y - 1;
+    const int j_max = ne11 - blockIdx.y*mmq_x - 1;
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+        if (j > j_max) {
+            return;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            if (need_check && i > i_max) {
+                continue;
+            }
+
+            dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+        }
+    }
+}
+
+struct mmq_args {
+    const char * x; const char * y; float * dst;
+    int64_t ne00; int64_t ne01; int64_t stride01;
+    int64_t ne10; int64_t ne11; int64_t stride11;
+    int64_t ne0;
+};
+
+template<ggml_type type>
+static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
+    const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
+    const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
+    const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+    const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
+    return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
+}
+
+template <ggml_type type, int mmq_x>
+static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
+    const int id = ggml_cuda_get_device();
+    const int cc = ggml_cuda_info().devices[id].cc;
+    const int nsm = ggml_cuda_info().devices[id].nsm;
+    const int mmq_y = get_mmq_y_host(cc);
+
+    const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
+
+    const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
+
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+    static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+    if (!shmem_limit_raised[id]) {
+        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
+        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>,  cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
+        shmem_limit_raised[id] = true;
+    }
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+
+    const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
+    const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
+    const dim3 block_nums_xy_tiling(nty, ntx, 1);
+
+    const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
+    if (!use_stream_k) {
+        if (args.ne01 % mmq_y == 0) {
+            constexpr bool need_check = false;
+            mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
+                (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+        } else {
+            constexpr bool need_check = true;
+            mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
+                (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+        }
+        return;
+    }
+
+    const dim3 block_nums_mmq(nsm, 1, 1);
+
+    ggml_cuda_pool & pool = ctx.pool(id);
+    ggml_cuda_pool_alloc<float> tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y);
+
+    if (args.ne01 % mmq_y == 0) {
+        constexpr bool need_check = false;
+
+        mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
+            (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+
+        mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
+            (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
+    } else {
+        constexpr bool need_check = true;
+
+        mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
+            (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+
+        mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
+            (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
+    }
+}
+
+template <ggml_type type>
+void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
+    const int id    = ggml_cuda_get_device();
+    const int nsm   = ggml_cuda_info().devices[id].nsm;
+    const int cc    = ggml_cuda_info().devices[id].cc;
+    const int smpbo = ggml_cuda_info().devices[id].smpbo;
+
+    const int mmq_x_max = get_mmq_x_max_host(cc);
+    const int mmq_y = get_mmq_y_host(cc);
+    const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
+    const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
+
+    int mmq_x_best  = 0;
+    int nparts_best = INT_MAX;
+
+    for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
+        const int granularity = mmq_get_granularity_host(mmq_x, cc);
+
+        if (mmq_x % granularity != 0 || mmq_get_shmem<type>(mmq_x, mmq_y, cc) > smpbo) {
+            continue;
+        }
+
+        const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
+        const int nwaves_xy_tiling = ntiles_x*block_num_y;
+        const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
+
+        if (nparts < nparts_best) {
+            mmq_x_best  = mmq_x;
+            nparts_best = nparts;
+        }
+    }
+
+    switch (mmq_x_best) {
+        case   8:
+            launch_mul_mat_q<type,   8>(ctx, args, stream);
+            break;
+        case  16:
+            launch_mul_mat_q<type,  16>(ctx, args, stream);
+            break;
+        case  24:
+            launch_mul_mat_q<type,  24>(ctx, args, stream);
+            break;
+        case  32:
+            launch_mul_mat_q<type,  32>(ctx, args, stream);
+            break;
+        case  40:
+            launch_mul_mat_q<type,  40>(ctx, args, stream);
+            break;
+        case  48:
+            launch_mul_mat_q<type,  48>(ctx, args, stream);
+            break;
+        case  56:
+            launch_mul_mat_q<type,  56>(ctx, args, stream);
+            break;
+        case  64:
+            launch_mul_mat_q<type,  64>(ctx, args, stream);
+            break;
+        case  72:
+            launch_mul_mat_q<type,  72>(ctx, args, stream);
+            break;
+        case  80:
+            launch_mul_mat_q<type,  80>(ctx, args, stream);
+            break;
+        case  88:
+            launch_mul_mat_q<type,  88>(ctx, args, stream);
+            break;
+        case  96:
+            launch_mul_mat_q<type,  96>(ctx, args, stream);
+            break;
+        case 104:
+            launch_mul_mat_q<type, 104>(ctx, args, stream);
+            break;
+        case 112:
+            launch_mul_mat_q<type, 112>(ctx, args, stream);
+            break;
+        case 120:
+            launch_mul_mat_q<type, 120>(ctx, args, stream);
+            break;
+        case 128:
+            launch_mul_mat_q<type, 128>(ctx, args, stream);
+            break;
+        default:
+            fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
+#define DECL_MMQ_CASE(type)                                                        \
+    template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
+
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
+extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
+
+// -------------------------------------------------------------------------------------------------------------------------
+
+void ggml_cuda_op_mul_mat_q(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream);
+
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);

+ 451 - 0
llama/ggml-cuda/mmvq.cu

@@ -0,0 +1,451 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "mmvq.cuh"
+#include "vecdotq.cuh"
+
+typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
+
+static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
+    return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
+        type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
+        type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
+        type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
+        type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
+        type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
+        type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
+        type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
+        type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
+        type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
+        type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
+        type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
+        type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
+        type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
+        type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
+        type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
+        type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
+        type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
+        type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
+        nullptr;
+}
+
+static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
+    return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
+        type == GGML_TYPE_Q4_1    ? VDR_Q4_1_Q8_1_MMVQ :
+        type == GGML_TYPE_Q5_0    ? VDR_Q5_0_Q8_1_MMVQ :
+        type == GGML_TYPE_Q5_1    ? VDR_Q5_1_Q8_1_MMVQ :
+        type == GGML_TYPE_Q8_0    ? VDR_Q8_0_Q8_1_MMVQ :
+        type == GGML_TYPE_Q2_K    ? VDR_Q2_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q3_K    ? VDR_Q3_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q4_K    ? VDR_Q4_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q5_K    ? VDR_Q5_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q6_K    ? VDR_Q6_K_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ2_XS  ? VDR_IQ2_XS_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ2_S   ? VDR_IQ2_S_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ3_S   ? VDR_IQ3_S_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ4_NL  ? VDR_IQ4_NL_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ4_XS  ? VDR_IQ4_XS_Q8_1_MMVQ :
+        1;
+}
+
+template <ggml_type type, int ncols_y>
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+// tell the compiler to use as many registers as it wants, see nwarps definition below
+__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void mul_mat_vec_q(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+    constexpr int qk  = ggml_cuda_type_traits<type>::qk;
+    constexpr int qi  = ggml_cuda_type_traits<type>::qi;
+    constexpr int vdr = get_vdr_mmvq(type);
+
+    constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
+
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
+    constexpr int nwarps              = 1;
+    constexpr int rows_per_cuda_block = 1;
+#else
+    constexpr int nwarps              = ncols_y <= 4 ? 4 : 2;
+    constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
+
+    const     int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+    const     int row0 = rows_per_cuda_block*blockIdx.x;
+    const     int blocks_per_row_x = ncols_x / qk;
+    const     int blocks_per_col_y = nrows_y / QK8_1;
+    constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
+
+// partial sum for each thread
+    float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
+
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
+        const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
+
+        // x block quant index when casting the quants to int
+        const int kqs = vdr * (tid % (qi/vdr));
+
+#pragma unroll
+        for (int j = 0; j < ncols_y; ++j) {
+#pragma unroll
+            for (int i = 0; i < rows_per_cuda_block; ++i) {
+                tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
+            }
+        }
+    }
+
+    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
+    if (threadIdx.y > 0) {
+#pragma unroll
+        for (int j = 0; j < ncols_y; ++j) {
+#pragma unroll
+            for (int i = 0; i < rows_per_cuda_block; ++i) {
+                tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
+            }
+        }
+    }
+    __syncthreads();
+    if (threadIdx.y > 0) {
+        return;
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int j = 0; j < ncols_y; ++j) {
+#pragma unroll
+        for (int i = 0; i < rows_per_cuda_block; ++i) {
+#pragma unroll
+            for (int l = 0; l < nwarps-1; ++l) {
+                tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
+            }
+            tmp[j][i] = warp_reduce_sum(tmp[j][i]);
+        }
+
+        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
+            dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
+        }
+    }
+}
+
+template <ggml_type type>
+static void mul_mat_vec_q_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
+    GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
+
+    int id = ggml_cuda_get_device();
+
+    int64_t nwarps = 1;
+    int64_t rows_per_cuda_block = 1;
+
+    if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
+        switch(ncols_y) {
+            case 1:
+                nwarps = 4;
+                rows_per_cuda_block = 1;
+                break;
+            case 2:
+            case 3:
+            case 4:
+                nwarps = 4;
+                rows_per_cuda_block = 2;
+                break;
+            case 5:
+            case 6:
+            case 7:
+            case 8:
+                nwarps = 2;
+                rows_per_cuda_block = 2;
+                break;
+            default:
+                GGML_ABORT("fatal error");
+                break;
+        }
+    }
+    const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
+    const dim3 block_nums(nblocks, 1, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    switch (ncols_y) {
+        case 1:
+            mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 2:
+            mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 3:
+            mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 4:
+            mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 5:
+            mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 6:
+            mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 7:
+            mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 8:
+            mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
+static void mul_mat_vec_q4_0_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q4_1_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q5_0_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q5_1_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q8_0_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q2_K_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q3_K_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q4_K_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q5_K_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q6_K_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq2_xxs_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq2_xs_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq2_s_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq3_xxs_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq1_s_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq1_m_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq4_nl_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq4_xs_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq3_s_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+void ggml_cuda_op_mul_mat_vec_q(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t row_diff = row_high - row_low;
+
+    const int64_t ne10 = src1->ne[0];
+    GGML_ASSERT(ne10 % QK8_1 == 0);
+
+    const int64_t ne0 = dst->ne[0];
+
+    int id = ggml_cuda_get_device();
+
+    // the main device has a larger memory buffer to hold the results from all GPUs
+    // nrows_dst == nrows of the matrix that the kernel writes into
+    const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q2_K:
+            mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q3_K:
+            mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q4_K:
+            mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q5_K:
+            mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q6_K:
+            mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ2_XXS:
+            mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ2_XS:
+            mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ2_S:
+            mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ3_XXS:
+            mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ1_S:
+            mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ1_M:
+            mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ4_NL:
+            mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ4_XS:
+            mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ3_S:
+            mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+
+    GGML_UNUSED(src1);
+    GGML_UNUSED(dst);
+    GGML_UNUSED(src1_ddf_i);
+    GGML_UNUSED(src1_ncols);
+    GGML_UNUSED(src1_padded_row_size);
+}

+ 35 - 0
llama/ggml-cuda/mmvq.cuh

@@ -0,0 +1,35 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
+
+void ggml_cuda_op_mul_mat_vec_q(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream);

+ 250 - 0
llama/ggml-cuda/norm.cu

@@ -0,0 +1,250 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "norm.cuh"
+
+template <int block_size>
+static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    const int tid = threadIdx.x;
+
+    float2 mean_var = make_float2(0.f, 0.f);
+
+    for (int col = tid; col < ncols; col += block_size) {
+        const float xi = x[row*ncols + col];
+        mean_var.x += xi;
+        mean_var.y += xi * xi;
+    }
+
+    // sum up partial sums
+    mean_var = warp_reduce_sum(mean_var);
+    if (block_size > WARP_SIZE) {
+        __shared__ float2 s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = mean_var;
+        }
+        __syncthreads();
+        mean_var = s_sum[lane_id];
+        mean_var = warp_reduce_sum(mean_var);
+    }
+
+    const float mean = mean_var.x / ncols;
+    const float var = mean_var.y / ncols - mean * mean;
+    const float inv_std = rsqrtf(var + eps);
+
+    for (int col = tid; col < ncols; col += block_size) {
+        dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
+    }
+}
+
+template <int block_size>
+static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
+    // blockIdx.x: num_groups idx
+    // threadIdx.x: block_size idx
+    int start = blockIdx.x * group_size;
+    int end = start + group_size;
+
+    start += threadIdx.x;
+
+    if (end >= ne_elements) {
+        end = ne_elements;
+    }
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int j = start; j < end; j += block_size) {
+        tmp += x[j];
+    }
+
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        __shared__ float s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        __syncthreads();
+        tmp = s_sum[lane_id];
+        tmp = warp_reduce_sum(tmp);
+    }
+
+    float mean = tmp / group_size;
+    tmp = 0.0f;
+
+    for (int j = start; j < end; j += block_size) {
+        float xi = x[j] - mean;
+        dst[j] = xi;
+        tmp += xi * xi;
+    }
+
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        __shared__ float s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        __syncthreads();
+        tmp = s_sum[lane_id];
+        tmp = warp_reduce_sum(tmp);
+    }
+
+    float variance = tmp / group_size;
+    float scale = rsqrtf(variance + eps);
+    for (int j = start; j < end; j += block_size) {
+        dst[j] *= scale;
+    }
+}
+
+template <int block_size>
+static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    const int tid = threadIdx.x;
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int col = tid; col < ncols; col += block_size) {
+        const float xi = x[row*ncols + col];
+        tmp += xi * xi;
+    }
+
+    // sum up partial sums
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        __shared__ float s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        __syncthreads();
+        tmp = s_sum[lane_id];
+        tmp = warp_reduce_sum(tmp);
+    }
+
+    const float mean = tmp / ncols;
+    const float scale = rsqrtf(mean + eps);
+
+    for (int col = tid; col < ncols; col += block_size) {
+        dst[row*ncols + col] = scale * x[row*ncols + col];
+    }
+}
+
+static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+    GGML_ASSERT(ncols % WARP_SIZE == 0);
+    if (ncols < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+    }
+}
+
+static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
+    if (group_size < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
+    }
+}
+
+static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+    GGML_ASSERT(ncols % WARP_SIZE == 0);
+    if (ncols < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+    }
+}
+
+void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
+}
+
+void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    int num_groups = dst->op_params[0];
+
+    float eps;
+    memcpy(&eps, dst->op_params + 1, sizeof(float));
+
+    int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
+    group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
+}

+ 33 - 0
llama/ggml-cuda/norm.cuh

@@ -0,0 +1,33 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 75 - 0
llama/ggml-cuda/pad.cu

@@ -0,0 +1,75 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "pad.cuh"
+
+static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
+    // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
+    // blockIdx.y: idx of ne1
+    // blockIDx.x: idx of ne0 / BLOCK_SIZE
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    // operation
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+    if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
+        int offset_src =
+            nidx +
+            blockIdx.y * ne00 +
+            blockIdx.z * ne00 * ne01;
+        dst[offset_dst] = x[offset_src];
+    } else {
+        dst[offset_dst] = 0.0f;
+    }
+}
+
+static void pad_f32_cuda(const float * x, float * dst,
+    const int ne00, const int ne01, const int ne02, const int ne03,
+    const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, ne1, ne2*ne3);
+    pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
+}
+
+void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+    pad_f32_cuda(src0_d, dst_d,
+        src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+        dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
+}

+ 31 - 0
llama/ggml-cuda/pad.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_PAD_BLOCK_SIZE 256
+
+void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 120 - 0
llama/ggml-cuda/pool2d.cu

@@ -0,0 +1,120 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "pool2d.cuh"
+
+template <typename Ti, typename To>
+static  __global__ void pool2d_nchw_kernel(
+        const int ih, const int iw, const int oh, const int ow,
+        const int kh, const int kw, const int sh, const int sw,
+        const int ph, const int pw, const int parallel_elements,
+        const Ti* src, To* dst, const enum ggml_op_pool op) {
+    int idx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (idx >= parallel_elements) {
+        return;
+    }
+
+    const int I_HW = ih * iw;
+    const int O_HW = oh * ow;
+    const int nc = idx / O_HW;
+    const int cur_oh = idx % O_HW / ow;
+    const int cur_ow = idx % O_HW % ow;
+    const Ti* i_ptr = src + nc * I_HW;
+    To* o_ptr = dst + nc * O_HW;
+    const int start_h = cur_oh * sh - ph;
+    const int bh = max(0, start_h);
+    const int eh = min(ih, start_h + kh);
+    const int start_w = cur_ow * sw - pw;
+    const int bw = max(0, start_w);
+    const int ew = min(iw, start_w + kw);
+    const To scale = 1. / (kh * kw);
+    To res = 0;
+
+    switch (op) {
+        case GGML_OP_POOL_AVG: res = 0; break;
+        case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
+        default: assert(false);
+    }
+
+    for (int i = bh; i < eh; i += 1) {
+        for (int j = bw; j < ew; j += 1) {
+#if __CUDA_ARCH__ >= 350
+            Ti cur = __ldg(i_ptr + i * iw + j);
+#else
+            Ti cur = i_ptr[i * iw + j];
+#endif
+            switch (op) {
+                case GGML_OP_POOL_AVG: res += cur * scale; break;
+                case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
+                default: assert(false);
+            }
+        }
+    }
+    o_ptr[cur_oh * ow + cur_ow] = res;
+}
+
+static void pool2d_nchw_kernel_f32_f32_cuda(
+        const int ih, const int iw, const int oh, const int ow,
+        const int kh, const int kw, const int sh, const int sw,
+        const int ph, const int pw, const int parallel_elements,
+        const float * src, float * dst, const enum ggml_op_pool op,
+        cudaStream_t stream) {
+
+    const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
+    dim3 block_nums(num_blocks);
+    pool2d_nchw_kernel<<<block_nums, CUDA_POOL2D_BLOCK_SIZE, 0, stream>>>(ih, iw, oh, ow, kh, kw, sh, sw, ph, pw, parallel_elements, src, dst, op);
+}
+
+void ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+    enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
+    const int k0 = opts[1];
+    const int k1 = opts[2];
+    const int s0 = opts[3];
+    const int s1 = opts[4];
+    const int p0 = opts[5];
+    const int p1 = opts[6];
+
+    const int64_t IH = src0->ne[1];
+    const int64_t IW = src0->ne[0];
+
+    const int64_t N = dst->ne[3];
+    const int64_t OC = dst->ne[2];
+    const int64_t OH = dst->ne[1];
+    const int64_t OW = dst->ne[0];
+
+    const int parallel_elements = N * OC * OH * OW;
+
+    pool2d_nchw_kernel_f32_f32_cuda(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_d, dst_d, op, stream);
+}

+ 31 - 0
llama/ggml-cuda/pool2d.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_POOL2D_BLOCK_SIZE 256
+
+void ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 195 - 0
llama/ggml-cuda/quantize.cu

@@ -0,0 +1,195 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "quantize.cuh"
+#include <cstdint>
+
+static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
+    const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (ix0 >= kx0_padded) {
+        return;
+    }
+
+    const int64_t ix1 = blockIdx.y;
+
+    const int64_t i_padded = ix1*kx0_padded + ix0;
+
+    block_q8_1 * y = (block_q8_1 *) vy;
+
+    const int64_t ib = i_padded / QK8_1; // block index
+    const int64_t iqs = i_padded % QK8_1; // quant index
+
+    const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
+    float amax = fabsf(xi);
+    float sum = xi;
+
+    amax = warp_reduce_max(amax);
+    sum = warp_reduce_sum(sum);
+
+    const float d = amax / 127;
+    const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
+
+    y[ib].qs[iqs] = q;
+
+    if (iqs > 0) {
+        return;
+    }
+
+    reinterpret_cast<half&>(y[ib].ds.x) = d;
+    reinterpret_cast<half&>(y[ib].ds.y) = sum;
+}
+
+template <mmq_q8_1_ds_layout ds_layout>
+static __global__ void quantize_mmq_q8_1(
+    const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
+
+    constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
+    constexpr int vals_per_sum   = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
+
+    const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
+
+    if (ix0 >= kx0_padded) {
+        return;
+    }
+
+    const float4 * x4 = (const float4 *) x;
+
+    const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
+
+    block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
+
+    const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
+    const int64_t ib  = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y;                   // block index in channel
+    const int64_t iqs = ix0 % (4*QK8_1);                                            // quant index in block
+
+    // Load 4 floats per thread and calculate max. abs. value between them:
+    const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
+    float amax = fabsf(xi.x);
+    amax = fmaxf(amax, fabsf(xi.y));
+    amax = fmaxf(amax, fabsf(xi.z));
+    amax = fmaxf(amax, fabsf(xi.w));
+
+    // Exchange max. abs. value between vals_per_scale/4 threads.
+#pragma unroll
+    for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
+        amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
+    }
+
+    float sum;
+    if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
+        sum = xi.x + xi.y + xi.z + xi.w;
+
+        // Exchange calculate sum across vals_per_sum/4 threads.
+#pragma unroll
+        for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
+            sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
+        }
+    }
+
+    const float d_inv = 127.0f / amax;
+    char4 q;
+    q.x = roundf(xi.x*d_inv);
+    q.y = roundf(xi.y*d_inv);
+    q.z = roundf(xi.z*d_inv);
+    q.w = roundf(xi.w*d_inv);
+
+    // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
+    char4 * yqs4 = (char4 *) y[ib].qs;
+    yqs4[iqs/4] = q;
+
+    if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
+        if (iqs % 16 != 0 || iqs >= 96) {
+            return;
+        }
+
+        y[ib].d2s6[2 + iqs/16] = sum;
+
+        if (iqs % 64 != 0) {
+            return;
+        }
+
+        const float d = 1.0f / d_inv;
+
+        y[ib].d2s6[iqs/64] = d;
+
+        return;
+    }
+
+    if (iqs % 32 != 0) {
+        return;
+    }
+
+    const float d = 1.0f / d_inv;
+
+    if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
+        y[ib].ds4[iqs/32] = make_half2(d, sum);
+    } else {
+        y[ib].d4[iqs/32]  = d;
+    }
+}
+
+void quantize_row_q8_1_cuda(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
+    const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
+
+    GGML_ASSERT(kx0_padded % QK8_1 == 0);
+
+    const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+    const dim3 num_blocks(block_num_x, kx1*channels, 1);
+    const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
+    quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded);
+
+    GGML_UNUSED(type_x);
+}
+
+void quantize_mmq_q8_1_cuda(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
+    const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
+
+    GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
+
+    const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
+    const dim3 num_blocks(block_num_x, kx1, channels);
+    const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
+    switch (mmq_get_q8_1_ds_layout(type_x)) {
+        case MMQ_Q8_1_DS_LAYOUT_D4:
+            quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
+                <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+            break;
+        case MMQ_Q8_1_DS_LAYOUT_DS4:
+            quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
+                <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+            break;
+        case MMQ_Q8_1_DS_LAYOUT_D2S6:
+            quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
+                <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}

+ 50 - 0
llama/ggml-cuda/quantize.cuh

@@ -0,0 +1,50 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "common.cuh"
+#include "mmq.cuh"
+
+#include <cstdint>
+
+#define CUDA_QUANTIZE_BLOCK_SIZE     256
+#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
+
+static_assert(MATRIX_ROW_PADDING %    CUDA_QUANTIZE_BLOCK_SIZE      == 0, "Risk of out-of-bounds access.");
+static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
+
+typedef void (*quantize_cuda_t)(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
+    const ggml_type type_x, cudaStream_t stream);
+
+void quantize_row_q8_1_cuda(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
+    const ggml_type type_x, cudaStream_t stream);
+
+void quantize_mmq_q8_1_cuda(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
+    const ggml_type type_x, cudaStream_t stream);

+ 297 - 0
llama/ggml-cuda/rope.cu

@@ -0,0 +1,297 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "rope.cuh"
+
+struct rope_corr_dims {
+    float v[2];
+};
+
+static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static __device__ void rope_yarn(
+    float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
+    float * cos_theta, float * sin_theta) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+    }
+    *cos_theta = cosf(theta) * mscale;
+    *sin_theta = sinf(theta) * mscale;
+}
+
+template<typename T, bool has_ff>
+static __global__ void rope_norm(
+    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+    if (i0 >= ne0) {
+        return;
+    }
+
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i0 >= n_dims) {
+        const int i = row*ne0 + i0;
+
+        dst[i + 0] = x[i + 0];
+        dst[i + 1] = x[i + 1];
+
+        return;
+    }
+
+    const int i  = row*ne0 + i0;
+    const int i2 = row/p_delta_rows;
+
+    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+
+    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+    const float x0 = x[i + 0];
+    const float x1 = x[i + 1];
+
+    dst[i + 0] = x0*cos_theta - x1*sin_theta;
+    dst[i + 1] = x0*sin_theta + x1*cos_theta;
+}
+
+template<typename T, bool has_ff>
+static __global__ void rope_neox(
+    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+    if (i0 >= ne0) {
+        return;
+    }
+
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i0 >= n_dims) {
+        const int i = row*ne0 + i0;
+
+        dst[i + 0] = x[i + 0];
+        dst[i + 1] = x[i + 1];
+
+        return;
+    }
+
+    const int i  = row*ne0 + i0/2;
+    const int i2 = row/p_delta_rows;
+
+    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+
+    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+    const float x0 = x[i + 0];
+    const float x1 = x[i + n_dims/2];
+
+    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
+    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+}
+
+template<typename T>
+static void rope_norm_cuda(
+    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+    GGML_ASSERT(ne0 % 2 == 0);
+    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(nr, n_blocks_x, 1);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    if (freq_factors == nullptr) {
+        rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors
+                );
+    } else {
+        rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors
+                );
+    }
+}
+
+template<typename T>
+static void rope_neox_cuda(
+    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+    GGML_ASSERT(ne0 % 2 == 0);
+    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(nr, n_blocks_x, 1);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    if (freq_factors == nullptr) {
+        rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors
+                );
+    } else {
+        rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors
+                );
+    }
+}
+
+static void rope_norm_cuda_f16(
+    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+
+    rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+}
+
+static void rope_norm_cuda_f32(
+    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+
+    rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+}
+
+static void rope_neox_cuda_f16(
+    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+
+    rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+}
+
+static void rope_neox_cuda_f32(
+    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
+) {
+
+    rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+}
+
+void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
+
+    const float * src0_d = (const float *)src0->data;
+    const float * src1_d = (const float *)src1->data;
+
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t nr = ggml_nrows(src0);
+
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
+    // RoPE alteration for extended context
+    float freq_base;
+    float freq_scale;
+    float ext_factor;
+    float attn_factor;
+    float beta_fast;
+    float beta_slow;
+
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+
+    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+
+    const int32_t * pos = (const int32_t *) src1_d;
+
+    const float * freq_factors = nullptr;
+    if (src2 != nullptr) {
+        freq_factors = (const float *) src2->data;
+    }
+
+    rope_corr_dims corr_dims;
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
+
+    // compute
+    if (is_neox) {
+        if (src0->type == GGML_TYPE_F32) {
+            rope_neox_cuda_f32(
+                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, stream
+            );
+        } else if (src0->type == GGML_TYPE_F16) {
+            rope_neox_cuda_f16(
+                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, stream
+            );
+        } else {
+            GGML_ABORT("fatal error");
+        }
+    } else {
+        if (src0->type == GGML_TYPE_F32) {
+            rope_norm_cuda_f32(
+                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, stream
+            );
+        } else if (src0->type == GGML_TYPE_F16) {
+            rope_norm_cuda_f16(
+                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, stream
+            );
+        } else {
+            GGML_ABORT("fatal error");
+        }
+    }
+}

+ 31 - 0
llama/ggml-cuda/rope.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_ROPE_BLOCK_SIZE 256
+
+void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 57 - 0
llama/ggml-cuda/scale.cu

@@ -0,0 +1,57 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "scale.cuh"
+
+static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = scale * x[i];
+}
+
+static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
+    scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
+}
+
+void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    float scale;
+    memcpy(&scale, dst->op_params, sizeof(float));
+
+    scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
+}

+ 31 - 0
llama/ggml-cuda/scale.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_SCALE_BLOCK_SIZE 256
+
+void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 232 - 0
llama/ggml-cuda/softmax.cu

@@ -0,0 +1,232 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "softmax.cuh"
+
+template <typename T>
+static __device__ __forceinline__ float t2f32(T val) {
+    return (float) val;
+}
+
+template <>
+__device__ float __forceinline__ t2f32<half>(half val) {
+    return __half2float(val);
+}
+
+template <bool vals_smem, int ncols_template, int block_size_template, typename T>
+static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
+    const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
+
+    const int tid  = threadIdx.x;
+    const int rowx = blockIdx.x;
+    const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
+
+    const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
+
+    const int warp_id = threadIdx.x / WARP_SIZE;
+    const int lane_id = threadIdx.x % WARP_SIZE;
+
+    const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
+
+    extern __shared__ float data_soft_max_f32[];
+    float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
+    // shared memory buffer to cache values between iterations:
+    float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
+
+    float max_val = -INFINITY;
+
+#pragma unroll
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+
+        if (ncols_template == 0 && col >= ncols) {
+            break;
+        }
+
+        const int64_t ix = (int64_t)rowx*ncols + col;
+        const int64_t iy = (int64_t)rowy*ncols + col;
+
+        const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
+
+        vals[col] = val;
+        max_val = max(max_val, val);
+    }
+
+    // find the max value in the block
+    max_val = warp_reduce_max(max_val);
+    if (block_size > WARP_SIZE) {
+        if (warp_id == 0) {
+            buf_iw[lane_id] = -INFINITY;
+        }
+        __syncthreads();
+
+        if (lane_id == 0) {
+            buf_iw[warp_id] = max_val;
+        }
+        __syncthreads();
+
+        max_val = buf_iw[lane_id];
+        max_val = warp_reduce_max(max_val);
+    }
+
+    float tmp = 0.0f; // partial sum
+
+#pragma unroll
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+
+        if (ncols_template == 0 && col >= ncols) {
+            break;
+        }
+
+        const float val = expf(vals[col] - max_val);
+        tmp += val;
+        vals[col] = val;
+    }
+
+    // find the sum of exps in the block
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        __syncthreads();
+        if (warp_id == 0) {
+            buf_iw[lane_id] = 0.0f;
+        }
+        __syncthreads();
+
+        if (lane_id == 0) {
+            buf_iw[warp_id] = tmp;
+        }
+        __syncthreads();
+
+        tmp = buf_iw[lane_id];
+        tmp = warp_reduce_sum(tmp);
+    }
+
+    const float inv_sum = 1.0f / tmp;
+
+#pragma unroll
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+
+        if (ncols_template == 0 && col >= ncols) {
+            return;
+        }
+
+        const int64_t idst = (int64_t)rowx*ncols + col;
+        dst[idst] = vals[col] * inv_sum;
+    }
+}
+
+template<typename T>
+static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
+    int nth = WARP_SIZE;
+    while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
+    const dim3 block_dims(nth,     1, 1);
+    const dim3 block_nums(nrows_x, 1, 1);
+    const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
+    static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
+
+    const uint32_t n_head      = nrows_x/nrows_y;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    // FIXME: this limit could be raised by ~2-4x on Ampere or newer
+    if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
+        switch (ncols_x) {
+            case 32:
+                soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+            case 64:
+                soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+            case 128:
+                soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+            case 256:
+                soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+            case 512:
+                soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+            case 1024:
+                soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+            case 2048:
+                soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+            case 4096:
+                soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+            default:
+                soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+        }
+    } else {
+        const size_t shmem_low = WARP_SIZE*sizeof(float);
+        soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+    }
+}
+
+void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    const float * src0_d = (const float *)src0->data;
+    const void  * src1_d = src1 ? (const void *)src1->data : nullptr;
+
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+
+    const int64_t ne00    = src0->ne[0];
+    const int64_t nrows_x = ggml_nrows(src0);
+    const int64_t nrows_y = src0->ne[1];
+
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+
+    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+
+    if (use_f16) {
+        const half * src1_dd = (const half *)src1_d;
+
+        soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+    } else {
+        const float * src1_dd = (const float *)src1_d;
+
+        soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+    }
+}

+ 31 - 0
llama/ggml-cuda/softmax.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
+
+void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 65 - 0
llama/ggml-cuda/sumrows.cu

@@ -0,0 +1,65 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "sumrows.cuh"
+
+static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
+    const int row = blockIdx.x;
+    const int col = threadIdx.x;
+
+    float sum = 0.0f;
+    for (int i = col; i < ncols; i += blockDim.x) {
+        sum += x[row * ncols + i];
+    }
+
+    sum = warp_reduce_sum(sum);
+
+    if (col == 0) {
+        dst[row] = sum;
+    }
+}
+
+void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    const dim3 block_nums(nrows, 1, 1);
+    k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+}
+
+void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
+}

+ 31 - 0
llama/ggml-cuda/sumrows.cuh

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
+
+void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 31 - 0
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);

+ 31 - 0
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);

+ 31 - 0
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);

+ 31 - 0
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);

+ 31 - 0
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);

+ 31 - 0
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);

+ 31 - 0
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);

+ 31 - 0
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);

+ 31 - 0
llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu

@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);

Vissa filer visades inte eftersom för många filer har ändrats