gemma.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. package convert
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "io"
  6. "log/slog"
  7. "os"
  8. "strings"
  9. "github.com/d4l3k/go-bfloat16"
  10. "github.com/pdevine/tensor"
  11. "github.com/pdevine/tensor/native"
  12. "github.com/ollama/ollama/llm"
  13. )
  14. type GemmaModel struct {
  15. ModelData
  16. }
  17. func gemmaLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
  18. slog.Debug(fmt.Sprintf("converting '%s'", r.t.Name))
  19. data := make([]byte, r.end-r.start)
  20. if err := binary.Read(f, r.bo, data); err != nil {
  21. return err
  22. }
  23. tDataF32 := bfloat16.DecodeFloat32(data)
  24. var err error
  25. tDataF32, err = addOnes(tDataF32, int(r.t.Shape[0]))
  26. if err != nil {
  27. return err
  28. }
  29. if err := binary.Write(w, r.bo, tDataF32); err != nil {
  30. return err
  31. }
  32. return nil
  33. }
  34. func addOnes(data []float32, vectorSize int) ([]float32, error) {
  35. n := tensor.New(tensor.WithShape(vectorSize), tensor.WithBacking(data))
  36. ones := tensor.Ones(tensor.Float32, vectorSize)
  37. var err error
  38. n, err = n.Add(ones)
  39. if err != nil {
  40. return []float32{}, err
  41. }
  42. newN, err := native.SelectF32(n, 0)
  43. if err != nil {
  44. return []float32{}, err
  45. }
  46. var fullTensor []float32
  47. for _, v := range newN {
  48. fullTensor = append(fullTensor, v...)
  49. }
  50. return fullTensor, nil
  51. }
  52. func (m *GemmaModel) GetTensors() error {
  53. t, err := m.Format.GetTensors(m.Path, m.Params)
  54. if err != nil {
  55. return err
  56. }
  57. slog.Debug(fmt.Sprintf("Total tensors: %d", len(t)))
  58. m.Tensors = []llm.Tensor{}
  59. for _, l := range t {
  60. if strings.HasSuffix(l.Name, "norm.weight") {
  61. wt := l.WriterTo.(safetensorWriterTo)
  62. wt.handler = gemmaLayerHandler
  63. l.WriterTo = wt
  64. }
  65. m.Tensors = append(m.Tensors, l)
  66. }
  67. return nil
  68. }
  69. func (m *GemmaModel) LoadVocab() error {
  70. v, err := LoadSentencePieceTokens(m.Path, m.Params)
  71. if err != nil {
  72. return err
  73. }
  74. m.Vocab = v
  75. return nil
  76. }
  77. func (m *GemmaModel) WriteGGUF() (string, error) {
  78. kv := llm.KV{
  79. "general.architecture": "gemma",
  80. "general.name": m.Name,
  81. "gemma.context_length": uint32(m.Params.ContextSize),
  82. "gemma.embedding_length": uint32(m.Params.HiddenSize),
  83. "gemma.block_count": uint32(m.Params.HiddenLayers),
  84. "gemma.feed_forward_length": uint32(m.Params.IntermediateSize),
  85. "gemma.attention.head_count": uint32(m.Params.AttentionHeads),
  86. "gemma.attention.head_count_kv": uint32(m.Params.KeyValHeads),
  87. "gemma.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
  88. "gemma.attention.key_length": uint32(m.Params.HeadDimension),
  89. "gemma.attention.value_length": uint32(m.Params.HeadDimension),
  90. "general.file_type": uint32(1),
  91. "tokenizer.ggml.model": "llama",
  92. "tokenizer.ggml.tokens": m.Vocab.Tokens,
  93. "tokenizer.ggml.scores": m.Vocab.Scores,
  94. "tokenizer.ggml.token_type": m.Vocab.Types,
  95. "tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
  96. "tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
  97. "tokenizer.ggml.padding_token_id": uint32(m.Params.PaddingTokenID),
  98. "tokenizer.ggml.unknown_token_id": uint32(3),
  99. "tokenizer.ggml.add_bos_token": true,
  100. "tokenizer.ggml.add_eos_token": false,
  101. }
  102. f, err := os.CreateTemp("", "ollama-gguf")
  103. if err != nil {
  104. return "", err
  105. }
  106. defer f.Close()
  107. mod := llm.NewGGUFV3(m.Params.ByteOrder)
  108. if err := mod.Encode(f, kv, m.Tensors); err != nil {
  109. return "", err
  110. }
  111. return f.Name(), nil
  112. }