Michael Yang hai 2 meses
pai
achega
760e8fa656
Modificáronse 1 ficheiros con 56 adicións e 28 borrados
  1. 56 28
      model/bert/model.go

+ 56 - 28
model/bert/model.go

@@ -1,7 +1,6 @@
 package bert
 package bert
 
 
 import (
 import (
-	"fmt"
 	"math"
 	"math"
 
 
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml"
@@ -13,21 +12,32 @@ func init() {
 	model.Register("bert", New)
 	model.Register("bert", New)
 }
 }
 
 
+type PoolingType int
+
+const (
+	PoolingTypeNone PoolingType = iota
+	PoolingTypeMean
+	PoolingTypeCLS
+	PoolingTypeLast
+	PoolingTypeRank
+)
+
 type Options struct {
 type Options struct {
 	hiddenSize, numHeads int64
 	hiddenSize, numHeads int64
 	eps                  float32
 	eps                  float32
+	poolingType          PoolingType
 }
 }
 
 
 type Model struct {
 type Model struct {
 	model.Base
 	model.Base
 	model.BytePairEncoding
 	model.BytePairEncoding
 
 
-	TokenEmbedding     *nn.Embedding `ggml:"token_embd"`
-	TypeEmbedding      *nn.Embedding `ggml:"type_embd,alt:token_types"`
-	PositionEmbedding  *nn.Embedding `ggml:"position_embd"`
-	TokenEmbeddingNorm *nn.LayerNorm `ggml:"token_embd_norm"`
+	TokenEmbedding     *nn.Embedding `gguf:"token_embd"`
+	TypeEmbedding      *nn.Embedding `gguf:"type_embd,alt:token_types"`
+	PositionEmbedding  *nn.Embedding `gguf:"position_embd"`
+	TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
 
 
-	Layers []EncoderLayer `ggml:"blk"`
+	Layers []EncoderLayer `gguf:"blk"`
 
 
 	*Options
 	*Options
 }
 }
@@ -38,33 +48,49 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	fmt.Println("inputs", inputs.Shape(), ml.Dump(inputs))
 
 
 	types, err := ctx.FromIntSlice([]int32{0}, 1)
 	types, err := ctx.FromIntSlice([]int32{0}, 1)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	fmt.Println("types", types.Shape(), ml.Dump(types))
 
 
 	positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
 	positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	fmt.Println("positions", positions.Shape(), ml.Dump(positions))
 
 
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
-	fmt.Println("TokenEmbedding.Forward", hiddenState.Shape(), ml.Dump(hiddenState))
-	return hiddenState, nil
 	hiddenState = hiddenState.Add(ctx, m.TypeEmbedding.Forward(ctx, types))
 	hiddenState = hiddenState.Add(ctx, m.TypeEmbedding.Forward(ctx, types))
-	fmt.Println("TypeEmbedding.Forward", hiddenState.Shape(), ml.Dump(hiddenState))
 	hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positions))
 	hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positions))
-	fmt.Println("PositionEmbedding.Forward", hiddenState.Shape(), ml.Dump(hiddenState))
 	hiddenState = m.TokenEmbeddingNorm.Forward(ctx, hiddenState, m.eps)
 	hiddenState = m.TokenEmbeddingNorm.Forward(ctx, hiddenState, m.eps)
-	fmt.Println("TokenEmbeddingNorm.Forward", hiddenState.Shape(), ml.Dump(hiddenState))
 
 
 	for i, layer := range m.Layers {
 	for i, layer := range m.Layers {
 		hiddenState = layer.Forward(ctx, hiddenState, positions, opts.Cache.Sub(i), m.Options)
 		hiddenState = layer.Forward(ctx, hiddenState, positions, opts.Cache.Sub(i), m.Options)
-		fmt.Println("EncoderLayer.Forward", i, hiddenState.Shape(), ml.Dump(hiddenState))
+	}
+
+	switch m.poolingType {
+	case PoolingTypeMean:
+		sum := func(s []int32) (sum int32) {
+			for _, v := range s {
+				sum += v
+			}
+
+			return
+		}
+
+		// TODO: handle batch
+		f32s := make([]float32, len(opts.Positions())*len(opts.Positions()))
+		for i := range opts.Positions() {
+			f32s[i] = 1 / float32(sum(opts.Positions()))
+		}
+
+		means, err := ctx.FromFloatSlice(f32s, len(opts.Positions()), len(opts.Positions()))
+		if err != nil {
+			return nil, err
+		}
+
+		hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+		hiddenState = hiddenState.Mulmat(ctx, means)
 	}
 	}
 
 
 	return hiddenState, nil
 	return hiddenState, nil
@@ -72,9 +98,9 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
 
 
 type EncoderLayer struct {
 type EncoderLayer struct {
 	*SelfAttention
 	*SelfAttention
-	MLPNorm *nn.LayerNorm `ggml:"attn_output_norm"`
+	MLPNorm *nn.LayerNorm `gguf:"attn_output_norm"`
 	*MLP
 	*MLP
-	LayerOutputNorm *nn.LayerNorm `ggml:"ffn_output_norm"`
+	LayerOutputNorm *nn.LayerNorm `gguf:"layer_output_norm"`
 }
 }
 
 
 func (e *EncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
 func (e *EncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
@@ -82,19 +108,19 @@ func (e *EncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tenso
 
 
 	hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
 	hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
 	hiddenState = hiddenState.Add(ctx, residual)
 	hiddenState = hiddenState.Add(ctx, residual)
+	hiddenState = e.MLPNorm.Forward(ctx, hiddenState, opts.eps)
 	residual = hiddenState
 	residual = hiddenState
 
 
-	hiddenState = e.MLPNorm.Forward(ctx, hiddenState, opts.eps)
 	hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
 	hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
 	hiddenState = hiddenState.Add(ctx, residual)
 	hiddenState = hiddenState.Add(ctx, residual)
 	return e.LayerOutputNorm.Forward(ctx, hiddenState, opts.eps)
 	return e.LayerOutputNorm.Forward(ctx, hiddenState, opts.eps)
 }
 }
 
 
 type SelfAttention struct {
 type SelfAttention struct {
-	Query  *nn.Linear `ggml:"attn_q"`
-	Key    *nn.Linear `ggml:"attn_k"`
-	Value  *nn.Linear `ggml:"attn_v"`
-	Output *nn.Linear `ggml:"attn_output"`
+	Query  *nn.Linear `gguf:"attn_q"`
+	Key    *nn.Linear `gguf:"attn_k"`
+	Value  *nn.Linear `gguf:"attn_v"`
+	Output *nn.Linear `gguf:"attn_output"`
 }
 }
 
 
 func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
 func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
@@ -105,7 +131,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
 	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
 	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
 
 
 	key := sa.Key.Forward(ctx, hiddenState)
 	key := sa.Key.Forward(ctx, hiddenState)
-	key = key.Reshape(ctx, opts.numHeads, headDim, batchSize)
+	key = key.Reshape(ctx, headDim, opts.numHeads, batchSize)
 
 
 	value := sa.Value.Forward(ctx, hiddenState)
 	value := sa.Value.Forward(ctx, hiddenState)
 	value = value.Reshape(ctx, headDim, opts.numHeads, batchSize)
 	value = value.Reshape(ctx, headDim, opts.numHeads, batchSize)
@@ -128,8 +154,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
 }
 }
 
 
 type MLP struct {
 type MLP struct {
-	Up   *nn.Linear `ggml:"ffn_up"`
-	Down *nn.Linear `ggml:"ffn_down"`
+	Up   *nn.Linear `gguf:"ffn_up"`
+	Down *nn.Linear `gguf:"ffn_down"`
 }
 }
 
 
 func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
 func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
@@ -138,6 +164,7 @@ func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml
 
 
 func New(c ml.Config) (model.Model, error) {
 func New(c ml.Config) (model.Model, error) {
 	return &Model{
 	return &Model{
+		Layers: make([]EncoderLayer, c.Uint("block_count")),
 		BytePairEncoding: model.NewBytePairEncoding(
 		BytePairEncoding: model.NewBytePairEncoding(
 			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
 			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
 			&model.Vocabulary{
 			&model.Vocabulary{
@@ -149,9 +176,10 @@ func New(c ml.Config) (model.Model, error) {
 			},
 			},
 		),
 		),
 		Options: &Options{
 		Options: &Options{
-			hiddenSize: int64(c.Uint("embedding_length")),
-			numHeads:   int64(c.Uint("attention.head_count")),
-			eps:        c.Float("attention.layer_norm_epsilon"),
+			hiddenSize:  int64(c.Uint("embedding_length")),
+			numHeads:    int64(c.Uint("attention.head_count")),
+			eps:         c.Float("attention.layer_norm_epsilon"),
+			poolingType: PoolingType(c.Uint("pooling_type")),
 		},
 		},
 	}, nil
 	}, nil
 }
 }