model_text.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. package gemma3
  2. import (
  3. "math"
  4. "github.com/ollama/ollama/kvcache"
  5. "github.com/ollama/ollama/ml"
  6. "github.com/ollama/ollama/ml/nn"
  7. "github.com/ollama/ollama/model"
  8. "github.com/ollama/ollama/model/input"
  9. )
  10. type TextOptions struct {
  11. hiddenSize, numHeads, numKVHeads int
  12. attnKeyLen, attnValLen int
  13. eps, ropeScale float32
  14. ropeLocalBase, ropeGlobalBase float32
  15. finalLogitSoftcap float32
  16. largeModelScaling bool
  17. }
  18. type TextModel struct {
  19. model.Base
  20. model.SentencePieceModel
  21. TokenEmbedding *nn.Embedding `gguf:"token_embd"`
  22. Layers []TextLayer `gguf:"blk"`
  23. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  24. Output *nn.Linear `gguf:"output,alt:token_embd"`
  25. *TextOptions
  26. }
  27. const (
  28. gemmaGlobalCacheCount = 6
  29. gemma27BLayerCount = 62
  30. )
  31. const (
  32. cacheTypeSWA = iota
  33. cacheTypeCausal
  34. )
  35. func newTextModel(c ml.Config) *TextModel {
  36. numBlocks := int(c.Uint("block_count"))
  37. m := TextModel{
  38. SentencePieceModel: model.NewSentencePieceModel(
  39. c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
  40. &model.Vocabulary{
  41. Values: c.Strings("tokenizer.ggml.tokens"),
  42. Scores: c.Floats("tokenizer.ggml.scores"),
  43. Types: c.Uints("tokenizer.ggml.token_type"),
  44. BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
  45. EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
  46. },
  47. ),
  48. Layers: make([]TextLayer, numBlocks),
  49. TextOptions: &TextOptions{
  50. hiddenSize: int(c.Uint("embedding_length")),
  51. numHeads: int(c.Uint("attention.head_count")),
  52. numKVHeads: int(c.Uint("attention.head_count_kv")),
  53. attnKeyLen: int(c.Uint("attention.key_length", 256)),
  54. attnValLen: int(c.Uint("attention.value_length", 256)),
  55. eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
  56. ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
  57. ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
  58. ropeScale: c.Float("rope.freq_scale", 1.0),
  59. finalLogitSoftcap: c.Float("final_logit_softcapping", 30.0),
  60. },
  61. }
  62. if numBlocks == gemma27BLayerCount {
  63. m.largeModelScaling = true
  64. }
  65. return &m
  66. }
  67. type TextSelfAttention struct {
  68. Query *nn.Linear `gguf:"attn_q"`
  69. QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
  70. Key *nn.Linear `gguf:"attn_k"`
  71. KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
  72. Value *nn.Linear `gguf:"attn_v"`
  73. Output *nn.Linear `gguf:"attn_output"`
  74. }
  75. func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
  76. batchSize := hiddenState.Dim(1)
  77. ropeType := uint32(2)
  78. ropeBase := opts.ropeLocalBase
  79. if (layer+1)%gemmaGlobalCacheCount == 0 {
  80. ropeBase = opts.ropeGlobalBase
  81. }
  82. q := sa.Query.Forward(ctx, hiddenState)
  83. q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
  84. q = sa.QueryNorm.Forward(ctx, q, opts.eps)
  85. q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
  86. if opts.largeModelScaling {
  87. q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
  88. } else {
  89. q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
  90. }
  91. k := sa.Key.Forward(ctx, hiddenState)
  92. k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
  93. k = sa.KeyNorm.Forward(ctx, k, opts.eps)
  94. k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
  95. v := sa.Value.Forward(ctx, hiddenState)
  96. v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
  97. scaleFactor := 1.0
  98. kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
  99. kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
  100. return sa.Output.Forward(ctx, kqv)
  101. }
  102. func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  103. ropeBase := m.TextOptions.ropeLocalBase
  104. if (layer+1)%gemmaGlobalCacheCount == 0 {
  105. ropeBase = m.TextOptions.ropeGlobalBase
  106. }
  107. return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
  108. }
  109. type TextMLP struct {
  110. Up *nn.Linear `gguf:"ffn_up"`
  111. Down *nn.Linear `gguf:"ffn_down"`
  112. Gate *nn.Linear `gguf:"ffn_gate"`
  113. }
  114. func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
  115. hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
  116. return mlp.Down.Forward(ctx, hiddenState)
  117. }
  118. type TextLayer struct {
  119. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  120. SelfAttention *TextSelfAttention
  121. PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
  122. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  123. MLP *TextMLP
  124. PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
  125. }
  126. func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
  127. residual := hiddenState
  128. hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  129. hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
  130. hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
  131. // In the final layer (outputs != nil), optimize by pruning to just the token positions
  132. // we need logits for.
  133. if outputs != nil {
  134. hiddenState = hiddenState.Rows(ctx, outputs)
  135. residual = residual.Rows(ctx, outputs)
  136. }
  137. hiddenState = hiddenState.Add(ctx, residual)
  138. residual = hiddenState
  139. hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  140. hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
  141. hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
  142. return hiddenState.Add(ctx, residual)
  143. }
  144. func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
  145. var embedding ml.Tensor
  146. var src, dst, length int
  147. var except []int
  148. for _, image := range multimodal {
  149. imageToken := image.Multimodal.(imageToken)
  150. imageSrc := imageToken.index
  151. imageDst := image.Index
  152. if embedding == nil {
  153. embedding = imageToken.embedding
  154. src = imageSrc
  155. dst = imageDst
  156. length = 1
  157. } else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
  158. src = imageSrc
  159. dst = imageDst
  160. length++
  161. } else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
  162. length++
  163. } else {
  164. visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
  165. ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
  166. embedding = imageToken.embedding
  167. src = imageSrc
  168. dst = imageDst
  169. length = 1
  170. }
  171. except = append(except, imageDst)
  172. }
  173. if embedding != nil {
  174. visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
  175. ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
  176. }
  177. return except
  178. }
  179. func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
  180. hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
  181. hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
  182. except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
  183. for i, layer := range m.Layers {
  184. // gemma alternates between the sliding window (local) and causal (global)
  185. // kv cache every 6 layers
  186. cacheType := cacheTypeSWA
  187. if (i+1)%gemmaGlobalCacheCount == 0 {
  188. cacheType = cacheTypeCausal
  189. }
  190. cache.SetLayer(i)
  191. wc := cache.(*kvcache.WrapperCache)
  192. wc.SetLayerType(cacheType)
  193. if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
  194. causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
  195. }
  196. var lastLayerOutputs ml.Tensor
  197. if i == len(m.Layers)-1 {
  198. lastLayerOutputs = outputs
  199. }
  200. hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
  201. }
  202. hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
  203. hiddenState = m.Output.Forward(ctx, hiddenState)
  204. // final logit softcap
  205. hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap))
  206. hiddenState = hiddenState.Tanh(ctx)
  207. return hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
  208. }