model_text.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. package mllama
  2. import (
  3. "math"
  4. "slices"
  5. "github.com/ollama/ollama/kvcache"
  6. "github.com/ollama/ollama/ml"
  7. "github.com/ollama/ollama/ml/nn"
  8. )
  9. type TextSelfAttention struct {
  10. Query *nn.Linear `gguf:"attn_q"`
  11. Key *nn.Linear `gguf:"attn_k"`
  12. Value *nn.Linear `gguf:"attn_v"`
  13. Output *nn.Linear `gguf:"attn_output"`
  14. }
  15. func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  16. batchSize := hiddenState.Dim(1)
  17. headDim := opts.hiddenSize / opts.numHeads
  18. query := sa.Query.Forward(ctx, hiddenState)
  19. query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
  20. query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
  21. key := sa.Key.Forward(ctx, hiddenState)
  22. key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  23. key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
  24. value := sa.Value.Forward(ctx, hiddenState)
  25. value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  26. cache.Put(ctx, key, value)
  27. key, value, mask := cache.Get(ctx)
  28. query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  29. key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  30. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  31. scaleFactor := 1.0 / math.Sqrt(float64(headDim))
  32. attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
  33. attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
  34. return sa.Output.Forward(ctx, attention)
  35. }
  36. func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  37. // This will only get called for layers in the cache, which are just the self attention layers
  38. return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
  39. }
  40. type TextMLP struct {
  41. Up *nn.Linear `gguf:"ffn_up"`
  42. Down *nn.Linear `gguf:"ffn_down"`
  43. Gate *nn.Linear `gguf:"ffn_gate"`
  44. }
  45. func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
  46. hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
  47. return mlp.Down.Forward(ctx, hiddenState)
  48. }
  49. type TextSelfAttentionDecoderLayer struct {
  50. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  51. SelfAttention *TextSelfAttention
  52. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  53. MLP *TextMLP
  54. }
  55. func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, outputs, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  56. residual := hiddenState
  57. hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  58. hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
  59. // In the final layer (outputs != nil), optimize by pruning to just the token positions
  60. // we need logits for.
  61. if outputs != nil {
  62. hiddenState = hiddenState.Rows(ctx, outputs)
  63. residual = residual.Rows(ctx, outputs)
  64. }
  65. hiddenState = hiddenState.Add(ctx, residual)
  66. residual = hiddenState
  67. hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  68. hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
  69. return hiddenState.Add(ctx, residual)
  70. }
  71. type TextCrossAttention struct {
  72. QueryNorm *nn.RMSNorm `gguf:"cross_attn_q_norm"`
  73. Query *nn.Linear `gguf:"cross_attn_q_proj"`
  74. KeyNorm *nn.RMSNorm `gguf:"cross_attn_k_norm"`
  75. Key *nn.Linear `gguf:"cross_attn_k_proj"`
  76. Value *nn.Linear `gguf:"cross_attn_v_proj"`
  77. Output *nn.Linear `gguf:"cross_attn_o_proj"`
  78. }
  79. func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  80. batchSize := hiddenState.Dim(1)
  81. headDim := opts.hiddenSize / opts.numHeads
  82. query := ca.Query.Forward(ctx, hiddenState)
  83. query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
  84. query = ca.QueryNorm.Forward(ctx, query, opts.eps)
  85. var key, value, mask ml.Tensor
  86. if crossAttentionStates != nil {
  87. numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
  88. key = ca.Key.Forward(ctx, crossAttentionStates)
  89. key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
  90. key = ca.KeyNorm.Forward(ctx, key, opts.eps)
  91. value = ca.Value.Forward(ctx, crossAttentionStates)
  92. value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
  93. cache.Put(ctx, key, value)
  94. } else {
  95. key, value, mask = cache.Get(ctx)
  96. }
  97. query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  98. key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  99. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  100. scaleFactor := 1.0 / math.Sqrt(float64(headDim))
  101. attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
  102. attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
  103. return ca.Output.Forward(ctx, attention)
  104. }
  105. type TextCrossAttentionDecoderLayer struct {
  106. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  107. CrossAttention *TextCrossAttention
  108. AttentionGate ml.Tensor `gguf:"cross_attn_attn_gate"`
  109. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  110. MLP *TextMLP
  111. MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
  112. }
  113. func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  114. residual := hiddenState
  115. hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  116. hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, opts)
  117. hiddenState = hiddenState.Mul(ctx, d.AttentionGate.Tanh(ctx))
  118. hiddenState = hiddenState.Add(ctx, residual)
  119. residual = hiddenState
  120. hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  121. hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
  122. hiddenState = hiddenState.Mul(ctx, d.MLPGate.Tanh(ctx))
  123. return hiddenState.Add(ctx, residual)
  124. }
  125. type TextDecoderLayer interface {
  126. Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
  127. }
  128. type TextDecoder struct {
  129. Layers []TextDecoderLayer
  130. }
  131. func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  132. for i, layer := range d.Layers {
  133. layerType := selfAttentionLayer
  134. if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
  135. layerType = crossAttentionLayer
  136. }
  137. cache.SetLayer(i)
  138. cache.SetLayerType(layerType)
  139. if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
  140. var lastLayerOutputs ml.Tensor
  141. if i == len(d.Layers)-1 {
  142. lastLayerOutputs = outputs
  143. }
  144. hiddenState = layer.Forward(ctx, hiddenState, positionIDs, lastLayerOutputs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
  145. }
  146. }
  147. return hiddenState
  148. }
  149. type TextModelOptions struct {
  150. RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
  151. hiddenSize, numHeads, numKVHeads int
  152. eps, ropeBase, ropeScale float32
  153. ropeDim uint32
  154. crossAttentionLayers []uint32
  155. }
  156. type TextModel struct {
  157. TokenEmbedding *nn.Embedding `gguf:"token_embd"`
  158. Transformer *TextDecoder `gguf:"blk"`
  159. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  160. Output *nn.Linear `gguf:"output"`
  161. *TextModelOptions
  162. }
  163. func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
  164. hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
  165. hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
  166. hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
  167. return m.Output.Forward(ctx, hiddenState)
  168. }
  169. func newTextModel(c ml.Config) *TextModel {
  170. var decoderLayers []TextDecoderLayer
  171. for i := range c.Uint("block_count") {
  172. var textDecoderLayer TextDecoderLayer
  173. if slices.Contains(c.Uints("attention.cross_attention_layers"), i) {
  174. textDecoderLayer = &TextCrossAttentionDecoderLayer{}
  175. } else {
  176. textDecoderLayer = &TextSelfAttentionDecoderLayer{}
  177. }
  178. decoderLayers = append(decoderLayers, textDecoderLayer)
  179. }
  180. return &TextModel{
  181. Transformer: &TextDecoder{Layers: decoderLayers},
  182. TextModelOptions: &TextModelOptions{
  183. hiddenSize: int(c.Uint("embedding_length")),
  184. numHeads: int(c.Uint("attention.head_count")),
  185. numKVHeads: int(c.Uint("attention.head_count_kv")),
  186. eps: c.Float("attention.layer_norm_rms_epsilon"),
  187. ropeBase: c.Float("rope.freq_base"),
  188. ropeScale: c.Float("rope.freq_scale", 1),
  189. ropeDim: c.Uint("rope.dimension_count"),
  190. crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
  191. },
  192. }
  193. }