model_text.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. package mllama
  2. import (
  3. "math"
  4. "slices"
  5. "github.com/ollama/ollama/ml"
  6. "github.com/ollama/ollama/ml/nn"
  7. "github.com/ollama/ollama/model"
  8. )
  9. type TextSelfAttention struct {
  10. Query *nn.Linear `gguf:"attn_q"`
  11. Key *nn.Linear `gguf:"attn_k"`
  12. Value *nn.Linear `gguf:"attn_v"`
  13. Output *nn.Linear `gguf:"attn_output"`
  14. }
  15. func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
  16. batchSize := hiddenState.Dim(1)
  17. headDim := opts.hiddenSize / opts.numHeads
  18. query := sa.Query.Forward(ctx, hiddenState)
  19. query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
  20. query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
  21. key := sa.Key.Forward(ctx, hiddenState)
  22. key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  23. key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
  24. value := sa.Value.Forward(ctx, hiddenState)
  25. value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  26. key, value = cache.Put(ctx, key, value, cache.Options)
  27. query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  28. key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  29. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  30. scores := key.Mulmat(ctx, query)
  31. scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
  32. if mask != nil {
  33. scores = scores.Add(ctx, mask)
  34. }
  35. scores = scores.Softmax(ctx)
  36. attention := value.Mulmat(ctx, scores)
  37. attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  38. attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
  39. return sa.Output.Forward(ctx, attention)
  40. }
  41. type TextMLP struct {
  42. Up *nn.Linear `gguf:"ffn_up"`
  43. Down *nn.Linear `gguf:"ffn_down"`
  44. Gate *nn.Linear `gguf:"ffn_gate"`
  45. }
  46. func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
  47. hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
  48. return mlp.Down.Forward(ctx, hiddenState)
  49. }
  50. type TextSelfAttentionDecoderLayer struct {
  51. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  52. SelfAttention *TextSelfAttention
  53. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  54. MLP *TextMLP
  55. }
  56. func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
  57. residual := hiddenState
  58. hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  59. hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
  60. hiddenState = hiddenState.Add(ctx, residual)
  61. residual = hiddenState
  62. hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  63. hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
  64. return hiddenState.Add(ctx, residual)
  65. }
  66. type TextCrossAttention struct {
  67. QueryNorm *nn.RMSNorm `gguf:"cross_attn_q_norm"`
  68. Query *nn.Linear `gguf:"cross_attn_q_proj"`
  69. KeyNorm *nn.RMSNorm `gguf:"cross_attn_k_norm"`
  70. Key *nn.Linear `gguf:"cross_attn_k_proj"`
  71. Value *nn.Linear `gguf:"cross_attn_v_proj"`
  72. Output *nn.Linear `gguf:"cross_attn_o_proj"`
  73. }
  74. func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
  75. batchSize := hiddenState.Dim(1)
  76. headDim := opts.hiddenSize / opts.numHeads
  77. numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
  78. query := ca.Query.Forward(ctx, hiddenState)
  79. query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
  80. query = ca.QueryNorm.Forward(ctx, query, opts.eps)
  81. key := ca.Key.Forward(ctx, crossAttentionStates)
  82. key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
  83. key = ca.KeyNorm.Forward(ctx, key, opts.eps)
  84. value := ca.Value.Forward(ctx, crossAttentionStates)
  85. value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
  86. // TODO cache key, value
  87. query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  88. key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  89. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  90. scores := key.Mulmat(ctx, query)
  91. scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
  92. scores = scores.Softmax(ctx)
  93. attention := value.Mulmat(ctx, scores)
  94. attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  95. attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
  96. return ca.Output.Forward(ctx, attention)
  97. }
  98. type TextCrossAttentionDecoderLayer struct {
  99. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  100. CrossAttention *TextCrossAttention
  101. AttentionGate ml.Tensor `gguf:"cross_attn_attn_gate"`
  102. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  103. MLP *TextMLP
  104. MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
  105. }
  106. func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
  107. residual := hiddenState
  108. hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  109. hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, opts)
  110. hiddenState = hiddenState.Mul(ctx, d.AttentionGate.Tanh(ctx))
  111. hiddenState = hiddenState.Add(ctx, residual)
  112. residual = hiddenState
  113. hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  114. hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
  115. hiddenState = hiddenState.Mul(ctx, d.MLPGate.Tanh(ctx))
  116. return hiddenState.Add(ctx, residual)
  117. }
  118. type TextDecoderLayer interface {
  119. Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor
  120. }
  121. type TextDecoder struct {
  122. Layers []TextDecoderLayer
  123. }
  124. func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
  125. for i, layer := range d.Layers {
  126. if !slices.Contains(opts.crossAttentionLayers, uint32(i)) || crossAttentionStates != nil {
  127. hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache.Sub(i), opts)
  128. }
  129. }
  130. return hiddenState
  131. }
  132. type TextModelOptions struct {
  133. RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
  134. hiddenSize, numHeads, numKVHeads int64
  135. eps, ropeBase, ropeScale float32
  136. ropeDim uint32
  137. crossAttentionLayers []uint32
  138. }
  139. type TextModel struct {
  140. TokenEmbedding *nn.Embedding `gguf:"token_embd"`
  141. Transformer *TextDecoder `gguf:"blk"`
  142. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  143. Output *nn.Linear `gguf:"output"`
  144. *TextModelOptions
  145. }
  146. func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache) ml.Tensor {
  147. hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
  148. hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
  149. hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
  150. return m.Output.Forward(ctx, hiddenState)
  151. }
  152. func newTextModel(c ml.Config) *TextModel {
  153. var decoderLayers []TextDecoderLayer
  154. for i := range c.Uint("block_count") {
  155. var textDecoderLayer TextDecoderLayer
  156. if slices.Contains(c.Uints("attention.cross_attention_layers"), i) {
  157. textDecoderLayer = &TextCrossAttentionDecoderLayer{}
  158. } else {
  159. textDecoderLayer = &TextSelfAttentionDecoderLayer{}
  160. }
  161. decoderLayers = append(decoderLayers, textDecoderLayer)
  162. }
  163. return &TextModel{
  164. Transformer: &TextDecoder{Layers: decoderLayers},
  165. TextModelOptions: &TextModelOptions{
  166. hiddenSize: int64(c.Uint("embedding_length")),
  167. numHeads: int64(c.Uint("attention.head_count")),
  168. numKVHeads: int64(c.Uint("attention.head_count_kv")),
  169. eps: c.Float("attention.layer_norm_rms_epsilon"),
  170. ropeBase: c.Float("rope.freq_base"),
  171. ropeScale: c.Float("rope.freq_scale", 1),
  172. ropeDim: c.Uint("rope.dimension_count"),
  173. crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
  174. },
  175. }
  176. }