convert.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. package convert
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/fs"
  8. "log/slog"
  9. "github.com/ollama/ollama/llm"
  10. )
  11. type Parameters struct {
  12. Architectures []string `json:"architectures"`
  13. VocabSize uint32 `json:"vocab_size"`
  14. }
  15. func (Parameters) KV(t *Tokenizer) llm.KV {
  16. kv := llm.KV{
  17. "general.file_type": uint32(1),
  18. "general.quantization_version": uint32(2),
  19. "tokenizer.ggml.pre": t.Pre,
  20. "tokenizer.ggml.model": t.Vocabulary.Model,
  21. "tokenizer.ggml.tokens": t.Vocabulary.Tokens,
  22. "tokenizer.ggml.scores": t.Vocabulary.Scores,
  23. "tokenizer.ggml.token_type": t.Vocabulary.Types,
  24. }
  25. if len(t.Merges) > 0 {
  26. kv["tokenizer.ggml.merges"] = t.Merges
  27. }
  28. if t.Template != "" {
  29. kv["tokenizer.chat_template"] = t.Template
  30. }
  31. for _, sv := range t.SpecialVocabulary {
  32. kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
  33. kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
  34. }
  35. return kv
  36. }
  37. func (Parameters) specialTokenTypes() []string {
  38. return []string{
  39. "bos", "eos", "unk", "sep", "pad", "cls", "mask",
  40. }
  41. }
  42. func (Parameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
  43. return llm.WriteGGUF(ws, kv, ts)
  44. }
  45. type Converter interface {
  46. // KV maps parameters to LLM key-values
  47. KV(*Tokenizer) llm.KV
  48. // Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
  49. Tensors([]Tensor) []llm.Tensor
  50. // tensorName returns the LLM tensor name for a specific input name
  51. tensorName(string) string
  52. // specialTokenTypes returns any special token types the model uses
  53. specialTokenTypes() []string
  54. writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
  55. }
  56. // Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
  57. // and files it finds in the input path.
  58. // Supported input model formats include safetensors.
  59. // Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
  60. func Convert(fsys fs.FS, ws io.WriteSeeker) error {
  61. bts, err := fs.ReadFile(fsys, "config.json")
  62. if err != nil {
  63. return err
  64. }
  65. var p Parameters
  66. if err := json.Unmarshal(bts, &p); err != nil {
  67. return err
  68. }
  69. if len(p.Architectures) < 1 {
  70. return errors.New("unknown architecture")
  71. }
  72. var conv Converter
  73. switch p.Architectures[0] {
  74. case "LlamaForCausalLM", "MistralForCausalLM":
  75. conv = &llama{}
  76. case "MixtralForCausalLM":
  77. conv = &mixtral{}
  78. case "GemmaForCausalLM":
  79. conv = &gemma{}
  80. case "Phi3ForCausalLM":
  81. conv = &phi3{}
  82. default:
  83. return errors.New("unsupported architecture")
  84. }
  85. if err := json.Unmarshal(bts, conv); err != nil {
  86. return err
  87. }
  88. t, err := parseTokenizer(fsys, conv.specialTokenTypes())
  89. if err != nil {
  90. return err
  91. }
  92. if vocabSize := int(p.VocabSize); vocabSize > len(t.Vocabulary.Tokens) {
  93. slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", p.VocabSize, "actual", len(t.Vocabulary.Tokens))
  94. for i := range vocabSize - len(t.Vocabulary.Tokens) {
  95. t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
  96. t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
  97. t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
  98. }
  99. } else {
  100. slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
  101. }
  102. ts, err := parseTensors(fsys)
  103. if err != nil {
  104. return err
  105. }
  106. return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
  107. }