model_text.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. package gemma3
  2. import (
  3. "math"
  4. "github.com/ollama/ollama/kvcache"
  5. "github.com/ollama/ollama/ml"
  6. "github.com/ollama/ollama/ml/nn"
  7. "github.com/ollama/ollama/model"
  8. )
  9. type TextOptions struct {
  10. hiddenSize, numHeads, numKVHeads int
  11. attnKeyLen, attnValLen int
  12. eps, ropeScale float32
  13. ropeLocalBase, ropeGlobalBase float32
  14. finalLogitSoftcap float32
  15. largeModelScaling bool
  16. }
  17. type TextModel struct {
  18. model.Base
  19. model.SentencePieceModel
  20. TokenEmbedding *nn.Embedding `gguf:"token_embd"`
  21. Layers []TextLayer `gguf:"blk"`
  22. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  23. Output *nn.Linear `gguf:"output,alt:token_embd"`
  24. *TextOptions
  25. }
  26. const (
  27. gemma27BLayerCount = 46
  28. )
  29. const (
  30. cacheTypeSWA = iota
  31. cacheTypeCausal
  32. )
  33. func newTextModel(c ml.Config) *TextModel {
  34. m := TextModel{
  35. SentencePieceModel: model.NewSentencePieceModel(
  36. c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
  37. &model.Vocabulary{
  38. Values: c.Strings("tokenizer.ggml.tokens"),
  39. Scores: c.Floats("tokenizer.ggml.scores"),
  40. Types: c.Uints("tokenizer.ggml.token_type"),
  41. BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
  42. EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
  43. },
  44. ),
  45. Layers: make([]TextLayer, c.Uint("block_count")),
  46. TextOptions: &TextOptions{
  47. hiddenSize: int(c.Uint("embedding_length")),
  48. numHeads: int(c.Uint("attention.head_count")),
  49. numKVHeads: int(c.Uint("attention.head_count_kv")),
  50. attnKeyLen: int(c.Uint("attention.key_length")),
  51. attnValLen: int(c.Uint("attention.value_length")),
  52. eps: c.Float("text.attention.layer_norm_rms_epsilon"),
  53. ropeLocalBase: c.Float("text.rope.local.freq_base", 10000.0),
  54. ropeGlobalBase: c.Float("text.rope.global.freq_base", 1000000.0),
  55. ropeScale: c.Float("text.rope.freq_scale", 1.0),
  56. finalLogitSoftcap: c.Float("text.final_logit_softcapping"),
  57. },
  58. }
  59. return &m
  60. }
  61. type TextSelfAttention struct {
  62. Query *nn.Linear `gguf:"attn_q"`
  63. QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
  64. Key *nn.Linear `gguf:"attn_k"`
  65. KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
  66. Value *nn.Linear `gguf:"attn_v"`
  67. Output *nn.Linear `gguf:"attn_output"`
  68. }
  69. func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
  70. batchSize := hiddenState.Dim(1)
  71. ropeType := uint32(2)
  72. ropeBase := opts.ropeLocalBase
  73. if (layer+1)%6 == 0 {
  74. ropeBase = opts.ropeGlobalBase
  75. }
  76. q := sa.Query.Forward(ctx, hiddenState)
  77. q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
  78. q = sa.QueryNorm.Forward(ctx, q, opts.eps)
  79. q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
  80. if opts.largeModelScaling {
  81. q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
  82. } else {
  83. q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
  84. }
  85. k := sa.Key.Forward(ctx, hiddenState)
  86. k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
  87. k = sa.KeyNorm.Forward(ctx, k, opts.eps)
  88. k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
  89. v := sa.Value.Forward(ctx, hiddenState)
  90. v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
  91. scaleFactor := 1.0
  92. kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
  93. kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
  94. return sa.Output.Forward(ctx, kqv)
  95. }
  96. func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  97. ropeBase := m.TextOptions.ropeLocalBase
  98. if (layer+1)%6 == 0 {
  99. ropeBase = m.TextOptions.ropeGlobalBase
  100. }
  101. return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
  102. }
  103. type TextMLP struct {
  104. Up *nn.Linear `gguf:"ffn_up"`
  105. Down *nn.Linear `gguf:"ffn_down"`
  106. Gate *nn.Linear `gguf:"ffn_gate"`
  107. }
  108. func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
  109. hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
  110. return mlp.Down.Forward(ctx, hiddenState)
  111. }
  112. type TextLayer struct {
  113. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  114. SelfAttention *TextSelfAttention
  115. PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
  116. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  117. MLP *TextMLP
  118. PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
  119. }
  120. func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
  121. residual := hiddenState
  122. hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  123. hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
  124. hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
  125. // In the final layer (outputs != nil), optimize by pruning to just the token positions
  126. // we need logits for.
  127. if outputs != nil {
  128. hiddenState = hiddenState.Rows(ctx, outputs)
  129. residual = residual.Rows(ctx, outputs)
  130. }
  131. hiddenState = hiddenState.Add(ctx, residual)
  132. residual = hiddenState
  133. hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  134. hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
  135. hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
  136. return hiddenState.Add(ctx, residual)
  137. }
  138. func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor {
  139. if embeddings == nil {
  140. embeddings = m.TokenEmbedding.Forward(ctx, inputs)
  141. }
  142. hiddenState := embeddings.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
  143. if len(m.Layers) == gemma27BLayerCount {
  144. m.TextOptions.largeModelScaling = true
  145. }
  146. for i, layer := range m.Layers {
  147. // gemma alternates between the sliding window (local) and causal (global)
  148. // kv cache every 6 layers
  149. cacheType := cacheTypeSWA
  150. if (i+1)%6 == 0 {
  151. cacheType = cacheTypeCausal
  152. }
  153. cache.SetLayer(i)
  154. wc := cache.(*kvcache.WrapperCache)
  155. wc.SetLayerType(cacheType)
  156. var lastLayerOutputs ml.Tensor
  157. if i == len(m.Layers)-1 {
  158. lastLayerOutputs = outputs
  159. }
  160. hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
  161. }
  162. hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
  163. hiddenState = m.Output.Forward(ctx, hiddenState)
  164. // final logit softcap
  165. hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap))
  166. hiddenState = hiddenState.Tanh(ctx)
  167. return hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
  168. }