model.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. package llama
  2. import (
  3. "math"
  4. "github.com/ollama/ollama/kvcache"
  5. "github.com/ollama/ollama/ml"
  6. "github.com/ollama/ollama/ml/nn"
  7. "github.com/ollama/ollama/model"
  8. )
  9. type Options struct {
  10. RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
  11. hiddenSize, numHeads, numKVHeads int
  12. eps, ropeBase, ropeScale float32
  13. ropeDim uint32
  14. }
  15. type Model struct {
  16. model.Base
  17. model.BytePairEncoding
  18. TokenEmbedding *nn.Embedding `gguf:"token_embd"`
  19. Layers []Layer `gguf:"blk"`
  20. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  21. Output *nn.Linear `gguf:"output,alt:token_embd"`
  22. *Options
  23. }
  24. func New(c ml.Config) (model.Model, error) {
  25. m := Model{
  26. BytePairEncoding: model.NewBytePairEncoding(
  27. c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
  28. &model.Vocabulary{
  29. Values: c.Strings("tokenizer.ggml.tokens"),
  30. Types: c.Uints("tokenizer.ggml.token_type"),
  31. Merges: c.Strings("tokenizer.ggml.merges"),
  32. BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
  33. AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
  34. EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
  35. AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
  36. },
  37. ),
  38. Layers: make([]Layer, c.Uint("block_count")),
  39. Options: &Options{
  40. hiddenSize: int(c.Uint("embedding_length")),
  41. numHeads: int(c.Uint("attention.head_count")),
  42. numKVHeads: int(c.Uint("attention.head_count_kv")),
  43. eps: c.Float("attention.layer_norm_rms_epsilon"),
  44. ropeBase: c.Float("rope.freq_base"),
  45. ropeScale: c.Float("rope.freq_scale", 1),
  46. ropeDim: c.Uint("rope.dimension_count"),
  47. },
  48. }
  49. m.Cache = kvcache.NewCausalCache(m.Shift)
  50. return &m, nil
  51. }
  52. type SelfAttention struct {
  53. Query *nn.Linear `gguf:"attn_q"`
  54. Key *nn.Linear `gguf:"attn_k"`
  55. Value *nn.Linear `gguf:"attn_v"`
  56. Output *nn.Linear `gguf:"attn_output"`
  57. }
  58. func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
  59. batchSize := hiddenState.Dim(1)
  60. headDim := opts.hiddenSize / opts.numHeads
  61. q := sa.Query.Forward(ctx, hiddenState)
  62. q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
  63. q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
  64. k := sa.Key.Forward(ctx, hiddenState)
  65. k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  66. k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
  67. v := sa.Value.Forward(ctx, hiddenState)
  68. v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
  69. scaleFactor := 1.0 / math.Sqrt(float64(headDim))
  70. kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
  71. kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
  72. return sa.Output.Forward(ctx, kqv)
  73. }
  74. func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  75. return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
  76. }
  77. type MLP struct {
  78. Up *nn.Linear `gguf:"ffn_up"`
  79. Down *nn.Linear `gguf:"ffn_down"`
  80. Gate *nn.Linear `gguf:"ffn_gate"`
  81. }
  82. func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
  83. hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
  84. return mlp.Down.Forward(ctx, hiddenState)
  85. }
  86. type Layer struct {
  87. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  88. SelfAttention *SelfAttention
  89. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  90. MLP *MLP
  91. }
  92. func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
  93. residual := hiddenState
  94. hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  95. hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
  96. // In the final layer (outputs != nil), optimize by pruning to just the token positions
  97. // we need logits for.
  98. if outputs != nil {
  99. hiddenState = hiddenState.Rows(ctx, outputs)
  100. residual = residual.Rows(ctx, outputs)
  101. }
  102. hiddenState = hiddenState.Add(ctx, residual)
  103. residual = hiddenState
  104. hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  105. hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
  106. return hiddenState.Add(ctx, residual)
  107. }
  108. func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
  109. inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
  110. if err != nil {
  111. return nil, err
  112. }
  113. positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
  114. if err != nil {
  115. return nil, err
  116. }
  117. outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
  118. if err != nil {
  119. return nil, err
  120. }
  121. hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
  122. for i, layer := range m.Layers {
  123. m.Cache.SetLayer(i)
  124. var lastLayerOutputs ml.Tensor
  125. if i == len(m.Layers)-1 {
  126. lastLayerOutputs = outputs
  127. }
  128. hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
  129. }
  130. hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
  131. return m.Output.Forward(ctx, hiddenState), nil
  132. }
  133. func init() {
  134. model.Register("llama", New)
  135. }