Bläddra i källkod

use fast attention

Michael Yang 1 månad sedan
förälder
incheckning
8934324b72
3 ändrade filer med 8 tillägg och 14 borttagningar
  1. 2 2
      ml/backend/ggml/ggml.go
  2. 2 2
      model/models/gemma3/model.go
  3. 4 10
      model/models/gemma3/model_vision.go

+ 2 - 2
ml/backend/ggml/ggml.go

@@ -958,9 +958,9 @@ func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) m
 	var tt *C.struct_ggml_tensor
 	var tt *C.struct_ggml_tensor
 	switch len(strides) {
 	switch len(strides) {
 	case 0:
 	case 0:
-		tt = C.ggml_set_1d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
+		tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
 	case 1:
 	case 1:
-		tt = C.ggml_set_2d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
+		tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
 	default:
 	default:
 		panic("unsupported number of dimensions")
 		panic("unsupported number of dimensions")
 	}
 	}

+ 2 - 2
model/models/gemma3/model.go

@@ -138,8 +138,8 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
 				{Token: 255999}, // "<start_of_image>""
 				{Token: 255999}, // "<start_of_image>""
 			}
 			}
 
 
-			// <image_soft_token>
-			imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 262144}}, 256)...)
+			// pad inputs with placeholders for image embeddings
+			imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 0}}, 256)...)
 			// <end_of_image>
 			// <end_of_image>
 			imageInputs = append(imageInputs, input.Input{Token: 256000})
 			imageInputs = append(imageInputs, input.Input{Token: 256000})
 
 

+ 4 - 10
model/models/gemma3/model_vision.go

@@ -24,17 +24,11 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
 	key := sa.Key.Forward(ctx, hiddenState)
 	key := sa.Key.Forward(ctx, hiddenState)
 	value := sa.Value.Forward(ctx, hiddenState)
 	value := sa.Value.Forward(ctx, hiddenState)
 
 
-	query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3)
-	key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3)
-	value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize).Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
+	query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
+	key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
+	value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
 
 
-	scores := key.Mulmat(ctx, query)
-	scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
-	scores = scores.Softmax(ctx)
-
-	attention := value.Mulmat(ctx, scores)
-	attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
-	attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+	attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
 	attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
 	attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
 
 
 	hiddenState = sa.Output.Forward(ctx, attention)
 	hiddenState = sa.Output.Forward(ctx, attention)