Browse Source

ml: let model specify rope configuration

Add support for model-specific RoPE configuration parameters by:

1. Creating a new `RopeConfig` struct to encapsulate all RoPE parameters
2. Adding `RopeType` enum to specify different RoPE variants (Standard/NeoX)
3. Extracting original context length from model config
4. Refactoring `RoPE()` interface to use the new config struct
5. Updating llama and mllama models to use new RoPE configuration

This change allows models to specify their RoPE implementation type and
original context length, which is important for proper position embedding
calculation and model compatibility.
Bruce MacDonald 2 months ago
parent
commit
8815a8ee25
5 changed files with 104 additions and 28 deletions
  1. 1 1
      kvcache/causal_test.go
  2. 37 1
      ml/backend.go
  3. 12 13
      ml/backend/ggml/ggml.go
  4. 28 7
      model/models/llama/model.go
  5. 26 6
      model/models/mllama/model_text.go

+ 1 - 1
kvcache/causal_test.go

@@ -430,7 +430,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0
 	panic("not implemented")
 	panic("not implemented")
 }
 }
 
 
-func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
+func (t *testTensor) RoPE(ctx ml.Context, rc ml.RopeConfig) ml.Tensor {
 	panic("not implemented")
 	panic("not implemented")
 }
 }
 
 

+ 37 - 1
ml/backend.go

@@ -43,6 +43,42 @@ func NewBackend(f *os.File) (Backend, error) {
 	return nil, fmt.Errorf("unsupported backend")
 	return nil, fmt.Errorf("unsupported backend")
 }
 }
 
 
+// RopeType specifies the type of RoPE (Rotary Position Embedding) to use, these types are implemented in the backend
+type RopeType int
+
+const (
+	RopeTypeStandard RopeType = iota
+	_                         // not yet used
+	RopeTypeNeoX
+)
+
+// RopeConfig contains all configuration for the RoPE (Rotary Position Embedding) operation
+type RopeConfig struct {
+	// PositionIDs contains the position indices for each token in the sequence
+	// These indices are used to calculate the rotary embeddings
+	PositionIDs Tensor
+
+	// RopeFactors is an optional tensor containing pre-computed rotation factors
+	RopeFactors Tensor
+
+	// RopeDim specifies the dimension size for the rotary embeddings
+	RopeDim uint32
+
+	// RopeType indicates which RoPE variant to use (e.g. normal or neox)
+	RopeType RopeType
+
+	// OrigCtxLen stores the original context length the model was trained with
+	OrigCtxLen int
+
+	// RopeBase is the base value used in the frequency calculation
+	RopeBase float32
+
+	// RopeScale is a scaling factor applied to position indices
+	RopeScale float32
+
+	// YaRN parameters can be added here if they need to be configurable
+}
+
 type Context interface {
 type Context interface {
 	Zeros(dtype DType, shape ...int) Tensor
 	Zeros(dtype DType, shape ...int) Tensor
 	FromFloatSlice(s []float32, shape ...int) (Tensor, error)
 	FromFloatSlice(s []float32, shape ...int) (Tensor, error)
@@ -75,7 +111,7 @@ type Tensor interface {
 	Scale(ctx Context, s float64) Tensor
 	Scale(ctx Context, s float64) Tensor
 
 
 	Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
 	Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
-	RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
+	RoPE(ctx Context, rc RopeConfig) Tensor
 
 
 	Tanh(ctx Context) Tensor
 	Tanh(ctx Context) Tensor
 	GELU(ctx Context) Tensor
 	GELU(ctx Context) Tensor

+ 12 - 13
ml/backend/ggml/ggml.go

@@ -579,13 +579,9 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
 	}
 	}
 }
 }
 
 
-const (
-	ropeTypeNorm C.int = iota
-)
-
-func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
-	if ropeFactors == nil {
-		ropeFactors = &Tensor{}
+func (t *Tensor) RoPE(ctx ml.Context, rc ml.RopeConfig) ml.Tensor {
+	if rc.RopeFactors == nil {
+		rc.RopeFactors = &Tensor{}
 	}
 	}
 
 
 	dequant := t.t
 	dequant := t.t
@@ -595,12 +591,15 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
 
 
 	return &Tensor{
 	return &Tensor{
 		t: C.ggml_rope_ext(
 		t: C.ggml_rope_ext(
-			ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
-			C.int(ropeDim),
-			131072,       // YaRN n_ctx_train
-			ropeTypeNorm, // ROPE_TYPE_NORM
-			C.float(ropeBase),
-			C.float(ropeScale),
+			ctx.(*Context).ctx,
+			dequant,
+			rc.PositionIDs.(*Tensor).t,
+			rc.RopeFactors.(*Tensor).t,
+			C.int(rc.RopeDim),
+			C.int(rc.RopeType),
+			C.int(rc.OrigCtxLen),
+			C.float(rc.RopeBase),
+			C.float(rc.RopeScale),
 			0.,  // YaRN ext_factor
 			0.,  // YaRN ext_factor
 			1.,  // YaRN attn_factor
 			1.,  // YaRN attn_factor
 			32., // YaRN beta_fast
 			32., // YaRN beta_fast

+ 28 - 7
model/models/llama/model.go

@@ -10,10 +10,10 @@ import (
 )
 )
 
 
 type Options struct {
 type Options struct {
-	RopeFactors                      ml.Tensor `gguf:"rope_freqs.weight"`
-	hiddenSize, numHeads, numKVHeads int
-	eps, ropeBase, ropeScale         float32
-	ropeDim                          uint32
+	RopeFactors                              ml.Tensor `gguf:"rope_freqs.weight"`
+	ctxLen, hiddenSize, numHeads, numKVHeads int
+	eps, ropeBase, ropeScale                 float32
+	ropeDim                                  uint32
 }
 }
 
 
 type Model struct {
 type Model struct {
@@ -46,6 +46,7 @@ func New(c ml.Config) (model.Model, error) {
 			numHeads:   int(c.Uint("attention.head_count")),
 			numHeads:   int(c.Uint("attention.head_count")),
 			numKVHeads: int(c.Uint("attention.head_count_kv")),
 			numKVHeads: int(c.Uint("attention.head_count_kv")),
 			eps:        c.Float("attention.layer_norm_rms_epsilon"),
 			eps:        c.Float("attention.layer_norm_rms_epsilon"),
+			ctxLen:     int(c.Uint("context_length")),
 			ropeBase:   c.Float("rope.freq_base"),
 			ropeBase:   c.Float("rope.freq_base"),
 			ropeScale:  c.Float("rope.freq_scale", 1),
 			ropeScale:  c.Float("rope.freq_scale", 1),
 			ropeDim:    c.Uint("rope.dimension_count"),
 			ropeDim:    c.Uint("rope.dimension_count"),
@@ -67,14 +68,23 @@ type SelfAttention struct {
 func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
 func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
 	batchSize := hiddenState.Dim(1)
 	batchSize := hiddenState.Dim(1)
 	headDim := opts.hiddenSize / opts.numHeads
 	headDim := opts.hiddenSize / opts.numHeads
+	rc := ml.RopeConfig{
+		PositionIDs: positionIDs,
+		RopeFactors: opts.RopeFactors,
+		RopeDim:     opts.ropeDim,
+		RopeType:    ml.RopeTypeStandard,
+		OrigCtxLen:  opts.ctxLen,
+		RopeBase:    opts.ropeBase,
+		RopeScale:   opts.ropeScale,
+	}
 
 
 	q := sa.Query.Forward(ctx, hiddenState)
 	q := sa.Query.Forward(ctx, hiddenState)
 	q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
 	q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
-	q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	q = q.RoPE(ctx, rc)
 
 
 	k := sa.Key.Forward(ctx, hiddenState)
 	k := sa.Key.Forward(ctx, hiddenState)
 	k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
 	k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
-	k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	k = k.RoPE(ctx, rc)
 
 
 	v := sa.Value.Forward(ctx, hiddenState)
 	v := sa.Value.Forward(ctx, hiddenState)
 	v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
 	v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -99,7 +109,18 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
 }
 }
 
 
 func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
 func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
-	return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
+	return key.RoPE(
+		ctx,
+		ml.RopeConfig{
+			PositionIDs: shift,
+			RopeFactors: m.Options.RopeFactors,
+			RopeDim:     m.Options.ropeDim,
+			RopeType:    ml.RopeTypeStandard,
+			OrigCtxLen:  m.Options.ctxLen,
+			RopeBase:    m.Options.ropeBase,
+			RopeScale:   m.Options.ropeScale,
+		},
+	), nil
 }
 }
 
 
 type MLP struct {
 type MLP struct {

+ 26 - 6
model/models/mllama/model_text.go

@@ -19,14 +19,23 @@ type TextSelfAttention struct {
 func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
 func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
 	batchSize := hiddenState.Dim(1)
 	batchSize := hiddenState.Dim(1)
 	headDim := opts.hiddenSize / opts.numHeads
 	headDim := opts.hiddenSize / opts.numHeads
+	rc := ml.RopeConfig{
+		PositionIDs: positions,
+		RopeFactors: opts.RopeFactors,
+		RopeDim:     opts.ropeDim,
+		RopeType:    ml.RopeTypeStandard,
+		OrigCtxLen:  opts.ctxLen,
+		RopeBase:    opts.ropeBase,
+		RopeScale:   opts.ropeScale,
+	}
 
 
 	query := sa.Query.Forward(ctx, hiddenState)
 	query := sa.Query.Forward(ctx, hiddenState)
 	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
 	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
-	query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	query = query.RoPE(ctx, rc)
 
 
 	key := sa.Key.Forward(ctx, hiddenState)
 	key := sa.Key.Forward(ctx, hiddenState)
 	key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
 	key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
-	key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	key = key.RoPE(ctx, rc)
 
 
 	value := sa.Value.Forward(ctx, hiddenState)
 	value := sa.Value.Forward(ctx, hiddenState)
 	value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
 	value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -52,7 +61,18 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
 
 
 func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
 func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
 	// This will only get called for layers in the cache, which are just the self attention layers
 	// This will only get called for layers in the cache, which are just the self attention layers
-	return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
+	return key.RoPE(
+		ctx,
+		ml.RopeConfig{
+			PositionIDs: shift,
+			RopeFactors: m.RopeFactors,
+			RopeDim:     m.ropeDim,
+			RopeType:    ml.RopeTypeStandard,
+			OrigCtxLen:  m.ctxLen,
+			RopeBase:    m.ropeBase,
+			RopeScale:   m.ropeScale,
+		},
+	), nil
 }
 }
 
 
 type TextMLP struct {
 type TextMLP struct {
@@ -189,9 +209,9 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr
 type TextModelOptions struct {
 type TextModelOptions struct {
 	RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
 	RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
 
 
-	hiddenSize, numHeads, numKVHeads int
-	eps, ropeBase, ropeScale         float32
-	ropeDim                          uint32
+	ctxLen, hiddenSize, numHeads, numKVHeads int
+	eps, ropeBase, ropeScale                 float32
+	ropeDim                                  uint32
 
 
 	crossAttentionLayers []uint32
 	crossAttentionLayers []uint32
 }
 }