Selaa lähdekoodia

backend: API to support full precision matmul

Most tensor backends try to optimize performance by using a lower
precision for matmuls. However, some operations (such as kq) on
some models are sensitive to this and require full precision.
Jesse Gross 2 kuukautta sitten
vanhempi
commit
d773b7d671
4 muutettua tiedostoa jossa 12 lisäystä ja 2 poistoa
  1. 1 0
      ml/backend.go
  2. 9 0
      ml/backend/ggml/ggml.go
  3. 1 1
      model/llama/model.go
  4. 1 1
      model/mllama/model_text.go

+ 1 - 0
ml/backend.go

@@ -66,6 +66,7 @@ type Tensor interface {
 	Add(ctx Context, t2 Tensor) Tensor
 	Add(ctx Context, t2 Tensor) Tensor
 	Mul(ctx Context, t2 Tensor) Tensor
 	Mul(ctx Context, t2 Tensor) Tensor
 	Mulmat(ctx Context, t2 Tensor) Tensor
 	Mulmat(ctx Context, t2 Tensor) Tensor
+	MulmatFullPrec(ctx Context, t2 Tensor) Tensor
 
 
 	Softmax(ctx Context) Tensor
 	Softmax(ctx Context) Tensor
 	LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
 	LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor

+ 9 - 0
ml/backend/ggml/ggml.go

@@ -421,6 +421,15 @@ func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
 	}
 	}
 }
 }
 
 
+func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
+	mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
+	C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
+
+	return &Tensor{
+		t: mul,
+	}
+}
+
 func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
 func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
 	tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
 	tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
 	if b != nil {
 	if b != nil {

+ 1 - 1
model/llama/model.go

@@ -80,7 +80,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
 	k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
 	k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
 	v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
 	v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
 
 
-	kq := k.Mulmat(ctx, q)
+	kq := k.MulmatFullPrec(ctx, q)
 	kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
 	kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
 	kq = kq.Softmax(ctx)
 	kq = kq.Softmax(ctx)
 
 

+ 1 - 1
model/mllama/model_text.go

@@ -37,7 +37,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
 	key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
 	key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
 	value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
 	value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
 
 
-	scores := key.Mulmat(ctx, query)
+	scores := key.MulmatFullPrec(ctx, query)
 	scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
 	scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
 
 
 	if mask != nil {
 	if mask != nil {