model_text.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. package mllama
  2. import (
  3. "math"
  4. "slices"
  5. "github.com/ollama/ollama/cache"
  6. "github.com/ollama/ollama/ml"
  7. "github.com/ollama/ollama/ml/nn"
  8. )
  9. type TextSelfAttention struct {
  10. Query *nn.Linear `gguf:"attn_q"`
  11. Key *nn.Linear `gguf:"attn_k"`
  12. Value *nn.Linear `gguf:"attn_v"`
  13. Output *nn.Linear `gguf:"attn_output"`
  14. }
  15. func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache cache.Cache, opts *TextModelOptions) ml.Tensor {
  16. batchSize := hiddenState.Dim(1)
  17. headDim := opts.hiddenSize / opts.numHeads
  18. query := sa.Query.Forward(ctx, hiddenState)
  19. query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
  20. query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
  21. key := sa.Key.Forward(ctx, hiddenState)
  22. key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  23. key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
  24. value := sa.Value.Forward(ctx, hiddenState)
  25. value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  26. cache.Put(ctx, key, value)
  27. key, value, mask := cache.Get(ctx)
  28. query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  29. key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  30. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  31. scores := key.Mulmat(ctx, query)
  32. scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
  33. scores = scores.Add(ctx, mask)
  34. scores = scores.Softmax(ctx)
  35. attention := value.Mulmat(ctx, scores)
  36. attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  37. attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
  38. return sa.Output.Forward(ctx, attention)
  39. }
  40. type TextMLP struct {
  41. Up *nn.Linear `gguf:"ffn_up"`
  42. Down *nn.Linear `gguf:"ffn_down"`
  43. Gate *nn.Linear `gguf:"ffn_gate"`
  44. }
  45. func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
  46. hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
  47. return mlp.Down.Forward(ctx, hiddenState)
  48. }
  49. type TextSelfAttentionDecoderLayer struct {
  50. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  51. SelfAttention *TextSelfAttention
  52. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  53. MLP *TextMLP
  54. }
  55. func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache cache.Cache, _ *cache.TensorCache, opts *TextModelOptions) ml.Tensor {
  56. residual := hiddenState
  57. hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  58. hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
  59. hiddenState = hiddenState.Add(ctx, residual)
  60. residual = hiddenState
  61. hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  62. hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
  63. return hiddenState.Add(ctx, residual)
  64. }
  65. func (d *TextSelfAttentionDecoderLayer) Run() bool {
  66. return true
  67. }
  68. type TextCrossAttention struct {
  69. QueryNorm *nn.RMSNorm `gguf:"cross_attn_q_norm"`
  70. Query *nn.Linear `gguf:"cross_attn_q_proj"`
  71. KeyNorm *nn.RMSNorm `gguf:"cross_attn_k_norm"`
  72. Key *nn.Linear `gguf:"cross_attn_k_proj"`
  73. Value *nn.Linear `gguf:"cross_attn_v_proj"`
  74. Output *nn.Linear `gguf:"cross_attn_o_proj"`
  75. }
  76. func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, _ cache.Cache, tCache *cache.TensorCache, opts *TextModelOptions) ml.Tensor {
  77. batchSize := hiddenState.Dim(1)
  78. headDim := opts.hiddenSize / opts.numHeads
  79. query := ca.Query.Forward(ctx, hiddenState)
  80. query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
  81. query = ca.QueryNorm.Forward(ctx, query, opts.eps)
  82. var key, value ml.Tensor
  83. if crossAttentionStates != nil {
  84. numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
  85. key = ca.Key.Forward(ctx, crossAttentionStates)
  86. key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
  87. key = ca.KeyNorm.Forward(ctx, key, opts.eps)
  88. value = ca.Value.Forward(ctx, crossAttentionStates)
  89. value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
  90. tCache.Put(ctx, key, value)
  91. } else {
  92. key, value, _ = tCache.Get(ctx)
  93. }
  94. query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  95. key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  96. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  97. scores := key.Mulmat(ctx, query)
  98. scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
  99. scores = scores.Softmax(ctx)
  100. attention := value.Mulmat(ctx, scores)
  101. attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  102. attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
  103. return ca.Output.Forward(ctx, attention)
  104. }
  105. type TextCrossAttentionDecoderLayer struct {
  106. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  107. CrossAttention *TextCrossAttention
  108. AttentionGate ml.Tensor `gguf:"cross_attn_attn_gate"`
  109. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  110. MLP *TextMLP
  111. MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
  112. run bool
  113. }
  114. func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache cache.Cache, tCache *cache.TensorCache, opts *TextModelOptions) ml.Tensor {
  115. d.run = true
  116. residual := hiddenState
  117. hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  118. hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, tCache, opts)
  119. hiddenState = hiddenState.Mul(ctx, d.AttentionGate.Tanh(ctx))
  120. hiddenState = hiddenState.Add(ctx, residual)
  121. residual = hiddenState
  122. hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  123. hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
  124. hiddenState = hiddenState.Mul(ctx, d.MLPGate.Tanh(ctx))
  125. return hiddenState.Add(ctx, residual)
  126. }
  127. func (d *TextCrossAttentionDecoderLayer) Run() bool {
  128. return d.run
  129. }
  130. type TextDecoderLayer interface {
  131. Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache cache.Cache, tCache *cache.TensorCache, opts *TextModelOptions) ml.Tensor
  132. Run() bool
  133. }
  134. type TextDecoder struct {
  135. Layers []TextDecoderLayer
  136. }
  137. func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache cache.Cache, tCache *cache.TensorCache, opts *TextModelOptions) ml.Tensor {
  138. for i, layer := range d.Layers {
  139. if layer.Run() || crossAttentionStates != nil {
  140. hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache.Sub(i), tCache.Sub(i), opts)
  141. }
  142. }
  143. return hiddenState
  144. }
  145. type TextModelOptions struct {
  146. RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
  147. hiddenSize, numHeads, numKVHeads int64
  148. eps, ropeBase, ropeScale float32
  149. ropeDim uint32
  150. crossAttentionLayers []uint32
  151. }
  152. type TextModel struct {
  153. TokenEmbedding *nn.Embedding `gguf:"token_embd"`
  154. Transformer *TextDecoder `gguf:"blk"`
  155. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  156. Output *nn.Linear `gguf:"output"`
  157. *TextModelOptions
  158. }
  159. func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache cache.Cache, tCache *cache.TensorCache) ml.Tensor {
  160. hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
  161. hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, tCache, m.TextModelOptions)
  162. hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
  163. return m.Output.Forward(ctx, hiddenState)
  164. }
  165. func newTextModel(c ml.Config) *TextModel {
  166. var decoderLayers []TextDecoderLayer
  167. for i := range c.Uint("block_count") {
  168. var textDecoderLayer TextDecoderLayer
  169. if slices.Contains(c.Uints("attention.cross_attention_layers"), i) {
  170. textDecoderLayer = &TextCrossAttentionDecoderLayer{}
  171. } else {
  172. textDecoderLayer = &TextSelfAttentionDecoderLayer{}
  173. }
  174. decoderLayers = append(decoderLayers, textDecoderLayer)
  175. }
  176. return &TextModel{
  177. Transformer: &TextDecoder{Layers: decoderLayers},
  178. TextModelOptions: &TextModelOptions{
  179. hiddenSize: int64(c.Uint("embedding_length")),
  180. numHeads: int64(c.Uint("attention.head_count")),
  181. numKVHeads: int64(c.Uint("attention.head_count_kv")),
  182. eps: c.Float("attention.layer_norm_rms_epsilon"),
  183. ropeBase: c.Float("rope.freq_base"),
  184. ropeScale: c.Float("rope.freq_scale", 1),
  185. ropeDim: c.Uint("rope.dimension_count"),
  186. crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
  187. },
  188. }
  189. }