model_text.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. package mllama
  2. import (
  3. "math"
  4. "slices"
  5. "github.com/ollama/ollama/kvcache"
  6. "github.com/ollama/ollama/ml"
  7. "github.com/ollama/ollama/ml/nn"
  8. )
  9. type TextSelfAttention struct {
  10. Query *nn.Linear `gguf:"attn_q"`
  11. Key *nn.Linear `gguf:"attn_k"`
  12. Value *nn.Linear `gguf:"attn_v"`
  13. Output *nn.Linear `gguf:"attn_output"`
  14. }
  15. func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  16. batchSize := hiddenState.Dim(1)
  17. headDim := opts.hiddenSize / opts.numHeads
  18. ropeType := uint32(0)
  19. query := sa.Query.Forward(ctx, hiddenState)
  20. query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
  21. query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
  22. key := sa.Key.Forward(ctx, hiddenState)
  23. key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  24. key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
  25. value := sa.Value.Forward(ctx, hiddenState)
  26. value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  27. cache.Put(ctx, key, value)
  28. key, value, mask := cache.Get(ctx)
  29. query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  30. key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  31. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  32. scores := key.MulmatFullPrec(ctx, query)
  33. scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
  34. scores = scores.Add(ctx, mask)
  35. scores = scores.Softmax(ctx)
  36. attention := value.Mulmat(ctx, scores)
  37. attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  38. attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
  39. return sa.Output.Forward(ctx, attention)
  40. }
  41. func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  42. // This will only get called for layers in the cache, which are just the self attention layers
  43. return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
  44. }
  45. type TextMLP struct {
  46. Up *nn.Linear `gguf:"ffn_up"`
  47. Down *nn.Linear `gguf:"ffn_down"`
  48. Gate *nn.Linear `gguf:"ffn_gate"`
  49. }
  50. func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
  51. hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
  52. return mlp.Down.Forward(ctx, hiddenState)
  53. }
  54. type TextSelfAttentionDecoderLayer struct {
  55. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  56. SelfAttention *TextSelfAttention
  57. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  58. MLP *TextMLP
  59. }
  60. func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  61. residual := hiddenState
  62. hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  63. hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
  64. hiddenState = hiddenState.Add(ctx, residual)
  65. residual = hiddenState
  66. hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  67. hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
  68. return hiddenState.Add(ctx, residual)
  69. }
  70. type TextCrossAttention struct {
  71. QueryNorm *nn.RMSNorm `gguf:"cross_attn_q_norm"`
  72. Query *nn.Linear `gguf:"cross_attn_q_proj"`
  73. KeyNorm *nn.RMSNorm `gguf:"cross_attn_k_norm"`
  74. Key *nn.Linear `gguf:"cross_attn_k_proj"`
  75. Value *nn.Linear `gguf:"cross_attn_v_proj"`
  76. Output *nn.Linear `gguf:"cross_attn_o_proj"`
  77. }
  78. func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  79. batchSize := hiddenState.Dim(1)
  80. headDim := opts.hiddenSize / opts.numHeads
  81. query := ca.Query.Forward(ctx, hiddenState)
  82. query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
  83. query = ca.QueryNorm.Forward(ctx, query, opts.eps)
  84. var key, value ml.Tensor
  85. if crossAttentionStates != nil {
  86. numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
  87. key = ca.Key.Forward(ctx, crossAttentionStates)
  88. key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
  89. key = ca.KeyNorm.Forward(ctx, key, opts.eps)
  90. value = ca.Value.Forward(ctx, crossAttentionStates)
  91. value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
  92. cache.Put(ctx, key, value)
  93. } else {
  94. key, value, _ = cache.Get(ctx)
  95. }
  96. query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  97. key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  98. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  99. scores := key.Mulmat(ctx, query)
  100. scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
  101. scores = scores.Softmax(ctx)
  102. attention := value.Mulmat(ctx, scores)
  103. attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  104. attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
  105. return ca.Output.Forward(ctx, attention)
  106. }
  107. type TextCrossAttentionDecoderLayer struct {
  108. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  109. CrossAttention *TextCrossAttention
  110. AttentionGate ml.Tensor `gguf:"cross_attn_attn_gate"`
  111. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  112. MLP *TextMLP
  113. MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
  114. }
  115. func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  116. residual := hiddenState
  117. hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  118. hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, opts)
  119. hiddenState = hiddenState.Mul(ctx, d.AttentionGate.Tanh(ctx))
  120. hiddenState = hiddenState.Add(ctx, residual)
  121. residual = hiddenState
  122. hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  123. hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
  124. hiddenState = hiddenState.Mul(ctx, d.MLPGate.Tanh(ctx))
  125. return hiddenState.Add(ctx, residual)
  126. }
  127. type TextDecoderLayer interface {
  128. Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
  129. }
  130. type TextDecoder struct {
  131. Layers []TextDecoderLayer
  132. }
  133. func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  134. for i, layer := range d.Layers {
  135. layerType := selfAttentionLayer
  136. if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
  137. layerType = crossAttentionLayer
  138. }
  139. cache.SetLayer(i)
  140. cache.SetLayerType(layerType)
  141. if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
  142. hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
  143. }
  144. }
  145. return hiddenState
  146. }
  147. type TextModelOptions struct {
  148. RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
  149. hiddenSize, numHeads, numKVHeads int
  150. eps, ropeBase, ropeScale float32
  151. ropeDim uint32
  152. crossAttentionLayers []uint32
  153. }
  154. type TextModel struct {
  155. TokenEmbedding *nn.Embedding `gguf:"token_embd"`
  156. Transformer *TextDecoder `gguf:"blk"`
  157. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  158. Output *nn.Linear `gguf:"output"`
  159. *TextModelOptions
  160. }
  161. func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
  162. hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
  163. hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
  164. hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
  165. return m.Output.Forward(ctx, hiddenState)
  166. }
  167. func newTextModel(c ml.Config) *TextModel {
  168. var decoderLayers []TextDecoderLayer
  169. for i := range c.Uint("block_count") {
  170. var textDecoderLayer TextDecoderLayer
  171. if slices.Contains(c.Uints("attention.cross_attention_layers"), i) {
  172. textDecoderLayer = &TextCrossAttentionDecoderLayer{}
  173. } else {
  174. textDecoderLayer = &TextSelfAttentionDecoderLayer{}
  175. }
  176. decoderLayers = append(decoderLayers, textDecoderLayer)
  177. }
  178. return &TextModel{
  179. Transformer: &TextDecoder{Layers: decoderLayers},
  180. TextModelOptions: &TextModelOptions{
  181. hiddenSize: int(c.Uint("embedding_length")),
  182. numHeads: int(c.Uint("attention.head_count")),
  183. numKVHeads: int(c.Uint("attention.head_count_kv")),
  184. eps: c.Float("attention.layer_norm_rms_epsilon"),
  185. ropeBase: c.Float("rope.freq_base"),
  186. ropeScale: c.Float("rope.freq_scale", 1),
  187. ropeDim: c.Uint("rope.dimension_count"),
  188. crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
  189. },
  190. }
  191. }