convert.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package convert
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/fs"
  8. "log/slog"
  9. "strings"
  10. "github.com/ollama/ollama/llm"
  11. )
  12. type Parameters struct {
  13. Architectures []string `json:"architectures"`
  14. VocabSize uint32 `json:"vocab_size"`
  15. }
  16. func (Parameters) KV(t *Tokenizer) llm.KV {
  17. kv := llm.KV{
  18. "general.file_type": uint32(1),
  19. "general.quantization_version": uint32(2),
  20. "tokenizer.ggml.pre": t.Pre,
  21. "tokenizer.ggml.model": t.Vocabulary.Model,
  22. "tokenizer.ggml.tokens": t.Vocabulary.Tokens,
  23. "tokenizer.ggml.scores": t.Vocabulary.Scores,
  24. "tokenizer.ggml.token_type": t.Vocabulary.Types,
  25. }
  26. if len(t.Merges) > 0 {
  27. kv["tokenizer.ggml.merges"] = t.Merges
  28. }
  29. if t.Template != "" {
  30. kv["tokenizer.chat_template"] = t.Template
  31. }
  32. for _, sv := range t.SpecialVocabulary {
  33. kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
  34. kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
  35. }
  36. return kv
  37. }
  38. func (Parameters) specialTokenTypes() []string {
  39. return []string{
  40. "bos", "eos", "unk", "sep", "pad", "cls", "mask",
  41. }
  42. }
  43. func (Parameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
  44. return llm.WriteGGUF(ws, kv, ts)
  45. }
  46. type Converter interface {
  47. // KV maps parameters to LLM key-values
  48. KV(*Tokenizer) llm.KV
  49. // Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
  50. Tensors([]Tensor) []llm.Tensor
  51. // Replacements returns a list of string pairs to replace in tensor names.
  52. // See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
  53. Replacements() []string
  54. // specialTokenTypes returns any special token types the model uses
  55. specialTokenTypes() []string
  56. // writeFile writes the model to the provided io.WriteSeeker
  57. writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
  58. }
  59. type moreParser interface {
  60. parseMore(fs.FS) error
  61. }
  62. // Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
  63. // and files it finds in the input path.
  64. // Supported input model formats include safetensors.
  65. // Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
  66. func Convert(fsys fs.FS, ws io.WriteSeeker) error {
  67. bts, err := fs.ReadFile(fsys, "config.json")
  68. if err != nil {
  69. return err
  70. }
  71. var p Parameters
  72. if err := json.Unmarshal(bts, &p); err != nil {
  73. return err
  74. }
  75. if len(p.Architectures) < 1 {
  76. return errors.New("unknown architecture")
  77. }
  78. var conv Converter
  79. switch p.Architectures[0] {
  80. case "LlamaForCausalLM", "MistralForCausalLM":
  81. conv = &llama{}
  82. case "MixtralForCausalLM":
  83. conv = &mixtral{}
  84. case "GemmaForCausalLM":
  85. conv = &gemma{}
  86. case "Gemma2ForCausalLM":
  87. conv = &gemma2{}
  88. case "Phi3ForCausalLM":
  89. conv = &phi3{}
  90. case "BertModel":
  91. conv = &bert{}
  92. default:
  93. return errors.New("unsupported architecture")
  94. }
  95. if err := json.Unmarshal(bts, conv); err != nil {
  96. return err
  97. }
  98. if t, ok := conv.(moreParser); ok {
  99. if err := t.parseMore(fsys); err != nil {
  100. return err
  101. }
  102. }
  103. t, err := parseTokenizer(fsys, conv.specialTokenTypes())
  104. if err != nil {
  105. return err
  106. }
  107. if vocabSize := int(p.VocabSize); vocabSize > len(t.Vocabulary.Tokens) {
  108. slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", p.VocabSize, "actual", len(t.Vocabulary.Tokens))
  109. for i := range vocabSize - len(t.Vocabulary.Tokens) {
  110. t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
  111. t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
  112. t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
  113. }
  114. } else {
  115. slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
  116. }
  117. ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
  118. if err != nil {
  119. return err
  120. }
  121. return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
  122. }