|
@@ -171,53 +171,20 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
|
return hiddenState.Add(ctx, residual)
|
|
return hiddenState.Add(ctx, residual)
|
|
}
|
|
}
|
|
|
|
|
|
-func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
|
|
|
|
- var embedding ml.Tensor
|
|
|
|
- var src, dst, length int
|
|
|
|
- var except []int
|
|
|
|
-
|
|
|
|
- for _, image := range multimodal {
|
|
|
|
- imageToken := image.Multimodal.(imageToken)
|
|
|
|
- imageSrc := imageToken.index
|
|
|
|
- imageDst := image.Index
|
|
|
|
-
|
|
|
|
- if embedding == nil {
|
|
|
|
- embedding = imageToken.embedding
|
|
|
|
- src = imageSrc
|
|
|
|
- dst = imageDst
|
|
|
|
- length = 1
|
|
|
|
- } else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
|
|
|
|
- src = imageSrc
|
|
|
|
- dst = imageDst
|
|
|
|
- length++
|
|
|
|
- } else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
|
|
|
|
- length++
|
|
|
|
- } else {
|
|
|
|
- visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
|
|
|
|
- ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
|
|
|
|
-
|
|
|
|
- embedding = imageToken.embedding
|
|
|
|
- src = imageSrc
|
|
|
|
- dst = imageDst
|
|
|
|
- length = 1
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- except = append(except, imageDst)
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if embedding != nil {
|
|
|
|
- visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
|
|
|
|
- ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- return except
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
|
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
|
|
|
|
|
- except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
|
|
|
|
|
|
+ // set image embeddings
|
|
|
|
+ var except []int
|
|
|
|
+ for _, image := range opts.Multimodal {
|
|
|
|
+ visionOutputs := image.Multimodal.(ml.Tensor)
|
|
|
|
+ ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
|
|
|
+
|
|
|
|
+ for i := range visionOutputs.Dim(1) {
|
|
|
|
+ except = append(except, image.Index+i)
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
|
|
for i, layer := range m.Layers {
|
|
for i, layer := range m.Layers {
|
|
// gemma alternates between the sliding window (local) and causal (global)
|
|
// gemma alternates between the sliding window (local) and causal (global)
|