model_text.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. package mllama
  2. import (
  3. "math"
  4. "slices"
  5. "github.com/ollama/ollama/kvcache"
  6. "github.com/ollama/ollama/ml"
  7. "github.com/ollama/ollama/ml/nn"
  8. )
  9. type TextSelfAttention struct {
  10. Query *nn.Linear `gguf:"attn_q"`
  11. Key *nn.Linear `gguf:"attn_k"`
  12. Value *nn.Linear `gguf:"attn_v"`
  13. Output *nn.Linear `gguf:"attn_output"`
  14. RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
  15. }
  16. func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  17. batchSize := hiddenState.Dim(1)
  18. headDim := opts.hiddenSize / opts.numHeads
  19. ropeType := uint32(0)
  20. query := sa.Query.Forward(ctx, hiddenState)
  21. query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
  22. query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
  23. key := sa.Key.Forward(ctx, hiddenState)
  24. key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  25. key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
  26. value := sa.Value.Forward(ctx, hiddenState)
  27. value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  28. scaleFactor := 1.0 / math.Sqrt(float64(headDim))
  29. attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
  30. attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
  31. return sa.Output.Forward(ctx, attention)
  32. }
  33. func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  34. // This will only get called for layers in the cache, which are just the self attention layers
  35. if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
  36. return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
  37. }
  38. return key, nil
  39. }
  40. type TextMLP struct {
  41. Up *nn.Linear `gguf:"ffn_up"`
  42. Down *nn.Linear `gguf:"ffn_down"`
  43. Gate *nn.Linear `gguf:"ffn_gate"`
  44. }
  45. func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
  46. hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
  47. return mlp.Down.Forward(ctx, hiddenState)
  48. }
  49. type TextSelfAttentionDecoderLayer struct {
  50. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  51. SelfAttention *TextSelfAttention
  52. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  53. MLP *TextMLP
  54. }
  55. func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, outputs, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  56. residual := hiddenState
  57. hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  58. hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
  59. // In the final layer (outputs != nil), optimize by pruning to just the token positions
  60. // we need logits for.
  61. if outputs != nil {
  62. hiddenState = hiddenState.Rows(ctx, outputs)
  63. residual = residual.Rows(ctx, outputs)
  64. }
  65. hiddenState = hiddenState.Add(ctx, residual)
  66. residual = hiddenState
  67. hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  68. hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
  69. return hiddenState.Add(ctx, residual)
  70. }
  71. type TextCrossAttention struct {
  72. QueryNorm *nn.RMSNorm `gguf:"cross_attn_q_norm"`
  73. Query *nn.Linear `gguf:"cross_attn_q_proj"`
  74. KeyNorm *nn.RMSNorm `gguf:"cross_attn_k_norm"`
  75. Key *nn.Linear `gguf:"cross_attn_k_proj"`
  76. Value *nn.Linear `gguf:"cross_attn_v_proj"`
  77. Output *nn.Linear `gguf:"cross_attn_o_proj"`
  78. }
  79. func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  80. batchSize := hiddenState.Dim(1)
  81. headDim := opts.hiddenSize / opts.numHeads
  82. query := ca.Query.Forward(ctx, hiddenState)
  83. query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
  84. query = ca.QueryNorm.Forward(ctx, query, opts.eps)
  85. var key, value ml.Tensor
  86. if crossAttentionStates != nil {
  87. numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
  88. key = ca.Key.Forward(ctx, crossAttentionStates)
  89. key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
  90. key = ca.KeyNorm.Forward(ctx, key, opts.eps)
  91. value = ca.Value.Forward(ctx, crossAttentionStates)
  92. value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
  93. cache.Put(ctx, key, value)
  94. }
  95. key, value, _ = cache.Get(ctx)
  96. scaleFactor := 1.0 / math.Sqrt(float64(headDim))
  97. query = query.Permute(ctx, 0, 2, 1, 3)
  98. key = key.Permute(ctx, 0, 2, 1, 3)
  99. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  100. kq := key.MulmatFullPrec(ctx, query)
  101. kq = kq.Scale(ctx, scaleFactor)
  102. kq = kq.Softmax(ctx)
  103. kqv := value.Mulmat(ctx, kq)
  104. attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  105. attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
  106. return ca.Output.Forward(ctx, attention)
  107. }
  108. type TextCrossAttentionDecoderLayer struct {
  109. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  110. CrossAttention *TextCrossAttention
  111. AttentionGate ml.Tensor `gguf:"cross_attn_attn_gate"`
  112. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  113. MLP *TextMLP
  114. MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
  115. }
  116. func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  117. residual := hiddenState
  118. hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  119. hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, opts)
  120. hiddenState = hiddenState.Mul(ctx, d.AttentionGate.Tanh(ctx))
  121. hiddenState = hiddenState.Add(ctx, residual)
  122. residual = hiddenState
  123. hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  124. hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
  125. hiddenState = hiddenState.Mul(ctx, d.MLPGate.Tanh(ctx))
  126. return hiddenState.Add(ctx, residual)
  127. }
  128. type TextDecoderLayer interface {
  129. Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
  130. }
  131. type TextDecoder struct {
  132. Layers []TextDecoderLayer
  133. }
  134. func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  135. for i, layer := range d.Layers {
  136. layerType := selfAttentionLayer
  137. if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
  138. layerType = crossAttentionLayer
  139. }
  140. cache.SetLayer(i)
  141. cache.SetLayerType(layerType)
  142. if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
  143. var lastLayerOutputs ml.Tensor
  144. if i == len(d.Layers)-1 {
  145. lastLayerOutputs = outputs
  146. }
  147. hiddenState = layer.Forward(ctx, hiddenState, positionIDs, lastLayerOutputs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
  148. }
  149. }
  150. return hiddenState
  151. }
  152. type TextModelOptions struct {
  153. hiddenSize, numHeads, numKVHeads int
  154. eps, ropeBase, ropeScale float32
  155. ropeDim uint32
  156. crossAttentionLayers []uint32
  157. }
  158. type TextModel struct {
  159. TokenEmbedding *nn.Embedding `gguf:"token_embd"`
  160. Transformer *TextDecoder `gguf:"blk"`
  161. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  162. Output *nn.Linear `gguf:"output"`
  163. *TextModelOptions
  164. }
  165. func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
  166. hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
  167. hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
  168. hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
  169. return m.Output.Forward(ctx, hiddenState)
  170. }
  171. func newTextModel(c ml.Config) *TextModel {
  172. var decoderLayers []TextDecoderLayer
  173. for i := range c.Uint("block_count") {
  174. var textDecoderLayer TextDecoderLayer
  175. if slices.Contains(c.Uints("attention.cross_attention_layers"), i) {
  176. textDecoderLayer = &TextCrossAttentionDecoderLayer{}
  177. } else {
  178. textDecoderLayer = &TextSelfAttentionDecoderLayer{}
  179. }
  180. decoderLayers = append(decoderLayers, textDecoderLayer)
  181. }
  182. return &TextModel{
  183. Transformer: &TextDecoder{Layers: decoderLayers},
  184. TextModelOptions: &TextModelOptions{
  185. hiddenSize: int(c.Uint("embedding_length")),
  186. numHeads: int(c.Uint("attention.head_count")),
  187. numKVHeads: int(c.Uint("attention.head_count_kv")),
  188. eps: c.Float("attention.layer_norm_rms_epsilon"),
  189. ropeBase: c.Float("rope.freq_base"),
  190. ropeScale: c.Float("rope.freq_scale", 1),
  191. ropeDim: c.Uint("rope.dimension_count"),
  192. crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
  193. },
  194. }
  195. }