model.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. package gemma2
  2. import (
  3. "math"
  4. "github.com/ollama/ollama/kvcache"
  5. "github.com/ollama/ollama/ml"
  6. "github.com/ollama/ollama/ml/nn"
  7. "github.com/ollama/ollama/model"
  8. "github.com/ollama/ollama/model/input"
  9. )
  10. type Options struct {
  11. hiddenSize, numHeads, numKVHeads int
  12. attnKeyLen, attnValLen int
  13. eps, ropeBase, ropeScale float32
  14. attnLogitSoftcap float32
  15. finalLogitSoftcap float32
  16. largeModelScaling bool
  17. }
  18. type Model struct {
  19. model.Base
  20. model.SentencePieceModel
  21. TokenEmbedding *nn.Embedding `gguf:"token_embd"`
  22. Layers []Layer `gguf:"blk"`
  23. OutputNorm *nn.RMSNorm `gguf:"output_norm"`
  24. Output *nn.Linear `gguf:"output,alt:token_embd"` // just set to token_embd?
  25. *Options
  26. }
  27. const (
  28. gemma27BLayerCount = 46
  29. )
  30. func New(c ml.Config) (model.Model, error) {
  31. m := Model{
  32. SentencePieceModel: model.NewSentencePieceModel(
  33. c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
  34. &model.Vocabulary{
  35. Values: c.Strings("tokenizer.ggml.tokens"),
  36. Scores: c.Floats("tokenizer.ggml.scores"),
  37. Types: c.Uints("tokenizer.ggml.token_type"),
  38. BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
  39. EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
  40. },
  41. ),
  42. Layers: make([]Layer, c.Uint("block_count")),
  43. Options: &Options{
  44. hiddenSize: int(c.Uint("embedding_length")),
  45. numHeads: int(c.Uint("attention.head_count")),
  46. numKVHeads: int(c.Uint("attention.head_count_kv")),
  47. attnKeyLen: int(c.Uint("attention.key_length")),
  48. attnValLen: int(c.Uint("attention.value_length")),
  49. eps: c.Float("attention.layer_norm_rms_epsilon"),
  50. ropeBase: c.Float("rope.freq_base", 10000.0),
  51. ropeScale: c.Float("rope.freq_scale", 1.0),
  52. attnLogitSoftcap: c.Float("attn_logit_softcapping"),
  53. finalLogitSoftcap: c.Float("final_logit_softcapping"),
  54. },
  55. }
  56. slidingWindowLen := int32(c.Uint("attention.sliding_window"))
  57. m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
  58. return &m, nil
  59. }
  60. type SelfAttention struct {
  61. Query *nn.Linear `gguf:"attn_q"`
  62. Key *nn.Linear `gguf:"attn_k"`
  63. Value *nn.Linear `gguf:"attn_v"`
  64. Output *nn.Linear `gguf:"attn_output"`
  65. }
  66. func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
  67. batchSize := hiddenState.Dim(1)
  68. ropeType := uint32(2)
  69. q := sa.Query.Forward(ctx, hiddenState)
  70. q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
  71. q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
  72. if opts.largeModelScaling {
  73. q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize / opts.numHeads)))
  74. } else {
  75. q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
  76. }
  77. k := sa.Key.Forward(ctx, hiddenState)
  78. k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
  79. k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
  80. v := sa.Value.Forward(ctx, hiddenState)
  81. v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
  82. cache.Put(ctx, k, v)
  83. k, v, mask := cache.Get(ctx)
  84. q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  85. k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  86. v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  87. kq := k.Mulmat(ctx, q)
  88. // logit softcap
  89. kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
  90. kq = kq.Tanh(ctx)
  91. kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
  92. kq = kq.Add(ctx, mask)
  93. kq = kq.Softmax(ctx)
  94. kqv := v.Mulmat(ctx, kq)
  95. kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  96. kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
  97. return sa.Output.Forward(ctx, kqv)
  98. }
  99. func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
  100. return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
  101. }
  102. type MLP struct {
  103. Up *nn.Linear `gguf:"ffn_up"`
  104. Down *nn.Linear `gguf:"ffn_down"`
  105. Gate *nn.Linear `gguf:"ffn_gate"`
  106. }
  107. func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
  108. hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
  109. return mlp.Down.Forward(ctx, hiddenState)
  110. }
  111. type Layer struct {
  112. AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
  113. SelfAttention *SelfAttention
  114. PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
  115. MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
  116. MLP *MLP
  117. PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
  118. }
  119. func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
  120. residual := hiddenState
  121. hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  122. hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
  123. hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
  124. hiddenState = hiddenState.Add(ctx, residual)
  125. residual = hiddenState
  126. hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  127. hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
  128. hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
  129. return hiddenState.Add(ctx, residual)
  130. }
  131. func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
  132. inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
  133. if err != nil {
  134. return nil, err
  135. }
  136. positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
  137. if err != nil {
  138. return nil, err
  139. }
  140. hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
  141. hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
  142. if len(m.Layers) == gemma27BLayerCount {
  143. m.Options.largeModelScaling = true
  144. }
  145. for i, layer := range m.Layers {
  146. cacheType := i % 2
  147. m.Cache.SetLayer(i)
  148. wc := m.Cache.(*kvcache.WrapperCache)
  149. wc.SetLayerType(cacheType)
  150. hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
  151. }
  152. hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
  153. hiddenState = m.Output.Forward(ctx, hiddenState)
  154. // final logit softcap
  155. hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
  156. hiddenState = hiddenState.Tanh(ctx)
  157. hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
  158. outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
  159. if err != nil {
  160. return nil, err
  161. }
  162. return hiddenState.Rows(ctx, outputs), nil
  163. }
  164. func init() {
  165. model.Register("gemma2", New)
  166. }