model.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. package llama
  2. import (
  3. "math"
  4. "github.com/ollama/ollama/kvcache"
  5. "github.com/ollama/ollama/ml"
  6. "github.com/ollama/ollama/ml/nn"
  7. "github.com/ollama/ollama/model"
  8. )
  9. type Options struct {
  10. RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
  11. hiddenSize, numHeads, numKVHeads int
  12. eps, ropeBase, ropeScale float32
  13. ropeDim uint32
  14. }
  15. type Model struct {
  16. model.Base
  17. model.BytePairEncoding
  18. TokenEmbedding *nn.Embedding `gguf:"token_embd"`
  19. Layers []Layer `gguf:"blk"`
  20. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  21. Output *nn.Linear `gguf:"output,alt:token_embd"`
  22. *Options
  23. }
  24. func New(c ml.Config) (model.Model, error) {
  25. m := Model{
  26. BytePairEncoding: model.NewBytePairEncoding(
  27. c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
  28. &model.Vocabulary{
  29. Values: c.Strings("tokenizer.ggml.tokens"),
  30. Types: c.Uints("tokenizer.ggml.token_type"),
  31. Merges: c.Strings("tokenizer.ggml.merges"),
  32. BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
  33. EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
  34. },
  35. ),
  36. Layers: make([]Layer, c.Uint("block_count")),
  37. Options: &Options{
  38. hiddenSize: int(c.Uint("embedding_length")),
  39. numHeads: int(c.Uint("attention.head_count")),
  40. numKVHeads: int(c.Uint("attention.head_count_kv")),
  41. eps: c.Float("attention.layer_norm_rms_epsilon"),
  42. ropeBase: c.Float("rope.freq_base"),
  43. ropeScale: c.Float("rope.freq_scale", 1),
  44. ropeDim: c.Uint("rope.dimension_count"),
  45. },
  46. }
  47. m.Cache = kvcache.NewCausalCache(m.Shift)
  48. return &m, nil
  49. }
  50. type SelfAttention struct {
  51. Query *nn.Linear `gguf:"attn_q"`
  52. Key *nn.Linear `gguf:"attn_k"`
  53. Value *nn.Linear `gguf:"attn_v"`
  54. Output *nn.Linear `gguf:"attn_output"`
  55. }
  56. func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
  57. batchSize := hiddenState.Dim(1)
  58. headDim := opts.hiddenSize / opts.numHeads
  59. q := sa.Query.Forward(ctx, hiddenState)
  60. q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
  61. q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
  62. k := sa.Key.Forward(ctx, hiddenState)
  63. k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  64. k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
  65. v := sa.Value.Forward(ctx, hiddenState)
  66. v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  67. cache.Put(ctx, k, v)
  68. k, v, mask := cache.Get(ctx)
  69. q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  70. k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  71. v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  72. kq := k.MulmatFullPrec(ctx, q)
  73. kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
  74. kq = kq.Add(ctx, mask)
  75. kq = kq.Softmax(ctx)
  76. kqv := v.Mulmat(ctx, kq)
  77. kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  78. kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
  79. return sa.Output.Forward(ctx, kqv)
  80. }
  81. func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  82. return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
  83. }
  84. type MLP struct {
  85. Up *nn.Linear `gguf:"ffn_up"`
  86. Down *nn.Linear `gguf:"ffn_down"`
  87. Gate *nn.Linear `gguf:"ffn_gate"`
  88. }
  89. func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
  90. hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
  91. return mlp.Down.Forward(ctx, hiddenState)
  92. }
  93. type Layer struct {
  94. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  95. SelfAttention *SelfAttention
  96. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  97. MLP *MLP
  98. }
  99. func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
  100. residual := hiddenState
  101. hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  102. hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
  103. hiddenState = hiddenState.Add(ctx, residual)
  104. residual = hiddenState
  105. hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  106. hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
  107. return hiddenState.Add(ctx, residual)
  108. }
  109. func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
  110. inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
  111. if err != nil {
  112. return nil, err
  113. }
  114. positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
  115. if err != nil {
  116. return nil, err
  117. }
  118. hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
  119. for i, layer := range m.Layers {
  120. m.Cache.SetLayer(i)
  121. hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
  122. }
  123. hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
  124. hiddenState = m.Output.Forward(ctx, hiddenState)
  125. outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
  126. if err != nil {
  127. return nil, err
  128. }
  129. return hiddenState.Rows(ctx, outputs), nil
  130. }
  131. func init() {
  132. model.Register("llama", New)
  133. }