|
@@ -173,25 +173,54 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
|
|
return hiddenState.Add(ctx, residual)
|
|
|
}
|
|
|
|
|
|
-func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, multimodal []input.MultimodalIndex, cache kvcache.Cache) ml.Tensor {
|
|
|
- hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
|
|
- hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
|
|
-
|
|
|
- if multimodal != nil {
|
|
|
- visionOutputs := multimodal[0].Multimodal.(ml.Tensor)
|
|
|
- offset := multimodal[0].Index - 1 - visionOutputs.Dim(1)
|
|
|
- hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1))
|
|
|
+func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex, positions []int32) []int32 {
|
|
|
+ var embedding ml.Tensor
|
|
|
+ var src, dst, length int
|
|
|
+ var except []int32
|
|
|
+
|
|
|
+ for _, image := range multimodal {
|
|
|
+ imageToken := image.Multimodal.(imageToken)
|
|
|
+ imageSrc := imageToken.index
|
|
|
+ imageDst := image.Index
|
|
|
+
|
|
|
+ if embedding == nil {
|
|
|
+ embedding = imageToken.embedding
|
|
|
+ src = imageSrc
|
|
|
+ dst = imageDst
|
|
|
+ length = 1
|
|
|
+ } else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
|
|
|
+ src = imageSrc
|
|
|
+ dst = imageDst
|
|
|
+ length++
|
|
|
+ } else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
|
|
|
+ length++
|
|
|
+ } else {
|
|
|
+ visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
|
|
|
+ ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
|
|
|
+
|
|
|
+ embedding = imageToken.embedding
|
|
|
+ src = imageSrc
|
|
|
+ dst = imageDst
|
|
|
+ length = 1
|
|
|
+ }
|
|
|
|
|
|
- if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok {
|
|
|
- except := make([]int32, visionOutputs.Dim(1))
|
|
|
- for i := 0; i < visionOutputs.Dim(1); i++ {
|
|
|
- except[i] = int32(offset + i)
|
|
|
- }
|
|
|
+ except = append(except, positions[imageDst])
|
|
|
+ }
|
|
|
|
|
|
- causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
|
|
- }
|
|
|
+ if embedding != nil {
|
|
|
+ visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
|
|
|
+ ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
|
|
|
}
|
|
|
|
|
|
+ return except
|
|
|
+}
|
|
|
+
|
|
|
+func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
|
|
+ hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
|
|
+ hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
|
|
+
|
|
|
+ except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal, opts.Positions)
|
|
|
+
|
|
|
for i, layer := range m.Layers {
|
|
|
// gemma alternates between the sliding window (local) and causal (global)
|
|
|
// kv cache every 6 layers
|
|
@@ -203,6 +232,10 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
|
|
wc := cache.(*kvcache.WrapperCache)
|
|
|
wc.SetLayerType(cacheType)
|
|
|
|
|
|
+ if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
|
|
+ causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
|
|
+ }
|
|
|
+
|
|
|
var lastLayerOutputs ml.Tensor
|
|
|
if i == len(m.Layers)-1 {
|
|
|
lastLayerOutputs = outputs
|