Prechádzať zdrojové kódy

image processing for llama3.2 (#6963)

Co-authored-by: jmorganca <jmorganca@gmail.com>
Co-authored-by: Michael Yang <mxyng@pm.me>
Co-authored-by: Jesse Gross <jesse@ollama.com>
Patrick Devine 6 mesiacov pred
rodič
commit
c7cb0f0602

+ 1 - 2
cmd/cmd.go

@@ -21,7 +21,6 @@ import (
 	"path/filepath"
 	"regexp"
 	"runtime"
-	"slices"
 	"strconv"
 	"strings"
 	"sync/atomic"
@@ -453,7 +452,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 		return err
 	}
 
-	opts.MultiModal = slices.Contains(info.Details.Families, "clip")
+	opts.MultiModal = len(info.ProjectorInfo) != 0
 	opts.ParentModel = info.Details.ParentModel
 
 	if interactive {

+ 20 - 27
cmd/interactive.go

@@ -494,28 +494,22 @@ func buildModelfile(opts runOptions) string {
 }
 
 func normalizeFilePath(fp string) string {
-	// Define a map of escaped characters and their replacements
-	replacements := map[string]string{
-		"\\ ":  " ",  // Escaped space
-		"\\(":  "(",  // Escaped left parenthesis
-		"\\)":  ")",  // Escaped right parenthesis
-		"\\[":  "[",  // Escaped left square bracket
-		"\\]":  "]",  // Escaped right square bracket
-		"\\{":  "{",  // Escaped left curly brace
-		"\\}":  "}",  // Escaped right curly brace
-		"\\$":  "$",  // Escaped dollar sign
-		"\\&":  "&",  // Escaped ampersand
-		"\\;":  ";",  // Escaped semicolon
-		"\\'":  "'",  // Escaped single quote
-		"\\\\": "\\", // Escaped backslash
-		"\\*":  "*",  // Escaped asterisk
-		"\\?":  "?",  // Escaped question mark
-	}
-
-	for escaped, actual := range replacements {
-		fp = strings.ReplaceAll(fp, escaped, actual)
-	}
-	return fp
+	return strings.NewReplacer(
+		"\\ ", " ", // Escaped space
+		"\\(", "(", // Escaped left parenthesis
+		"\\)", ")", // Escaped right parenthesis
+		"\\[", "[", // Escaped left square bracket
+		"\\]", "]", // Escaped right square bracket
+		"\\{", "{", // Escaped left curly brace
+		"\\}", "}", // Escaped right curly brace
+		"\\$", "$", // Escaped dollar sign
+		"\\&", "&", // Escaped ampersand
+		"\\;", ";", // Escaped semicolon
+		"\\'", "'", // Escaped single quote
+		"\\\\", "\\", // Escaped backslash
+		"\\*", "*", // Escaped asterisk
+		"\\?", "?", // Escaped question mark
+	).Replace(fp)
 }
 
 func extractFileNames(input string) []string {
@@ -535,10 +529,9 @@ func extractFileData(input string) (string, []api.ImageData, error) {
 	for _, fp := range filePaths {
 		nfp := normalizeFilePath(fp)
 		data, err := getImageData(nfp)
-		if err != nil {
-			if os.IsNotExist(err) {
-				continue
-			}
+		if errors.Is(err, os.ErrNotExist) {
+			continue
+		} else if err != nil {
 			fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err)
 			return "", imgs, err
 		}
@@ -546,7 +539,7 @@ func extractFileData(input string) (string, []api.ImageData, error) {
 		input = strings.ReplaceAll(input, fp, "")
 		imgs = append(imgs, data)
 	}
-	return input, imgs, nil
+	return strings.TrimSpace(input), imgs, nil
 }
 
 func getImageData(filePath string) ([]byte, error) {

+ 2 - 2
convert/convert_test.go

@@ -29,7 +29,7 @@ type tensorData struct {
 	Shape   []int  `json:"shape"`
 }
 
-func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) {
+func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, *llm.Tensors) {
 	t.Helper()
 
 	f, err := os.CreateTemp(t.TempDir(), "f16")
@@ -60,7 +60,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) {
 	return r, m.KV(), m.Tensors()
 }
 
-func generateResultsJSON(t *testing.T, f *os.File, kv llm.KV, tensors llm.Tensors) map[string]string {
+func generateResultsJSON(t *testing.T, f *os.File, kv llm.KV, tensors *llm.Tensors) map[string]string {
 	actual := make(map[string]string)
 	for k, v := range kv {
 		if s, ok := v.(json.Marshaler); !ok {

+ 1 - 0
go.mod

@@ -22,6 +22,7 @@ require (
 	github.com/mattn/go-runewidth v0.0.14
 	github.com/nlpodyssey/gopickle v0.3.0
 	github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
+	golang.org/x/image v0.14.0
 )
 
 require (

+ 2 - 0
go.sum

@@ -230,6 +230,8 @@ golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+o
 golang.org/x/image v0.0.0-20200618115811-c13761719519/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
 golang.org/x/image v0.0.0-20201208152932-35266b937fa6/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
 golang.org/x/image v0.0.0-20210216034530-4410531fe030/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
+golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
+golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
 golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
 golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
 golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=

+ 4 - 0
llama/ggml-cuda.cu

@@ -2296,6 +2296,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_PAD:
             ggml_cuda_op_pad(ctx, dst);
             break;
+        case GGML_OP_UNPAD:
+            ggml_cuda_op_unpad(ctx, dst);
+            break;
         case GGML_OP_ARANGE:
             ggml_cuda_op_arange(ctx, dst);
             break;
@@ -3018,6 +3021,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_GROUP_NORM:
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
+        case GGML_OP_UNPAD:
         case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_LEAKY_RELU:

+ 46 - 0
llama/ggml-cuda/pad.cu

@@ -73,3 +73,49 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
         dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
 }
+
+static __global__ void unpad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
+    // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
+    // blockIdx.y: idx of ne1
+    // blockIDx.x: idx of ne0 / BLOCK_SIZE
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    // operation
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+    if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
+        int offset_src =
+            nidx +
+            blockIdx.y * ne00 +
+            blockIdx.z * ne00 * ne01;
+        dst[offset_dst] = x[offset_src];
+    }
+}
+
+static void unpad_f32_cuda(const float * x, float * dst,
+    const int ne00, const int ne01, const int ne02, const int ne03,
+    const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, ne1, ne2*ne3);
+    unpad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
+}
+
+void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+    unpad_f32_cuda(src0_d, dst_d,
+        src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+        dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
+}

+ 1 - 0
llama/ggml-cuda/pad.cuh

@@ -29,3 +29,4 @@
 #define CUDA_PAD_BLOCK_SIZE 256
 
 void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

+ 45 - 0
llama/ggml-metal.metal

@@ -2055,6 +2055,51 @@ kernel void kernel_pad_f32(
     }
 }
 
+kernel void kernel_unpad_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3;
+    const int64_t i02 = i2;
+    const int64_t i01 = i1;
+
+    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+
+    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+            if (i0 < ne00) {
+                dst_ptr[i0] = src0_ptr[i0];
+            }
+        }
+
+        return;
+    }
+}
+
 kernel void kernel_arange_f32(
     device        char * dst,
     constant   int64_t & ne0,

+ 33 - 0
llama/ggml-metal_darwin_arm64.m

@@ -219,6 +219,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_IM2COL_F32,
     GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
     GGML_METAL_KERNEL_TYPE_PAD_F32,
+    GGML_METAL_KERNEL_TYPE_UNPAD_F32,
     GGML_METAL_KERNEL_TYPE_ARANGE_F32,
     GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
     GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@@ -715,6 +716,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UNPAD_F32,                     unpad_f32,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                    arange_f32,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true);
@@ -872,6 +874,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
             return false;
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
+        case GGML_OP_UNPAD:
         case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_ARGSORT:
@@ -2681,6 +2684,36 @@ static void ggml_metal_encode_node(
 
                 const int nth = MIN(1024, ne0);
 
+                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_UNPAD:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UNPAD_F32].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+
+                const int nth = MIN(1024, ne0);
+
                 [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
             } break;
         case GGML_OP_ARANGE:

+ 91 - 2
llama/ggml.c

@@ -3023,6 +3023,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "POOL_2D_BACK",
     "UPSCALE",
     "PAD",
+    "UNPAD",
     "ARANGE",
     "TIMESTEP_EMBEDDING",
     "ARGSORT",
@@ -3056,7 +3057,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3117,6 +3118,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "pool_2d_back(x)",
     "upscale(x)",
     "pad(x)",
+    "unpad(x)",
     "arange(start, stop, step)",
     "timestep_embedding(timesteps, dim, max_period)",
     "argsort(x)",
@@ -3150,7 +3152,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -6981,6 +6983,32 @@ struct ggml_tensor * ggml_pad(
     return result;
 }
 
+// ggml_unpad
+
+struct ggml_tensor * ggml_unpad(
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a,
+    int p0, int p1, int p2, int p3) {
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ABORT("fatal error"); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
+            a->ne[0] - p0,
+            a->ne[1] - p1,
+            a->ne[2] - p2,
+            a->ne[3] - p3);
+
+    result->op = GGML_OP_UNPAD;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
+}
+
 // ggml_arange
 
 struct ggml_tensor * ggml_arange(
@@ -15338,6 +15366,58 @@ static void ggml_compute_forward_pad(
     }
 }
 
+static void ggml_compute_forward_unpad_f32(
+    const struct ggml_compute_params *params,
+    struct ggml_tensor *dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float * dst_ptr = (float *) dst->data;
+
+    // TODO: optimize
+
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                for (int64_t i3 = 0; i3 < ne3; ++i3) {
+                    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
+
+                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        dst_ptr[dst_idx] = *src_ptr;
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_unpad(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_unpad_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
 
 // ggml_compute_forward_arange
 
@@ -17320,6 +17400,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_pad(params, tensor);
             } break;
+        case GGML_OP_UNPAD:
+            {
+                ggml_compute_forward_unpad(params, tensor);
+            } break;
         case GGML_OP_ARANGE:
             {
                 ggml_compute_forward_arange(params, tensor);
@@ -18395,6 +18479,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ABORT("fatal error"); // TODO: not implemented
             }
+        case GGML_OP_UNPAD:
+            {
+                GGML_ABORT("fatal error"); // TODO: not implemented
+            }
         case GGML_OP_ARANGE:
             {
                 GGML_ABORT("fatal error"); // TODO: not implemented
@@ -19191,6 +19279,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             } break;
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
+        case GGML_OP_UNPAD:
         case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_ARGSORT:

+ 10 - 0
llama/ggml.h

@@ -532,6 +532,7 @@ extern "C" {
         GGML_OP_POOL_2D_BACK,
         GGML_OP_UPSCALE, // nearest interpolate
         GGML_OP_PAD,
+        GGML_OP_UNPAD,
         GGML_OP_ARANGE,
         GGML_OP_TIMESTEP_EMBEDDING,
         GGML_OP_ARGSORT,
@@ -1790,6 +1791,15 @@ extern "C" {
             int                  p2,
             int                  p3);
 
+    // unpad each dimension: [x, ..., x, y, ..., y] -> [x, ..., x]
+    GGML_API struct ggml_tensor * ggml_unpad(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                  p0,
+            int                  p1,
+            int                  p2,
+            int                  p3);
+
     // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
     // timesteps: [N,]
     // return: [N, dim]

+ 443 - 13
llama/llama.cpp

@@ -195,6 +195,7 @@ static std::string format(const char * fmt, ...) {
 
 enum llm_arch {
     LLM_ARCH_LLAMA,
+    LLM_ARCH_MLLAMA,
     LLM_ARCH_FALCON,
     LLM_ARCH_BAICHUAN,
     LLM_ARCH_GROK,
@@ -249,6 +250,7 @@ enum llm_arch {
 
 static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_LLAMA,           "llama"        },
+    { LLM_ARCH_MLLAMA,          "mllama"       },
     { LLM_ARCH_FALCON,          "falcon"       },
     { LLM_ARCH_GROK,            "grok"         },
     { LLM_ARCH_GPT2,            "gpt2"         },
@@ -356,6 +358,7 @@ enum llm_kv {
     LLM_KV_ATTENTION_SLIDING_WINDOW,
     LLM_KV_ATTENTION_SCALE,
     LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
+    LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,
 
     LLM_KV_ROPE_DIMENSION_COUNT,
     LLM_KV_ROPE_FREQ_BASE,
@@ -465,6 +468,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_SLIDING_WINDOW,         "%s.attention.sliding_window"           },
     { LLM_KV_ATTENTION_SCALE,                  "%s.attention.scale"                    },
     { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,  "%s.attention.block_skip_connection.%d" },
+    { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers"   },
 
     { LLM_KV_ROPE_DIMENSION_COUNT,          "%s.rope.dimension_count"                 },
     { LLM_KV_ROPE_FREQ_BASE,                "%s.rope.freq_base"                       },
@@ -639,6 +643,14 @@ enum llm_tensor {
     LLM_TENSOR_CLS,
     LLM_TENSOR_CLS_OUT,
     LLM_TENSOR_BSKCN_TV,
+    LLM_TENSOR_CROSS_ATTN_K_NORM,
+    LLM_TENSOR_CROSS_ATTN_K_PROJ,
+    LLM_TENSOR_CROSS_ATTN_O_PROJ,
+    LLM_TENSOR_CROSS_ATTN_Q_NORM,
+    LLM_TENSOR_CROSS_ATTN_Q_PROJ,
+    LLM_TENSOR_CROSS_ATTN_V_PROJ,
+    LLM_TENSOR_CROSS_ATTN_ATTN_GATE,
+    LLM_TENSOR_CROSS_ATTN_MLP_GATE,
 };
 
 static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
@@ -668,6 +680,40 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
         },
     },
+    {
+        LLM_ARCH_MLLAMA,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
+            { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
+            { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
+            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
+            { LLM_TENSOR_CROSS_ATTN_K_NORM,    "blk.%d.cross_attn_k_norm" },
+            { LLM_TENSOR_CROSS_ATTN_K_PROJ,    "blk.%d.cross_attn_k_proj" },
+            { LLM_TENSOR_CROSS_ATTN_O_PROJ,    "blk.%d.cross_attn_o_proj" },
+            { LLM_TENSOR_CROSS_ATTN_Q_NORM,    "blk.%d.cross_attn_q_norm" },
+            { LLM_TENSOR_CROSS_ATTN_Q_PROJ,    "blk.%d.cross_attn_q_proj" },
+            { LLM_TENSOR_CROSS_ATTN_V_PROJ,    "blk.%d.cross_attn_v_proj" },
+            { LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" },
+            { LLM_TENSOR_CROSS_ATTN_MLP_GATE,  "blk.%d.cross_attn_mlp_gate" },
+        },
+    },
     {
         LLM_ARCH_BAICHUAN,
         {
@@ -2416,6 +2462,7 @@ enum e_model {
     MODEL_40B,
     MODEL_65B,
     MODEL_70B,
+    MODEL_90B,
     MODEL_236B,
     MODEL_314B,
     MODEL_SMALL,
@@ -2460,6 +2507,7 @@ struct llama_hparams {
     std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
 
     std::array<std::array<uint32_t, LLAMA_MAX_LAYERS>, 4> n_bskcn_arr;
+    std::array<uint32_t, LLAMA_MAX_LAYERS> cross_attn_layers;
 
     uint32_t n_layer_dense_lead = 0;
     uint32_t n_lora_q = 0;
@@ -2528,10 +2576,11 @@ struct llama_hparams {
         if (this->n_expert      != other.n_expert)      return true;
         if (this->n_expert_used != other.n_expert_used) return true;
 
-        if (this->n_head_arr    != other.n_head_arr)    return true;
-        if (this->n_head_kv_arr != other.n_head_kv_arr) return true;
-        if (this->n_ff_arr      != other.n_ff_arr)      return true;
-        if (this->n_bskcn_arr   != other.n_bskcn_arr)   return true;
+        if (this->n_head_arr        != other.n_head_arr)        return true;
+        if (this->n_head_kv_arr     != other.n_head_kv_arr)     return true;
+        if (this->n_ff_arr          != other.n_ff_arr)          return true;
+        if (this->n_bskcn_arr       != other.n_bskcn_arr)       return true;
+        if (this->cross_attn_layers != other.cross_attn_layers) return true;
 
         if (this->n_rel_attn_bkts    != other.n_rel_attn_bkts)    return true;
         if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true;
@@ -2649,6 +2698,10 @@ struct llama_hparams {
 
         GGML_ABORT("fatal error");
     }
+
+    bool cross_attention_layer(uint32_t il) const {
+        return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
+    }
 };
 
 static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
@@ -2832,6 +2885,16 @@ struct llama_layer {
     struct ggml_tensor * ffn_down_scale;
 
     struct ggml_tensor * bskcn_tv;
+
+    // cross attention
+    struct ggml_tensor * cross_attn_k_norm;
+    struct ggml_tensor * cross_attn_k_proj;
+    struct ggml_tensor * cross_attn_o_proj;
+    struct ggml_tensor * cross_attn_q_norm;
+    struct ggml_tensor * cross_attn_q_proj;
+    struct ggml_tensor * cross_attn_v_proj;
+    struct ggml_tensor * cross_attn_attn_gate;
+    struct ggml_tensor * cross_attn_mlp_gate;
 };
 
 // very similar to llama_batch,
@@ -3478,6 +3541,12 @@ struct llama_context {
     struct ggml_tensor * inp_pos_bucket;    // I32 [n_batch|n_kv, n_batch]
     struct ggml_tensor * inp_embd_enc;      // F32 [n_embd, n_outputs_enc]
     struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
+
+    // TODO (jmorganca): this should most likely be passed in as part of a batch
+    // and not set on the context for all batches.
+    float * cross_attn_state = nullptr;
+    bool cross_attn_state_first_pass = true;
+    struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
 };
 
 struct llama_lora_weight {
@@ -3712,6 +3781,18 @@ static bool llama_kv_cache_init(
     cache.v_l.reserve(n_layer);
 
     for (int i = 0; i < (int) n_layer; i++) {
+        // for cross attention layers
+        if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layer(i)) {
+            struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
+            ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
+            ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
+            ggml_format_name(k, "cache_k_l%d", i);
+            ggml_format_name(v, "cache_v_l%d", i);
+            cache.k_l.push_back(k);
+            cache.v_l.push_back(v);
+            continue;
+        }
+
         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
         const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
 
@@ -5486,12 +5567,14 @@ static void llm_load_hparams(
     }
 
     // zero-out the per-layer hparams
-    std::fill(hparams.n_head_arr.begin(),    hparams.n_head_arr.end(),    0);
-    std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
-    std::fill(hparams.n_ff_arr.begin(),      hparams.n_ff_arr.end(),      0);
+    std::fill(hparams.n_head_arr.begin(),             hparams.n_head_arr.end(),        0);
+    std::fill(hparams.n_head_kv_arr.begin(),          hparams.n_head_kv_arr.end(),     0);
+    std::fill(hparams.n_ff_arr.begin(),               hparams.n_ff_arr.end(),          0);
+    std::fill(hparams.cross_attn_layers.begin(),      hparams.cross_attn_layers.end(), -1);
 
-    ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff_arr,   hparams.n_layer);
-    ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
+    ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,       hparams.n_ff_arr,          hparams.n_layer);
+    ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT,      hparams.n_head_arr,        hparams.n_layer);
+    ml.get_arr(LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, hparams.cross_attn_layers, false);
 
     // n_head_kv is optional, default to n_head
     hparams.n_head_kv_arr = hparams.n_head_arr;
@@ -5540,7 +5623,7 @@ static void llm_load_hparams(
 
         ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
 
-        if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
+        if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_MLLAMA || model.arch == LLM_ARCH_FALCON) {
             if (hparams.n_rot != hparams.n_embd_head_k) {
                 throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
             }
@@ -5580,6 +5663,16 @@ static void llm_load_hparams(
                     }
                 }
             } break;
+        case LLM_ARCH_MLLAMA:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                switch (hparams.n_layer) {
+                    case 40: model.type = e_model::MODEL_11B; break;
+                    case 100: model.type = e_model::MODEL_90B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_MINICPM:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -7275,6 +7368,55 @@ static bool llm_load_tensors(
                         layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
                     }
                 } break;
+            case LLM_ARCH_MLLAMA:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8});
+
+                    // output
+                    {
+                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        // if output is NULL, init from the input tok embed
+                        if (model.output == NULL) {
+                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
+                        }
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        if (hparams.cross_attention_layer(i)) {
+                            layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM,   "weight", i), {128});
+                            layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ,   "weight", i), {n_embd, 1024});
+                            layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ,   "weight", i), {n_embd, n_embd});
+                            layer.cross_attn_q_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_Q_NORM, "weight", i), {128});
+                            layer.cross_attn_q_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_Q_PROJ, "weight", i), {n_embd, n_embd});
+                            layer.cross_attn_v_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_V_PROJ, "weight", i), {n_embd, 1024});
+                            layer.cross_attn_attn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_ATTN_GATE, i), {1});
+                            layer.cross_attn_mlp_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_MLP_GATE, i), {1});
+                            layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                            layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                            layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
+                            layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                            layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        } else {
+                            layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                            layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
+                            layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
+                            layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
+                            layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
+                            layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                            layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                            layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
+                            layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                            layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        }
+                    }
+                } break;
             case LLM_ARCH_GROK:
                 {
                     if (n_expert == 0) {
@@ -9119,7 +9261,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
 
         if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
             model.hparams.n_vocab != model.vocab.id_to_token.size()) {
-            throw std::runtime_error("vocab size mismatch");
+            LLAMA_LOG_WARN("%s: vocab mismatch %u !- %zu ...\n", __func__, model.hparams.n_vocab, model.vocab.id_to_token.size());
         }
 
         if (params.vocab_only) {
@@ -9204,7 +9346,7 @@ static struct ggml_tensor * llm_build_inp_embd(
 
         inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
     } else {
-       lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
+        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
         inpL = lctx.inp_embd;
         ggml_set_input(lctx.inp_embd);
     }
@@ -9219,6 +9361,22 @@ static struct ggml_tensor * llm_build_inp_embd(
     return inpL;
 }
 
+static struct ggml_tensor * llm_build_inp_cross_attn_state(
+        struct ggml_context * ctx,
+       struct llama_context & lctx,
+        const llama_hparams & hparams,
+         const llm_build_cb & cb) {
+    const int64_t n_embd = hparams.n_embd;
+
+    struct ggml_tensor * inpCAS;
+    lctx.inp_cross_attn_state = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
+    cb(lctx.inp_cross_attn_state, "inp_cross_attn_state", -1);
+    ggml_set_input(lctx.inp_cross_attn_state);
+    inpCAS = lctx.inp_cross_attn_state;
+
+    return inpCAS;
+}
+
 static void llm_build_kv_store(
         struct ggml_context * ctx,
         const llama_hparams & hparams,
@@ -10193,6 +10351,7 @@ struct llm_build_context {
         lctx.inp_pos_bucket    = nullptr;
         lctx.inp_embd_enc      = nullptr;
         lctx.inp_KQ_mask_cross = nullptr;
+        lctx.inp_cross_attn_state = nullptr;
     }
 
     void free() {
@@ -10780,6 +10939,253 @@ struct llm_build_context {
                 LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
 
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    struct ggml_cgraph * build_mllama() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+        struct ggml_tensor * inpCAS;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpCAS = llm_build_inp_cross_attn_state(ctx0, lctx, hparams, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            if (hparams.cross_attention_layer(il)) {
+                if (!lctx.cross_attn_state) {
+                    continue;
+                }
+
+                // cross attention layer
+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_q_proj, cur);
+                cb(Qcur, "Qcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                cb(Qcur, "Qcur", il);
+
+                Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+                cb(Qcur, "Qcur", il);
+
+                // TODO: is this required?
+                Qcur = ggml_cont(ctx0, Qcur);
+                cb(Qcur, "Qcur", il);
+
+                Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur;
+                if (lctx.cross_attn_state_first_pass) {
+                    Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
+                    cb(Kcur, "Kcur", il);
+
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
+                    cb(Kcur, "Kcur", il);
+
+                    Kcur = ggml_permute(ctx0, Kcur, 0, 2, 1, 3);
+                    cb(Kcur, "Kcur", il);
+
+                    // TODO: is this required?
+                    Kcur = ggml_cont(ctx0, Kcur);
+                    cb(Kcur, "Kcur", il);
+
+                    Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
+                    cb(Kcur, "Kcur", il);
+
+                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
+                } else {
+                    Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
+                    cb(Kcur, "Kcur (view)", il);
+                }
+
+                struct ggml_tensor * Vcur;
+                if (lctx.cross_attn_state_first_pass) {
+                    Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
+                    cb(Vcur, "Vcur", il);
+
+                    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, 6404);
+                    cb(Vcur, "Vcur", il);
+
+                    Vcur = ggml_permute(ctx0, Vcur, 0, 2, 1, 3);
+                    cb(Vcur, "Vcur", il);
+
+                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
+                } else {
+                    Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
+                    cb(Vcur, "Vcur (view)", il);
+                }
+
+                struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
+                cb(kq, "kq", il);
+
+                kq = ggml_scale_inplace(ctx0, kq, 1.0f/sqrtf(float(n_embd_head)));
+                cb(kq, "kq_scaled", il);
+
+                // TODO: apply causal masks
+                struct ggml_tensor * kq_soft_max = ggml_soft_max_inplace(ctx0, kq);
+                cb(kq_soft_max, "kq_soft_max", il);
+
+                Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
+                cb(Vcur, "Vcur", il);
+
+                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, Vcur, kq_soft_max);
+                cb(kqv, "kqv", il);
+
+                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                cb(kqv_merged, "kqv_merged", il);
+
+                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
+                cb(cur, "kqv_merged_cont", il);
+
+                cur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_o_proj, cur);
+                cb(cur, "cur", il);
+
+                // TODO: do this in place once?
+                cur = ggml_mul(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_attn_gate));
+
+                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+                cb(ffn_inp, "ffn_inp", il);
+
+                // feed-forward network
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+
+                // TODO: do this inplace once?
+                cur = ggml_add_inplace(ctx0, ggml_mul_inplace(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_mlp_gate)), ffn_inp);
+                cb(cur, "ffn_out", il);
+
+                cur = lctx.cvec.apply_to(ctx0, cur, il);
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
+            } else {
+                // self attention layer
+
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                struct ggml_tensor * rope_factors = build_rope_factors(il);
+
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+
+
+                if (il == n_layer - 1) {
+                    // skip computing output for unused tokens
+                    struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                    n_tokens = n_outputs;
+                    cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                    inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+                }
+
+                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+                cb(ffn_inp, "ffn_inp", il);
+
+                // feed-forward network
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+
+                cur = ggml_add(ctx0, cur, ffn_inp);
+                cb(cur, "ffn_out", il);
+
+                cur = lctx.cvec.apply_to(ctx0, cur, il);
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
+            }
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
         // lm_head
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
@@ -16527,6 +16933,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_llama();
             } break;
+        case LLM_ARCH_MLLAMA:
+            {
+                result = llm.build_mllama();
+            } break;
         case LLM_ARCH_BAICHUAN:
             {
                 result = llm.build_baichuan();
@@ -16799,6 +17209,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
         ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
     }
 
+    // TODO (jmorganca): this might copy a lot of data on every request of a
+    // single generation even though it doesn't change, so we should
+    // find a way to not set this more than one time per image
+    if (lctx.inp_cross_attn_state &&
+        lctx.inp_cross_attn_state->buffer) {
+        ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state));
+    }
+
     if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
         GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
         const int64_t n_tokens = batch.n_tokens;
@@ -17481,6 +17899,10 @@ static int llama_decode_internal(
 
         llama_set_inputs(lctx, ubatch);
 
+        // TODO: replace with something better to find out if its
+        // our first actual pass
+        lctx.cross_attn_state_first_pass = false;
+
         llama_graph_compute(lctx, gf, n_threads, threadpool);
 
         // update the kv ring buffer
@@ -18674,7 +19096,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         if (llama_model_has_encoder(&model)) {
             n_attn_layer *= 3;
         }
-        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
+        if (qs.n_attention_wv != n_attn_layer) {
+            LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv);
+        }
     }
 
     size_t total_size_org = 0;
@@ -19770,6 +20194,11 @@ struct llama_context * llama_new_context_with_model(
     return ctx;
 }
 
+void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state) {
+    ctx->cross_attn_state_first_pass = true;
+    ctx->cross_attn_state = cross_attn_state;
+}
+
 void llama_free(struct llama_context * ctx) {
     delete ctx;
 }
@@ -19840,6 +20269,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
 
         // use what we call a normal RoPE, operating on pairs of consecutive head values
         case LLM_ARCH_LLAMA:
+        case LLM_ARCH_MLLAMA:
         case LLM_ARCH_BAICHUAN:
         case LLM_ARCH_STARCODER:
         case LLM_ARCH_PLAMO:

+ 91 - 5
llama/llama.go

@@ -60,7 +60,9 @@ package llama
 #include <stdlib.h>
 #include "llama.h"
 #include "clip.h"
+#include "ggml.h"
 #include "llava.h"
+#include "mllama.h"
 #include "sampling_ext.h"
 
 bool llamaProgressCallback(float progress, void *user_data);
@@ -410,18 +412,60 @@ func Quantize(infile, outfile string, ftype uint32) error {
 
 // llava
 type ClipContext struct {
-	c *C.struct_clip_ctx
+	c        *C.struct_clip_ctx
+	m        *C.struct_mllama_ctx
+	IsMllama bool
+	embedPin runtime.Pinner
+	pinned   bool
 }
 
-func NewClipContext(modelPath string) *ClipContext {
+func getVisionArch(mp *C.char) (string, error) {
+	gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
+	if gguf_ctx == nil {
+		return "", errors.New("unable to load vision projector")
+	}
+	defer C.gguf_free(gguf_ctx)
+
+	arch_index := C.gguf_find_key(gguf_ctx, C.CString("general.architecture"))
+	if int(arch_index) < 0 {
+		return "", errors.New("unknown vision model architecture")
+	}
+
+	arch := C.gguf_get_val_str(gguf_ctx, arch_index)
+
+	return C.GoString(arch), nil
+}
+
+func NewClipContext(modelPath string) (*ClipContext, error) {
 	mp := C.CString(modelPath)
 	defer C.free(unsafe.Pointer(mp))
-	cc := C.clip_model_load(mp, 1)
-	return &ClipContext{c: cc}
+
+	arch, err := getVisionArch(mp)
+	if err != nil {
+		return nil, err
+	}
+
+	var cc ClipContext
+	if arch == "clip" {
+		cc.c = C.clip_model_load(mp, 1)
+	} else if arch == "mllama" {
+		cc.m = C.mllama_model_load(mp, 1)
+		cc.IsMllama = true
+	} else {
+		return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
+	}
+
+	// XXX: check embedding size?
+	return &cc, nil
 }
 
 func (c *ClipContext) Free() {
-	C.clip_free(c.c)
+	if c.c != nil {
+		C.clip_free(c.c)
+	}
+	if c.m != nil {
+		C.mllama_free(c.m)
+	}
 }
 
 func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte) [][]float32 {
@@ -445,6 +489,48 @@ func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []
 	return embed
 }
 
+func NewMllamaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte, aspectRatioId int) [][]float32 {
+	img := C.mllama_image_init()
+	defer C.mllama_image_free(img)
+
+	C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img)
+
+	numTokens := int(C.mllama_n_positions(clipContext.m) * C.mllama_n_tiles(clipContext.m))
+	numEmbed := llamaContext.Model().NEmbd()
+
+	rows := make([]float32, numEmbed*numTokens)
+	C.mllama_image_encode(clipContext.m, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0])))
+
+	embed := make([][]float32, numTokens)
+	for i := range embed {
+		embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
+	}
+
+	return embed
+}
+
+// This really needs to be set on a batch instead
+func MllamaSetCrossAttn(llamaContext *Context, clipContext *ClipContext, embed [][]float32) {
+	if embed != nil {
+		if clipContext.pinned {
+			panic("Cross attention state already pinned")
+		}
+
+		embedData := &embed[0][0]
+		clipContext.embedPin.Pin(embedData)
+		clipContext.pinned = true
+
+		C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(unsafe.Pointer(embedData)))
+	} else {
+		C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(C.NULL))
+
+		if clipContext.pinned {
+			clipContext.embedPin.Unpin()
+			clipContext.pinned = false
+		}
+	}
+}
+
 // sampling
 // TODO: this is a temporary wrapper to allow calling C++ code from CGo
 type SamplingContext struct {

+ 4 - 0
llama/llama.h

@@ -449,6 +449,10 @@ extern "C" {
                      struct llama_model * model,
             struct llama_context_params   params);
 
+    // TODO (jmorganca): this should most likely be passed in as part of a batch
+    // and not set on the context for all batches.
+    LLAMA_API void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state);
+
     // Frees all allocated memory
     LLAMA_API void llama_free(struct llama_context * ctx);
 

+ 900 - 0
llama/mllama.cpp

@@ -0,0 +1,900 @@
+// NOTE: This is modified from clip.cpp for Mllama only
+#include "mllama.h"
+
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+#include "ggml.h"
+
+#ifdef GGML_USE_CUDA
+#include "ggml-cuda.h"
+#endif
+
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
+
+#ifdef GGML_USE_CANN
+#include "ggml-cann.h"
+#endif
+
+#ifdef GGML_USE_VULKAN
+#include "ggml-vulkan.h"
+#endif
+
+#include <algorithm>
+#include <cmath>
+#include <cstdarg>
+#include <cstdlib>
+#include <cstring>
+#include <fstream>
+#include <stdexcept>
+#include <vector>
+
+#define REQUIRE(x)                                           \
+    do {                                                     \
+        if (!(x)) {                                          \
+            throw std::runtime_error("REQUIRE failed: " #x); \
+        }                                                    \
+    } while (0)
+
+#define LOG(fmt, ...) fprintf(stderr, "%s: " fmt "\n", __func__, ##__VA_ARGS__)
+
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+    #define NOMINMAX
+#endif
+#include <windows.h>
+#if __GLIBCXX__
+#include <cstdio>
+#include <ext/stdio_filebuf.h>
+#include <fcntl.h>
+#endif
+#endif
+
+struct mllama_image {
+    int width;
+    int height;
+
+    int num_channels = 3;
+    int num_tiles = 4;
+
+    int aspect_ratio_id;
+
+    std::vector<float> data;
+};
+
+static std::string format(const char *fmt, ...) {
+    va_list args;
+    va_start(args, fmt);
+    std::vector<char> b(128);
+    int n = vsnprintf(b.data(), b.size(), fmt, args);
+    REQUIRE(n >= 0 && n < b.size());
+    va_end(args);
+    return std::string(b.data(), b.size());
+}
+
+//
+// utilities to get data from a gguf file
+//
+
+static int get_key_index(const gguf_context *ctx, const char *key) {
+    int key_index = gguf_find_key(ctx, key);
+    REQUIRE(key_index != -1);
+    return key_index;
+}
+
+static std::vector<uint32_t> get_u32_array(const gguf_context *ctx, const std::string &key) {
+    const int i = get_key_index(ctx, key.c_str());
+    const int n = gguf_get_arr_n(ctx, i);
+    const uint32_t *data = (uint32_t *)gguf_get_arr_data(ctx, i);
+
+    std::vector<uint32_t> s(n);
+    for (size_t j = 0; j < s.size(); j++) {
+        s[j] = data[j];
+    }
+
+    return s;
+}
+
+static uint32_t get_u32(const gguf_context *ctx, const std::string &key) {
+    return gguf_get_val_u32(ctx, get_key_index(ctx, key.c_str()));
+}
+
+static float get_f32(const gguf_context *ctx, const std::string &key) {
+    return gguf_get_val_f32(ctx, get_key_index(ctx, key.c_str()));
+}
+
+static std::string get_ftype(int ftype) {
+    return ggml_type_name(static_cast<ggml_type>(ftype));
+}
+
+//
+// mllama layers
+//
+
+struct mllama_hparams {
+    uint32_t image_size;
+    uint32_t patch_size;
+    uint32_t hidden_size;
+    uint32_t n_intermediate;
+    uint32_t projection_dim;
+    uint32_t n_head;
+    uint32_t n_layer;
+    uint32_t n_global_layer;
+    uint32_t n_tiles;
+
+    float eps;
+
+    std::vector<bool> intermediate_layers;
+};
+
+struct mllama_layer {
+    // attention
+    struct ggml_tensor *k_w;
+    struct ggml_tensor *k_b;
+    struct ggml_tensor *q_w;
+    struct ggml_tensor *q_b;
+    struct ggml_tensor *v_w;
+    struct ggml_tensor *v_b;
+
+    struct ggml_tensor *o_w;
+    struct ggml_tensor *o_b;
+
+    struct ggml_tensor *attn_gate;
+
+    // layernorm 1
+    struct ggml_tensor *ln_1_w;
+    struct ggml_tensor *ln_1_b;
+
+    // ff
+    struct ggml_tensor *ff_i_w;
+    struct ggml_tensor *ff_i_b;
+
+    struct ggml_tensor *ff_o_w;
+    struct ggml_tensor *ff_o_b;
+
+    struct ggml_tensor *ff_gate;
+
+    // layernorm 2
+    struct ggml_tensor *ln_2_w;
+    struct ggml_tensor *ln_2_b;
+};
+
+struct mllama_vision_model {
+    struct mllama_hparams hparams;
+
+    // embeddings
+    struct ggml_tensor *class_embedding;
+    struct ggml_tensor *patch_embeddings;
+    struct ggml_tensor *position_embeddings;
+    struct ggml_tensor *position_embeddings_gate;
+    struct ggml_tensor *tile_position_embeddings;
+    struct ggml_tensor *tile_position_embeddings_gate;
+    struct ggml_tensor *pre_tile_position_embeddings;
+    struct ggml_tensor *pre_tile_position_embeddings_gate;
+    struct ggml_tensor *post_tile_position_embeddings;
+    struct ggml_tensor *post_tile_position_embeddings_gate;
+
+    struct ggml_tensor *pre_ln_w;
+    struct ggml_tensor *pre_ln_b;
+
+    std::vector<mllama_layer> layers;
+    std::vector<mllama_layer> global_layers;
+
+    struct ggml_tensor *post_ln_w;
+    struct ggml_tensor *post_ln_b;
+
+    struct ggml_tensor *mm_0_w;
+    struct ggml_tensor *mm_0_b;
+};
+
+struct mllama_ctx {
+    struct mllama_vision_model vision_model;
+
+    uint32_t ftype = 1;
+
+    struct gguf_context *ctx_gguf;
+    struct ggml_context *ctx_data;
+
+    std::vector<uint8_t> buf_compute_meta;
+
+    // memory buffers to evaluate the model
+    ggml_backend_buffer_t params_buffer = nullptr;
+
+    ggml_backend_t backend = nullptr;
+    ggml_gallocr_t compute_alloc = nullptr;
+};
+
+static ggml_tensor *mllama_image_build_encoder_layer(
+    struct ggml_context *ctx0, const size_t il, const struct mllama_layer &layer, struct ggml_tensor *embeddings,
+    const float eps, const int hidden_size, const int batch_size, const int n_head, const int d_head) {
+    struct ggml_tensor *cur = embeddings;
+
+    {
+        // layernorm1
+        cur = ggml_norm(ctx0, cur, eps);
+        cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.ln_1_w), layer.ln_1_b);
+        ggml_set_name(cur, format("%d pre layernorm", il).c_str());
+    }
+
+    {
+        // self-attention
+        struct ggml_tensor *Q = ggml_mul_mat(ctx0, layer.q_w, cur);
+        if (layer.q_b != nullptr) {
+            Q = ggml_add(ctx0, Q, layer.q_b);
+        }
+
+        Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, Q->ne[1], batch_size);
+        Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
+        ggml_set_name(Q, format("%d query", il).c_str());
+
+        struct ggml_tensor *K = ggml_mul_mat(ctx0, layer.k_w, cur);
+        if (layer.k_b != nullptr) {
+            K = ggml_add(ctx0, K, layer.k_b);
+        }
+
+        K = ggml_reshape_4d(ctx0, K, d_head, n_head, K->ne[1], batch_size);
+        K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
+        ggml_set_name(K, format("%d key", il).c_str());
+
+        struct ggml_tensor *V = ggml_mul_mat(ctx0, layer.v_w, cur);
+        if (layer.v_b != nullptr) {
+            V = ggml_add(ctx0, V, layer.v_b);
+        }
+
+        V = ggml_reshape_4d(ctx0, V, d_head, n_head, V->ne[1], batch_size);
+        V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
+        ggml_set_name(V, format("%d value", il).c_str());
+
+        struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q);
+        KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf((float)d_head));
+        KQ = ggml_soft_max_inplace(ctx0, KQ);
+        ggml_set_name(KQ, format("%d KQ", il).c_str());
+
+        struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ);
+        KQV = ggml_reshape_4d(ctx0, KQV, d_head, KQV->ne[1], n_head, batch_size);
+        KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+        KQV = ggml_cont_3d(ctx0, KQV, hidden_size, KQV->ne[2], batch_size);
+        ggml_set_name(KQV, format("%d KQV", il).c_str());
+
+        cur = ggml_mul_mat(ctx0, layer.o_w, KQV);
+        if (layer.o_b != nullptr) {
+            cur = ggml_add(ctx0, cur, layer.o_b);
+        }
+        ggml_set_name(cur, format("%d self attention", il).c_str());
+
+        if (layer.attn_gate != nullptr) {
+            cur = ggml_mul_inplace(ctx0, cur, layer.attn_gate);
+            ggml_set_name(cur, format("%d self attention gate", il).c_str());
+        }
+    }
+
+    cur = ggml_add(ctx0, cur, embeddings);
+    ggml_set_name(cur, format("%d residual", il).c_str());
+
+    embeddings = cur;
+
+    {
+        // layernorm2
+        cur = ggml_norm(ctx0, cur, eps);
+        cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.ln_2_w), layer.ln_2_b);
+        ggml_set_name(cur, format("%d post layernorm", il).c_str());
+    }
+
+    {
+        // feed forward
+        cur = ggml_add(ctx0, ggml_mul_mat(ctx0, layer.ff_i_w, cur), layer.ff_i_b);
+        cur = ggml_gelu_inplace(ctx0, cur);
+        cur = ggml_add(ctx0, ggml_mul_mat(ctx0, layer.ff_o_w, cur), layer.ff_o_b);
+        ggml_set_name(cur, format("%d feed forward", il).c_str());
+
+        if (layer.ff_gate != nullptr) {
+            cur = ggml_mul_inplace(ctx0, cur, layer.ff_gate);
+            ggml_set_name(cur, format("%d feed forward gate", il).c_str());
+        }
+    }
+
+    // residual 2
+    cur = ggml_add(ctx0, cur, embeddings);
+    ggml_set_name(cur, format("%d residual", il).c_str());
+
+    embeddings = cur;
+
+    return embeddings;
+}
+
+static ggml_cgraph *mllama_image_build_graph(mllama_ctx *ctx, const mllama_image_batch *imgs) {
+    const auto &model = ctx->vision_model;
+    const auto &hparams = model.hparams;
+
+    const int image_size = hparams.image_size;
+    const int image_size_width = image_size;
+    const int image_size_height = image_size;
+
+    const int patch_size = hparams.patch_size;
+    const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
+    const int num_positions = num_patches + (model.class_embedding == nullptr ? 0 : 1);
+    const int hidden_size = hparams.hidden_size;
+    const int n_head = hparams.n_head;
+    const int d_head = hidden_size / n_head;
+
+    const int batch_size = imgs->size;
+    REQUIRE(batch_size == 1);
+
+    int num_tiles = 4;
+    int num_channels = 3;
+    if (imgs->data != nullptr) {
+        num_tiles = imgs->data[0].num_tiles > 0 ? imgs->data[0].num_tiles : num_tiles;
+        num_channels = imgs->data[0].num_channels > 0 ? imgs->data[0].num_channels : num_channels;
+    }
+
+    struct ggml_init_params params = {
+        ctx->buf_compute_meta.size(), // mem_size
+        ctx->buf_compute_meta.data(), // mem_buffer
+        true,                         // no_alloc
+    };
+
+    struct ggml_context *ctx0 = ggml_init(params);
+    struct ggml_cgraph *gf = ggml_new_graph(ctx0);
+
+    struct ggml_tensor *inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, num_channels, num_tiles);
+    ggml_set_name(inp_raw, "inp_raw");
+    ggml_set_input(inp_raw);
+
+    struct ggml_tensor *inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+
+    inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, num_tiles);
+    inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
+
+    struct ggml_tensor *aspect_ratios = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, imgs->size);
+    ggml_set_name(aspect_ratios, "aspect_ratios");
+    ggml_set_input(aspect_ratios);
+
+    if (model.pre_tile_position_embeddings != nullptr) {
+        struct ggml_tensor *pre_tile_position_embeddings = ggml_get_rows(ctx0, model.pre_tile_position_embeddings, aspect_ratios);
+        ggml_set_name(pre_tile_position_embeddings, "pre_tile_position_embeddings");
+
+        pre_tile_position_embeddings = ggml_reshape_3d(ctx0, pre_tile_position_embeddings, hidden_size, 1, num_tiles);
+        if (model.pre_tile_position_embeddings_gate != nullptr) {
+            pre_tile_position_embeddings = ggml_mul_inplace(ctx0, pre_tile_position_embeddings, model.pre_tile_position_embeddings_gate);
+        }
+
+        inp = ggml_add(ctx0, inp, pre_tile_position_embeddings);
+    }
+
+    struct ggml_tensor *embeddings = inp;
+
+    if (model.class_embedding != nullptr) {
+        // concat class_embeddings and patch_embeddings
+        embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, num_tiles);
+        ggml_set_name(embeddings, "embeddings");
+        ggml_set_input(embeddings);
+        for (int i = 0; i < num_tiles; ++i) {
+            // repeat class embeddings for each tile
+            embeddings = ggml_acc_inplace(ctx0, embeddings, model.class_embedding, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], i * embeddings->nb[2]);
+        }
+
+        embeddings = ggml_acc_inplace(ctx0, embeddings, inp, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
+    }
+
+    struct ggml_tensor *positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
+    ggml_set_name(positions, "positions");
+    ggml_set_input(positions);
+
+    struct ggml_tensor *position_embd = ggml_get_rows(ctx0, model.position_embeddings, positions);
+    if (model.position_embeddings_gate != nullptr) {
+        position_embd = ggml_mul_inplace(ctx0, position_embd, model.position_embeddings_gate);
+    }
+
+    embeddings = ggml_add(ctx0, embeddings, position_embd);
+
+    if (model.tile_position_embeddings != nullptr) {
+        struct ggml_tensor *tile_position_embeddings = ggml_get_rows(ctx0, model.tile_position_embeddings, aspect_ratios);
+        ggml_set_name(tile_position_embeddings, "tile_position_embeddings");
+
+        tile_position_embeddings = ggml_reshape_3d(ctx0, tile_position_embeddings, hidden_size, num_positions, num_tiles);
+        if (model.tile_position_embeddings_gate != nullptr) {
+            tile_position_embeddings = ggml_mul_inplace(ctx0, tile_position_embeddings, model.tile_position_embeddings_gate);
+        }
+
+        embeddings = ggml_add(ctx0, embeddings, tile_position_embeddings);
+    }
+
+    // pre-layernorm
+    if (model.pre_ln_w != nullptr) {
+        embeddings = ggml_mul(ctx0, ggml_norm(ctx0, embeddings, hparams.eps), model.pre_ln_w);
+        if (model.pre_ln_b != nullptr) {
+            embeddings = ggml_add(ctx0, embeddings, model.pre_ln_b);
+        }
+
+        ggml_set_name(embeddings, "pre layernorm");
+    }
+
+    const int num_padding_patches = 8 - (embeddings->ne[1] % 8) % 8;
+
+    embeddings = ggml_pad(ctx0, embeddings, 0, num_padding_patches, 0, 0);
+    embeddings = ggml_view_3d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1] * embeddings->ne[2], batch_size, embeddings->nb[1], embeddings->nb[2] * embeddings->ne[3], 0);
+
+    std::vector<struct ggml_tensor *> intermediate_embeddings;
+
+    // encoder
+    for (size_t il = 0; il < model.layers.size(); il++) {
+        if (hparams.intermediate_layers[il]) {
+            intermediate_embeddings.push_back(embeddings);
+        }
+
+        embeddings = mllama_image_build_encoder_layer(
+            ctx0, il, model.layers[il], embeddings,
+            hparams.eps, hidden_size, batch_size, n_head, d_head);
+    }
+
+    // post-layernorm
+    if (model.post_ln_w != nullptr) {
+        embeddings = ggml_mul(ctx0, ggml_norm(ctx0, embeddings, hparams.eps), model.post_ln_w);
+        if (model.post_ln_b != nullptr) {
+            embeddings = ggml_add(ctx0, embeddings, model.post_ln_b);
+        }
+
+        ggml_set_name(embeddings, "post layernorm");
+    }
+
+    embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, num_positions + num_padding_patches, num_tiles);
+
+    if (model.post_tile_position_embeddings != nullptr) {
+        struct ggml_tensor *post_tile_position_embeddings = ggml_get_rows(ctx0, model.post_tile_position_embeddings, aspect_ratios);
+        ggml_set_name(post_tile_position_embeddings, "post_tile_position_embeddings");
+
+        post_tile_position_embeddings = ggml_reshape_3d(ctx0, post_tile_position_embeddings, hidden_size, 1, num_tiles);
+        if (model.post_tile_position_embeddings_gate != nullptr) {
+            post_tile_position_embeddings = ggml_mul(ctx0, post_tile_position_embeddings, model.post_tile_position_embeddings_gate);
+        }
+
+        embeddings = ggml_add(ctx0, embeddings, post_tile_position_embeddings);
+    }
+
+    embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, num_tiles * (num_positions + num_padding_patches), 1);
+
+    // global encoder
+    for (size_t il = 0; il < model.global_layers.size(); il++) {
+        embeddings = mllama_image_build_encoder_layer(
+            ctx0, il, model.global_layers[il], embeddings,
+            hparams.eps, hidden_size, batch_size, n_head, d_head);
+    }
+
+    struct ggml_tensor *stacked_embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 0, hidden_size, (num_positions + num_padding_patches) * num_tiles);
+    for (size_t i = 0; i < intermediate_embeddings.size(); ++i) {
+        stacked_embeddings = ggml_concat(ctx0, stacked_embeddings, ggml_reshape_3d(ctx0, intermediate_embeddings[i], 1, intermediate_embeddings[i]->ne[0], intermediate_embeddings[i]->ne[1]), 0);
+    }
+
+    stacked_embeddings = ggml_reshape_4d(ctx0, stacked_embeddings, intermediate_embeddings.size() * hidden_size, num_positions + num_padding_patches, num_tiles, batch_size);
+    stacked_embeddings = ggml_unpad(ctx0, stacked_embeddings, 0, num_padding_patches, 0, 0);
+
+    embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, num_positions + num_padding_patches, num_tiles);
+    embeddings = ggml_unpad(ctx0, embeddings, 0, num_padding_patches, 0, 0);
+    embeddings = ggml_concat(ctx0, embeddings, stacked_embeddings, 0);
+
+    // mllama projector
+    embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_0_w, embeddings), model.mm_0_b);
+    ggml_set_name(embeddings, "multi modal projector");
+
+    // build the graph
+    ggml_build_forward_expand(gf, embeddings);
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+
+static struct ggml_tensor *mllama_tensor_load(struct ggml_context *ctx, const char *name, const bool optional) {
+    struct ggml_tensor *cur = ggml_get_tensor(ctx, name);
+    REQUIRE(cur != nullptr || optional);
+    return cur;
+}
+
+static std::vector<struct mllama_layer> mllama_layers_load(struct ggml_context *ctx, const char *prefix, const int n) {
+    std::vector<struct mllama_layer> layers(n);
+    for (size_t i = 0; i < layers.size(); i++) {
+        auto &layer = layers[i];
+        layer.ln_1_w = mllama_tensor_load(ctx, format("%s.blk.%d.ln1.weight", prefix, i).c_str(), false);
+        layer.ln_1_b = mllama_tensor_load(ctx, format("%s.blk.%d.ln1.bias", prefix, i).c_str(), false);
+        layer.ln_2_w = mllama_tensor_load(ctx, format("%s.blk.%d.ln2.weight", prefix, i).c_str(), false);
+        layer.ln_2_b = mllama_tensor_load(ctx, format("%s.blk.%d.ln2.bias", prefix, i).c_str(), false);
+
+        layer.k_w = mllama_tensor_load(ctx, format("%s.blk.%d.attn_k.weight", prefix, i).c_str(), false);
+        layer.k_b = mllama_tensor_load(ctx, format("%s.blk.%d.attn_k.bias", prefix, i).c_str(), true);
+        layer.q_w = mllama_tensor_load(ctx, format("%s.blk.%d.attn_q.weight", prefix, i).c_str(), false);
+        layer.q_b = mllama_tensor_load(ctx, format("%s.blk.%d.attn_q.bias", prefix, i).c_str(), true);
+        layer.v_w = mllama_tensor_load(ctx, format("%s.blk.%d.attn_v.weight", prefix, i).c_str(), false);
+        layer.v_b = mllama_tensor_load(ctx, format("%s.blk.%d.attn_v.bias", prefix, i).c_str(), true);
+        layer.o_w = mllama_tensor_load(ctx, format("%s.blk.%d.attn_out.weight", prefix, i).c_str(), false);
+        layer.o_b = mllama_tensor_load(ctx, format("%s.blk.%d.attn_out.bias", prefix, i).c_str(), true);
+
+        layer.ff_i_w = mllama_tensor_load(ctx, format("%s.blk.%d.ffn_down.weight", prefix, i).c_str(), false);
+        layer.ff_i_b = mllama_tensor_load(ctx, format("%s.blk.%d.ffn_down.bias", prefix, i).c_str(), false);
+        layer.ff_o_w = mllama_tensor_load(ctx, format("%s.blk.%d.ffn_up.weight", prefix, i).c_str(), false);
+        layer.ff_o_b = mllama_tensor_load(ctx, format("%s.blk.%d.ffn_up.bias", prefix, i).c_str(), false);
+
+        layer.attn_gate = mllama_tensor_load(ctx, format("%s.blk.%d.attn_gate", prefix, i).c_str(), true);
+        layer.ff_gate = mllama_tensor_load(ctx, format("%s.blk.%d.ffn_gate", prefix, i).c_str(), true);
+    }
+
+    return layers;
+}
+
+// read and create ggml_context containing the tensors and their data
+struct mllama_ctx *mllama_model_load(const char *fname, const int verbosity = 1) {
+    struct ggml_context *meta = nullptr;
+
+    struct gguf_init_params params = {
+        true,  // no_alloc
+        &meta, // ctx
+    };
+
+    struct gguf_context *ctx = gguf_init_from_file(fname, params);
+    REQUIRE(ctx != nullptr);
+
+    if (verbosity >= 1) {
+        const int n_tensors = gguf_get_n_tensors(ctx);
+        const int n_kv = gguf_get_n_kv(ctx);
+        const std::string ftype = get_ftype(get_u32(ctx, "general.file_type"));
+        const int idx_desc = get_key_index(ctx, "general.description");
+        const std::string description = gguf_get_val_str(ctx, idx_desc);
+        const int idx_name = gguf_find_key(ctx, "general.name");
+        if (idx_name != -1) { // make name optional temporarily as some of the uploaded models missing it due to a bug
+            const std::string name = gguf_get_val_str(ctx, idx_name);
+            LOG("model name:   %s", name.c_str());
+        }
+        LOG("description:  %s", description.c_str());
+        LOG("GGUF version: %d", gguf_get_version(ctx));
+        LOG("alignment:    %zu", gguf_get_alignment(ctx));
+        LOG("n_tensors:    %d", n_tensors);
+        LOG("n_kv:         %d", n_kv);
+        LOG("ftype:        %s", ftype.c_str());
+        LOG("");
+    }
+    const int n_tensors = gguf_get_n_tensors(ctx);
+
+    mllama_ctx *new_mllama = new mllama_ctx{};
+
+#ifdef GGML_USE_CUDA
+    new_mllama->backend = ggml_backend_cuda_init(0);
+    LOG("vision using CUDA backend");
+#endif
+
+#ifdef GGML_USE_METAL
+    new_mllama->backend = ggml_backend_metal_init();
+    LOG("vision using Metal backend");
+#endif
+
+#ifdef GGML_USE_CANN
+    new_mllama->backend = ggml_backend_cann_init(0);
+    LOG("vision using CANN backend");
+#endif
+
+#ifdef GGML_USE_VULKAN
+    new_mllama->backend = ggml_backend_vk_init(0);
+    LOG("vision using Vulkan backend");
+#endif
+
+    if (!new_mllama->backend) {
+        new_mllama->backend = ggml_backend_cpu_init();
+        LOG("vision using CPU backend");
+    }
+
+    // load tensors
+    {
+        std::vector<uint8_t> read_buf;
+        struct ggml_init_params params = {
+            (n_tensors + 1) * ggml_tensor_overhead(), // mem_size
+            nullptr,                                  // mem_buffer
+            true,                                     // no_alloc
+        };
+
+        new_mllama->ctx_data = ggml_init(params);
+        if (!new_mllama->ctx_data) {
+            LOG("ggml_init() failed");
+            mllama_free(new_mllama);
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+#ifdef _WIN32
+        int wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, NULL, 0);
+        if (!wlen) {
+            return NULL;
+        }
+        wchar_t * wbuf = (wchar_t *) malloc(wlen * sizeof(wchar_t));
+        wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, wbuf, wlen);
+        if (!wlen) {
+            free(wbuf);
+            return NULL;
+        }
+#if __GLIBCXX__
+        int fd = _wopen(wbuf, _O_RDONLY | _O_BINARY);
+        __gnu_cxx::stdio_filebuf<char> buffer(fd, std::ios_base::in);
+        std::istream fin(&buffer);
+#else // MSVC
+        // unused in our current build
+        auto fin = std::ifstream(wbuf, std::ios::binary);
+#endif
+        free(wbuf);
+#else
+        auto fin = std::ifstream(fname, std::ios::binary);
+#endif
+        if (!fin) {
+            LOG("cannot open model file for loading tensors\n");
+            mllama_free(new_mllama);
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+        // add tensors to context
+        for (int i = 0; i < n_tensors; ++i) {
+            const char *name = gguf_get_tensor_name(ctx, i);
+            struct ggml_tensor *t = ggml_get_tensor(meta, name);
+            struct ggml_tensor *cur = ggml_dup_tensor(new_mllama->ctx_data, t);
+            ggml_set_name(cur, name);
+        }
+
+        // alloc memory and offload data
+        new_mllama->params_buffer = ggml_backend_alloc_ctx_tensors(new_mllama->ctx_data, new_mllama->backend);
+        for (int i = 0; i < n_tensors; ++i) {
+            const char *name = gguf_get_tensor_name(ctx, i);
+            struct ggml_tensor *cur = ggml_get_tensor(new_mllama->ctx_data, name);
+            const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i);
+            fin.seekg(offset, std::ios::beg);
+            if (!fin) {
+                LOG("failed to seek for tensor %s\n", name);
+                mllama_free(new_mllama);
+                gguf_free(ctx);
+                return nullptr;
+            }
+            int num_bytes = ggml_nbytes(cur);
+            if (ggml_backend_buffer_is_host(new_mllama->params_buffer)) {
+                // for the CPU and Metal backend, we can read directly into the tensor
+                fin.read(reinterpret_cast<char *>(cur->data), num_bytes);
+            } else {
+                // read into a temporary buffer first, then copy to device memory
+                read_buf.resize(num_bytes);
+                fin.read(reinterpret_cast<char *>(read_buf.data()), num_bytes);
+                ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
+            }
+        }
+
+#if defined(_WIN32) && defined(__GLIBCXX__)
+        close(fd);
+#else
+        fin.close();
+#endif
+    }
+
+    // vision model
+    // load vision model
+    auto &vision_model = new_mllama->vision_model;
+    auto &hparams = vision_model.hparams;
+    hparams.hidden_size = get_u32(ctx, "mllama.vision.embedding_length");
+    hparams.n_head = get_u32(ctx, "mllama.vision.attention.head_count");
+    hparams.n_intermediate = get_u32(ctx, "mllama.vision.feed_forward_length");
+    hparams.n_layer = get_u32(ctx, "mllama.vision.block_count");
+    hparams.n_global_layer = get_u32(ctx, "mllama.vision.global.block_count");
+    hparams.n_tiles = get_u32(ctx, "mllama.vision.max_num_tiles");
+    hparams.image_size = get_u32(ctx, "mllama.vision.image_size");
+    hparams.patch_size = get_u32(ctx, "mllama.vision.patch_size");
+    hparams.projection_dim = get_u32(ctx, "mllama.vision.projection_dim");
+    hparams.eps = get_f32(ctx, "mllama.vision.attention.layer_norm_epsilon");
+
+    std::vector<uint32_t> intermediate_layers_indices = get_u32_array(ctx, "mllama.vision.intermediate_layers_indices");
+    hparams.intermediate_layers.resize(hparams.n_layer);
+    for (size_t i = 0; i < intermediate_layers_indices.size(); i++) {
+        hparams.intermediate_layers[intermediate_layers_indices[i]] = true;
+    }
+
+    if (verbosity >= 2) {
+        LOG("");
+        LOG("vision model hparams");
+        LOG("image_size         %d", hparams.image_size);
+        LOG("patch_size         %d", hparams.patch_size);
+        LOG("v_hidden_size      %d", hparams.hidden_size);
+        LOG("v_n_intermediate   %d", hparams.n_intermediate);
+        LOG("v_projection_dim   %d", hparams.projection_dim);
+        LOG("v_n_head           %d", hparams.n_head);
+        LOG("v_n_layer          %d", hparams.n_layer);
+        LOG("v_n_global_layer   %d", hparams.n_global_layer);
+        LOG("v_eps              %f", hparams.eps);
+    }
+
+    vision_model.class_embedding = mllama_tensor_load(new_mllama->ctx_data, "v.class_embd", true);
+    vision_model.patch_embeddings = mllama_tensor_load(new_mllama->ctx_data, "v.patch_embd.weight", true);
+
+    vision_model.position_embeddings = mllama_tensor_load(new_mllama->ctx_data, "v.position_embd.weight", true);
+    vision_model.position_embeddings_gate = mllama_tensor_load(new_mllama->ctx_data, "v.position_embd.gate", true);
+
+    vision_model.pre_ln_w = mllama_tensor_load(new_mllama->ctx_data, "v.pre_ln.weight", true);
+    vision_model.pre_ln_b = mllama_tensor_load(new_mllama->ctx_data, "v.pre_ln.bias", true);
+    vision_model.post_ln_w = mllama_tensor_load(new_mllama->ctx_data, "v.post_ln.weight", true);
+    vision_model.post_ln_b = mllama_tensor_load(new_mllama->ctx_data, "v.post_ln.bias", true);
+
+    vision_model.tile_position_embeddings = mllama_tensor_load(new_mllama->ctx_data, "v.tile_position_embd.weight", true);
+    vision_model.tile_position_embeddings_gate = mllama_tensor_load(new_mllama->ctx_data, "v.tile_position_embd.gate", true);
+
+    vision_model.pre_tile_position_embeddings = mllama_tensor_load(new_mllama->ctx_data, "v.pre_tile_position_embd.weight", true);
+    vision_model.pre_tile_position_embeddings_gate = mllama_tensor_load(new_mllama->ctx_data, "v.pre_tile_position_embd.gate", true);
+
+    vision_model.post_tile_position_embeddings = mllama_tensor_load(new_mllama->ctx_data, "v.post_tile_position_embd.weight", true);
+    vision_model.post_tile_position_embeddings_gate = mllama_tensor_load(new_mllama->ctx_data, "v.post_tile_position_embd.gate", true);
+
+    vision_model.mm_0_w = mllama_tensor_load(new_mllama->ctx_data, "mm.0.weight", false);
+    vision_model.mm_0_b = mllama_tensor_load(new_mllama->ctx_data, "mm.0.bias", false);
+
+    vision_model.layers = mllama_layers_load(new_mllama->ctx_data, "v", hparams.n_layer);
+    vision_model.global_layers = mllama_layers_load(new_mllama->ctx_data, "v.global", hparams.n_global_layer);
+
+    ggml_free(meta);
+
+    new_mllama->ctx_gguf = ctx;
+
+    {
+        // measure mem requirement and allocate
+        new_mllama->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
+        new_mllama->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_mllama->backend));
+        struct mllama_image_batch batch;
+        batch.size = 1;
+        ggml_cgraph *gf = mllama_image_build_graph(new_mllama, &batch);
+        ggml_gallocr_reserve(new_mllama->compute_alloc, gf);
+        size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_mllama->compute_alloc, 0);
+        LOG("compute allocated memory: %.2f MB", compute_memory_buffer_size / 1024.0 / 1024.0);
+    }
+
+    return new_mllama;
+}
+
+struct mllama_image *mllama_image_init() {
+    return new mllama_image();
+}
+
+void mllama_image_free(struct mllama_image *img) { delete img; }
+void mllama_image_batch_free(struct mllama_image_batch *batch) {
+    if (batch->size > 0) {
+        delete[] batch->data;
+        batch->size = 0;
+    }
+}
+
+bool mllama_image_load_from_data(const void *data, const int n, const int width, const int height, const int num_channels, const int num_tiles, const int aspect_ratio_id, struct mllama_image *img) {
+    img->width = width;
+    img->height = height;
+    img->num_channels = num_channels;
+    img->num_tiles = num_tiles;
+    img->aspect_ratio_id = aspect_ratio_id;
+    img->data.resize(n);
+
+    memcpy(img->data.data(), data, n);
+    return true;
+}
+
+inline int mllama(int x, int lower, int upper) {
+    return std::max(lower, std::min(x, upper));
+}
+
+void mllama_free(mllama_ctx *ctx) {
+    ggml_free(ctx->ctx_data);
+    gguf_free(ctx->ctx_gguf);
+
+    ggml_backend_buffer_free(ctx->params_buffer);
+    ggml_backend_free(ctx->backend);
+    ggml_gallocr_free(ctx->compute_alloc);
+    delete ctx;
+}
+
+bool mllama_image_encode(struct mllama_ctx *ctx, const int n_threads, mllama_image *img, float *vec) {
+    mllama_image_batch imgs{};
+    imgs.size = 1;
+    imgs.data = img;
+    return mllama_image_batch_encode(ctx, n_threads, &imgs, vec);
+}
+
+bool mllama_image_batch_encode(mllama_ctx *ctx, const int n_threads, const mllama_image_batch *imgs, float *vec) {
+    int batch_size = imgs->size;
+    REQUIRE(batch_size == 1);
+
+    // build the inference graph
+    ggml_cgraph *gf = mllama_image_build_graph(ctx, imgs);
+    ggml_gallocr_alloc_graph(ctx->compute_alloc, gf);
+
+    // set inputs
+    const auto &model = ctx->vision_model;
+    const auto &hparams = model.hparams;
+
+    const int image_size = hparams.image_size;
+    int image_size_width = image_size;
+    int image_size_height = image_size;
+
+    const int patch_size = hparams.patch_size;
+    const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
+    const int num_positions = num_patches + (model.class_embedding == nullptr ? 0 : 1);
+
+    {
+        struct ggml_tensor *inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
+        ggml_backend_tensor_set(inp_raw, imgs->data[0].data.data(), 0, ggml_nbytes(inp_raw));
+    }
+
+    {
+        struct ggml_tensor *embeddings = ggml_graph_get_tensor(gf, "embeddings");
+        if (embeddings != nullptr) {
+            void *zeros = malloc(ggml_nbytes(embeddings));
+            memset(zeros, 0, ggml_nbytes(embeddings));
+            ggml_backend_tensor_set(embeddings, zeros, 0, ggml_nbytes(embeddings));
+            free(zeros);
+        }
+    }
+
+    {
+        struct ggml_tensor *positions = ggml_graph_get_tensor(gf, "positions");
+        if (positions != nullptr) {
+            int *positions_data = (int *)malloc(ggml_nbytes(positions));
+            for (int i = 0; i < num_positions; i++) {
+                positions_data[i] = i;
+            }
+            ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
+            free(positions_data);
+        }
+    }
+
+    {
+        struct ggml_tensor *aspect_ratios = ggml_graph_get_tensor(gf, "aspect_ratios");
+        if (aspect_ratios != nullptr) {
+            int *aspect_ratios_data = (int *)malloc(ggml_nbytes(aspect_ratios));
+            aspect_ratios_data[0] = imgs->data[0].aspect_ratio_id;
+            ggml_backend_tensor_set(aspect_ratios, aspect_ratios_data, 0, ggml_nbytes(aspect_ratios));
+            free(aspect_ratios_data);
+        }
+    }
+
+    if (ggml_backend_is_cpu(ctx->backend)) {
+        ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
+    }
+
+    ggml_backend_graph_compute(ctx->backend, gf);
+
+    // the last node is the embedding tensor
+    struct ggml_tensor *embeddings = ggml_graph_node(gf, ggml_graph_n_nodes(gf) - 1);
+
+    // copy the embeddings to the location passed by the user
+    ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
+
+    return true;
+}
+
+int32_t mllama_image_size(const struct mllama_ctx *ctx) {
+    return ctx->vision_model.hparams.image_size;
+}
+
+int32_t mllama_patch_size(const struct mllama_ctx *ctx) {
+    return ctx->vision_model.hparams.patch_size;
+}
+
+int32_t mllama_hidden_size(const struct mllama_ctx *ctx) {
+    return ctx->vision_model.hparams.hidden_size;
+}
+
+int mllama_n_patches(const struct mllama_ctx *ctx) {
+    const auto &hparams = ctx->vision_model.hparams;
+    return (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size);
+}
+
+int mllama_n_positions(const struct mllama_ctx *ctx) {
+    return mllama_n_patches(ctx) + (ctx->vision_model.class_embedding == nullptr ? 0 : 1);
+}
+
+int mllama_n_tiles(const struct mllama_ctx *ctx) {
+    return ctx->vision_model.hparams.n_tiles;
+}
+
+int mllama_n_embd(const struct mllama_ctx *ctx) {
+    return ctx->vision_model.hparams.projection_dim;
+}
+
+size_t mllama_n_embd_bytes(const struct mllama_ctx *ctx) {
+    return mllama_n_positions(ctx) * mllama_n_embd(ctx) * mllama_n_tiles(ctx) * sizeof(float);
+}

+ 61 - 0
llama/mllama.h

@@ -0,0 +1,61 @@
+#ifndef MLLAMA_H
+#define MLLAMA_H
+
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef LLAMA_SHARED
+#if defined(_WIN32) && !defined(__MINGW32__)
+#ifdef LLAMA_BUILD
+#define MLLAMA_API __declspec(dllexport)
+#else
+#define MLLAMA_API __declspec(dllimport)
+#endif
+#else
+#define MLLAMA_API __attribute__((visibility("default")))
+#endif
+#else
+#define MLLAMA_API
+#endif
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct mllama_ctx;
+
+struct mllama_image_batch {
+    struct mllama_image *data;
+    size_t size;
+};
+
+MLLAMA_API struct mllama_ctx *mllama_model_load(const char *fname, int verbosity);
+MLLAMA_API struct mllama_ctx *mllama_model_load_cpu(const char *fname, int verbosity);
+
+MLLAMA_API void mllama_free(struct mllama_ctx *ctx);
+
+MLLAMA_API int32_t mllama_image_size(const struct mllama_ctx *ctx);
+MLLAMA_API int32_t mllama_patch_size(const struct mllama_ctx *ctx);
+MLLAMA_API int32_t mllama_hidden_size(const struct mllama_ctx *ctx);
+
+MLLAMA_API int mllama_n_patches(const struct mllama_ctx *ctx);
+MLLAMA_API int mllama_n_positions(const struct mllama_ctx *ctx);
+MLLAMA_API int mllama_n_tiles(const struct mllama_ctx *ctx);
+MLLAMA_API int mllama_n_embd(const struct mllama_ctx *ctx);
+MLLAMA_API size_t mllama_n_embd_bytes(const struct mllama_ctx *ctx);
+
+MLLAMA_API struct mllama_image *mllama_image_init();
+
+MLLAMA_API void mllama_image_free(struct mllama_image *img);
+MLLAMA_API void mllama_image_batch_free(struct mllama_image_batch *batch);
+
+MLLAMA_API bool mllama_image_load_from_data(const void *data, const int n, const int nx, const int ny, const int nc, const int nt, const int aspect_ratio_id, struct mllama_image *img);
+
+MLLAMA_API bool mllama_image_encode(struct mllama_ctx *ctx, int n_threads, struct mllama_image *img, float *vec);
+MLLAMA_API bool mllama_image_batch_encode(struct mllama_ctx *ctx, int n_threads, const struct mllama_image_batch *imgs, float *vec);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLLAMA_H

+ 690 - 0
llama/patches/0010-add-mllama-support.patch

@@ -0,0 +1,690 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: jmorganca <jmorganca@gmail.com>
+Date: Thu, 17 Oct 2024 15:18:22 -0700
+Subject: [PATCH] add mllama support
+
+mllama adds cross-attention layers to the standard llama architecture
+it also requires a way to input a new tensor: cross_attention_state
+once per generation
+
+cross-attention layers don't change and so they are cached in the
+kv cache once per run
+
+remaining is to implement the cross attention mask
+---
+ include/llama.h |   4 +
+ src/llama.cpp   | 456 ++++++++++++++++++++++++++++++++++++++++++++++--
+ 2 files changed, 447 insertions(+), 13 deletions(-)
+
+diff --git a/include/llama.h b/include/llama.h
+index 7cae1bbe..122e3cf1 100644
+--- a/include/llama.h
++++ b/include/llama.h
+@@ -423,6 +423,10 @@ extern "C" {
+                      struct llama_model * model,
+             struct llama_context_params   params);
+ 
++    // TODO (jmorganca): this should most likely be passed in as part of a batch
++    // and not set on the context for all batches.
++    LLAMA_API void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state);
++
+     // Frees all allocated memory
+     LLAMA_API void llama_free(struct llama_context * ctx);
+ 
+diff --git a/src/llama.cpp b/src/llama.cpp
+index 83b80b59..b189a19a 100644
+--- a/src/llama.cpp
++++ b/src/llama.cpp
+@@ -169,6 +169,7 @@ static std::string format(const char * fmt, ...) {
+ 
+ enum llm_arch {
+     LLM_ARCH_LLAMA,
++    LLM_ARCH_MLLAMA,
+     LLM_ARCH_FALCON,
+     LLM_ARCH_BAICHUAN,
+     LLM_ARCH_GROK,
+@@ -223,6 +224,7 @@ enum llm_arch {
+ 
+ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
+     { LLM_ARCH_LLAMA,           "llama"        },
++    { LLM_ARCH_MLLAMA,          "mllama"       },
+     { LLM_ARCH_FALCON,          "falcon"       },
+     { LLM_ARCH_GROK,            "grok"         },
+     { LLM_ARCH_GPT2,            "gpt2"         },
+@@ -330,6 +332,7 @@ enum llm_kv {
+     LLM_KV_ATTENTION_SLIDING_WINDOW,
+     LLM_KV_ATTENTION_SCALE,
+     LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
++    LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,
+ 
+     LLM_KV_ROPE_DIMENSION_COUNT,
+     LLM_KV_ROPE_FREQ_BASE,
+@@ -439,6 +442,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
+     { LLM_KV_ATTENTION_SLIDING_WINDOW,         "%s.attention.sliding_window"           },
+     { LLM_KV_ATTENTION_SCALE,                  "%s.attention.scale"                    },
+     { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,  "%s.attention.block_skip_connection.%d" },
++    { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers"   },
+ 
+     { LLM_KV_ROPE_DIMENSION_COUNT,          "%s.rope.dimension_count"                 },
+     { LLM_KV_ROPE_FREQ_BASE,                "%s.rope.freq_base"                       },
+@@ -613,6 +617,14 @@ enum llm_tensor {
+     LLM_TENSOR_CLS,
+     LLM_TENSOR_CLS_OUT,
+     LLM_TENSOR_BSKCN_TV,
++    LLM_TENSOR_CROSS_ATTN_K_NORM,
++    LLM_TENSOR_CROSS_ATTN_K_PROJ,
++    LLM_TENSOR_CROSS_ATTN_O_PROJ,
++    LLM_TENSOR_CROSS_ATTN_Q_NORM,
++    LLM_TENSOR_CROSS_ATTN_Q_PROJ,
++    LLM_TENSOR_CROSS_ATTN_V_PROJ,
++    LLM_TENSOR_CROSS_ATTN_ATTN_GATE,
++    LLM_TENSOR_CROSS_ATTN_MLP_GATE,
+ };
+ 
+ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
+@@ -642,6 +654,40 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
+             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
+         },
+     },
++    {
++        LLM_ARCH_MLLAMA,
++        {
++            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
++            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
++            { LLM_TENSOR_OUTPUT,          "output" },
++            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
++            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
++            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
++            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
++            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
++            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
++            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
++            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
++            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
++            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
++            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
++            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
++            { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
++            { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
++            { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
++            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
++            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
++            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
++            { LLM_TENSOR_CROSS_ATTN_K_NORM,    "blk.%d.cross_attn_k_norm" },
++            { LLM_TENSOR_CROSS_ATTN_K_PROJ,    "blk.%d.cross_attn_k_proj" },
++            { LLM_TENSOR_CROSS_ATTN_O_PROJ,    "blk.%d.cross_attn_o_proj" },
++            { LLM_TENSOR_CROSS_ATTN_Q_NORM,    "blk.%d.cross_attn_q_norm" },
++            { LLM_TENSOR_CROSS_ATTN_Q_PROJ,    "blk.%d.cross_attn_q_proj" },
++            { LLM_TENSOR_CROSS_ATTN_V_PROJ,    "blk.%d.cross_attn_v_proj" },
++            { LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" },
++            { LLM_TENSOR_CROSS_ATTN_MLP_GATE,  "blk.%d.cross_attn_mlp_gate" },
++        },
++    },
+     {
+         LLM_ARCH_BAICHUAN,
+         {
+@@ -2390,6 +2436,7 @@ enum e_model {
+     MODEL_40B,
+     MODEL_65B,
+     MODEL_70B,
++    MODEL_90B,
+     MODEL_236B,
+     MODEL_314B,
+     MODEL_SMALL,
+@@ -2434,6 +2481,7 @@ struct llama_hparams {
+     std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
+ 
+     std::array<std::array<uint32_t, LLAMA_MAX_LAYERS>, 4> n_bskcn_arr;
++    std::array<uint32_t, LLAMA_MAX_LAYERS> cross_attn_layers;
+ 
+     uint32_t n_layer_dense_lead = 0;
+     uint32_t n_lora_q = 0;
+@@ -2502,10 +2550,11 @@ struct llama_hparams {
+         if (this->n_expert      != other.n_expert)      return true;
+         if (this->n_expert_used != other.n_expert_used) return true;
+ 
+-        if (this->n_head_arr    != other.n_head_arr)    return true;
+-        if (this->n_head_kv_arr != other.n_head_kv_arr) return true;
+-        if (this->n_ff_arr      != other.n_ff_arr)      return true;
+-        if (this->n_bskcn_arr   != other.n_bskcn_arr)   return true;
++        if (this->n_head_arr        != other.n_head_arr)        return true;
++        if (this->n_head_kv_arr     != other.n_head_kv_arr)     return true;
++        if (this->n_ff_arr          != other.n_ff_arr)          return true;
++        if (this->n_bskcn_arr       != other.n_bskcn_arr)       return true;
++        if (this->cross_attn_layers != other.cross_attn_layers) return true;
+ 
+         if (this->n_rel_attn_bkts    != other.n_rel_attn_bkts)    return true;
+         if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true;
+@@ -2623,6 +2672,10 @@ struct llama_hparams {
+ 
+         GGML_ABORT("fatal error");
+     }
++
++    bool cross_attention_layer(uint32_t il) const {
++        return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
++    }
+ };
+ 
+ static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
+@@ -2806,6 +2859,16 @@ struct llama_layer {
+     struct ggml_tensor * ffn_down_scale;
+ 
+     struct ggml_tensor * bskcn_tv;
++
++    // cross attention
++    struct ggml_tensor * cross_attn_k_norm;
++    struct ggml_tensor * cross_attn_k_proj;
++    struct ggml_tensor * cross_attn_o_proj;
++    struct ggml_tensor * cross_attn_q_norm;
++    struct ggml_tensor * cross_attn_q_proj;
++    struct ggml_tensor * cross_attn_v_proj;
++    struct ggml_tensor * cross_attn_attn_gate;
++    struct ggml_tensor * cross_attn_mlp_gate;
+ };
+ 
+ // very similar to llama_batch,
+@@ -3452,6 +3515,12 @@ struct llama_context {
+     struct ggml_tensor * inp_pos_bucket;    // I32 [n_batch|n_kv, n_batch]
+     struct ggml_tensor * inp_embd_enc;      // F32 [n_embd, n_outputs_enc]
+     struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
++
++    // TODO (jmorganca): this should most likely be passed in as part of a batch
++    // and not set on the context for all batches.
++    float * cross_attn_state = nullptr;
++    bool cross_attn_state_first_pass = true;
++    struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
+ };
+ 
+ struct llama_lora_weight {
+@@ -3686,6 +3755,18 @@ static bool llama_kv_cache_init(
+     cache.v_l.reserve(n_layer);
+ 
+     for (int i = 0; i < (int) n_layer; i++) {
++        // for cross attention layers
++        if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layer(i)) {
++            struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
++            ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
++            ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
++            ggml_format_name(k, "cache_k_l%d", i);
++            ggml_format_name(v, "cache_v_l%d", i);
++            cache.k_l.push_back(k);
++            cache.v_l.push_back(v);
++            continue;
++        }
++
+         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
+         const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
+ 
+@@ -5460,12 +5541,14 @@ static void llm_load_hparams(
+     }
+ 
+     // zero-out the per-layer hparams
+-    std::fill(hparams.n_head_arr.begin(),    hparams.n_head_arr.end(),    0);
+-    std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
+-    std::fill(hparams.n_ff_arr.begin(),      hparams.n_ff_arr.end(),      0);
++    std::fill(hparams.n_head_arr.begin(),             hparams.n_head_arr.end(),        0);
++    std::fill(hparams.n_head_kv_arr.begin(),          hparams.n_head_kv_arr.end(),     0);
++    std::fill(hparams.n_ff_arr.begin(),               hparams.n_ff_arr.end(),          0);
++    std::fill(hparams.cross_attn_layers.begin(),      hparams.cross_attn_layers.end(), -1);
+ 
+-    ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff_arr,   hparams.n_layer);
+-    ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
++    ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,       hparams.n_ff_arr,          hparams.n_layer);
++    ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT,      hparams.n_head_arr,        hparams.n_layer);
++    ml.get_arr(LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, hparams.cross_attn_layers, false);
+ 
+     // n_head_kv is optional, default to n_head
+     hparams.n_head_kv_arr = hparams.n_head_arr;
+@@ -5514,7 +5597,7 @@ static void llm_load_hparams(
+ 
+         ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
+ 
+-        if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
++        if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_MLLAMA || model.arch == LLM_ARCH_FALCON) {
+             if (hparams.n_rot != hparams.n_embd_head_k) {
+                 throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
+             }
+@@ -5554,6 +5637,16 @@ static void llm_load_hparams(
+                     }
+                 }
+             } break;
++        case LLM_ARCH_MLLAMA:
++            {
++                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
++
++                switch (hparams.n_layer) {
++                    case 40: model.type = e_model::MODEL_11B; break;
++                    case 100: model.type = e_model::MODEL_90B; break;
++                    default: model.type = e_model::MODEL_UNKNOWN;
++                }
++            } break;
+         case LLM_ARCH_MINICPM:
+             {
+                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+@@ -7249,6 +7342,55 @@ static bool llm_load_tensors(
+                         layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                     }
+                 } break;
++            case LLM_ARCH_MLLAMA:
++                {
++                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8});
++
++                    // output
++                    {
++                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
++                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
++
++                        // if output is NULL, init from the input tok embed
++                        if (model.output == NULL) {
++                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
++                        }
++                    }
++
++                    for (int i = 0; i < n_layer; ++i) {
++                        ggml_context * ctx_layer = ctx_for_layer(i);
++                        ggml_context * ctx_split = ctx_for_layer_split(i);
++
++                        auto & layer = model.layers[i];
++
++                        if (hparams.cross_attention_layer(i)) {
++                            layer.cross_attn_k_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_NORM,   "weight", i), {128});
++                            layer.cross_attn_k_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_K_PROJ,   "weight", i), {n_embd, 1024});
++                            layer.cross_attn_o_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_O_PROJ,   "weight", i), {n_embd, n_embd});
++                            layer.cross_attn_q_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_Q_NORM, "weight", i), {128});
++                            layer.cross_attn_q_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_Q_PROJ, "weight", i), {n_embd, n_embd});
++                            layer.cross_attn_v_proj = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_V_PROJ, "weight", i), {n_embd, 1024});
++                            layer.cross_attn_attn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_ATTN_GATE, i), {1});
++                            layer.cross_attn_mlp_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_CROSS_ATTN_MLP_GATE, i), {1});
++                            layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
++                            layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
++                            layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
++                            layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
++                            layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
++                        } else {
++                            layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
++                            layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
++                            layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
++                            layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
++                            layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
++                            layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
++                            layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
++                            layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
++                            layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
++                            layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
++                        }
++                    }
++                } break;
+             case LLM_ARCH_GROK:
+                 {
+                     if (n_expert == 0) {
+@@ -9093,7 +9235,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
+ 
+         if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
+             model.hparams.n_vocab != model.vocab.id_to_token.size()) {
+-            throw std::runtime_error("vocab size mismatch");
++            LLAMA_LOG_WARN("%s: vocab mismatch %u !- %zu ...\n", __func__, model.hparams.n_vocab, model.vocab.id_to_token.size());
+         }
+ 
+         if (params.vocab_only) {
+@@ -9178,7 +9320,7 @@ static struct ggml_tensor * llm_build_inp_embd(
+ 
+         inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
+     } else {
+-       lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
++        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
+         inpL = lctx.inp_embd;
+         ggml_set_input(lctx.inp_embd);
+     }
+@@ -9193,6 +9335,22 @@ static struct ggml_tensor * llm_build_inp_embd(
+     return inpL;
+ }
+ 
++static struct ggml_tensor * llm_build_inp_cross_attn_state(
++        struct ggml_context * ctx,
++       struct llama_context & lctx,
++        const llama_hparams & hparams,
++         const llm_build_cb & cb) {
++    const int64_t n_embd = hparams.n_embd;
++
++    struct ggml_tensor * inpCAS;
++    lctx.inp_cross_attn_state = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, 1601, 4);
++    cb(lctx.inp_cross_attn_state, "inp_cross_attn_state", -1);
++    ggml_set_input(lctx.inp_cross_attn_state);
++    inpCAS = lctx.inp_cross_attn_state;
++
++    return inpCAS;
++}
++
+ static void llm_build_kv_store(
+         struct ggml_context * ctx,
+         const llama_hparams & hparams,
+@@ -10167,6 +10325,7 @@ struct llm_build_context {
+         lctx.inp_pos_bucket    = nullptr;
+         lctx.inp_embd_enc      = nullptr;
+         lctx.inp_KQ_mask_cross = nullptr;
++        lctx.inp_cross_attn_state = nullptr;
+     }
+ 
+     void free() {
+@@ -10754,6 +10913,253 @@ struct llm_build_context {
+                 LLM_NORM_RMS, cb, -1);
+         cb(cur, "result_norm", -1);
+ 
++        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
++        cb(cur, "result_output", -1);
++
++        ggml_build_forward_expand(gf, cur);
++
++        return gf;
++    }
++
++    struct ggml_cgraph * build_mllama() {
++        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
++
++        // mutable variable, needed during the last layer of the computation to skip unused tokens
++        int32_t n_tokens = this->n_tokens;
++
++        const int64_t n_embd_head = hparams.n_embd_head_v;
++        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
++        GGML_ASSERT(n_embd_head == hparams.n_rot);
++
++        struct ggml_tensor * cur;
++        struct ggml_tensor * inpL;
++        struct ggml_tensor * inpCAS;
++
++        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
++        inpCAS = llm_build_inp_cross_attn_state(ctx0, lctx, hparams, cb);
++
++        // inp_pos - contains the positions
++        struct ggml_tensor * inp_pos = build_inp_pos();
++
++        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
++        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
++
++        for (int il = 0; il < n_layer; ++il) {
++            struct ggml_tensor * inpSA = inpL;
++
++            // norm
++            cur = llm_build_norm(ctx0, inpL, hparams,
++                    model.layers[il].attn_norm, NULL,
++                    LLM_NORM_RMS, cb, il);
++            cb(cur, "attn_norm", il);
++
++            if (hparams.cross_attention_layer(il)) {
++                if (!lctx.cross_attn_state) {
++                    continue;
++                }
++
++                // cross attention layer
++                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_q_proj, cur);
++                cb(Qcur, "Qcur", il);
++
++                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
++                cb(Qcur, "Qcur", il);
++
++                Qcur = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
++                cb(Qcur, "Qcur", il);
++
++                // TODO: is this required?
++                Qcur = ggml_cont(ctx0, Qcur);
++                cb(Qcur, "Qcur", il);
++
++                Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
++                cb(Qcur, "Qcur", il);
++
++                struct ggml_tensor * Kcur;
++                if (lctx.cross_attn_state_first_pass) {
++                    Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
++                    cb(Kcur, "Kcur", il);
++
++                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
++                    cb(Kcur, "Kcur", il);
++
++                    Kcur = ggml_permute(ctx0, Kcur, 0, 2, 1, 3);
++                    cb(Kcur, "Kcur", il);
++
++                    // TODO: is this required?
++                    Kcur = ggml_cont(ctx0, Kcur);
++                    cb(Kcur, "Kcur", il);
++
++                    Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
++                    cb(Kcur, "Kcur", il);
++
++                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self.k_l[il]));
++                } else {
++                    Kcur = ggml_view_tensor(ctx0, kv_self.k_l[il]);
++                    cb(Kcur, "Kcur (view)", il);
++                }
++
++                struct ggml_tensor * Vcur;
++                if (lctx.cross_attn_state_first_pass) {
++                    Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
++                    cb(Vcur, "Vcur", il);
++
++                    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, 6404);
++                    cb(Vcur, "Vcur", il);
++
++                    Vcur = ggml_permute(ctx0, Vcur, 0, 2, 1, 3);
++                    cb(Vcur, "Vcur", il);
++
++                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self.v_l[il]));
++                } else {
++                    Vcur = ggml_view_tensor(ctx0, kv_self.v_l[il]);
++                    cb(Vcur, "Vcur (view)", il);
++                }
++
++                struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
++                cb(kq, "kq", il);
++
++                kq = ggml_scale_inplace(ctx0, kq, 1.0f/sqrtf(float(n_embd_head)));
++                cb(kq, "kq_scaled", il);
++
++                // TODO: apply causal masks
++                struct ggml_tensor * kq_soft_max = ggml_soft_max_inplace(ctx0, kq);
++                cb(kq_soft_max, "kq_soft_max", il);
++
++                Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
++                cb(Vcur, "Vcur", il);
++
++                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, Vcur, kq_soft_max);
++                cb(kqv, "kqv", il);
++
++                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
++                cb(kqv_merged, "kqv_merged", il);
++
++                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
++                cb(cur, "kqv_merged_cont", il);
++
++                cur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_o_proj, cur);
++                cb(cur, "cur", il);
++
++                // TODO: do this in place once?
++                cur = ggml_mul(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_attn_gate));
++
++                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
++                cb(ffn_inp, "ffn_inp", il);
++
++                // feed-forward network
++                cur = llm_build_norm(ctx0, ffn_inp, hparams,
++                        model.layers[il].ffn_norm, NULL,
++                        LLM_NORM_RMS, cb, il);
++                cb(cur, "ffn_norm", il);
++
++                cur = llm_build_ffn(ctx0, lctx, cur,
++                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
++                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
++                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
++                        NULL,
++                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
++                cb(cur, "ffn_out", il);
++
++                // TODO: do this inplace once?
++                cur = ggml_add_inplace(ctx0, ggml_mul_inplace(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_mlp_gate)), ffn_inp);
++                cb(cur, "ffn_out", il);
++
++                cur = lctx.cvec.apply_to(ctx0, cur, il);
++                cb(cur, "l_out", il);
++
++                // input for next layer
++                inpL = cur;
++            } else {
++                // self attention layer
++
++                // rope freq factors for llama3; may return nullptr for llama2 and other models
++                struct ggml_tensor * rope_factors = build_rope_factors(il);
++
++                // compute Q and K and RoPE them
++                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
++                cb(Qcur, "Qcur", il);
++                if (model.layers[il].bq) {
++                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
++                    cb(Qcur, "Qcur", il);
++                }
++
++                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
++                cb(Kcur, "Kcur", il);
++                if (model.layers[il].bk) {
++                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
++                    cb(Kcur, "Kcur", il);
++                }
++
++                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
++                cb(Vcur, "Vcur", il);
++                if (model.layers[il].bv) {
++                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
++                    cb(Vcur, "Vcur", il);
++                }
++
++                Qcur = ggml_rope_ext(
++                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
++                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
++                    ext_factor, attn_factor, beta_fast, beta_slow
++                );
++                cb(Qcur, "Qcur", il);
++
++                Kcur = ggml_rope_ext(
++                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
++                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
++                    ext_factor, attn_factor, beta_fast, beta_slow
++                );
++                cb(Kcur, "Kcur", il);
++
++                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
++                        model.layers[il].wo, model.layers[il].bo,
++                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
++
++
++                if (il == n_layer - 1) {
++                    // skip computing output for unused tokens
++                    struct ggml_tensor * inp_out_ids = build_inp_out_ids();
++                    n_tokens = n_outputs;
++                    cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
++                    inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
++                }
++
++                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
++                cb(ffn_inp, "ffn_inp", il);
++
++                // feed-forward network
++                cur = llm_build_norm(ctx0, ffn_inp, hparams,
++                        model.layers[il].ffn_norm, NULL,
++                        LLM_NORM_RMS, cb, il);
++                cb(cur, "ffn_norm", il);
++
++                cur = llm_build_ffn(ctx0, lctx, cur,
++                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
++                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
++                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
++                        NULL,
++                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
++                cb(cur, "ffn_out", il);
++
++                cur = ggml_add(ctx0, cur, ffn_inp);
++                cb(cur, "ffn_out", il);
++
++                cur = lctx.cvec.apply_to(ctx0, cur, il);
++                cb(cur, "l_out", il);
++
++                // input for next layer
++                inpL = cur;
++            }
++        }
++
++        cur = inpL;
++
++        cur = llm_build_norm(ctx0, cur, hparams,
++                model.output_norm, NULL,
++                LLM_NORM_RMS, cb, -1);
++        cb(cur, "result_norm", -1);
++
+         // lm_head
+         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+         cb(cur, "result_output", -1);
+@@ -16501,6 +16907,10 @@ static struct ggml_cgraph * llama_build_graph(
+             {
+                 result = llm.build_llama();
+             } break;
++        case LLM_ARCH_MLLAMA:
++            {
++                result = llm.build_mllama();
++            } break;
+         case LLM_ARCH_BAICHUAN:
+             {
+                 result = llm.build_baichuan();
+@@ -16773,6 +17183,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
+         ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
+     }
+ 
++    // TODO (jmorganca): this might copy a lot of data on every request of a
++    // single generation even though it doesn't change, so we should
++    // find a way to not set this more than one time per image
++    if (lctx.inp_cross_attn_state &&
++        lctx.inp_cross_attn_state->buffer) {
++        ggml_backend_tensor_set(lctx.inp_cross_attn_state, lctx.cross_attn_state, 0, hparams.n_embd * 1601 * 4 * ggml_element_size(lctx.inp_cross_attn_state));
++    }
++
+     if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
+         GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
+         const int64_t n_tokens = batch.n_tokens;
+@@ -17455,6 +17873,10 @@ static int llama_decode_internal(
+ 
+         llama_set_inputs(lctx, ubatch);
+ 
++        // TODO: replace with something better to find out if its
++        // our first actual pass
++        lctx.cross_attn_state_first_pass = false;
++
+         llama_graph_compute(lctx, gf, n_threads, threadpool);
+ 
+         // update the kv ring buffer
+@@ -18648,7 +19070,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
+         if (llama_model_has_encoder(&model)) {
+             n_attn_layer *= 3;
+         }
+-        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
++        if (qs.n_attention_wv != n_attn_layer) {
++            LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv);
++        }
+     }
+ 
+     size_t total_size_org = 0;
+@@ -19744,6 +20168,11 @@ struct llama_context * llama_new_context_with_model(
+     return ctx;
+ }
+ 
++void llama_set_cross_attn_state(struct llama_context * ctx, float * cross_attn_state) {
++    ctx->cross_attn_state_first_pass = true;
++    ctx->cross_attn_state = cross_attn_state;
++}
++
+ void llama_free(struct llama_context * ctx) {
+     delete ctx;
+ }
+@@ -19814,6 +20243,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
+ 
+         // use what we call a normal RoPE, operating on pairs of consecutive head values
+         case LLM_ARCH_LLAMA:
++        case LLM_ARCH_MLLAMA:
+         case LLM_ARCH_BAICHUAN:
+         case LLM_ARCH_STARCODER:
+         case LLM_ARCH_PLAMO:

+ 409 - 0
llama/patches/0011-add-unpad-operator.patch

@@ -0,0 +1,409 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: Michael Yang <mxyng@pm.me>
+Date: Thu, 17 Oct 2024 17:19:25 -0700
+Subject: [PATCH] add unpad operator
+
+---
+ ggml/include/ggml.h        | 10 ++++
+ ggml/src/ggml-cuda.cu      |  4 ++
+ ggml/src/ggml-cuda/pad.cu  | 46 +++++++++++++++++++
+ ggml/src/ggml-cuda/pad.cuh |  1 +
+ ggml/src/ggml-metal.m      | 33 ++++++++++++++
+ ggml/src/ggml-metal.metal  | 45 ++++++++++++++++++
+ ggml/src/ggml.c            | 93 +++++++++++++++++++++++++++++++++++++-
+ 7 files changed, 230 insertions(+), 2 deletions(-)
+
+diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
+index ce3d92cb..962cb5f7 100644
+--- a/ggml/include/ggml.h
++++ b/ggml/include/ggml.h
+@@ -506,6 +506,7 @@ extern "C" {
+         GGML_OP_POOL_2D_BACK,
+         GGML_OP_UPSCALE, // nearest interpolate
+         GGML_OP_PAD,
++        GGML_OP_UNPAD,
+         GGML_OP_ARANGE,
+         GGML_OP_TIMESTEP_EMBEDDING,
+         GGML_OP_ARGSORT,
+@@ -1764,6 +1765,15 @@ extern "C" {
+             int                  p2,
+             int                  p3);
+ 
++    // unpad each dimension: [x, ..., x, y, ..., y] -> [x, ..., x]
++    GGML_API struct ggml_tensor * ggml_unpad(
++            struct ggml_context * ctx,
++            struct ggml_tensor  * a,
++            int                  p0,
++            int                  p1,
++            int                  p2,
++            int                  p3);
++
+     // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
+     // timesteps: [N,]
+     // return: [N, dim]
+diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
+index fe77b81c..6e84af56 100644
+--- a/ggml/src/ggml-cuda.cu
++++ b/ggml/src/ggml-cuda.cu
+@@ -2270,6 +2270,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
+         case GGML_OP_PAD:
+             ggml_cuda_op_pad(ctx, dst);
+             break;
++        case GGML_OP_UNPAD:
++            ggml_cuda_op_unpad(ctx, dst);
++            break;
+         case GGML_OP_ARANGE:
+             ggml_cuda_op_arange(ctx, dst);
+             break;
+@@ -2992,6 +2995,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
+         case GGML_OP_GROUP_NORM:
+         case GGML_OP_UPSCALE:
+         case GGML_OP_PAD:
++        case GGML_OP_UNPAD:
+         case GGML_OP_ARANGE:
+         case GGML_OP_TIMESTEP_EMBEDDING:
+         case GGML_OP_LEAKY_RELU:
+diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu
+index aba539e8..39fd4b16 100644
+--- a/ggml/src/ggml-cuda/pad.cu
++++ b/ggml/src/ggml-cuda/pad.cu
+@@ -47,3 +47,49 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+         dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
+ }
++
++static __global__ void unpad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
++    // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
++    // blockIdx.y: idx of ne1
++    // blockIDx.x: idx of ne0 / BLOCK_SIZE
++    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
++    if (nidx >= ne0) {
++        return;
++    }
++
++    // operation
++    int offset_dst =
++        nidx +
++        blockIdx.y * ne0 +
++        blockIdx.z * ne0 * gridDim.y;
++    if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
++        int offset_src =
++            nidx +
++            blockIdx.y * ne00 +
++            blockIdx.z * ne00 * ne01;
++        dst[offset_dst] = x[offset_src];
++    }
++}
++
++static void unpad_f32_cuda(const float * x, float * dst,
++    const int ne00, const int ne01, const int ne02, const int ne03,
++    const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
++    int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
++    dim3 gridDim(num_blocks, ne1, ne2*ne3);
++    unpad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
++}
++
++void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
++    const ggml_tensor * src0 = dst->src[0];
++    const float * src0_d = (const float *)src0->data;
++    float * dst_d = (float *)dst->data;
++    cudaStream_t stream = ctx.stream();
++
++    GGML_ASSERT(src0->type == GGML_TYPE_F32);
++    GGML_ASSERT(dst->type == GGML_TYPE_F32);
++    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
++
++    unpad_f32_cuda(src0_d, dst_d,
++        src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
++        dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
++}
+diff --git a/ggml/src/ggml-cuda/pad.cuh b/ggml/src/ggml-cuda/pad.cuh
+index 8fd386b0..e2ededc3 100644
+--- a/ggml/src/ggml-cuda/pad.cuh
++++ b/ggml/src/ggml-cuda/pad.cuh
+@@ -3,3 +3,4 @@
+ #define CUDA_PAD_BLOCK_SIZE 256
+ 
+ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
++void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
+index 829c5e39..25702d85 100644
+--- a/ggml/src/ggml-metal.m
++++ b/ggml/src/ggml-metal.m
+@@ -193,6 +193,7 @@
+     GGML_METAL_KERNEL_TYPE_IM2COL_F32,
+     GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
+     GGML_METAL_KERNEL_TYPE_PAD_F32,
++    GGML_METAL_KERNEL_TYPE_UNPAD_F32,
+     GGML_METAL_KERNEL_TYPE_ARANGE_F32,
+     GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
+     GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
+@@ -689,6 +690,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true);
++        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UNPAD_F32,                     unpad_f32,                        true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                    arange_f32,                     true);
+         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true);
+@@ -846,6 +848,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
+             return false;
+         case GGML_OP_UPSCALE:
+         case GGML_OP_PAD:
++        case GGML_OP_UNPAD:
+         case GGML_OP_ARANGE:
+         case GGML_OP_TIMESTEP_EMBEDDING:
+         case GGML_OP_ARGSORT:
+@@ -2655,6 +2658,36 @@ static void ggml_metal_encode_node(
+ 
+                 const int nth = MIN(1024, ne0);
+ 
++                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
++            } break;
++        case GGML_OP_UNPAD:
++            {
++                GGML_ASSERT(src0->type == GGML_TYPE_F32);
++
++                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UNPAD_F32].pipeline;
++
++                [encoder setComputePipelineState:pipeline];
++                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
++                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
++                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
++                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
++                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
++                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
++                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
++                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
++                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
++                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
++                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
++                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
++                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
++                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
++                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
++                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
++                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
++                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
++
++                const int nth = MIN(1024, ne0);
++
+                 [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+             } break;
+         case GGML_OP_ARANGE:
+diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
+index 2b200032..09887511 100644
+--- a/ggml/src/ggml-metal.metal
++++ b/ggml/src/ggml-metal.metal
+@@ -2029,6 +2029,51 @@ kernel void kernel_pad_f32(
+     }
+ }
+ 
++kernel void kernel_unpad_f32(
++    device  const char * src0,
++    device        char * dst,
++    constant   int64_t & ne00,
++    constant   int64_t & ne01,
++    constant   int64_t & ne02,
++    constant   int64_t & ne03,
++    constant  uint64_t & nb00,
++    constant  uint64_t & nb01,
++    constant  uint64_t & nb02,
++    constant  uint64_t & nb03,
++    constant   int64_t & ne0,
++    constant   int64_t & ne1,
++    constant   int64_t & ne2,
++    constant   int64_t & ne3,
++    constant  uint64_t & nb0,
++    constant  uint64_t & nb1,
++    constant  uint64_t & nb2,
++    constant  uint64_t & nb3,
++    uint3 tgpig[[threadgroup_position_in_grid]],
++    uint3 tpitg[[thread_position_in_threadgroup]],
++    uint3   ntg[[threads_per_threadgroup]]) {
++
++    const int64_t i3 = tgpig.z;
++    const int64_t i2 = tgpig.y;
++    const int64_t i1 = tgpig.x;
++
++    const int64_t i03 = i3;
++    const int64_t i02 = i2;
++    const int64_t i01 = i1;
++
++    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
++    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
++
++    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
++        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
++            if (i0 < ne00) {
++                dst_ptr[i0] = src0_ptr[i0];
++            }
++        }
++
++        return;
++    }
++}
++
+ kernel void kernel_arange_f32(
+     device        char * dst,
+     constant   int64_t & ne0,
+diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
+index bcbc32d9..f4864ac8 100644
+--- a/ggml/src/ggml.c
++++ b/ggml/src/ggml.c
+@@ -2997,6 +2997,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
+     "POOL_2D_BACK",
+     "UPSCALE",
+     "PAD",
++    "UNPAD",
+     "ARANGE",
+     "TIMESTEP_EMBEDDING",
+     "ARGSORT",
+@@ -3030,7 +3031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
+     "OPT_STEP_ADAMW",
+ };
+ 
+-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
++static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
+ 
+ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
+     "none",
+@@ -3091,6 +3092,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
+     "pool_2d_back(x)",
+     "upscale(x)",
+     "pad(x)",
++    "unpad(x)",
+     "arange(start, stop, step)",
+     "timestep_embedding(timesteps, dim, max_period)",
+     "argsort(x)",
+@@ -3124,7 +3126,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
+     "adamw(x)",
+ };
+ 
+-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
++static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
+ 
+ static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
+ 
+@@ -6955,6 +6957,32 @@ struct ggml_tensor * ggml_pad(
+     return result;
+ }
+ 
++// ggml_unpad
++
++struct ggml_tensor * ggml_unpad(
++    struct ggml_context * ctx,
++    struct ggml_tensor  * a,
++    int p0, int p1, int p2, int p3) {
++    bool is_node = false;
++
++    if (a->grad) {
++        GGML_ABORT("fatal error"); // TODO: implement backward
++        is_node = true;
++    }
++
++    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
++            a->ne[0] - p0,
++            a->ne[1] - p1,
++            a->ne[2] - p2,
++            a->ne[3] - p3);
++
++    result->op = GGML_OP_UNPAD;
++    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
++    result->src[0] = a;
++
++    return result;
++}
++
+ // ggml_arange
+ 
+ struct ggml_tensor * ggml_arange(
+@@ -15312,6 +15340,58 @@ static void ggml_compute_forward_pad(
+     }
+ }
+ 
++static void ggml_compute_forward_unpad_f32(
++    const struct ggml_compute_params *params,
++    struct ggml_tensor *dst) {
++
++    const struct ggml_tensor * src0 = dst->src[0];
++
++    GGML_ASSERT(src0->nb[0] == sizeof(float));
++    GGML_ASSERT( dst->nb[0] == sizeof(float));
++
++    const int ith = params->ith;
++    const int nth = params->nth;
++
++    GGML_TENSOR_UNARY_OP_LOCALS
++
++    float * dst_ptr = (float *) dst->data;
++
++    // TODO: optimize
++
++    for (int64_t i2 = 0; i2 < ne2; ++i2) {
++        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
++            for (int64_t i0 = 0; i0 < ne0; ++i0) {
++                for (int64_t i3 = 0; i3 < ne3; ++i3) {
++                    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
++
++                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
++
++                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
++                        dst_ptr[dst_idx] = *src_ptr;
++                    }
++                }
++            }
++        }
++    }
++}
++
++static void ggml_compute_forward_unpad(
++    const struct ggml_compute_params * params,
++    struct ggml_tensor * dst) {
++
++    const struct ggml_tensor * src0 = dst->src[0];
++
++    switch (src0->type) {
++        case GGML_TYPE_F32:
++            {
++                ggml_compute_forward_unpad_f32(params, dst);
++            } break;
++        default:
++            {
++                GGML_ABORT("fatal error");
++            }
++    }
++}
+ 
+ // ggml_compute_forward_arange
+ 
+@@ -17294,6 +17374,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
+             {
+                 ggml_compute_forward_pad(params, tensor);
+             } break;
++        case GGML_OP_UNPAD:
++            {
++                ggml_compute_forward_unpad(params, tensor);
++            } break;
+         case GGML_OP_ARANGE:
+             {
+                 ggml_compute_forward_arange(params, tensor);
+@@ -18369,6 +18453,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
+             {
+                 GGML_ABORT("fatal error"); // TODO: not implemented
+             }
++        case GGML_OP_UNPAD:
++            {
++                GGML_ABORT("fatal error"); // TODO: not implemented
++            }
+         case GGML_OP_ARANGE:
+             {
+                 GGML_ABORT("fatal error"); // TODO: not implemented
+@@ -19165,6 +19253,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
+             } break;
+         case GGML_OP_UPSCALE:
+         case GGML_OP_PAD:
++        case GGML_OP_UNPAD:
+         case GGML_OP_ARANGE:
+         case GGML_OP_TIMESTEP_EMBEDDING:
+         case GGML_OP_ARGSORT:

+ 31 - 3
llama/runner/runner.go

@@ -206,6 +206,26 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
 		}
 	}
 
+	if s.clip.cc != nil {
+		var embed [][]float32
+
+		if s.clip.cc.IsMllama && len(images) >= 1 {
+			hash := s.cache.HashImage(images[0].Data)
+
+			s.clip.mu.Lock()
+			var err error
+			embed, err = s.cache.FindImage(hash)
+			if err != nil {
+				embed = llama.NewMllamaImageEmbed(s.lc, s.clip.cc, images[0].Data, images[0].AspectRatioID)
+				s.cache.AddImage(hash, embed)
+			}
+			s.clip.mu.Unlock()
+		}
+		s.mu.Lock()
+		llama.MllamaSetCrossAttn(s.lc, s.clip.cc, embed)
+		s.mu.Unlock()
+	}
+
 	return inputs, nil
 }
 
@@ -294,6 +314,9 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
 	close(seq.responses)
 	close(seq.embedding)
 	seq.cache.InUse = false
+	if s.clip.cc != nil {
+		llama.MllamaSetCrossAttn(s.lc, s.clip.cc, nil)
+	}
 	s.seqs[seqIndex] = nil
 }
 
@@ -517,8 +540,9 @@ type Options struct {
 }
 
 type ImageData struct {
-	Data []byte `json:"data"`
-	ID   int    `json:"id"`
+	Data          []byte `json:"data"`
+	ID            int    `json:"id"`
+	AspectRatioID int    `json:"aspect_ratio_id"`
 }
 
 type CompletionRequest struct {
@@ -770,7 +794,11 @@ func (s *Server) loadModel(
 	}
 
 	if ppath != "" {
-		s.clip.cc = llama.NewClipContext(ppath)
+		var err error
+		s.clip.cc, err = llama.NewClipContext(ppath)
+		if err != nil {
+			panic(err)
+		}
 	}
 
 	s.cache = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache)

+ 2 - 2
llm/ggla.go

@@ -51,8 +51,8 @@ func (llm *ggla) KV() KV {
 	return llm.kv
 }
 
-func (llm *ggla) Tensors() Tensors {
-	return Tensors{
+func (llm *ggla) Tensors() *Tensors {
+	return &Tensors{
 		Items:  llm.tensors,
 		Offset: llm.tensorOffset,
 	}

+ 27 - 16
llm/ggml.go

@@ -5,7 +5,9 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"slices"
 	"strings"
+	"sync"
 
 	"github.com/ollama/ollama/util/bufioutil"
 )
@@ -17,7 +19,7 @@ type GGML struct {
 
 type model interface {
 	KV() KV
-	Tensors() Tensors
+	Tensors() *Tensors
 }
 
 type KV map[string]any
@@ -123,25 +125,34 @@ func (kv KV) ChatTemplate() string {
 type Tensors struct {
 	Items  []*Tensor
 	Offset uint64
-}
 
-func (ts Tensors) Layers() map[string]Layer {
-	layers := make(map[string]Layer)
-	for _, t := range ts.Items {
-		parts := strings.Split(t.Name, ".")
-		if parts[0] == "blk" {
-			// join first and second part, e.g. blk.%d
-			parts = append([]string{fmt.Sprintf("%s.%s", parts[0], parts[1])}, parts[2:]...)
-		}
+	layers     map[string]Layer
+	layersOnce sync.Once
+}
 
-		if _, ok := layers[parts[0]]; !ok {
-			layers[parts[0]] = make(Layer)
+func (ts *Tensors) Layers() map[string]Layer {
+	ts.layersOnce.Do(func() {
+		ts.layers = make(map[string]Layer)
+		for _, t := range ts.Items {
+			parts := strings.Split(t.Name, ".")
+			if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
+				if len(parts) > index+2 {
+					// blk and mm should have a number after them, join it
+					parts = append(
+						[]string{strings.Join(parts[:index+2], ".")},
+						parts[index+2:]...)
+				}
+			}
+
+			if _, ok := ts.layers[parts[0]]; !ok {
+				ts.layers[parts[0]] = make(Layer)
+			}
+
+			ts.layers[parts[0]][strings.Join(parts[1:], ".")] = t
 		}
+	})
 
-		layers[parts[0]][strings.Join(parts[1:], ".")] = t
-	}
-
-	return layers
+	return ts.layers
 }
 
 type Layer map[string]*Tensor

+ 2 - 2
llm/gguf.go

@@ -110,8 +110,8 @@ func (llm *gguf) KV() KV {
 	return llm.kv
 }
 
-func (llm *gguf) Tensors() Tensors {
-	return Tensors{
+func (llm *gguf) Tensors() *Tensors {
+	return &Tensors{
 		Items:  llm.tensors,
 		Offset: llm.tensorOffset,
 	}

+ 73 - 4
llm/memory.go

@@ -3,6 +3,7 @@ package llm
 import (
 	"fmt"
 	"log/slog"
+	"os"
 	"strconv"
 	"strings"
 
@@ -63,6 +64,8 @@ type MemoryEstimate struct {
 	memoryLayerOutput   uint64
 	graphFullOffload    uint64
 	graphPartialOffload uint64
+
+	projectorWeights, projectorGraph uint64
 }
 
 // Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
@@ -78,7 +81,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
 	var graphOffload uint64
 
 	// Projectors loaded into GPU0 only
-	var projectorSize uint64
+	var projectorWeights uint64
+	var projectorGraph uint64
 
 	// Conditional output size on GPU 0
 	var memoryLayerOutput uint64
@@ -103,7 +107,9 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
 	slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", availableList)
 
 	for _, projector := range projectors {
-		projectorSize += projectorMemoryRequirements(projector)
+		weight, graph := projectorMemoryRequirements(projector)
+		projectorWeights += weight
+		projectorGraph += graph
 
 		// multimodal models require at least 2048 context
 		opts.NumCtx = max(opts.NumCtx, 2048)
@@ -149,7 +155,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
 	}
 
 	// Output layer handled at the end if we have space
-	gpuZeroOverhead := projectorSize
+	gpuZeroOverhead := projectorWeights + projectorGraph
 
 	// Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer
 	var layerCount int
@@ -303,6 +309,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
 		memoryLayerOutput:   memoryLayerOutput,
 		graphFullOffload:    graphFullOffload,
 		graphPartialOffload: graphPartialOffload,
+		projectorWeights:    projectorWeights,
+		projectorGraph:      projectorGraph,
 	}
 
 	if gpus[0].Library == "cpu" {
@@ -323,7 +331,19 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
 
 func (m MemoryEstimate) log() {
 	overhead := envconfig.GpuOverhead()
-	slog.Info(
+
+	log := slog.With()
+	if m.projectorWeights > 0 {
+		log = log.With(
+			slog.Group(
+				"projector",
+				"weights", format.HumanBytes2(m.projectorWeights),
+				"graph", format.HumanBytes2(m.projectorGraph),
+			),
+		)
+	}
+
+	log.Info(
 		"offload to "+m.inferenceLibrary,
 		slog.Group(
 			"layers",
@@ -371,3 +391,52 @@ func (m MemoryEstimate) log() {
 		),
 	)
 }
+
+func projectorMemoryRequirements(filename string) (weights, graphSize uint64) {
+	file, err := os.Open(filename)
+	if err != nil {
+		return 0, 0
+	}
+	defer file.Close()
+
+	ggml, _, err := DecodeGGML(file, 0)
+	if err != nil {
+		return 0, 0
+	}
+
+	for _, layer := range ggml.Tensors().Layers() {
+		weights += layer.size()
+	}
+
+	switch arch := ggml.KV().Architecture(); arch {
+	case "mllama":
+		kv := func(n string) uint64 {
+			if v, ok := ggml.KV()[arch+".vision."+n].(uint32); ok {
+				return uint64(v)
+			}
+
+			return 0
+		}
+
+		imageSize := kv("image_size")
+
+		maxNumTiles := kv("max_num_tiles")
+		embeddingLength := kv("embedding_length")
+		headCount := kv("attention.head_count")
+
+		numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
+		if _, ok := ggml.Tensors().Layers()["v"]["class_embd"]; ok {
+			numPatches++
+		}
+
+		numPaddedPatches := numPatches + 8 - (numPatches%8)%8
+
+		graphSize = 4 * (8 +
+			imageSize*imageSize*kv("num_channels")*maxNumTiles +
+			embeddingLength*numPatches*maxNumTiles +
+			9*embeddingLength*numPaddedPatches*maxNumTiles +
+			numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
+	}
+
+	return weights, graphSize
+}

+ 3 - 22
llm/server.go

@@ -442,26 +442,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
 	return nil, finalErr
 }
 
-func projectorMemoryRequirements(filename string) uint64 {
-	file, err := os.Open(filename)
-	if err != nil {
-		return 0
-	}
-	defer file.Close()
-
-	ggml, _, err := DecodeGGML(file, 0)
-	if err != nil {
-		return 0
-	}
-
-	var mem uint64
-	for _, layer := range ggml.Tensors().Layers() {
-		mem += layer.size()
-	}
-
-	return mem
-}
-
 type ServerStatus int
 
 const ( // iota is reset to 0
@@ -673,8 +653,9 @@ ws ::= ([ \t\n] ws)?
 const maxBufferSize = 512 * format.KiloByte
 
 type ImageData struct {
-	Data []byte `json:"data"`
-	ID   int    `json:"id"`
+	Data          []byte `json:"data"`
+	ID            int    `json:"id"`
+	AspectRatioID int    `json:"aspect_ratio_id"`
 }
 
 type completion struct {

+ 240 - 0
server/imageproc/images.go

@@ -0,0 +1,240 @@
+package imageproc
+
+import (
+	"bytes"
+	"fmt"
+	"image"
+	"image/color"
+	_ "image/jpeg"
+	_ "image/png"
+	"math"
+	"slices"
+
+	"golang.org/x/image/draw"
+)
+
+func GetSupportedAspectRatios(maxTiles int) []image.Point {
+	ratios := []image.Point{}
+
+	for w := range maxTiles {
+		for h := range maxTiles {
+			if (w+1)*(h+1) <= maxTiles {
+				ratios = append(ratios, image.Point{w + 1, h + 1})
+			}
+		}
+	}
+
+	return ratios
+}
+
+func clip(a, a_min, a_max int) int {
+	if a < a_min {
+		return a_min
+	} else if a > a_max {
+		return a_max
+	}
+
+	return a
+}
+
+func getImageSizeFitToCanvas(imageSize, canvasSize image.Point, tileSize int) image.Point {
+	targetWidth := clip(imageSize.X, tileSize, canvasSize.X)
+	targetHeight := clip(imageSize.Y, tileSize, canvasSize.Y)
+
+	scaleWidth := float64(targetWidth) / float64(imageSize.X)
+	scaleHeight := float64(targetHeight) / float64(imageSize.Y)
+
+	var w, h int
+
+	if scaleWidth < scaleHeight {
+		w = targetWidth
+		h = min(int(math.Floor(float64(imageSize.Y)*scaleWidth)), targetHeight)
+	} else {
+		w = min(int(math.Floor(float64(imageSize.X)*scaleHeight)), targetWidth)
+		h = targetHeight
+	}
+
+	return image.Point{w, h}
+}
+
+func getOptimalTiledCanvas(imageSize image.Point, maxImageTiles, tileSize int) image.Point {
+	possibleTileArrangements := GetSupportedAspectRatios(maxImageTiles)
+	possibleCanvasSizes := []image.Point{}
+	for _, pta := range possibleTileArrangements {
+		possibleCanvasSizes = append(possibleCanvasSizes, image.Point{pta.X * tileSize, pta.Y * tileSize})
+	}
+
+	scales := []float64{}
+
+	for _, pcs := range possibleCanvasSizes {
+		scaleHeight := float64(pcs.Y) / float64(imageSize.Y)
+		scaleWidth := float64(pcs.X) / float64(imageSize.X)
+
+		if scaleWidth > scaleHeight {
+			scales = append(scales, scaleHeight)
+		} else {
+			scales = append(scales, scaleWidth)
+		}
+	}
+
+	var minUpscale float64
+	var maxDownscale float64
+	var upscale bool
+
+	for _, s := range scales {
+		if s > 1.0 {
+			upscale = true
+			if minUpscale == 0 {
+				minUpscale = s
+			} else {
+				minUpscale = math.Min(minUpscale, s)
+			}
+		} else {
+			maxDownscale = math.Max(maxDownscale, s)
+		}
+	}
+
+	selectedScale := maxDownscale
+	if upscale {
+		selectedScale = minUpscale
+	}
+
+	var selectedCanvas image.Point
+	for n, pcs := range possibleCanvasSizes {
+		if scales[n] == selectedScale {
+			// choose the smallest possible canvas
+			if selectedCanvas.X == 0 && selectedCanvas.Y == 0 {
+				selectedCanvas = pcs
+			} else if pcs.X*pcs.Y < selectedCanvas.X*selectedCanvas.Y {
+				selectedCanvas = pcs
+			}
+		}
+	}
+	return selectedCanvas
+}
+
+func splitToTiles(img image.Image, numTilesSize image.Point) []image.Image {
+	b := img.Bounds()
+	width := b.Max.X - b.Min.X
+	height := b.Max.Y - b.Min.Y
+	tileHeight := height / numTilesSize.Y
+	tileWidth := width / numTilesSize.X
+
+	images := []image.Image{}
+
+	for h := range numTilesSize.Y {
+		for w := range numTilesSize.X {
+			rect := image.Rect(tileWidth*w, tileHeight*h, tileWidth*(w+1), tileHeight*(h+1))
+			images = append(images, img.(interface {
+				SubImage(image.Rectangle) image.Image
+			}).SubImage(rect))
+		}
+	}
+
+	return images
+}
+
+// remove the "alpha" channel by drawing over a prefilled image
+func compositeImage(img image.Image) image.Image {
+	dst := image.NewRGBA(img.Bounds())
+
+	white := color.RGBA{255, 255, 255, 255}
+	draw.Draw(dst, dst.Bounds(), &image.Uniform{white}, image.Point{}, draw.Src)
+	draw.Draw(dst, dst.Bounds(), img, img.Bounds().Min, draw.Over)
+
+	return dst
+}
+
+func ResizeImage(img image.Image, format string, outputSize image.Point, maxImageTiles int) (image.Image, image.Point) {
+	if format == "png" {
+		img = compositeImage(img)
+	}
+
+	b := img.Bounds()
+	tileSize := outputSize.Y
+
+	canvasSize := getOptimalTiledCanvas(b.Max, maxImageTiles, tileSize)
+	aspectRatio := image.Point{canvasSize.X / tileSize, canvasSize.Y / tileSize}
+	newSize := getImageSizeFitToCanvas(b.Max, canvasSize, tileSize)
+
+	dst := image.NewRGBA(image.Rect(0, 0, newSize.X, newSize.Y))
+
+	// scaling choices:
+	//   NearestNeighbor	fast, blocky output
+	//   ApproxBiLinear	fast, medium quality
+	//   BiLinear		slow, high quality
+	//   CatmullRom		very slow, very high quality
+	draw.BiLinear.Scale(dst, dst.Rect, img, b, draw.Over, nil)
+
+	return dst, aspectRatio
+}
+
+func PadImage(img image.Image, outputSize, aspectRatio image.Point) image.Image {
+	paddedSize := image.Point{
+		X: outputSize.X * aspectRatio.X,
+		Y: outputSize.Y * aspectRatio.Y,
+	}
+
+	dst := image.NewRGBA(image.Rect(0, 0, paddedSize.X, paddedSize.Y))
+	draw.Draw(dst, img.Bounds(), img, image.Point{0, 0}, draw.Over)
+
+	return dst
+}
+
+func PackImages(img image.Image, aspectRatio image.Point, mean, std [3]float32) []float32 {
+	subImages := splitToTiles(img, aspectRatio)
+
+	var pixelVals []float32
+
+	for _, subImg := range subImages {
+		bounds := subImg.Bounds()
+		var rVals, gVals, bVals []float32
+		for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
+			for x := bounds.Min.X; x < bounds.Max.X; x++ {
+				c := subImg.At(x, y)
+				r, g, b, _ := c.RGBA()
+				rVal := float32(r>>8) / 255.0
+				gVal := float32(g>>8) / 255.0
+				bVal := float32(b>>8) / 255.0
+
+				rVal = (rVal - mean[0]) / std[0]
+				gVal = (gVal - mean[1]) / std[1]
+				bVal = (bVal - mean[2]) / std[2]
+
+				rVals = append(rVals, rVal)
+				gVals = append(gVals, gVal)
+				bVals = append(bVals, bVal)
+			}
+		}
+		pixelVals = append(pixelVals, rVals...)
+		pixelVals = append(pixelVals, gVals...)
+		pixelVals = append(pixelVals, bVals...)
+	}
+
+	return pixelVals
+}
+
+func Preprocess(imageData []byte) ([]float32, int, error) {
+	// todo: need guard in here for bad image data
+
+	// mllama values
+	outputSize := image.Point{560, 560}
+	maxTiles := 4
+
+	// clip values
+	mean := [3]float32{0.48145466, 0.4578275, 0.40821073}
+	std := [3]float32{0.26862954, 0.26130258, 0.27577711}
+
+	img, format, err := image.Decode(bytes.NewReader(imageData))
+	if err != nil {
+		return nil, 0, fmt.Errorf("failed to decode image: %w", err)
+	}
+
+	newImage, aspectRatio := ResizeImage(img, format, outputSize, maxTiles)
+	newImage = PadImage(newImage, outputSize, aspectRatio)
+
+	data := PackImages(newImage, aspectRatio, mean, std)
+	aspectRatioIndex := slices.Index(GetSupportedAspectRatios(maxTiles), aspectRatio) + 1
+
+	return data, aspectRatioIndex, nil
+}

+ 344 - 0
server/imageproc/images_test.go

@@ -0,0 +1,344 @@
+package imageproc
+
+import (
+	"bytes"
+	"image"
+	"image/png"
+	"testing"
+
+	"github.com/google/go-cmp/cmp"
+)
+
+func TestAspectRatios(t *testing.T) {
+	type aspectCase struct {
+		MaxTiles int
+		Expected []image.Point
+	}
+
+	cases := []aspectCase{
+		{
+			MaxTiles: 1,
+			Expected: []image.Point{{1, 1}},
+		},
+		{
+			MaxTiles: 2,
+			Expected: []image.Point{{1, 1}, {1, 2}, {2, 1}},
+		},
+		{
+			MaxTiles: 3,
+			Expected: []image.Point{{1, 1}, {1, 2}, {1, 3}, {2, 1}, {3, 1}},
+		},
+		{
+			MaxTiles: 4,
+			Expected: []image.Point{{1, 1}, {1, 2}, {1, 3}, {1, 4}, {2, 1}, {2, 2}, {3, 1}, {4, 1}},
+		},
+	}
+
+	for _, c := range cases {
+		actual := GetSupportedAspectRatios(c.MaxTiles)
+
+		if diff := cmp.Diff(actual, c.Expected); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	}
+}
+
+func TestGetImageSizeFitToCanvas(t *testing.T) {
+	type imageSizeCase struct {
+		ImageRect  image.Point
+		CanvasRect image.Point
+		TileSize   int
+		Expected   image.Point
+	}
+
+	cases := []imageSizeCase{
+		{
+			ImageRect:  image.Point{400, 400},
+			CanvasRect: image.Point{640, 480},
+			TileSize:   200,
+			Expected:   image.Point{400, 400},
+		},
+		{
+			ImageRect:  image.Point{1024, 768},
+			CanvasRect: image.Point{640, 480},
+			TileSize:   200,
+			Expected:   image.Point{640, 480},
+		},
+		{
+			ImageRect:  image.Point{500, 500},
+			CanvasRect: image.Point{1000, 1000},
+			TileSize:   750,
+			Expected:   image.Point{750, 750},
+		},
+		{
+			ImageRect:  image.Point{500, 1000},
+			CanvasRect: image.Point{2000, 2000},
+			TileSize:   2000,
+			Expected:   image.Point{1000, 2000},
+		},
+		{
+			ImageRect:  image.Point{4000, 3000},
+			CanvasRect: image.Point{2000, 1000},
+			TileSize:   1000,
+			Expected:   image.Point{1333, 1000},
+		},
+		{
+			ImageRect:  image.Point{667, 1000},
+			CanvasRect: image.Point{1000, 1000},
+			TileSize:   560,
+			Expected:   image.Point{667, 1000},
+		},
+	}
+
+	for _, c := range cases {
+		actual := getImageSizeFitToCanvas(c.ImageRect, c.CanvasRect, c.TileSize)
+
+		if actual != c.Expected {
+			t.Errorf("incorrect image rect: '%#v'. expected: '%#v'", actual, c.Expected)
+		}
+	}
+}
+
+func TestGetOptimalTiledCanvas(t *testing.T) {
+	type tiledCanvasSizeCase struct {
+		ImageSize     image.Point
+		MaxImageTiles int
+		TileSize      int
+		Expected      image.Point
+	}
+
+	cases := []tiledCanvasSizeCase{
+		{
+			ImageSize:     image.Point{1024, 768},
+			MaxImageTiles: 4,
+			TileSize:      1000,
+			Expected:      image.Point{2000, 1000},
+		},
+		{
+			ImageSize:     image.Point{1024, 768},
+			MaxImageTiles: 4,
+			TileSize:      560,
+			Expected:      image.Point{1120, 1120},
+		},
+	}
+
+	for _, c := range cases {
+		actual := getOptimalTiledCanvas(c.ImageSize, c.MaxImageTiles, c.TileSize)
+
+		if actual != c.Expected {
+			t.Errorf("incorrect tiled canvas: '%#v'. expected: '%#v'", actual, c.Expected)
+		}
+	}
+}
+
+func TestSplitToTiles(t *testing.T) {
+	type splitCase struct {
+		TestImage    image.Image
+		NumTilesSize image.Point
+		Expected     []image.Image
+	}
+
+	cases := []splitCase{
+		{
+			TestImage:    image.NewRGBA(image.Rect(0, 0, 1024, 768)),
+			NumTilesSize: image.Point{1, 1},
+			Expected:     []image.Image{image.NewRGBA(image.Rect(0, 0, 1024, 768))},
+		},
+		{
+			TestImage:    image.NewRGBA(image.Rect(0, 0, 1000, 500)),
+			NumTilesSize: image.Point{2, 1},
+			Expected: []image.Image{
+				image.NewRGBA(image.Rect(0, 0, 500, 500)),
+				image.NewRGBA(image.Rect(500, 0, 1000, 500)),
+			},
+		},
+		{
+			TestImage:    image.NewRGBA(image.Rect(0, 0, 1000, 1000)),
+			NumTilesSize: image.Point{2, 2},
+			Expected: []image.Image{
+				image.NewRGBA(image.Rect(0, 0, 500, 500)),
+				image.NewRGBA(image.Rect(500, 0, 1000, 500)),
+				image.NewRGBA(image.Rect(0, 500, 500, 1000)),
+				image.NewRGBA(image.Rect(500, 500, 1000, 1000)),
+			},
+		},
+	}
+
+	for _, c := range cases {
+		actual := splitToTiles(c.TestImage, c.NumTilesSize)
+
+		if len(actual) != len(c.Expected) {
+			t.Errorf("incorrect number of images '%d': expected: '%d'", len(actual), len(c.Expected))
+		}
+
+		for i := range actual {
+			if actual[i].Bounds() != c.Expected[i].Bounds() {
+				t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual[i].Bounds(), c.Expected[i].Bounds())
+			}
+		}
+	}
+}
+
+func TestResize(t *testing.T) {
+	type resizeCase struct {
+		TestImage           image.Image
+		OutputSize          image.Point
+		MaxImageTiles       int
+		ExpectedImage       image.Image
+		ExpectedAspectRatio image.Point
+	}
+
+	cases := []resizeCase{
+		{
+			TestImage:           image.NewRGBA(image.Rect(0, 0, 200, 200)),
+			OutputSize:          image.Point{100, 100},
+			MaxImageTiles:       1,
+			ExpectedImage:       image.NewRGBA(image.Rect(0, 0, 100, 100)),
+			ExpectedAspectRatio: image.Point{1, 1},
+		},
+		{
+			TestImage:           image.NewRGBA(image.Rect(0, 0, 200, 200)),
+			OutputSize:          image.Point{100, 100},
+			MaxImageTiles:       2,
+			ExpectedImage:       image.NewRGBA(image.Rect(0, 0, 100, 100)),
+			ExpectedAspectRatio: image.Point{1, 1},
+		},
+		{
+			TestImage:           image.NewRGBA(image.Rect(0, 0, 10, 10)),
+			OutputSize:          image.Point{560, 560},
+			MaxImageTiles:       4,
+			ExpectedImage:       image.NewRGBA(image.Rect(0, 0, 560, 560)),
+			ExpectedAspectRatio: image.Point{1, 1},
+		},
+		{
+			TestImage:           image.NewRGBA(image.Rect(0, 0, 2560, 1920)),
+			OutputSize:          image.Point{560, 560},
+			MaxImageTiles:       4,
+			ExpectedImage:       image.NewRGBA(image.Rect(0, 0, 1120, 840)),
+			ExpectedAspectRatio: image.Point{2, 2},
+		},
+		{
+			TestImage:           image.NewRGBA(image.Rect(0, 0, 1024, 768)),
+			OutputSize:          image.Point{560, 560},
+			MaxImageTiles:       4,
+			ExpectedImage:       image.NewRGBA(image.Rect(0, 0, 1024, 768)),
+			ExpectedAspectRatio: image.Point{2, 2},
+		},
+	}
+
+	for _, c := range cases {
+		actualImage, actualAspectRatio := ResizeImage(c.TestImage, "png", c.OutputSize, c.MaxImageTiles)
+
+		if actualImage.Bounds() != c.ExpectedImage.Bounds() {
+			t.Errorf("image size incorrect: '%#v': expected: '%#v'", actualImage.Bounds(), c.ExpectedImage.Bounds())
+		}
+
+		if actualAspectRatio != c.ExpectedAspectRatio {
+			t.Errorf("aspect ratio incorrect: '%#v': expected: '%#v'", actualAspectRatio, c.ExpectedAspectRatio)
+		}
+	}
+}
+
+func TestPad(t *testing.T) {
+	type padCase struct {
+		TestImage   image.Image
+		OutputSize  image.Point
+		AspectRatio image.Point
+		Expected    image.Image
+	}
+
+	cases := []padCase{
+		{
+			TestImage:   image.NewRGBA(image.Rect(0, 0, 1000, 667)),
+			OutputSize:  image.Point{560, 560},
+			AspectRatio: image.Point{2, 2},
+			Expected:    image.NewRGBA(image.Rect(0, 0, 1120, 1120)),
+		},
+	}
+
+	for _, c := range cases {
+		actual := PadImage(c.TestImage, c.OutputSize, c.AspectRatio)
+
+		if actual.Bounds() != c.Expected.Bounds() {
+			t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual.Bounds(), c.Expected.Bounds())
+		}
+	}
+}
+
+func TestPackImages(t *testing.T) {
+	type packCase struct {
+		TestImage    image.Image
+		AspectRatio  image.Point
+		ExpectedVals int
+	}
+
+	mean := [3]float32{0.48145466, 0.4578275, 0.40821073}
+	std := [3]float32{0.26862954, 0.26130258, 0.27577711}
+
+	cases := []packCase{
+		{
+			TestImage:    image.NewRGBA(image.Rect(0, 0, 1120, 1120)),
+			AspectRatio:  image.Point{2, 2},
+			ExpectedVals: 2 * 2 * 3 * 560 * 560,
+		},
+		{
+			TestImage:    image.NewRGBA(image.Rect(0, 0, 560, 560)),
+			AspectRatio:  image.Point{1, 1},
+			ExpectedVals: 1 * 1 * 3 * 560 * 560,
+		},
+		{
+			TestImage:    image.NewRGBA(image.Rect(0, 0, 1120, 560)),
+			AspectRatio:  image.Point{1, 2},
+			ExpectedVals: 1 * 2 * 3 * 560 * 560,
+		},
+	}
+
+	for _, c := range cases {
+		actualVals := PackImages(c.TestImage, c.AspectRatio, mean, std)
+		if len(actualVals) != c.ExpectedVals {
+			t.Errorf("packed image size incorrect: '%d': expected: '%d'", len(actualVals), c.ExpectedVals)
+		}
+	}
+}
+
+func TestPreprocess(t *testing.T) {
+	type preprocessCase struct {
+		TestImage             image.Image
+		ExpectedVals          int
+		ExpectedAspectRatioID int
+	}
+
+	cases := []preprocessCase{
+		{
+			TestImage:             image.NewRGBA(image.Rect(0, 0, 10, 10)),
+			ExpectedVals:          0,
+			ExpectedAspectRatioID: 1,
+		},
+		{
+			TestImage:             image.NewRGBA(image.Rect(0, 0, 1024, 768)),
+			ExpectedVals:          0,
+			ExpectedAspectRatioID: 6,
+		},
+	}
+
+	for _, c := range cases {
+		var buf bytes.Buffer
+		err := png.Encode(&buf, c.TestImage)
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		imgData, aspectRatioID, err := Preprocess(buf.Bytes())
+		if err != nil {
+			t.Fatalf("error processing: %q", err)
+		}
+
+		if len(imgData) == 0 {
+			t.Errorf("no image data returned")
+		}
+
+		if aspectRatioID != c.ExpectedAspectRatioID {
+			t.Errorf("aspect ratio incorrect: '%d': expected: '%d'", aspectRatioID, c.ExpectedAspectRatioID)
+		}
+	}
+}

+ 3 - 1
server/model.go

@@ -194,7 +194,9 @@ func parseFromFile(ctx context.Context, command string, baseLayers []*layerGGML,
 		mediatype := "application/vnd.ollama.image.model"
 		if ggml.Name() == "ggla" || ggml.KV().Kind() == "adapter" {
 			mediatype = "application/vnd.ollama.image.adapter"
-		} else if ggml.KV().Architecture() == "clip" {
+		}
+
+		if _, ok := ggml.KV()[fmt.Sprintf("%s.vision.block_count", ggml.KV().Architecture())]; ok || ggml.KV().Kind() == "projector" {
 			mediatype = "application/vnd.ollama.image.projector"
 		}
 

+ 82 - 14
server/prompt.go

@@ -3,24 +3,42 @@ package server
 import (
 	"bytes"
 	"context"
+	"encoding/binary"
+	"errors"
+	"fmt"
 	"log/slog"
+	"strings"
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/llm"
+	"github.com/ollama/ollama/server/imageproc"
 	"github.com/ollama/ollama/template"
 )
 
 type tokenizeFunc func(context.Context, string) ([]int, error)
 
+var errTooManyImages = errors.New("vision model only supports a single image per message")
+
 // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
 // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
 // latest message and 2) system messages
 func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
 	var system []api.Message
-	// always include the last message
+
+	isMllama := checkMllamaModelFamily(m)
+
 	n := len(msgs) - 1
 	// in reverse, find all messages that fit into context window
-	for i := n - 1; i >= 0; i-- {
+	for i := n; i >= 0; i-- {
+		if isMllama && len(msgs[i].Images) > 1 {
+			return "", nil, errTooManyImages
+		}
+
+		// always include the last message
+		if i == n {
+			continue
+		}
+
 		system = make([]api.Message, 0)
 		for j := range i {
 			if msgs[j].Role == "system" {
@@ -38,16 +56,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 			return "", nil, err
 		}
 
-		c := len(s)
+		ctxLen := len(s)
 		if m.ProjectorPaths != nil {
 			for _, m := range msgs[i:] {
 				// images are represented as 768 sized embeddings
 				// TODO: get embedding length from project metadata
-				c += 768 * len(m.Images)
+				ctxLen += 768 * len(m.Images)
 			}
 		}
 
-		if c > opts.NumCtx {
+		if ctxLen > opts.NumCtx {
 			slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
 			break
 		} else {
@@ -55,20 +73,70 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 		}
 	}
 
+	currMsgIdx := n
+
+	if isMllama {
+		lastMsgIdx := len(msgs) - 1
+		for i := lastMsgIdx; i >= currMsgIdx; i-- {
+			if len(msgs[i].Images) > 0 {
+				data, aspectRatioID, err := imageproc.Preprocess(msgs[i].Images[0])
+				if err != nil {
+					return "", nil, err
+				}
+
+				buf := new(bytes.Buffer)
+				err = binary.Write(buf, binary.LittleEndian, data)
+				if err != nil {
+					return "", nil, err
+				}
+
+				imgData := llm.ImageData{
+					Data:          buf.Bytes(),
+					AspectRatioID: aspectRatioID,
+				}
+
+				msgs[i].Content = strings.TrimSpace("<|image|>" + msgs[i].Content)
+				images = append(images, imgData)
+				break
+			}
+		}
+	} else {
+		for cnt, msg := range msgs[currMsgIdx:] {
+			prefix := ""
+			prompt := msg.Content
+			for _, i := range msg.Images {
+				imgData := llm.ImageData{
+					ID:   len(images),
+					Data: i,
+				}
+
+				imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
+				if !strings.Contains(prompt, "[img]") {
+					prefix += imgTag
+				} else {
+					prompt = strings.Replace(prompt, "[img]", imgTag, 1)
+				}
+
+				images = append(images, imgData)
+			}
+			msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + " " + prompt)
+		}
+	}
+
 	// truncate any messages that do not fit into the context window
 	var b bytes.Buffer
-	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
+	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
 		return "", nil, err
 	}
 
-	for _, m := range msgs[n:] {
-		for _, i := range m.Images {
-			images = append(images, llm.ImageData{
-				ID:   len(images),
-				Data: i,
-			})
+	return b.String(), images, nil
+}
+
+func checkMllamaModelFamily(m *Model) bool {
+	for _, arch := range m.Config.ModelFamilies {
+		if arch == "mllama" {
+			return true
 		}
 	}
-
-	return b.String(), images, nil
+	return false
 }

+ 156 - 14
server/prompt_test.go

@@ -3,6 +3,8 @@ package server
 import (
 	"bytes"
 	"context"
+	"image"
+	"image/png"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
@@ -13,18 +15,53 @@ import (
 
 func TestChatPrompt(t *testing.T) {
 	type expect struct {
-		prompt string
-		images [][]byte
+		prompt        string
+		images        [][]byte
+		aspectRatioID int
+		error         error
+	}
+
+	tmpl, err := template.Parse(`
+{{- if .System }}{{ .System }} {{ end }}
+{{- if .Prompt }}{{ .Prompt }} {{ end }}
+{{- if .Response }}{{ .Response }} {{ end }}`)
+	if err != nil {
+		t.Fatal(err)
+	}
+	visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
+	mllamaModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}, Config: ConfigV2{ModelFamilies: []string{"mllama"}}}
+
+	createImg := func(width, height int) ([]byte, error) {
+		img := image.NewRGBA(image.Rect(0, 0, 5, 5))
+		var buf bytes.Buffer
+
+		if err := png.Encode(&buf, img); err != nil {
+			return nil, err
+		}
+
+		return buf.Bytes(), nil
+	}
+
+	imgBuf, err := createImg(5, 5)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	imgBuf2, err := createImg(6, 6)
+	if err != nil {
+		t.Fatal(err)
 	}
 
 	cases := []struct {
 		name  string
+		model Model
 		limit int
 		msgs  []api.Message
 		expect
 	}{
 		{
 			name:  "messages",
+			model: visionModel,
 			limit: 64,
 			msgs: []api.Message{
 				{Role: "user", Content: "You're a test, Harry!"},
@@ -37,6 +74,7 @@ func TestChatPrompt(t *testing.T) {
 		},
 		{
 			name:  "truncate messages",
+			model: visionModel,
 			limit: 1,
 			msgs: []api.Message{
 				{Role: "user", Content: "You're a test, Harry!"},
@@ -49,6 +87,7 @@ func TestChatPrompt(t *testing.T) {
 		},
 		{
 			name:  "truncate messages with image",
+			model: visionModel,
 			limit: 64,
 			msgs: []api.Message{
 				{Role: "user", Content: "You're a test, Harry!"},
@@ -64,6 +103,7 @@ func TestChatPrompt(t *testing.T) {
 		},
 		{
 			name:  "truncate messages with images",
+			model: visionModel,
 			limit: 64,
 			msgs: []api.Message{
 				{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
@@ -79,6 +119,7 @@ func TestChatPrompt(t *testing.T) {
 		},
 		{
 			name:  "messages with images",
+			model: visionModel,
 			limit: 2048,
 			msgs: []api.Message{
 				{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
@@ -95,6 +136,7 @@ func TestChatPrompt(t *testing.T) {
 		},
 		{
 			name:  "message with image tag",
+			model: visionModel,
 			limit: 2048,
 			msgs: []api.Message{
 				{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
@@ -111,6 +153,7 @@ func TestChatPrompt(t *testing.T) {
 		},
 		{
 			name:  "messages with interleaved images",
+			model: visionModel,
 			limit: 2048,
 			msgs: []api.Message{
 				{Role: "user", Content: "You're a test, Harry!"},
@@ -129,6 +172,7 @@ func TestChatPrompt(t *testing.T) {
 		},
 		{
 			name:  "truncate message with interleaved images",
+			model: visionModel,
 			limit: 1024,
 			msgs: []api.Message{
 				{Role: "user", Content: "You're a test, Harry!"},
@@ -146,6 +190,7 @@ func TestChatPrompt(t *testing.T) {
 		},
 		{
 			name:  "message with system prompt",
+			model: visionModel,
 			limit: 2048,
 			msgs: []api.Message{
 				{Role: "system", Content: "You are the Test Who Lived."},
@@ -159,6 +204,7 @@ func TestChatPrompt(t *testing.T) {
 		},
 		{
 			name:  "out of order system",
+			model: visionModel,
 			limit: 2048,
 			msgs: []api.Message{
 				{Role: "user", Content: "You're a test, Harry!"},
@@ -170,23 +216,113 @@ func TestChatPrompt(t *testing.T) {
 				prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
 			},
 		},
-	}
-
-	tmpl, err := template.Parse(`
-{{- if .System }}{{ .System }} {{ end }}
-{{- if .Prompt }}{{ .Prompt }} {{ end }}
-{{- if .Response }}{{ .Response }} {{ end }}`)
-	if err != nil {
-		t.Fatal(err)
+		{
+			name:  "multiple images same prompt",
+			model: visionModel,
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "Compare these two pictures of hotdogs", Images: []api.ImageData{[]byte("one hotdog"), []byte("two hotdogs")}},
+			},
+			expect: expect{
+				prompt: "[img-0][img-1] Compare these two pictures of hotdogs ",
+				images: [][]byte{[]byte("one hotdog"), []byte("two hotdogs")},
+			},
+		},
+		{
+			name:  "messages with mllama (no images)",
+			model: mllamaModel,
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+			},
+			expect: expect{
+				prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
+			},
+		},
+		{
+			name:  "messages with mllama single prompt",
+			model: mllamaModel,
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}},
+			},
+			expect: expect{
+				prompt:        "<|image|>How many hotdogs are in this image? ",
+				images:        [][]byte{imgBuf},
+				aspectRatioID: 1,
+			},
+		},
+		{
+			name:  "messages with mllama",
+			model: mllamaModel,
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf}},
+			},
+			expect: expect{
+				prompt:        "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ",
+				images:        [][]byte{imgBuf},
+				aspectRatioID: 1,
+			},
+		},
+		{
+			name:  "multiple messages with mllama",
+			model: mllamaModel,
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{imgBuf}},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}},
+			},
+			expect: expect{
+				prompt:        "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ",
+				images:        [][]byte{imgBuf2},
+				aspectRatioID: 1,
+			},
+		},
+		{
+			name:  "earlier image with mllama",
+			model: mllamaModel,
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}},
+				{Role: "assistant", Content: "There are four hotdogs."},
+				{Role: "user", Content: "Which ones have mustard?"},
+			},
+			expect: expect{
+				prompt:        "<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ",
+				images:        [][]byte{imgBuf},
+				aspectRatioID: 1,
+			},
+		},
+		{
+			name:  "too many images with mllama",
+			model: mllamaModel,
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf, imgBuf}},
+			},
+			expect: expect{
+				error: errTooManyImages,
+			},
+		},
 	}
 
 	for _, tt := range cases {
 		t.Run(tt.name, func(t *testing.T) {
-			model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
+			model := tt.model
 			opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
 			prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
-			if err != nil {
+			if tt.error == nil && err != nil {
 				t.Fatal(err)
+			} else if tt.error != nil && err != tt.error {
+				t.Fatalf("expected err '%q', got '%q'", tt.error, err)
 			}
 
 			if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
@@ -202,8 +338,14 @@ func TestChatPrompt(t *testing.T) {
 					t.Errorf("expected ID %d, got %d", i, images[i].ID)
 				}
 
-				if !bytes.Equal(images[i].Data, tt.images[i]) {
-					t.Errorf("expected %q, got %q", tt.images[i], images[i])
+				if len(model.Config.ModelFamilies) == 0 {
+					if !bytes.Equal(images[i].Data, tt.images[i]) {
+						t.Errorf("expected %q, got %q", tt.images[i], images[i].Data)
+					}
+				} else {
+					if images[i].AspectRatioID != tt.aspectRatioID {
+						t.Errorf("expected aspect ratio %d, got %d", tt.aspectRatioID, images[i].AspectRatioID)
+					}
 				}
 			}
 		})

+ 25 - 13
server/routes.go

@@ -119,20 +119,21 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		return
 	}
 
+	model, err := GetModel(req.Model)
+	if err != nil {
+		switch {
+		case os.IsNotExist(err):
+			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
+		case err.Error() == "invalid model name":
+			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		default:
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		}
+		return
+	}
+
 	// expire the runner
 	if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
-		model, err := GetModel(req.Model)
-		if err != nil {
-			switch {
-			case os.IsNotExist(err):
-				c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
-			case err.Error() == "invalid model name":
-				c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-			default:
-				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-			}
-			return
-		}
 		s.sched.expireRunner(model)
 
 		c.JSON(http.StatusOK, api.GenerateResponse{
@@ -169,6 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 
 	checkpointLoaded := time.Now()
 
+	// load the model
 	if req.Prompt == "" {
 		c.JSON(http.StatusOK, api.GenerateResponse{
 			Model:      req.Model,
@@ -179,6 +181,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		return
 	}
 
+	isMllama := checkMllamaModelFamily(model)
+	if isMllama && len(req.Images) > 1 {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
+		return
+	}
+
 	images := make([]llm.ImageData, len(req.Images))
 	for i := range req.Images {
 		images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
@@ -212,7 +220,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 			}
 
 			for _, i := range images {
-				msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
+				if isMllama {
+					msgs = append(msgs, api.Message{Role: "user", Content: "<|image|>"})
+				} else {
+					msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
+				}
 			}
 
 			values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})

+ 6 - 6
server/routes_generate_test.go

@@ -421,22 +421,22 @@ func TestGenerate(t *testing.T) {
 
 	t.Run("missing body", func(t *testing.T) {
 		w := createRequest(t, s.GenerateHandler, nil)
-		if w.Code != http.StatusBadRequest {
-			t.Errorf("expected status 400, got %d", w.Code)
+		if w.Code != http.StatusNotFound {
+			t.Errorf("expected status 404, got %d", w.Code)
 		}
 
-		if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
+		if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
 			t.Errorf("mismatch (-got +want):\n%s", diff)
 		}
 	})
 
 	t.Run("missing model", func(t *testing.T) {
 		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
-		if w.Code != http.StatusBadRequest {
-			t.Errorf("expected status 400, got %d", w.Code)
+		if w.Code != http.StatusNotFound {
+			t.Errorf("expected status 404, got %d", w.Code)
 		}
 
-		if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
+		if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
 			t.Errorf("mismatch (-got +want):\n%s", diff)
 		}
 	})

+ 1 - 1
server/routes_test.go

@@ -562,7 +562,7 @@ func TestShow(t *testing.T) {
 		Modelfile: fmt.Sprintf(
 			"FROM %s\nFROM %s",
 			createBinFile(t, llm.KV{"general.architecture": "test"}, nil),
-			createBinFile(t, llm.KV{"general.architecture": "clip"}, nil),
+			createBinFile(t, llm.KV{"general.type": "projector", "general.architecture": "clip"}, nil),
 		),
 	})
 

+ 0 - 13
template/template.go

@@ -5,7 +5,6 @@ import (
 	"embed"
 	"encoding/json"
 	"errors"
-	"fmt"
 	"io"
 	"math"
 	"slices"
@@ -302,22 +301,10 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 // into a single message. collate also collects and returns all system messages.
 // collate mutates message content adding image tags ([img-%d]) as needed
 func collate(msgs []api.Message) (string, []*api.Message) {
-	var n int
-
 	var system []string
 	var collated []*api.Message
 	for i := range msgs {
 		msg := msgs[i]
-		for range msg.Images {
-			imageTag := fmt.Sprintf("[img-%d]", n)
-			if !strings.Contains(msg.Content, "[img]") {
-				msg.Content = strings.TrimSpace("[img] " + msg.Content)
-			}
-
-			msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
-			n++
-		}
-
 		if msg.Role == "system" {
 			system = append(system, msg.Content)
 		}

+ 0 - 39
template/template_test.go

@@ -317,45 +317,6 @@ What is your name?<|im_end|>
 <|im_start|>assistant
 `,
 		},
-		{
-			"moondream",
-			[]template{
-				// this does not have a "no response" test because it's impossible to render the same output
-				{"response", `{{ if .Prompt }}Question: {{ .Prompt }}
-
-{{ end }}Answer: {{ .Response }}
-
-`},
-				{"messages", `
-{{- range .Messages }}
-{{- if eq .Role "user" }}Question: {{ .Content }}
-
-{{ else if eq .Role "assistant" }}Answer: {{ .Content }}
-
-{{ end }}
-{{- end }}Answer: `},
-			},
-			Values{
-				Messages: []api.Message{
-					{Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}},
-					{Role: "assistant", Content: "It's a hot dog."},
-					{Role: "user", Content: "What's in _this_ image?"},
-					{Role: "user", Images: []api.ImageData{[]byte("")}},
-					{Role: "user", Content: "Is it a hot dog?"},
-				},
-			},
-			`Question: [img-0] What's in this image?
-
-Answer: It's a hot dog.
-
-Question: What's in _this_ image?
-
-[img-1]
-
-Is it a hot dog?
-
-Answer: `,
-		},
 	}
 
 	for _, tt := range cases {