convert.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. package convert
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/fs"
  8. "log/slog"
  9. "strings"
  10. "github.com/ollama/ollama/fs/ggml"
  11. )
  12. type ModelParameters struct {
  13. Architectures []string `json:"architectures"`
  14. VocabSize uint32 `json:"vocab_size"`
  15. TextModel TextParameters `json:"text_config"`
  16. }
  17. type TextParameters struct {
  18. VocabSize uint32 `json:"vocab_size"`
  19. }
  20. type AdapterParameters struct {
  21. Alpha uint32 `json:"lora_alpha"`
  22. LoraLayers uint32 `json:"lora_layers"`
  23. LoraParameters struct {
  24. Rank uint32 `json:"rank"`
  25. Alpha float32 `json:"alpha"`
  26. Scale float32 `json:"scale"`
  27. } `json:"lora_parameters"`
  28. }
  29. func (ModelParameters) KV(t *Tokenizer) ggml.KV {
  30. kv := ggml.KV{
  31. "general.file_type": uint32(1),
  32. "general.quantization_version": uint32(2),
  33. "tokenizer.ggml.pre": t.Pre,
  34. "tokenizer.ggml.model": t.Vocabulary.Model,
  35. "tokenizer.ggml.tokens": t.Vocabulary.Tokens,
  36. "tokenizer.ggml.scores": t.Vocabulary.Scores,
  37. "tokenizer.ggml.token_type": t.Vocabulary.Types,
  38. }
  39. if len(t.Merges) > 0 {
  40. kv["tokenizer.ggml.merges"] = t.Merges
  41. }
  42. if t.Template != "" {
  43. kv["tokenizer.chat_template"] = t.Template
  44. }
  45. for _, sv := range t.SpecialVocabulary {
  46. kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
  47. kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
  48. }
  49. return kv
  50. }
  51. func (p AdapterParameters) KV() ggml.KV {
  52. var alpha float32
  53. if p.LoraParameters.Alpha == 0 {
  54. alpha = float32(p.Alpha)
  55. } else {
  56. alpha = p.LoraParameters.Alpha
  57. }
  58. kv := ggml.KV{
  59. "adapter.lora.alpha": alpha,
  60. "adapter.type": "lora",
  61. "general.file_type": uint32(1),
  62. "general.type": "adapter",
  63. "general.version": "v0.2",
  64. }
  65. return kv
  66. }
  67. func (ModelParameters) specialTokenTypes() []string {
  68. return []string{
  69. "bos", "eos", "unk", "sep", "pad", "cls", "mask",
  70. }
  71. }
  72. func (ModelParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
  73. return ggml.WriteGGUF(ws, kv, ts)
  74. }
  75. func (AdapterParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
  76. return ggml.WriteGGUF(ws, kv, ts)
  77. }
  78. type ModelConverter interface {
  79. // KV maps parameters to LLM key-values
  80. KV(*Tokenizer) ggml.KV
  81. // Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
  82. Tensors([]Tensor) []ggml.Tensor
  83. // Replacements returns a list of string pairs to replace in tensor names.
  84. // See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
  85. Replacements() []string
  86. // specialTokenTypes returns any special token types the model uses
  87. specialTokenTypes() []string
  88. // writeFile writes the model to the provided io.WriteSeeker
  89. writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
  90. }
  91. type moreParser interface {
  92. parseMore(fs.FS) error
  93. }
  94. type AdapterConverter interface {
  95. // KV maps parameters to LLM key-values
  96. KV(ggml.KV) ggml.KV
  97. // Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
  98. Tensors([]Tensor) []ggml.Tensor
  99. // Replacements returns a list of string pairs to replace in tensor names.
  100. // See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
  101. Replacements() []string
  102. writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
  103. }
  104. func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
  105. bts, err := fs.ReadFile(fsys, "adapter_config.json")
  106. if err != nil {
  107. return err
  108. }
  109. var p AdapterParameters
  110. if err := json.Unmarshal(bts, &p); err != nil {
  111. return err
  112. }
  113. arch, ok := baseKV["general.architecture"]
  114. if !ok {
  115. return errors.New("architecture not set for the base model")
  116. }
  117. var conv AdapterConverter
  118. switch arch {
  119. case "llama":
  120. conv = &llamaAdapter{}
  121. case "gemma2":
  122. conv = &gemma2Adapter{}
  123. default:
  124. return errors.New("unsupported architecture")
  125. }
  126. ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
  127. if err != nil {
  128. return err
  129. }
  130. if err := json.Unmarshal(bts, conv); err != nil {
  131. return err
  132. }
  133. return conv.writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
  134. }
  135. // Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
  136. // and files it finds in the input path.
  137. // Supported input model formats include safetensors.
  138. // Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
  139. func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
  140. bts, err := fs.ReadFile(fsys, "config.json")
  141. if err != nil {
  142. return err
  143. }
  144. var p ModelParameters
  145. if err := json.Unmarshal(bts, &p); err != nil {
  146. return err
  147. }
  148. if len(p.Architectures) < 1 {
  149. return errors.New("unknown architecture")
  150. }
  151. var conv ModelConverter
  152. switch p.Architectures[0] {
  153. case "LlamaForCausalLM", "MistralForCausalLM":
  154. conv = &llamaModel{}
  155. case "MixtralForCausalLM":
  156. conv = &mixtralModel{}
  157. case "GemmaForCausalLM":
  158. conv = &gemmaModel{}
  159. case "Gemma2ForCausalLM":
  160. conv = &gemma2Model{}
  161. case "Gemma3ForConditionalGeneration":
  162. conv = &gemma3Model{}
  163. case "Phi3ForCausalLM":
  164. conv = &phi3Model{}
  165. case "Qwen2ForCausalLM":
  166. conv = &qwen2Model{}
  167. case "BertModel":
  168. conv = &bertModel{}
  169. case "CohereForCausalLM":
  170. conv = &commandrModel{}
  171. default:
  172. return errors.New("unsupported architecture")
  173. }
  174. if err := json.Unmarshal(bts, conv); err != nil {
  175. return err
  176. }
  177. if t, ok := conv.(moreParser); ok {
  178. if err := t.parseMore(fsys); err != nil {
  179. return err
  180. }
  181. }
  182. t, err := parseTokenizer(fsys, conv.specialTokenTypes())
  183. if err != nil {
  184. return err
  185. }
  186. vocabSize := int(p.VocabSize)
  187. if vocabSize == 0 {
  188. tVocabSize := int(p.TextModel.VocabSize)
  189. vocabSize = tVocabSize
  190. }
  191. switch {
  192. case vocabSize > len(t.Vocabulary.Tokens):
  193. slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
  194. for i := range vocabSize - len(t.Vocabulary.Tokens) {
  195. t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
  196. t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
  197. t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
  198. }
  199. case vocabSize < len(t.Vocabulary.Tokens):
  200. return fmt.Errorf("vocabulary is larger than expected '%d' instead of '%d'", len(t.Vocabulary.Tokens), vocabSize)
  201. default:
  202. slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
  203. }
  204. ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
  205. if err != nil {
  206. return err
  207. }
  208. return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
  209. }