convert.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. package convert
  2. import (
  3. "cmp"
  4. "encoding/binary"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "log/slog"
  9. "os"
  10. "path/filepath"
  11. "slices"
  12. "strings"
  13. "google.golang.org/protobuf/proto"
  14. "github.com/ollama/ollama/convert/sentencepiece"
  15. "github.com/ollama/ollama/llm"
  16. )
  17. type Params struct {
  18. Architectures []string `json:"architectures"`
  19. VocabSize int `json:"vocab_size"`
  20. HiddenSize int `json:"hidden_size"` // n_embd
  21. HiddenLayers int `json:"num_hidden_layers"` // n_layer
  22. ContextSize int `json:"max_position_embeddings"`
  23. IntermediateSize int `json:"intermediate_size"`
  24. AttentionHeads int `json:"num_attention_heads"` // n_head
  25. KeyValHeads int `json:"num_key_value_heads"`
  26. NormEPS float64 `json:"rms_norm_eps"`
  27. BoSTokenID int `json:"bos_token_id"`
  28. EoSTokenID int `json:"eos_token_id"`
  29. HeadDimension int `json:"head_dim"`
  30. PaddingTokenID int `json:"pad_token_id"`
  31. RopeFrequencyBase float64 `json:"rope_theta"`
  32. Experts int `json:"num_local_experts"`
  33. ExpertsUsed int `json:"num_experts_per_tok"`
  34. ByteOrder
  35. }
  36. type ByteOrder interface {
  37. binary.ByteOrder
  38. binary.AppendByteOrder
  39. }
  40. type ModelArch interface {
  41. GetTensors() error
  42. LoadVocab() error
  43. WriteGGUF(io.WriteSeeker) error
  44. }
  45. type ModelFormat interface {
  46. GetLayerName(string) (string, error)
  47. GetTensors(string, *Params) ([]llm.Tensor, error)
  48. GetParams(string) (*Params, error)
  49. GetModelArch(string, string, *Params) (ModelArch, error)
  50. }
  51. type ModelData struct {
  52. Path string
  53. Name string
  54. Params *Params
  55. Vocab *Vocab
  56. Tensors []llm.Tensor
  57. Format ModelFormat
  58. }
  59. func GetModelFormat(dirname string) (ModelFormat, error) {
  60. files, err := filepath.Glob(filepath.Join(dirname, "*"))
  61. if err != nil {
  62. return nil, err
  63. }
  64. for _, fn := range files {
  65. slog.Debug(fmt.Sprintf("file = %s", fn))
  66. if strings.HasSuffix(fn, ".safetensors") {
  67. return &SafetensorFormat{}, nil
  68. } else if strings.HasSuffix(fn, ".bin") {
  69. slog.Debug("model is torch")
  70. return &TorchFormat{}, nil
  71. }
  72. }
  73. return nil, fmt.Errorf("couldn't determine model format")
  74. }
  75. // Details on gguf's tokenizer can be found at:
  76. // https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#tokenizer
  77. type Vocab struct {
  78. Tokens []string
  79. Scores []float32
  80. Types []int32
  81. }
  82. func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
  83. slog.Info(fmt.Sprintf("reading vocab from %s", filepath.Join(dirpath, "tokenizer.model")))
  84. in, err := os.ReadFile(filepath.Join(dirpath, "tokenizer.model"))
  85. if err != nil {
  86. return nil, err
  87. }
  88. // To regenerate sentencepiece from the protobufs use:
  89. // protoc -I=./ --go_out=./ sentencepiece_model.proto
  90. modelProto := &sentencepiece.ModelProto{}
  91. if err := proto.Unmarshal(in, modelProto); err != nil {
  92. return nil, err
  93. }
  94. v := &Vocab{
  95. Tokens: make([]string, 0),
  96. Scores: make([]float32, 0),
  97. Types: make([]int32, 0),
  98. }
  99. pieces := modelProto.GetPieces()
  100. for _, p := range pieces {
  101. v.Tokens = append(v.Tokens, p.GetPiece())
  102. v.Scores = append(v.Scores, p.GetScore())
  103. t := p.GetType()
  104. switch t {
  105. case sentencepiece.ModelProto_SentencePiece_UNKNOWN:
  106. case sentencepiece.ModelProto_SentencePiece_CONTROL:
  107. case sentencepiece.ModelProto_SentencePiece_UNUSED:
  108. case sentencepiece.ModelProto_SentencePiece_BYTE:
  109. default:
  110. t = sentencepiece.ModelProto_SentencePiece_NORMAL
  111. }
  112. v.Types = append(v.Types, int32(t))
  113. }
  114. slog.Info(fmt.Sprintf("vocab size: %d", len(v.Tokens)))
  115. // add any additional tokens
  116. addIn, err := os.ReadFile(filepath.Join(dirpath, "added_tokens.json"))
  117. if os.IsNotExist(err) {
  118. return v, nil
  119. } else if err != nil {
  120. return nil, err
  121. }
  122. slog.Info("reading user defined tokens")
  123. var extraTokenData map[string]int
  124. if err := json.Unmarshal(addIn, &extraTokenData); err != nil {
  125. return nil, err
  126. }
  127. type token struct {
  128. key string
  129. pos int
  130. }
  131. extraTokens := make([]token, 0)
  132. for k, id := range extraTokenData {
  133. extraTokens = append(extraTokens, token{k, id})
  134. }
  135. slices.SortFunc(extraTokens, func(a, b token) int {
  136. return cmp.Compare(a.pos, b.pos)
  137. })
  138. numToks := len(v.Tokens)
  139. for cnt, t := range extraTokens {
  140. // the token id should match the specific index for the total number of tokens
  141. if t.pos != cnt+numToks {
  142. return nil, fmt.Errorf("token ID '%d' for '%s' doesn't match total token size", t.pos, t.key)
  143. }
  144. v.Tokens = append(v.Tokens, t.key)
  145. v.Scores = append(v.Scores, -1000.0)
  146. v.Types = append(v.Types, int32(llm.GGUFTokenUserDefined))
  147. }
  148. slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens)))
  149. if params.VocabSize > len(v.Tokens) {
  150. missingTokens := params.VocabSize - len(v.Tokens)
  151. slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
  152. for cnt := 0; cnt < missingTokens; cnt++ {
  153. v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
  154. v.Scores = append(v.Scores, -1)
  155. v.Types = append(v.Types, int32(llm.GGUFTokenUserDefined))
  156. }
  157. }
  158. return v, nil
  159. }