|
@@ -64,6 +64,7 @@ func New(c ml.Config) (model.Model, error) {
|
|
|
|
|
|
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
|
|
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
|
|
+ m.Cache.SetConfig(ml.CacheConfig{})
|
|
|
|
|
|
return &m, nil
|
|
|
}
|
|
@@ -84,7 +85,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|
|
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
|
|
|
|
|
if opts.largeModelScaling {
|
|
|
- q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize / opts.numHeads)))
|
|
|
+ q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
|
|
} else {
|
|
|
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
|
|
|
}
|
|
@@ -99,8 +100,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|
|
cache.Put(ctx, k, v)
|
|
|
k, v, mask := cache.Get(ctx)
|
|
|
|
|
|
- q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
|
- k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
|
+ q = q.Permute(ctx, 0, 2, 1, 3)
|
|
|
+ k = k.Permute(ctx, 0, 2, 1, 3)
|
|
|
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
|
|
|
|
kq := k.Mulmat(ctx, q)
|
|
@@ -144,12 +145,20 @@ type Layer struct {
|
|
|
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
|
|
}
|
|
|
|
|
|
-func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
|
|
+func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
|
|
residual := hiddenState
|
|
|
|
|
|
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
|
|
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
|
|
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
|
|
+
|
|
|
+ // In the final layer (outputs != nil), optimize by pruning to just the token positions
|
|
|
+ // we need logits for.
|
|
|
+ if outputs != nil {
|
|
|
+ hiddenState = hiddenState.Rows(ctx, outputs)
|
|
|
+ residual = residual.Rows(ctx, outputs)
|
|
|
+ }
|
|
|
+
|
|
|
hiddenState = hiddenState.Add(ctx, residual)
|
|
|
residual = hiddenState
|
|
|
|
|
@@ -170,6 +179,11 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
+ outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
|
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
|
|
|
|
@@ -182,7 +196,13 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|
|
m.Cache.SetLayer(i)
|
|
|
wc := m.Cache.(*kvcache.WrapperCache)
|
|
|
wc.SetLayerType(cacheType)
|
|
|
- hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
|
|
|
+
|
|
|
+ var lastLayerOutputs ml.Tensor
|
|
|
+ if i == len(m.Layers)-1 {
|
|
|
+ lastLayerOutputs = outputs
|
|
|
+ }
|
|
|
+
|
|
|
+ hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
|
|
}
|
|
|
|
|
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
|
@@ -192,12 +212,6 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|
|
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
|
|
hiddenState = hiddenState.Tanh(ctx)
|
|
|
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
|
|
|
-
|
|
|
- outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
-
|
|
|
return hiddenState.Rows(ctx, outputs), nil
|
|
|
}
|
|
|
|