model_text.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. package mllama
  2. import (
  3. "math"
  4. "slices"
  5. "github.com/ollama/ollama/kvcache"
  6. "github.com/ollama/ollama/ml"
  7. "github.com/ollama/ollama/ml/nn"
  8. )
  9. type TextSelfAttention struct {
  10. Query *nn.Linear `gguf:"attn_q"`
  11. Key *nn.Linear `gguf:"attn_k"`
  12. Value *nn.Linear `gguf:"attn_v"`
  13. Output *nn.Linear `gguf:"attn_output"`
  14. }
  15. func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  16. batchSize := hiddenState.Dim(1)
  17. headDim := opts.hiddenSize / opts.numHeads
  18. query := sa.Query.Forward(ctx, hiddenState)
  19. query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
  20. query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
  21. key := sa.Key.Forward(ctx, hiddenState)
  22. key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  23. key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
  24. value := sa.Value.Forward(ctx, hiddenState)
  25. value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  26. cache.Put(ctx, key, value)
  27. key, value, mask := cache.Get(ctx)
  28. query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  29. key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  30. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  31. scores := key.MulmatFullPrec(ctx, query)
  32. scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
  33. scores = scores.Add(ctx, mask)
  34. scores = scores.Softmax(ctx)
  35. attention := value.Mulmat(ctx, scores)
  36. attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  37. attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
  38. return sa.Output.Forward(ctx, attention)
  39. }
  40. func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  41. // This will only get called for layers in the cache, which are just the self attention layers
  42. return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
  43. }
  44. type TextMLP struct {
  45. Up *nn.Linear `gguf:"ffn_up"`
  46. Down *nn.Linear `gguf:"ffn_down"`
  47. Gate *nn.Linear `gguf:"ffn_gate"`
  48. }
  49. func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
  50. hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
  51. return mlp.Down.Forward(ctx, hiddenState)
  52. }
  53. type TextSelfAttentionDecoderLayer struct {
  54. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  55. SelfAttention *TextSelfAttention
  56. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  57. MLP *TextMLP
  58. }
  59. func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  60. residual := hiddenState
  61. hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  62. hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
  63. hiddenState = hiddenState.Add(ctx, residual)
  64. residual = hiddenState
  65. hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  66. hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
  67. return hiddenState.Add(ctx, residual)
  68. }
  69. type TextCrossAttention struct {
  70. QueryNorm *nn.RMSNorm `gguf:"cross_attn_q_norm"`
  71. Query *nn.Linear `gguf:"cross_attn_q_proj"`
  72. KeyNorm *nn.RMSNorm `gguf:"cross_attn_k_norm"`
  73. Key *nn.Linear `gguf:"cross_attn_k_proj"`
  74. Value *nn.Linear `gguf:"cross_attn_v_proj"`
  75. Output *nn.Linear `gguf:"cross_attn_o_proj"`
  76. }
  77. func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  78. batchSize := hiddenState.Dim(1)
  79. headDim := opts.hiddenSize / opts.numHeads
  80. query := ca.Query.Forward(ctx, hiddenState)
  81. query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
  82. query = ca.QueryNorm.Forward(ctx, query, opts.eps)
  83. var key, value ml.Tensor
  84. if crossAttentionStates != nil {
  85. numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
  86. key = ca.Key.Forward(ctx, crossAttentionStates)
  87. key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
  88. key = ca.KeyNorm.Forward(ctx, key, opts.eps)
  89. value = ca.Value.Forward(ctx, crossAttentionStates)
  90. value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
  91. cache.Put(ctx, key, value)
  92. } else {
  93. key, value, _ = cache.Get(ctx)
  94. }
  95. query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  96. key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  97. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  98. scores := key.Mulmat(ctx, query)
  99. scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
  100. scores = scores.Softmax(ctx)
  101. attention := value.Mulmat(ctx, scores)
  102. attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  103. attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
  104. return ca.Output.Forward(ctx, attention)
  105. }
  106. type TextCrossAttentionDecoderLayer struct {
  107. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  108. CrossAttention *TextCrossAttention
  109. AttentionGate ml.Tensor `gguf:"cross_attn_attn_gate"`
  110. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  111. MLP *TextMLP
  112. MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
  113. }
  114. func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  115. residual := hiddenState
  116. hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  117. hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, opts)
  118. hiddenState = hiddenState.Mul(ctx, d.AttentionGate.Tanh(ctx))
  119. hiddenState = hiddenState.Add(ctx, residual)
  120. residual = hiddenState
  121. hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  122. hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
  123. hiddenState = hiddenState.Mul(ctx, d.MLPGate.Tanh(ctx))
  124. return hiddenState.Add(ctx, residual)
  125. }
  126. type TextDecoderLayer interface {
  127. Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
  128. }
  129. type TextDecoder struct {
  130. Layers []TextDecoderLayer
  131. }
  132. func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  133. for i, layer := range d.Layers {
  134. layerType := selfAttentionLayer
  135. if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
  136. layerType = crossAttentionLayer
  137. }
  138. cache.SetLayer(i)
  139. cache.SetLayerType(layerType)
  140. if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
  141. hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
  142. }
  143. }
  144. return hiddenState
  145. }
  146. type TextModelOptions struct {
  147. RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
  148. hiddenSize, numHeads, numKVHeads int
  149. eps, ropeBase, ropeScale float32
  150. ropeDim uint32
  151. crossAttentionLayers []uint32
  152. }
  153. type TextModel struct {
  154. TokenEmbedding *nn.Embedding `gguf:"token_embd"`
  155. Transformer *TextDecoder `gguf:"blk"`
  156. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  157. Output *nn.Linear `gguf:"output"`
  158. *TextModelOptions
  159. }
  160. func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
  161. hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
  162. hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
  163. hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
  164. return m.Output.Forward(ctx, hiddenState)
  165. }
  166. func newTextModel(c ml.Config) *TextModel {
  167. var decoderLayers []TextDecoderLayer
  168. for i := range c.Uint("block_count") {
  169. var textDecoderLayer TextDecoderLayer
  170. if slices.Contains(c.Uints("attention.cross_attention_layers"), i) {
  171. textDecoderLayer = &TextCrossAttentionDecoderLayer{}
  172. } else {
  173. textDecoderLayer = &TextSelfAttentionDecoderLayer{}
  174. }
  175. decoderLayers = append(decoderLayers, textDecoderLayer)
  176. }
  177. return &TextModel{
  178. Transformer: &TextDecoder{Layers: decoderLayers},
  179. TextModelOptions: &TextModelOptions{
  180. hiddenSize: int(c.Uint("embedding_length")),
  181. numHeads: int(c.Uint("attention.head_count")),
  182. numKVHeads: int(c.Uint("attention.head_count_kv")),
  183. eps: c.Float("attention.layer_norm_rms_epsilon"),
  184. ropeBase: c.Float("rope.freq_base"),
  185. ropeScale: c.Float("rope.freq_scale", 1),
  186. ropeDim: c.Uint("rope.dimension_count"),
  187. crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
  188. },
  189. }
  190. }