model_text.go 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. package mllama
  2. import (
  3. "math"
  4. "slices"
  5. "github.com/ollama/ollama/kvcache"
  6. "github.com/ollama/ollama/ml"
  7. "github.com/ollama/ollama/ml/nn"
  8. )
  9. type TextSelfAttention struct {
  10. Query *nn.Linear `gguf:"attn_q"`
  11. Key *nn.Linear `gguf:"attn_k"`
  12. Value *nn.Linear `gguf:"attn_v"`
  13. Output *nn.Linear `gguf:"attn_output"`
  14. }
  15. func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  16. batchSize := hiddenState.Dim(1)
  17. headDim := opts.hiddenSize / opts.numHeads
  18. rc := ml.RopeConfig{
  19. PositionIDs: positions,
  20. RopeFactors: opts.RopeFactors,
  21. RopeDim: opts.ropeDim,
  22. RopeType: ml.RopeTypeStandard,
  23. OrigCtxLen: opts.ctxLen,
  24. RopeBase: opts.ropeBase,
  25. RopeScale: opts.ropeScale,
  26. }
  27. query := sa.Query.Forward(ctx, hiddenState)
  28. query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
  29. query = query.RoPE(ctx, rc)
  30. key := sa.Key.Forward(ctx, hiddenState)
  31. key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  32. key = key.RoPE(ctx, rc)
  33. value := sa.Value.Forward(ctx, hiddenState)
  34. value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  35. cache.Put(ctx, key, value)
  36. key, value, mask := cache.Get(ctx)
  37. query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  38. key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  39. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  40. scores := key.MulmatFullPrec(ctx, query)
  41. scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
  42. scores = scores.Add(ctx, mask)
  43. scores = scores.Softmax(ctx)
  44. attention := value.Mulmat(ctx, scores)
  45. attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  46. attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
  47. return sa.Output.Forward(ctx, attention)
  48. }
  49. func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  50. // This will only get called for layers in the cache, which are just the self attention layers
  51. return key.RoPE(
  52. ctx,
  53. ml.RopeConfig{
  54. PositionIDs: shift,
  55. RopeFactors: m.RopeFactors,
  56. RopeDim: m.ropeDim,
  57. RopeType: ml.RopeTypeStandard,
  58. OrigCtxLen: m.ctxLen,
  59. RopeBase: m.ropeBase,
  60. RopeScale: m.ropeScale,
  61. },
  62. ), nil
  63. }
  64. type TextMLP struct {
  65. Up *nn.Linear `gguf:"ffn_up"`
  66. Down *nn.Linear `gguf:"ffn_down"`
  67. Gate *nn.Linear `gguf:"ffn_gate"`
  68. }
  69. func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
  70. hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
  71. return mlp.Down.Forward(ctx, hiddenState)
  72. }
  73. type TextSelfAttentionDecoderLayer struct {
  74. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  75. SelfAttention *TextSelfAttention
  76. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  77. MLP *TextMLP
  78. }
  79. func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  80. residual := hiddenState
  81. hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  82. hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
  83. hiddenState = hiddenState.Add(ctx, residual)
  84. residual = hiddenState
  85. hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  86. hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
  87. return hiddenState.Add(ctx, residual)
  88. }
  89. type TextCrossAttention struct {
  90. QueryNorm *nn.RMSNorm `gguf:"cross_attn_q_norm"`
  91. Query *nn.Linear `gguf:"cross_attn_q_proj"`
  92. KeyNorm *nn.RMSNorm `gguf:"cross_attn_k_norm"`
  93. Key *nn.Linear `gguf:"cross_attn_k_proj"`
  94. Value *nn.Linear `gguf:"cross_attn_v_proj"`
  95. Output *nn.Linear `gguf:"cross_attn_o_proj"`
  96. }
  97. func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  98. batchSize := hiddenState.Dim(1)
  99. headDim := opts.hiddenSize / opts.numHeads
  100. query := ca.Query.Forward(ctx, hiddenState)
  101. query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
  102. query = ca.QueryNorm.Forward(ctx, query, opts.eps)
  103. var key, value ml.Tensor
  104. if crossAttentionStates != nil {
  105. numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
  106. key = ca.Key.Forward(ctx, crossAttentionStates)
  107. key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
  108. key = ca.KeyNorm.Forward(ctx, key, opts.eps)
  109. value = ca.Value.Forward(ctx, crossAttentionStates)
  110. value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
  111. cache.Put(ctx, key, value)
  112. } else {
  113. key, value, _ = cache.Get(ctx)
  114. }
  115. query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  116. key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  117. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  118. scores := key.Mulmat(ctx, query)
  119. scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
  120. scores = scores.Softmax(ctx)
  121. attention := value.Mulmat(ctx, scores)
  122. attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  123. attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
  124. return ca.Output.Forward(ctx, attention)
  125. }
  126. type TextCrossAttentionDecoderLayer struct {
  127. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  128. CrossAttention *TextCrossAttention
  129. AttentionGate ml.Tensor `gguf:"cross_attn_attn_gate"`
  130. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  131. MLP *TextMLP
  132. MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
  133. }
  134. func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  135. residual := hiddenState
  136. hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  137. hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, opts)
  138. hiddenState = hiddenState.Mul(ctx, d.AttentionGate.Tanh(ctx))
  139. hiddenState = hiddenState.Add(ctx, residual)
  140. residual = hiddenState
  141. hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  142. hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
  143. hiddenState = hiddenState.Mul(ctx, d.MLPGate.Tanh(ctx))
  144. return hiddenState.Add(ctx, residual)
  145. }
  146. type TextDecoderLayer interface {
  147. Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
  148. }
  149. type TextDecoder struct {
  150. Layers []TextDecoderLayer
  151. }
  152. func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
  153. for i, layer := range d.Layers {
  154. layerType := selfAttentionLayer
  155. if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
  156. layerType = crossAttentionLayer
  157. }
  158. cache.SetLayer(i)
  159. cache.SetLayerType(layerType)
  160. if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
  161. hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
  162. }
  163. }
  164. return hiddenState
  165. }
  166. type TextModelOptions struct {
  167. RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
  168. ctxLen, hiddenSize, numHeads, numKVHeads int
  169. eps, ropeBase, ropeScale float32
  170. ropeDim uint32
  171. crossAttentionLayers []uint32
  172. }
  173. type TextModel struct {
  174. TokenEmbedding *nn.Embedding `gguf:"token_embd"`
  175. Transformer *TextDecoder `gguf:"blk"`
  176. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  177. Output *nn.Linear `gguf:"output"`
  178. *TextModelOptions
  179. }
  180. func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
  181. hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
  182. hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
  183. hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
  184. return m.Output.Forward(ctx, hiddenState)
  185. }
  186. func newTextModel(c ml.Config) *TextModel {
  187. var decoderLayers []TextDecoderLayer
  188. for i := range c.Uint("block_count") {
  189. var textDecoderLayer TextDecoderLayer
  190. if slices.Contains(c.Uints("attention.cross_attention_layers"), i) {
  191. textDecoderLayer = &TextCrossAttentionDecoderLayer{}
  192. } else {
  193. textDecoderLayer = &TextSelfAttentionDecoderLayer{}
  194. }
  195. decoderLayers = append(decoderLayers, textDecoderLayer)
  196. }
  197. return &TextModel{
  198. Transformer: &TextDecoder{Layers: decoderLayers},
  199. TextModelOptions: &TextModelOptions{
  200. hiddenSize: int(c.Uint("embedding_length")),
  201. numHeads: int(c.Uint("attention.head_count")),
  202. numKVHeads: int(c.Uint("attention.head_count_kv")),
  203. eps: c.Float("attention.layer_norm_rms_epsilon"),
  204. ropeBase: c.Float("rope.freq_base"),
  205. ropeScale: c.Float("rope.freq_scale", 1),
  206. ropeDim: c.Uint("rope.dimension_count"),
  207. crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
  208. },
  209. }
  210. }