model_text.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. package mllama
  2. import (
  3. "math"
  4. "slices"
  5. "github.com/ollama/ollama/cache"
  6. "github.com/ollama/ollama/ml"
  7. "github.com/ollama/ollama/ml/nn"
  8. )
  9. type TextSelfAttention struct {
  10. Query *nn.Linear `gguf:"attn_q"`
  11. Key *nn.Linear `gguf:"attn_k"`
  12. Value *nn.Linear `gguf:"attn_v"`
  13. Output *nn.Linear `gguf:"attn_output"`
  14. }
  15. func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache cache.Cache, opts *TextModelOptions) ml.Tensor {
  16. batchSize := hiddenState.Dim(1)
  17. headDim := opts.hiddenSize / opts.numHeads
  18. query := sa.Query.Forward(ctx, hiddenState)
  19. query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
  20. query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
  21. key := sa.Key.Forward(ctx, hiddenState)
  22. key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  23. key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
  24. value := sa.Value.Forward(ctx, hiddenState)
  25. value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  26. key, value, mask := cache.Put(ctx, key, value)
  27. query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  28. key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  29. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  30. scores := key.Mulmat(ctx, query)
  31. scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
  32. scores = scores.Add(ctx, mask)
  33. scores = scores.Softmax(ctx)
  34. attention := value.Mulmat(ctx, scores)
  35. attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  36. attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
  37. return sa.Output.Forward(ctx, attention)
  38. }
  39. type TextMLP struct {
  40. Up *nn.Linear `gguf:"ffn_up"`
  41. Down *nn.Linear `gguf:"ffn_down"`
  42. Gate *nn.Linear `gguf:"ffn_gate"`
  43. }
  44. func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
  45. hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
  46. return mlp.Down.Forward(ctx, hiddenState)
  47. }
  48. type TextSelfAttentionDecoderLayer struct {
  49. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  50. SelfAttention *TextSelfAttention
  51. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  52. MLP *TextMLP
  53. }
  54. func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache cache.Cache, _ *cache.TensorCache, opts *TextModelOptions) ml.Tensor {
  55. residual := hiddenState
  56. hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  57. hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
  58. hiddenState = hiddenState.Add(ctx, residual)
  59. residual = hiddenState
  60. hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  61. hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
  62. return hiddenState.Add(ctx, residual)
  63. }
  64. func (d *TextSelfAttentionDecoderLayer) Run() bool {
  65. return true
  66. }
  67. type TextCrossAttention struct {
  68. QueryNorm *nn.RMSNorm `gguf:"cross_attn_q_norm"`
  69. Query *nn.Linear `gguf:"cross_attn_q_proj"`
  70. KeyNorm *nn.RMSNorm `gguf:"cross_attn_k_norm"`
  71. Key *nn.Linear `gguf:"cross_attn_k_proj"`
  72. Value *nn.Linear `gguf:"cross_attn_v_proj"`
  73. Output *nn.Linear `gguf:"cross_attn_o_proj"`
  74. }
  75. func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, _ cache.Cache, tCache *cache.TensorCache, opts *TextModelOptions) ml.Tensor {
  76. batchSize := hiddenState.Dim(1)
  77. headDim := opts.hiddenSize / opts.numHeads
  78. query := ca.Query.Forward(ctx, hiddenState)
  79. query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
  80. query = ca.QueryNorm.Forward(ctx, query, opts.eps)
  81. var key, value ml.Tensor
  82. if crossAttentionStates != nil {
  83. numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
  84. key = ca.Key.Forward(ctx, crossAttentionStates)
  85. key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
  86. key = ca.KeyNorm.Forward(ctx, key, opts.eps)
  87. value = ca.Value.Forward(ctx, crossAttentionStates)
  88. value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
  89. tCache.Put(ctx, key, value)
  90. } else {
  91. key, value, _ = tCache.Get(ctx)
  92. }
  93. query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  94. key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  95. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  96. scores := key.Mulmat(ctx, query)
  97. scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
  98. scores = scores.Softmax(ctx)
  99. attention := value.Mulmat(ctx, scores)
  100. attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  101. attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
  102. return ca.Output.Forward(ctx, attention)
  103. }
  104. type TextCrossAttentionDecoderLayer struct {
  105. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  106. CrossAttention *TextCrossAttention
  107. AttentionGate ml.Tensor `gguf:"cross_attn_attn_gate"`
  108. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  109. MLP *TextMLP
  110. MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
  111. run bool
  112. }
  113. func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache cache.Cache, tCache *cache.TensorCache, opts *TextModelOptions) ml.Tensor {
  114. d.run = true
  115. residual := hiddenState
  116. hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  117. hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, tCache, opts)
  118. hiddenState = hiddenState.Mul(ctx, d.AttentionGate.Tanh(ctx))
  119. hiddenState = hiddenState.Add(ctx, residual)
  120. residual = hiddenState
  121. hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  122. hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
  123. hiddenState = hiddenState.Mul(ctx, d.MLPGate.Tanh(ctx))
  124. return hiddenState.Add(ctx, residual)
  125. }
  126. func (d *TextCrossAttentionDecoderLayer) Run() bool {
  127. return d.run
  128. }
  129. type TextDecoderLayer interface {
  130. Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache cache.Cache, tCache *cache.TensorCache, opts *TextModelOptions) ml.Tensor
  131. Run() bool
  132. }
  133. type TextDecoder struct {
  134. Layers []TextDecoderLayer
  135. }
  136. func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache cache.Cache, tCache *cache.TensorCache, opts *TextModelOptions) ml.Tensor {
  137. for i, layer := range d.Layers {
  138. if layer.Run() || crossAttentionStates != nil {
  139. hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache.Sub(i), tCache.Sub(i), opts)
  140. }
  141. }
  142. return hiddenState
  143. }
  144. type TextModelOptions struct {
  145. RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
  146. hiddenSize, numHeads, numKVHeads int64
  147. eps, ropeBase, ropeScale float32
  148. ropeDim uint32
  149. crossAttentionLayers []uint32
  150. }
  151. type TextModel struct {
  152. TokenEmbedding *nn.Embedding `gguf:"token_embd"`
  153. Transformer *TextDecoder `gguf:"blk"`
  154. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  155. Output *nn.Linear `gguf:"output"`
  156. *TextModelOptions
  157. }
  158. func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache cache.Cache, tCache *cache.TensorCache) ml.Tensor {
  159. hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
  160. hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, tCache, m.TextModelOptions)
  161. hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
  162. return m.Output.Forward(ctx, hiddenState)
  163. }
  164. func newTextModel(c ml.Config) *TextModel {
  165. var decoderLayers []TextDecoderLayer
  166. for i := range c.Uint("block_count") {
  167. var textDecoderLayer TextDecoderLayer
  168. if slices.Contains(c.Uints("attention.cross_attention_layers"), i) {
  169. textDecoderLayer = &TextCrossAttentionDecoderLayer{}
  170. } else {
  171. textDecoderLayer = &TextSelfAttentionDecoderLayer{}
  172. }
  173. decoderLayers = append(decoderLayers, textDecoderLayer)
  174. }
  175. return &TextModel{
  176. Transformer: &TextDecoder{Layers: decoderLayers},
  177. TextModelOptions: &TextModelOptions{
  178. hiddenSize: int64(c.Uint("embedding_length")),
  179. numHeads: int64(c.Uint("attention.head_count")),
  180. numKVHeads: int64(c.Uint("attention.head_count_kv")),
  181. eps: c.Float("attention.layer_norm_rms_epsilon"),
  182. ropeBase: c.Float("rope.freq_base"),
  183. ropeScale: c.Float("rope.freq_scale", 1),
  184. ropeDim: c.Uint("rope.dimension_count"),
  185. crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
  186. },
  187. }
  188. }