瀏覽代碼

set non-causal attention

Michael Yang 1 月之前
父節點
當前提交
0df1800436
共有 6 個文件被更改,包括 57 次插入25 次删除
  1. 0 3
      convert/convert_gemma3.go
  2. 1 0
      ml/backend.go
  3. 14 0
      ml/backend/ggml/ggml.go
  4. 27 14
      model/models/gemma3/model.go
  5. 8 4
      model/models/gemma3/model_text.go
  6. 7 4
      model/process_text.go

+ 0 - 3
convert/convert_gemma3.go

@@ -58,9 +58,6 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
 	kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels
 	kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
 	kv["gemma3.vision.attention.layer_norm_epsilon"] = p.VisionModel.LayerNormEpsilon
-
-	kv["tokenizer.ggml.bos_token_id"] = uint32(2)
-	kv["tokenizer.ggml.eot_token_id"] = uint32(1)
 	return kv
 }
 

+ 1 - 0
ml/backend.go

@@ -148,6 +148,7 @@ type Tensor interface {
 	View(ctx Context, offset int, shape ...int) Tensor
 	Permute(ctx Context, shape ...int) Tensor
 	Contiguous(ctx Context) Tensor
+	Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
 
 	Pad(ctx Context, shape ...int) Tensor
 	Unpad(ctx Context, shape ...int) Tensor

+ 14 - 0
ml/backend/ggml/ggml.go

@@ -954,6 +954,20 @@ func (t *Tensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
 	}
 }
 
+func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
+	var tt *C.struct_ggml_tensor
+	switch len(strides) {
+	case 0:
+		tt = C.ggml_set_1d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
+	case 1:
+		tt = C.ggml_set_2d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
+	default:
+		panic("unsupported number of dimensions")
+	}
+
+	return &Tensor{b: t.b, t: tt}
+}
+
 func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
 	var kqMask *C.struct_ggml_tensor
 	if mask != nil {

+ 27 - 14
model/models/gemma3/model.go

@@ -51,8 +51,10 @@ func New(c ml.Config) (model.Model, error) {
 				Types:  c.Uints("tokenizer.ggml.token_type"),
 				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
 				AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
-				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
+				EOS:    int32(1),
 				AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
+				EOT:    int32(106),
+				AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false),
 			},
 		),
 		ImageProcessor: newImageProcessor(c),
@@ -109,35 +111,46 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
 
 	for i := range inputs {
 		if inputs[i].Multimodal == nil {
-			if len(images) > 0 {
-				inputs[i].Multimodal = images[0].Multimodal
-				inputs[i].MultimodalHash = images[0].MultimodalHash
-				for j := 1; j < len(images); j++ {
+			for j := range images {
+				if j == 0 {
+					inputs[i].Multimodal = images[j].Multimodal
+					inputs[i].MultimodalHash = images[j].MultimodalHash
+				} else {
 					inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
 					fnvHash.Reset()
 					binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
-					binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
+					binary.Write(fnvHash, binary.NativeEndian, images[j].MultimodalHash)
 					inputs[i].MultimodalHash = fnvHash.Sum64()
 				}
-				images = nil
 			}
+
+			images = nil
 		} else {
 			images = append(images, inputs[i])
 			inputs[i].Token = -1
 		}
 	}
 
-	inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
+	for i := range inputs {
+		if inputs[i].Token == -1 {
+			imageInputs := []input.Input{
+				{Token: 108},    // "\n\n"
+				{Token: 255999}, // "<start_of_image>""
+			}
+
+			// <image_soft_token>
+			imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 262144}}, 256)...)
+			// <end_of_image>
+			imageInputs = append(imageInputs, input.Input{Token: 256000})
+
+			inputs = append(inputs[:i], append(imageInputs, inputs[i+1:]...)...)
+		}
+	}
 
 	return inputs, nil
 }
 
 func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
-	var embeddings ml.Tensor
-	if opts.Multimodal != nil {
-		embeddings = opts.Multimodal[0].Multimodal.(ml.Tensor)
-	}
-
 	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
 	if err != nil {
 		return nil, err
@@ -153,7 +166,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	return m.TextModel.Forward(ctx, inputs, positions, embeddings, outputs, m.Cache), nil
+	return m.TextModel.Forward(ctx, inputs, positions, outputs, opts.Multimodal, m.Cache), nil
 }
 
 func init() {

+ 8 - 4
model/models/gemma3/model_text.go

@@ -7,6 +7,7 @@ import (
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml/nn"
 	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/model/input"
 )
 
 type TextOptions struct {
@@ -165,12 +166,15 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
 	return hiddenState.Add(ctx, residual)
 }
 
-func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor {
-	if embeddings == nil {
-		embeddings = m.TokenEmbedding.Forward(ctx, inputs)
+func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, multimodal []input.MultimodalIndex, cache kvcache.Cache) ml.Tensor {
+	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
+	if multimodal != nil {
+		visionOutputs := multimodal[0].Multimodal.(ml.Tensor)
+		offset := multimodal[0].Index - 1 - visionOutputs.Dim(1)
+		hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(0))
 	}
 
-	hiddenState := embeddings.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
+	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
 
 	if len(m.Layers) == gemma27BLayerCount {
 		m.TextOptions.largeModelScaling = true

+ 7 - 4
model/process_text.go

@@ -4,6 +4,7 @@ import (
 	"cmp"
 	"iter"
 	"log/slog"
+	"slices"
 	"strings"
 	"sync"
 
@@ -39,8 +40,8 @@ type Vocabulary struct {
 	Scores []float32
 	Merges []string
 
-	BOS, EOS       int32
-	AddBOS, AddEOS bool
+	BOS, EOS, EOT          int32
+	AddBOS, AddEOS, AddEOT bool
 
 	specialOnce sync.Once
 	special     []string
@@ -57,7 +58,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool {
 	case SpecialBOS:
 		return id == v.BOS
 	case SpecialEOS:
-		return id == v.EOS
+		return id == v.EOS || id == v.EOT
 	default:
 		return false
 	}
@@ -85,7 +86,9 @@ func (v *Vocabulary) Decode(id int32) string {
 func (v *Vocabulary) SpecialVocabulary() []string {
 	v.specialOnce.Do(func() {
 		for i := range v.Values {
-			if v.Types[i] == TOKEN_TYPE_CONTROL {
+			if slices.Contains([]int{105, 106}, i) {
+				v.special = append(v.special, v.Values[i])
+			} else if v.Types[i] == TOKEN_TYPE_CONTROL {
 				v.special = append(v.special, v.Values[i])
 			}
 		}