Bruce MacDonald 3 ماه پیش
والد
کامیت
60f0b7db76
8فایلهای تغییر یافته به همراه817 افزوده شده و 11 حذف شده
  1. 3 1
      .gitignore
  2. 13 0
      ml/backend.go
  3. 45 10
      ml/backend/ggml/ggml.go
  4. 169 0
      model/README.md
  5. 91 0
      model/model_test/model_test.go
  6. 294 0
      model/model_test/testdata/qwen2.json
  7. 201 0
      model/qwen2/model.go
  8. 1 0
      runner/newrunner/runner.go

+ 3 - 1
.gitignore

@@ -12,4 +12,6 @@ test_data
 *.crt
 llama/build
 __debug_bin*
-llama/vendor
+llama/vendor
+model/model_test/testdata/*/
+!model/model_test/testdata/*.*

+ 13 - 0
ml/backend.go

@@ -24,6 +24,15 @@ type Backend interface {
 	NewContext() Context
 }
 
+type GraphLayer struct {
+	Name  string  `json:"name"`
+	Shape []int64 `json:"shape"`
+}
+
+type Graph struct {
+	Graph []GraphLayer `json:"graph"`
+}
+
 var backends = make(map[string]func(*os.File) (Backend, error))
 
 func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
@@ -50,6 +59,10 @@ type Context interface {
 	Forward(Tensor)
 	Compute(Tensor) Tensor
 	Close() error
+
+	SetDebug(bool)
+	Trace(string, Tensor)
+	GetTrace() Graph
 }
 
 type Tensor interface {

+ 45 - 10
ml/backend/ggml/ggml.go

@@ -222,6 +222,7 @@ func (b *Backend) NewContext() ml.Context {
 			C.size_t(nodes),
 			true,
 		),
+		traceGraph: ml.Graph{},
 	}
 }
 
@@ -232,6 +233,9 @@ type Context struct {
 	sched *C.struct_ggml_backend_sched
 	graph *C.struct_ggml_cgraph
 	nodes int
+
+	debug      bool
+	traceGraph ml.Graph
 }
 
 func (c *Context) Forward(t ml.Tensor) {
@@ -320,6 +324,34 @@ func (c *Context) Close() error {
 	return nil
 }
 
+func (c *Context) SetDebug(debug bool) {
+	c.debug = debug
+}
+
+func (c *Context) Trace(name string, t ml.Tensor) {
+	if !c.debug {
+		return
+	}
+
+	shape := t.Shape()
+	shapeArr := make([]int64, 4)
+	for i := 0; i < len(shape); i++ {
+		shapeArr[i] = shape[i]
+	}
+
+	c.traceGraph.Graph = append(
+		c.traceGraph.Graph,
+		ml.GraphLayer{
+			Name:  name,
+			Shape: shapeArr,
+		},
+	)
+}
+
+func (c *Context) GetTrace() ml.Graph {
+	return c.traceGraph
+}
+
 type Tensor struct {
 	t    *C.struct_ggml_tensor
 	data []byte
@@ -555,16 +587,19 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
 
 	return &Tensor{
 		t: C.ggml_rope_ext(
-			ctx.(*Context).ctx, t.t, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
-			C.int(ropeDim),
-			131072,       // YaRN n_ctx_train
-			ropeTypeNorm, // ROPE_TYPE_NORM
-			C.float(ropeBase),
-			C.float(ropeScale),
-			0.,  // YaRN ext_factor
-			1.,  // YaRN attn_factor
-			32., // YaRN beta_fast
-			1.,  // YaRN beta_slow
+			ctx.(*Context).ctx,
+			t.t,                     // a tensor
+			positionIDs.(*Tensor).t, // b tensor with dims [512, 1, 1, 1]
+			nil,                     // c tensor (not shown in log)
+			C.int(64),               // n_dims: 64
+			2,                       // mode: 2 (ropeTypeNeox = 2)
+			C.int(32768),            // n_ctx_orig: 32768
+			C.float(1000000.0),      // freq_base: 1000000.000000
+			C.float(1.0),            // freq_scale: 1.000000
+			C.float(0.0),            // ext_factor: 0.000000
+			C.float(1.0),            // attn_factor: 1.000000
+			C.float(32.0),           // beta_fast: 32.000000
+			C.float(1.0),            // beta_slow: 1.000000
 		),
 	}
 }

+ 169 - 0
model/README.md

@@ -0,0 +1,169 @@
+# Ollama Models
+
+!! This is a work in progress document !!
+
+## Architecture
+
+```mermaid
+graph TB
+    subgraph Models["Model Layer: LLM Implementations"]
+        direction TB
+        llama["llama/model.go"]
+        mllama["mllama/model.go"]
+        qwen["qwen2/model.go"]
+        qwen_vl["qwen2vl/model.go"]
+        pixtral["pixtral/"]
+        
+        note1["Each model implements a specific architecture
+        - Defines model parameters
+        - Handles tokenization
+        - Implements forward pass
+        - Manages model weights"]
+    end
+
+    subgraph ML_Ops["Neural Network Operations"]
+        direction TB
+        nn_ops["nn/
+            linear.go - Matrix operations
+            embedding.go - Token embeddings
+            normalization.go - Layer normalization
+            convolution.go - Conv operations"]
+        
+        backend["ml/backend.go
+        Hardware Abstraction Layer
+        - Defines tensor operations
+        - Manages computation graphs
+        - Handles memory allocation"]
+
+        note2["Common neural net operations
+        used across different models
+        - Abstracts hardware details
+        - Provides unified API
+        - Manages computation flow"]
+    end
+
+    subgraph GGML["Hardware Execution Layer"]
+        direction TB
+        ggml["ggml.go
+        CGO Interface
+        - Bridges Go and C++
+        - Handles type conversion
+        - Manages memory between languages"]
+        
+        subgraph Hardware_Specific["Hardware-Specific Implementations"]
+            direction LR
+            cpu["ggml-cpu.h
+            CPU optimized ops"]
+            cuda["ggml-cuda.h
+            NVIDIA GPU ops"]
+            metal["ggml-metal.h
+            Apple GPU ops"]
+            vulkan["ggml-vulkan.h
+            Cross-platform GPU"]
+            opencl["ggml-opencl.h
+            OpenCL acceleration"]
+        end
+
+        note3["GGML provides optimized 
+        implementations for each hardware:
+        - Automatic dispatch
+        - Hardware-specific optimizations
+        - Memory management
+        - Parallel execution"]
+    end
+
+    %% Connections with explanations
+    Models --> |"Makes high-level calls
+    (e.g., self-attention)"| ML_Ops
+    ML_Ops --> |"Translates to tensor operations
+    (e.g., matmul, softmax)"| GGML
+    GGML --> |"Executes optimized code
+    on target hardware"| Hardware_Specific
+    
+    %% Styling
+    classDef model fill:#fff,stroke:#01579b,stroke-width:2px
+    classDef ml fill:#fff,stroke:#e65100,stroke-width:2px
+    classDef hw fill:#fff,stroke:#b71c1c,stroke-width:2px
+    classDef note fill:#fff,stroke:#666,stroke-dasharray: 5 5
+    
+    class llama,mllama,qwen,qwen_vl,pixtral model
+    class nn_ops,backend ml
+    class ggml,cpu,cuda,metal,vulkan,opencl hw
+    class note1,note2,note3 note
+
+    %% Style subgraphs
+    style Models fill:#fff,stroke:#01579b,stroke-width:2px
+    style ML_Ops fill:#fff,stroke:#e65100,stroke-width:2px
+    style GGML fill:#fff,stroke:#b71c1c,stroke-width:2px
+    style Hardware_Specific fill:#fff,stroke:#b71c1c,stroke-width:1px
+```
+
+## Adding support for a new model to Ollama
+
+1. Clone the Ollama repo and get it running locally: https://github.com/ollama/ollama/blob/main/docs/development.md
+2. Get the original model (research code) running locally. This will 99.99% of the time be a Python repository.
+3. Get a dump of the graph built with Pytorch or Safetensors. Use this snippet to do so.
+```python
+import torch
+import sys
+from safetensors.torch import load_file
+
+def extract_graph(model_path):
+    if model_path.endswith('.safetensors'):
+        state_dict = load_file(model_path)
+    else:
+        state_dict = torch.load(model_path, weights_only=True)
+    
+    graph = []
+    for name, tensor in state_dict.items():
+        if isinstance(tensor, torch.Tensor):
+            graph.append({
+                "name": name,
+                "shape": list(tensor.shape)
+            })
+    
+    print("{")
+    print('    "graph": [')
+    for i, layer in enumerate(graph):
+        comma = "," if i < len(graph) - 1 else ""
+        print(f'        {{"name": "{layer["name"]}", "shape": {layer["shape"]}}}{comma}')
+    print("    ]")
+    print("}")
+
+if __name__ == "__main__":
+    if len(sys.argv) != 2:
+        print("Usage: python extract.py <path/to/model>")
+        sys.exit(1)
+    
+    extract_graph(sys.argv[1])
+``` 
+4. Look at a previous model implementation pull request and copy the structure of the files needed. We will need:
+    1. A `model/<model-name>`  directory
+    2. A `model/<model-name>/model.go`  file to implement the architecture and forward pass.
+    3. A `model/<model-name>/convert.go`  file to implement to conversion from pytorch/safetensors to ggml.
+    4. `model/<model-name>/model_test.go`  and `model/<model-name>/convert_test.go` files for testing.
+    5. Modify main paths to make this new model accessible.
+5. Open a draft pull request in the `ollama/ollama` repo, as a place to ask questions and get answers from Ollama maintainers.
+6. Implement conversion from the model weights (pytorch, safetensors) to ggml in the `model/<your-model>/convert.go`  file. Reference other `convert.go` files. 
+7. Create a Modelfile that only references the pytorch/safetensor directory. We will handle the other fields later.
+Modelfile:
+```
+FROM /path/to/model
+```
+Use `ollama create` to convert the model:
+`go run . create <my-model> -f /path/to/Modelfie`
+6. Implement the `New()` and `Forward()` logic in `model/<your-model>/model.go` . Reference other `model.go` files. 
+
+Run the model and get the debug output of the forward pass to compare with the output of the research implementation from step 1: 
+`OLLAMA_DEBUG=1 go run . run <my-model>` 
+7. (maybe) Implement a new tokenizer, if needed.
+8. Test text generation, this step requires knowing the prompt format:
+`go run . run <my-model> "hello"`  
+9. Add tests to `model/<your-model>/model_test.go`  and `model/<your-model>/convert_test.go` 
+10. Push changes to `ollama/ollama` pull request, and move the pull request out of the draft state.
+11. Push model to ollama.com:
+    1. Find model prompt format and convert it to a Go template.
+    2. Create a Modelfile `FROM` the converted gguf, add the `TEMPLATE`, `LICENSE`, and parameters if needed.
+    3. `ollama create <your-namespace>/<your-model> -f /path/to/Modelfile`
+    4. `ollama push <your-namespace>/<your-model>`
+12. Run end-to-end integration tests.

+ 91 - 0
model/model_test/model_test.go

@@ -0,0 +1,91 @@
+package modeltest
+
+import (
+	"encoding/json"
+	"os"
+	"path/filepath"
+	"reflect"
+	"testing"
+
+	"github.com/ollama/ollama/cache"
+	"github.com/ollama/ollama/convert"
+	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/model"
+	_ "github.com/ollama/ollama/model/qwen2"
+)
+
+func TestForward(t *testing.T) {
+	cases := []string{
+		"qwen2",
+		// Add more model architectures here...
+	}
+
+	for _, tt := range cases {
+		t.Run(tt, func(t *testing.T) {
+			t.Parallel()
+
+			p := filepath.Join("testdata", tt)
+			if testing.Short() {
+				t.Skip("skipping in short mode")
+			} else if _, err := os.Stat(p); err != nil {
+				t.Skipf("%s not found", p)
+			}
+
+			f, err := os.CreateTemp(t.TempDir(), "f16")
+			if err != nil {
+				t.Fatal(err)
+			}
+			defer func() {
+				f.Close()
+				os.Remove(f.Name())
+			}()
+
+			if err := convert.ConvertModel(os.DirFS(p), f); err != nil {
+				t.Fatal(err)
+			}
+
+			m, err := model.New(f.Name())
+			if err != nil {
+				t.Fatal(err)
+			}
+			b := m.Backend()
+			ctx := b.NewContext()
+			ctx.SetDebug(true)
+
+			// Run forward pass
+			_, err = model.Forward(ctx, m, model.WithCache(cache.NewCausalCache(m.Backend(), 2048, ml.DTypeF32)))
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			// Validate the graph layers
+			data, err := os.ReadFile(filepath.Join("testdata", tt+".json"))
+			if err != nil {
+				t.Fatal(err)
+			}
+			var expected ml.Graph
+			if err := json.Unmarshal(data, &expected); err != nil {
+				t.Fatal(err)
+			}
+
+			result := ctx.GetTrace()
+
+			if len(result.Graph) != len(expected.Graph) {
+				t.Errorf("expected %d layers, got %d", len(expected.Graph), len(result.Graph))
+			}
+
+			for i, layer := range expected.Graph {
+				if i >= len(result.Graph) {
+					break
+				}
+				actual := result.Graph[i]
+				if layer.Name != actual.Name {
+					t.Errorf("layer %d: expected name %s, got %s", i, layer.Name, actual.Name)
+				}
+				if !reflect.DeepEqual(layer.Shape, actual.Shape) {
+					t.Errorf("layer %d: expected shape %v, got %v", i, layer.Shape, actual.Shape)
+				}
+			}
+		})
+	}
+}

+ 294 - 0
model/model_test/testdata/qwen2.json

@@ -0,0 +1,294 @@
+{
+    "graph": [
+        {"name": "model.embed_tokens.weight", "shape": [151936, 896]},
+        {"name": "model.layers.0.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.0.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.0.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.0.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.0.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.0.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.0.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.0.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.0.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.0.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.0.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.0.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.1.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.1.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.1.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.1.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.1.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.1.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.1.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.1.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.1.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.1.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.1.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.1.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.10.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.10.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.10.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.10.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.10.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.10.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.10.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.10.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.10.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.10.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.10.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.10.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.11.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.11.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.11.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.11.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.11.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.11.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.11.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.11.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.11.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.11.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.11.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.11.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.12.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.12.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.12.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.12.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.12.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.12.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.12.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.12.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.12.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.12.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.12.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.12.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.13.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.13.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.13.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.13.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.13.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.13.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.13.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.13.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.13.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.13.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.13.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.13.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.14.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.14.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.14.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.14.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.14.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.14.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.14.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.14.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.14.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.14.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.14.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.14.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.15.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.15.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.15.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.15.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.15.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.15.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.15.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.15.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.15.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.15.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.15.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.15.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.16.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.16.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.16.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.16.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.16.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.16.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.16.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.16.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.16.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.16.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.16.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.16.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.17.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.17.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.17.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.17.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.17.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.17.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.17.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.17.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.17.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.17.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.17.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.17.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.18.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.18.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.18.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.18.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.18.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.18.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.18.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.18.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.18.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.18.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.18.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.18.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.19.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.19.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.19.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.19.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.19.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.19.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.19.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.19.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.19.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.19.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.19.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.19.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.2.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.2.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.2.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.2.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.2.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.2.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.2.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.2.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.2.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.2.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.2.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.2.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.20.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.20.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.20.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.20.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.20.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.20.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.20.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.20.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.20.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.20.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.20.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.20.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.21.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.21.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.21.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.21.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.21.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.21.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.21.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.21.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.21.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.21.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.21.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.21.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.22.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.22.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.22.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.22.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.22.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.22.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.22.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.22.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.22.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.22.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.22.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.22.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.23.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.23.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.23.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.23.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.23.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.23.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.23.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.23.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.23.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.23.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.23.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.23.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.3.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.3.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.3.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.3.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.3.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.3.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.3.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.3.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.3.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.3.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.3.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.3.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.4.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.4.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.4.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.4.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.4.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.4.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.4.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.4.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.4.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.4.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.4.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.4.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.5.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.5.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.5.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.5.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.5.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.5.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.5.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.5.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.5.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.5.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.5.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.5.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.6.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.6.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.6.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.6.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.6.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.6.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.6.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.6.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.6.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.6.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.6.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.6.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.7.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.7.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.7.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.7.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.7.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.7.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.7.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.7.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.7.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.7.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.7.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.7.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.8.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.8.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.8.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.8.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.8.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.8.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.8.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.8.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.8.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.8.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.8.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.8.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.9.input_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.9.mlp.down_proj.weight", "shape": [896, 4864]},
+        {"name": "model.layers.9.mlp.gate_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.9.mlp.up_proj.weight", "shape": [4864, 896]},
+        {"name": "model.layers.9.post_attention_layernorm.weight", "shape": [896]},
+        {"name": "model.layers.9.self_attn.k_proj.bias", "shape": [128]},
+        {"name": "model.layers.9.self_attn.k_proj.weight", "shape": [128, 896]},
+        {"name": "model.layers.9.self_attn.o_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.9.self_attn.q_proj.bias", "shape": [896]},
+        {"name": "model.layers.9.self_attn.q_proj.weight", "shape": [896, 896]},
+        {"name": "model.layers.9.self_attn.v_proj.bias", "shape": [128]},
+        {"name": "model.layers.9.self_attn.v_proj.weight", "shape": [128, 896]},
+        {"name": "model.norm.weight", "shape": [896]}
+    ]
+}

+ 201 - 0
model/qwen2/model.go

@@ -0,0 +1,201 @@
+package qwen2
+
+import (
+	"fmt"
+	"log/slog"
+	"math"
+
+	"github.com/ollama/ollama/cache"
+	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/ml/nn"
+	"github.com/ollama/ollama/model"
+)
+
+type Options struct {
+	hiddenSize, numHeads, numKVHeads int64
+	eps, ropeBase, ropeScale         float32
+	ropeDim                          uint32
+}
+
+type Model struct {
+	model.Base
+	model.BytePairEncoding
+
+	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
+	Layers         []Layer       `gguf:"blk"`
+	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
+	Output         *nn.Linear    `gguf:"output,alt:token_embd"`
+
+	*Options
+}
+
+func New(c ml.Config) (model.Model, error) {
+	m := &Model{
+		BytePairEncoding: model.BytePairEncoding{
+			Pretokenizer: c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
+			Vocabulary: &model.Vocabulary{
+				Values: c.Strings("tokenizer.ggml.tokens"),
+				Types:  c.Uints("tokenizer.ggml.token_type"),
+				Merges: c.Strings("tokenizer.ggml.merges"),
+				BOS:    c.Uint("tokenizer.ggml.bos_token_id"),
+				EOS:    c.Uint("tokenizer.ggml.eos_token_id"),
+			},
+		},
+		Layers: make([]Layer, c.Uint("block_count")),
+		Options: &Options{
+			hiddenSize: int64(c.Uint("embedding_length")),
+			numHeads:   int64(c.Uint("attention.head_count")),
+			numKVHeads: int64(c.Uint("attention.head_count_kv")),
+			eps:        c.Float("attention.layer_norm_rms_epsilon"),
+			ropeBase:   c.Float("rope.freq_base"),
+			ropeScale:  c.Float("rope.freq_scale", 1),
+			ropeDim:    c.Uint("rope.dimension_count", 64),
+		},
+	}
+
+	slog.Debug("model configuration",
+		"arch", "qwen2",
+		"vocab_size", len(c.Strings("tokenizer.ggml.tokens")),
+		"n_merges", len(c.Strings("tokenizer.ggml.merges")),
+		"n_ctx_train", c.Uint("context_length"),
+		"n_embd", m.hiddenSize,
+		"n_layer", len(m.Layers),
+		"n_head", m.numHeads,
+		"n_head_kv", m.numKVHeads,
+		"n_rot", m.ropeDim,
+		"f_norm_rms_eps", m.eps,
+		"rope_freq_base", m.ropeBase,
+		"rope_freq_scale", m.ropeScale,
+		"bos_token_id", c.Uint("tokenizer.ggml.bos_token_id"),
+		"eos_token_id", c.Uint("tokenizer.ggml.eos_token_id"),
+	)
+
+	return m, nil
+}
+
+type SelfAttention struct {
+	Query  *nn.Linear `gguf:"attn_q"`
+	Key    *nn.Linear `gguf:"attn_k"`
+	Value  *nn.Linear `gguf:"attn_v"`
+	Output *nn.Linear `gguf:"attn_output"`
+}
+
+func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, inputPositions ml.Tensor, layerIdx int, cache cache.Cache, opts *Options) ml.Tensor {
+	batchSize := hiddenState.Dim(1)
+	headDim := opts.hiddenSize / opts.numHeads
+
+	q := sa.Query.Forward(ctx, hiddenState)
+	ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.q_proj", layerIdx), q)
+
+	q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
+	q = q.RoPE(ctx, inputPositions, nil, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.q_proj.rope", layerIdx), q)
+
+	k := sa.Key.Forward(ctx, hiddenState)
+	k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
+	k = k.RoPE(ctx, inputPositions, nil, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.k_proj.rope", layerIdx), k)
+
+	v := sa.Value.Forward(ctx, hiddenState)
+	v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
+	ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.v_proj", layerIdx), v)
+
+	k, v, mask := cache.Put(ctx, k, v)
+
+	q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+	k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+	v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
+
+	kq := k.Mulmat(ctx, q)
+	kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
+	kq = kq.Add(ctx, mask)
+	kq = kq.Softmax(ctx)
+
+	kqv := v.Mulmat(ctx, kq)
+	kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+	kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
+
+	output := sa.Output.Forward(ctx, kqv)
+	return output
+}
+
+type MLP struct {
+	Up   *nn.Linear `gguf:"ffn_up"`
+	Down *nn.Linear `gguf:"ffn_down"`
+	Gate *nn.Linear `gguf:"ffn_gate"`
+}
+
+func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
+	hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+	return mlp.Down.Forward(ctx, hiddenState)
+}
+
+type Layer struct {
+	AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
+	SelfAttention *SelfAttention
+	MLPNorm       *nn.RMSNorm `gguf:"ffn_norm"`
+	MLP           *MLP
+}
+
+func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, layerIdx int, cache cache.Cache, opts *Options) ml.Tensor {
+	ctx.Trace(fmt.Sprintf("model.layers.%d.input", layerIdx), hiddenState)
+	residual := hiddenState
+
+	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
+	ctx.Trace(fmt.Sprintf("model.layers.%d.input_layernorm", layerIdx), hiddenState)
+
+	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, layerIdx, cache, opts)
+	ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.output", layerIdx), hiddenState)
+
+	hiddenState = hiddenState.Add(ctx, residual)
+	residual = hiddenState
+	ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.residual", layerIdx), hiddenState)
+
+	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
+	ctx.Trace(fmt.Sprintf("model.layers.%d.post_attention_layernorm", layerIdx), hiddenState)
+
+	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
+	ctx.Trace(fmt.Sprintf("model.layers.%d.mlp", layerIdx), hiddenState)
+
+	output := hiddenState.Add(ctx, residual)
+	ctx.Trace(fmt.Sprintf("model.layers.%d.output", layerIdx), output)
+
+	return output
+}
+
+func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
+	slog.Debug("input tokens", "input_ids", opts.Inputs())
+	inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs()))
+	if err != nil {
+		return nil, err
+	}
+
+	positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
+	if err != nil {
+		return nil, err
+	}
+
+	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
+	ctx.Trace("model.embed_tokens", hiddenState)
+
+	for i, layer := range m.Layers {
+		hiddenState = layer.Forward(ctx, hiddenState, positions, i, opts.Cache.Sub(i), m.Options)
+	}
+
+	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
+	ctx.Trace("model.norm", hiddenState)
+
+	hiddenState = m.Output.Forward(ctx, hiddenState)
+	ctx.Trace("model.output", hiddenState)
+
+	outputs, err := ctx.FromIntSlice(opts.Outputs(), len(opts.Outputs()))
+	if err != nil {
+		return nil, err
+	}
+
+	return hiddenState.Rows(ctx, outputs), nil
+}
+
+func init() {
+	model.Register("qwen2", New)
+}

+ 1 - 0
runner/newrunner/runner.go

@@ -32,6 +32,7 @@ import (
 
 	_ "github.com/ollama/ollama/model/llama"
 	_ "github.com/ollama/ollama/model/mllama"
+	_ "github.com/ollama/ollama/model/qwen2"
 )
 
 // input is an element of the prompt to process, either