model_text.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. package gemma3
  2. import (
  3. "math"
  4. "github.com/ollama/ollama/kvcache"
  5. "github.com/ollama/ollama/ml"
  6. "github.com/ollama/ollama/ml/nn"
  7. "github.com/ollama/ollama/model"
  8. "github.com/ollama/ollama/model/input"
  9. )
  10. type TextOptions struct {
  11. hiddenSize, numHeads, numKVHeads int
  12. attnKeyLen, attnValLen int
  13. eps, ropeScale float32
  14. ropeLocalBase, ropeGlobalBase float32
  15. largeModelScaling bool
  16. }
  17. type TextModel struct {
  18. model.Base
  19. model.SentencePieceModel
  20. TokenEmbedding *nn.Embedding `gguf:"token_embd"`
  21. Layers []TextLayer `gguf:"blk"`
  22. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  23. Output *nn.Linear `gguf:"output,alt:token_embd"`
  24. *TextOptions
  25. }
  26. const (
  27. gemmaGlobalCacheCount = 6
  28. gemma27BLayerCount = 62
  29. )
  30. const (
  31. cacheTypeSWA = iota
  32. cacheTypeCausal
  33. )
  34. func newTextModel(c ml.Config) *TextModel {
  35. numBlocks := int(c.Uint("block_count"))
  36. m := TextModel{
  37. SentencePieceModel: model.NewSentencePieceModel(
  38. c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
  39. &model.Vocabulary{
  40. Values: c.Strings("tokenizer.ggml.tokens"),
  41. Scores: c.Floats("tokenizer.ggml.scores"),
  42. Types: c.Uints("tokenizer.ggml.token_type"),
  43. BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
  44. EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
  45. },
  46. ),
  47. Layers: make([]TextLayer, numBlocks),
  48. TextOptions: &TextOptions{
  49. hiddenSize: int(c.Uint("embedding_length")),
  50. numHeads: int(c.Uint("attention.head_count")),
  51. numKVHeads: int(c.Uint("attention.head_count_kv")),
  52. attnKeyLen: int(c.Uint("attention.key_length", 256)),
  53. attnValLen: int(c.Uint("attention.value_length", 256)),
  54. eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
  55. ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
  56. ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
  57. ropeScale: c.Float("rope.freq_scale", 1.0),
  58. },
  59. }
  60. if numBlocks == gemma27BLayerCount {
  61. m.largeModelScaling = true
  62. }
  63. return &m
  64. }
  65. type TextSelfAttention struct {
  66. Query *nn.Linear `gguf:"attn_q"`
  67. QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
  68. Key *nn.Linear `gguf:"attn_k"`
  69. KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
  70. Value *nn.Linear `gguf:"attn_v"`
  71. Output *nn.Linear `gguf:"attn_output"`
  72. }
  73. func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
  74. batchSize := hiddenState.Dim(1)
  75. ropeType := uint32(2)
  76. ropeBase := opts.ropeLocalBase
  77. if (layer+1)%gemmaGlobalCacheCount == 0 {
  78. ropeBase = opts.ropeGlobalBase
  79. }
  80. q := sa.Query.Forward(ctx, hiddenState)
  81. q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
  82. q = sa.QueryNorm.Forward(ctx, q, opts.eps)
  83. q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
  84. if opts.largeModelScaling {
  85. q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
  86. } else {
  87. q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
  88. }
  89. k := sa.Key.Forward(ctx, hiddenState)
  90. k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
  91. k = sa.KeyNorm.Forward(ctx, k, opts.eps)
  92. k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
  93. v := sa.Value.Forward(ctx, hiddenState)
  94. v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
  95. scaleFactor := 1.0
  96. kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
  97. kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
  98. return sa.Output.Forward(ctx, kqv)
  99. }
  100. func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  101. ropeBase := m.TextOptions.ropeLocalBase
  102. if (layer+1)%gemmaGlobalCacheCount == 0 {
  103. ropeBase = m.TextOptions.ropeGlobalBase
  104. }
  105. return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
  106. }
  107. type TextMLP struct {
  108. Up *nn.Linear `gguf:"ffn_up"`
  109. Down *nn.Linear `gguf:"ffn_down"`
  110. Gate *nn.Linear `gguf:"ffn_gate"`
  111. }
  112. func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
  113. hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
  114. return mlp.Down.Forward(ctx, hiddenState)
  115. }
  116. type TextLayer struct {
  117. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  118. SelfAttention *TextSelfAttention
  119. PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
  120. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  121. MLP *TextMLP
  122. PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
  123. }
  124. func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
  125. residual := hiddenState
  126. hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  127. hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
  128. hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
  129. // In the final layer (outputs != nil), optimize by pruning to just the token positions
  130. // we need logits for.
  131. if outputs != nil {
  132. hiddenState = hiddenState.Rows(ctx, outputs)
  133. residual = residual.Rows(ctx, outputs)
  134. }
  135. hiddenState = hiddenState.Add(ctx, residual)
  136. residual = hiddenState
  137. hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  138. hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
  139. hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
  140. return hiddenState.Add(ctx, residual)
  141. }
  142. func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
  143. hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
  144. hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
  145. // set image embeddings
  146. var except []int
  147. for _, image := range opts.Multimodal {
  148. visionOutputs := image.Multimodal.(ml.Tensor)
  149. ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
  150. for i := range visionOutputs.Dim(1) {
  151. except = append(except, image.Index+i)
  152. }
  153. }
  154. for i, layer := range m.Layers {
  155. // gemma alternates between the sliding window (local) and causal (global)
  156. // kv cache every 6 layers
  157. cacheType := cacheTypeSWA
  158. if (i+1)%gemmaGlobalCacheCount == 0 {
  159. cacheType = cacheTypeCausal
  160. }
  161. cache.SetLayer(i)
  162. wc := cache.(*kvcache.WrapperCache)
  163. wc.SetLayerType(cacheType)
  164. if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
  165. causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
  166. }
  167. var lastLayerOutputs ml.Tensor
  168. if i == len(m.Layers)-1 {
  169. lastLayerOutputs = outputs
  170. }
  171. hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
  172. }
  173. hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
  174. return m.Output.Forward(ctx, hiddenState)
  175. }