model.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. package qwen2
  2. import (
  3. "fmt"
  4. "log/slog"
  5. "math"
  6. "github.com/ollama/ollama/cache"
  7. "github.com/ollama/ollama/ml"
  8. "github.com/ollama/ollama/ml/nn"
  9. "github.com/ollama/ollama/model"
  10. )
  11. type Options struct {
  12. hiddenSize, numHeads, numKVHeads int64
  13. eps, ropeBase, ropeScale float32
  14. ropeDim uint32
  15. }
  16. type Model struct {
  17. model.Base
  18. model.BytePairEncoding
  19. TokenEmbedding *nn.Embedding `gguf:"token_embd"`
  20. Layers []Layer `gguf:"blk"`
  21. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  22. Output *nn.Linear `gguf:"output,alt:token_embd"`
  23. *Options
  24. }
  25. func New(c ml.Config) (model.Model, error) {
  26. m := &Model{
  27. BytePairEncoding: model.BytePairEncoding{
  28. Pretokenizer: c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
  29. Vocabulary: &model.Vocabulary{
  30. Values: c.Strings("tokenizer.ggml.tokens"),
  31. Types: c.Uints("tokenizer.ggml.token_type"),
  32. Merges: c.Strings("tokenizer.ggml.merges"),
  33. BOS: c.Uint("tokenizer.ggml.bos_token_id"),
  34. EOS: c.Uint("tokenizer.ggml.eos_token_id"),
  35. },
  36. },
  37. Layers: make([]Layer, c.Uint("block_count")),
  38. Options: &Options{
  39. hiddenSize: int64(c.Uint("embedding_length")),
  40. numHeads: int64(c.Uint("attention.head_count")),
  41. numKVHeads: int64(c.Uint("attention.head_count_kv")),
  42. eps: c.Float("attention.layer_norm_rms_epsilon"),
  43. ropeBase: c.Float("rope.freq_base"),
  44. ropeScale: c.Float("rope.freq_scale", 1),
  45. ropeDim: c.Uint("rope.dimension_count", 64),
  46. },
  47. }
  48. slog.Debug("model configuration",
  49. "arch", "qwen2",
  50. "vocab_size", len(c.Strings("tokenizer.ggml.tokens")),
  51. "n_merges", len(c.Strings("tokenizer.ggml.merges")),
  52. "n_ctx_train", c.Uint("context_length"),
  53. "n_embd", m.hiddenSize,
  54. "n_layer", len(m.Layers),
  55. "n_head", m.numHeads,
  56. "n_head_kv", m.numKVHeads,
  57. "n_rot", m.ropeDim,
  58. "f_norm_rms_eps", m.eps,
  59. "rope_freq_base", m.ropeBase,
  60. "rope_freq_scale", m.ropeScale,
  61. "bos_token_id", c.Uint("tokenizer.ggml.bos_token_id"),
  62. "eos_token_id", c.Uint("tokenizer.ggml.eos_token_id"),
  63. )
  64. return m, nil
  65. }
  66. type SelfAttention struct {
  67. Query *nn.Linear `gguf:"attn_q"`
  68. Key *nn.Linear `gguf:"attn_k"`
  69. Value *nn.Linear `gguf:"attn_v"`
  70. Output *nn.Linear `gguf:"attn_output"`
  71. }
  72. func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, inputPositions ml.Tensor, layerIdx int, cache cache.Cache, opts *Options) ml.Tensor {
  73. batchSize := hiddenState.Dim(1)
  74. headDim := opts.hiddenSize / opts.numHeads
  75. q := sa.Query.Forward(ctx, hiddenState)
  76. ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.q_proj", layerIdx), q)
  77. q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
  78. q = q.RoPE(ctx, inputPositions, nil, opts.ropeDim, opts.ropeBase, opts.ropeScale)
  79. ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.q_proj.rope", layerIdx), q)
  80. k := sa.Key.Forward(ctx, hiddenState)
  81. k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  82. k = k.RoPE(ctx, inputPositions, nil, opts.ropeDim, opts.ropeBase, opts.ropeScale)
  83. ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.k_proj.rope", layerIdx), k)
  84. v := sa.Value.Forward(ctx, hiddenState)
  85. v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  86. ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.v_proj", layerIdx), v)
  87. k, v, mask := cache.Put(ctx, k, v)
  88. q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  89. k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  90. v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  91. kq := k.Mulmat(ctx, q)
  92. kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
  93. kq = kq.Add(ctx, mask)
  94. kq = kq.Softmax(ctx)
  95. kqv := v.Mulmat(ctx, kq)
  96. kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  97. kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
  98. output := sa.Output.Forward(ctx, kqv)
  99. return output
  100. }
  101. type MLP struct {
  102. Up *nn.Linear `gguf:"ffn_up"`
  103. Down *nn.Linear `gguf:"ffn_down"`
  104. Gate *nn.Linear `gguf:"ffn_gate"`
  105. }
  106. func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
  107. hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
  108. return mlp.Down.Forward(ctx, hiddenState)
  109. }
  110. type Layer struct {
  111. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  112. SelfAttention *SelfAttention
  113. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  114. MLP *MLP
  115. }
  116. func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, layerIdx int, cache cache.Cache, opts *Options) ml.Tensor {
  117. ctx.Trace(fmt.Sprintf("model.layers.%d.input", layerIdx), hiddenState)
  118. residual := hiddenState
  119. hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  120. ctx.Trace(fmt.Sprintf("model.layers.%d.input_layernorm", layerIdx), hiddenState)
  121. hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, layerIdx, cache, opts)
  122. ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.output", layerIdx), hiddenState)
  123. hiddenState = hiddenState.Add(ctx, residual)
  124. residual = hiddenState
  125. ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.residual", layerIdx), hiddenState)
  126. hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  127. ctx.Trace(fmt.Sprintf("model.layers.%d.post_attention_layernorm", layerIdx), hiddenState)
  128. hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
  129. ctx.Trace(fmt.Sprintf("model.layers.%d.mlp", layerIdx), hiddenState)
  130. output := hiddenState.Add(ctx, residual)
  131. ctx.Trace(fmt.Sprintf("model.layers.%d.output", layerIdx), output)
  132. return output
  133. }
  134. func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
  135. slog.Debug("input tokens", "input_ids", opts.Inputs())
  136. inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs()))
  137. if err != nil {
  138. return nil, err
  139. }
  140. positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
  141. if err != nil {
  142. return nil, err
  143. }
  144. hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
  145. ctx.Trace("model.embed_tokens", hiddenState)
  146. for i, layer := range m.Layers {
  147. hiddenState = layer.Forward(ctx, hiddenState, positions, i, opts.Cache.Sub(i), m.Options)
  148. }
  149. hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
  150. ctx.Trace("model.norm", hiddenState)
  151. hiddenState = m.Output.Forward(ctx, hiddenState)
  152. ctx.Trace("model.output", hiddenState)
  153. outputs, err := ctx.FromIntSlice(opts.Outputs(), len(opts.Outputs()))
  154. if err != nil {
  155. return nil, err
  156. }
  157. return hiddenState.Rows(ctx, outputs), nil
  158. }
  159. func init() {
  160. model.Register("qwen2", New)
  161. }