convert.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. package convert
  2. import (
  3. "cmp"
  4. "encoding/binary"
  5. "encoding/json"
  6. "fmt"
  7. "log/slog"
  8. "os"
  9. "path/filepath"
  10. "slices"
  11. "strings"
  12. "google.golang.org/protobuf/proto"
  13. "github.com/ollama/ollama/convert/sentencepiece"
  14. "github.com/ollama/ollama/llm"
  15. )
  16. type Params struct {
  17. Architectures []string `json:"architectures"`
  18. VocabSize int `json:"vocab_size"`
  19. HiddenSize int `json:"hidden_size"` // n_embd
  20. HiddenLayers int `json:"num_hidden_layers"` // n_layer
  21. ContextSize int `json:"max_position_embeddings"`
  22. IntermediateSize int `json:"intermediate_size"`
  23. AttentionHeads int `json:"num_attention_heads"` // n_head
  24. KeyValHeads int `json:"num_key_value_heads"`
  25. NormEPS float64 `json:"rms_norm_eps"`
  26. BoSTokenID int `json:"bos_token_id"`
  27. EoSTokenID int `json:"eos_token_id"`
  28. HeadDimension int `json:"head_dim"`
  29. PaddingTokenID int `json:"pad_token_id"`
  30. RopeFrequencyBase float64 `json:"rope_theta"`
  31. Experts int `json:"num_local_experts"`
  32. ExpertsUsed int `json:"num_experts_per_tok"`
  33. ByteOrder
  34. }
  35. type ByteOrder interface {
  36. binary.ByteOrder
  37. binary.AppendByteOrder
  38. }
  39. type ModelArch interface {
  40. GetTensors() error
  41. LoadVocab() error
  42. WriteGGUF() (string, error)
  43. }
  44. type ModelFormat interface {
  45. GetLayerName(string) (string, error)
  46. GetTensors(string, *Params) ([]llm.Tensor, error)
  47. GetParams(string) (*Params, error)
  48. GetModelArch(string, string, *Params) (ModelArch, error)
  49. }
  50. type ModelData struct {
  51. Path string
  52. Name string
  53. Params *Params
  54. Vocab *Vocab
  55. Tensors []llm.Tensor
  56. Format ModelFormat
  57. }
  58. func GetModelFormat(dirname string) (ModelFormat, error) {
  59. files, err := filepath.Glob(filepath.Join(dirname, "*"))
  60. if err != nil {
  61. return nil, err
  62. }
  63. for _, fn := range files {
  64. slog.Debug(fmt.Sprintf("file = %s", fn))
  65. if strings.HasSuffix(fn, ".safetensors") {
  66. return &SafetensorFormat{}, nil
  67. } else if strings.HasSuffix(fn, ".bin") {
  68. slog.Debug("model is torch")
  69. return &TorchFormat{}, nil
  70. }
  71. }
  72. return nil, fmt.Errorf("couldn't determine model format")
  73. }
  74. // Details on gguf's tokenizer can be found at:
  75. // https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#tokenizer
  76. type Vocab struct {
  77. Tokens []string
  78. Scores []float32
  79. Types []int32
  80. }
  81. func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
  82. slog.Info(fmt.Sprintf("reading vocab from %s", filepath.Join(dirpath, "tokenizer.model")))
  83. in, err := os.ReadFile(filepath.Join(dirpath, "tokenizer.model"))
  84. if err != nil {
  85. return nil, err
  86. }
  87. // To regenerate sentencepiece from the protobufs use:
  88. // protoc -I=./ --go_out=./ sentencepiece_model.proto
  89. modelProto := &sentencepiece.ModelProto{}
  90. if err := proto.Unmarshal(in, modelProto); err != nil {
  91. return nil, err
  92. }
  93. v := &Vocab{
  94. Tokens: make([]string, 0),
  95. Scores: make([]float32, 0),
  96. Types: make([]int32, 0),
  97. }
  98. pieces := modelProto.GetPieces()
  99. for _, p := range pieces {
  100. v.Tokens = append(v.Tokens, p.GetPiece())
  101. v.Scores = append(v.Scores, p.GetScore())
  102. t := p.GetType()
  103. switch t {
  104. case sentencepiece.ModelProto_SentencePiece_UNKNOWN:
  105. case sentencepiece.ModelProto_SentencePiece_CONTROL:
  106. case sentencepiece.ModelProto_SentencePiece_UNUSED:
  107. case sentencepiece.ModelProto_SentencePiece_BYTE:
  108. default:
  109. t = sentencepiece.ModelProto_SentencePiece_NORMAL
  110. }
  111. v.Types = append(v.Types, int32(t))
  112. }
  113. slog.Info(fmt.Sprintf("vocab size: %d", len(v.Tokens)))
  114. // add any additional tokens
  115. addIn, err := os.ReadFile(filepath.Join(dirpath, "added_tokens.json"))
  116. if os.IsNotExist(err) {
  117. return v, nil
  118. } else if err != nil {
  119. return nil, err
  120. }
  121. slog.Info("reading user defined tokens")
  122. var extraTokenData map[string]int
  123. if err := json.Unmarshal(addIn, &extraTokenData); err != nil {
  124. return nil, err
  125. }
  126. type token struct {
  127. key string
  128. pos int
  129. }
  130. extraTokens := make([]token, 0)
  131. for k, id := range extraTokenData {
  132. extraTokens = append(extraTokens, token{k, id})
  133. }
  134. slices.SortFunc(extraTokens, func(a, b token) int {
  135. return cmp.Compare(a.pos, b.pos)
  136. })
  137. numToks := len(v.Tokens)
  138. for cnt, t := range extraTokens {
  139. // the token id should match the specific index for the total number of tokens
  140. if t.pos != cnt+numToks {
  141. return nil, fmt.Errorf("token ID '%d' for '%s' doesn't match total token size", t.pos, t.key)
  142. }
  143. v.Tokens = append(v.Tokens, t.key)
  144. v.Scores = append(v.Scores, -1000.0)
  145. v.Types = append(v.Types, int32(llm.GGUFTokenUserDefined))
  146. }
  147. slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens)))
  148. if params.VocabSize > len(v.Tokens) {
  149. missingTokens := params.VocabSize - len(v.Tokens)
  150. slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
  151. for cnt := 0; cnt < missingTokens; cnt++ {
  152. v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
  153. v.Scores = append(v.Scores, -1)
  154. v.Types = append(v.Types, int32(llm.GGUFTokenUserDefined))
  155. }
  156. }
  157. return v, nil
  158. }