瀏覽代碼

additional review comments

Jesse Gross 1 月之前
父節點
當前提交
98272fbd58
共有 2 個文件被更改,包括 32 次插入16 次删除
  1. 20 6
      ml/backend/ggml/ggml.go
  2. 12 10
      model/models/mllama/model_text.go

+ 20 - 6
ml/backend/ggml/ggml.go

@@ -402,7 +402,10 @@ func (b *Backend) NewContext() ml.Context {
 }
 }
 
 
 func (b *Backend) NewContextSize(n int) ml.Context {
 func (b *Backend) NewContextSize(n int) ml.Context {
-	n = min(n, b.maxGraphNodes)
+	if n > b.maxGraphNodes {
+		panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
+	}
+
 	return &Context{
 	return &Context{
 		b:             b,
 		b:             b,
 		maxGraphNodes: n,
 		maxGraphNodes: n,
@@ -534,7 +537,7 @@ func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
 		panic("unsupported dtype")
 		panic("unsupported dtype")
 	}
 	}
 
 
-	if len(shape) < 1 {
+	if len(shape) < 1 || shape[0] == 0 {
 		var shape C.int64_t = 0
 		var shape C.int64_t = 0
 		return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
 		return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
 	} else if len(shape) > 4 {
 	} else if len(shape) > 4 {
@@ -565,6 +568,11 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
 
 
 func checkShape[S ~[]E, E any](s S, shape ...int) error {
 func checkShape[S ~[]E, E any](s S, shape ...int) error {
 	n := len(s)
 	n := len(s)
+
+	if n == 0 {
+		return nil
+	}
+
 	for _, v := range shape {
 	for _, v := range shape {
 		n /= v
 		n /= v
 	}
 	}
@@ -577,22 +585,28 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error {
 }
 }
 
 
 func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
 func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
-	if err := checkShape(s, shape...); err != nil && len(shape) > 0 {
+	if err := checkShape(s, shape...); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
 	t := c.newTensor(ml.DTypeF32, shape)
 	t := c.newTensor(ml.DTypeF32, shape)
-	C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
+	if len(s) > 0 {
+		C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
+	}
+
 	return t, nil
 	return t, nil
 }
 }
 
 
 func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
 func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
-	if err := checkShape(s, shape...); err != nil && len(shape) > 0 {
+	if err := checkShape(s, shape...); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
 	t := c.newTensor(ml.DTypeI32, shape)
 	t := c.newTensor(ml.DTypeI32, shape)
-	C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
+	if len(s) > 0 {
+		C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
+	}
+
 	return t, nil
 	return t, nil
 }
 }
 
 

+ 12 - 10
model/models/mllama/model_text.go

@@ -10,10 +10,11 @@ import (
 )
 )
 
 
 type TextSelfAttention struct {
 type TextSelfAttention struct {
-	Query  *nn.Linear `gguf:"attn_q"`
-	Key    *nn.Linear `gguf:"attn_k"`
-	Value  *nn.Linear `gguf:"attn_v"`
-	Output *nn.Linear `gguf:"attn_output"`
+	Query       *nn.Linear `gguf:"attn_q"`
+	Key         *nn.Linear `gguf:"attn_k"`
+	Value       *nn.Linear `gguf:"attn_v"`
+	Output      *nn.Linear `gguf:"attn_output"`
+	RopeFactors ml.Tensor  `gguf:"rope_freqs.weight"`
 }
 }
 
 
 func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
 func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
@@ -22,11 +23,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
 
 
 	query := sa.Query.Forward(ctx, hiddenState)
 	query := sa.Query.Forward(ctx, hiddenState)
 	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
 	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
-	query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
 
 
 	key := sa.Key.Forward(ctx, hiddenState)
 	key := sa.Key.Forward(ctx, hiddenState)
 	key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
 	key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
-	key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
 
 
 	value := sa.Value.Forward(ctx, hiddenState)
 	value := sa.Value.Forward(ctx, hiddenState)
 	value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
 	value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -39,8 +40,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
 }
 }
 
 
 func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
 func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
-	// This will only get called for layers in the causal cache, which are just the self attention layers
-	return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
+	if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
+		return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
+	}
+
+	return key, nil
 }
 }
 
 
 type TextMLP struct {
 type TextMLP struct {
@@ -191,8 +195,6 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
 }
 }
 
 
 type TextModelOptions struct {
 type TextModelOptions struct {
-	RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
-
 	hiddenSize, numHeads, numKVHeads int
 	hiddenSize, numHeads, numKVHeads int
 	eps, ropeBase, ropeScale         float32
 	eps, ropeBase, ropeScale         float32
 	ropeDim                          uint32
 	ropeDim                          uint32