convert.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. package convert
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/fs"
  8. "log/slog"
  9. "strings"
  10. "github.com/ollama/ollama/llm"
  11. )
  12. type ModelParameters struct {
  13. Architectures []string `json:"architectures"`
  14. VocabSize uint32 `json:"vocab_size"`
  15. }
  16. type AdapterParameters struct {
  17. Alpha uint32 `json:"lora_alpha"`
  18. LoraLayers uint32 `json:"lora_layers"`
  19. LoraParameters struct {
  20. Rank uint32 `json:"rank"`
  21. Alpha float32 `json:"alpha"`
  22. Scale float32 `json:"scale"`
  23. } `json:"lora_parameters"`
  24. }
  25. func (ModelParameters) KV(t *Tokenizer) llm.KV {
  26. kv := llm.KV{
  27. "general.file_type": uint32(1),
  28. "general.quantization_version": uint32(2),
  29. "tokenizer.ggml.pre": t.Pre,
  30. "tokenizer.ggml.model": t.Vocabulary.Model,
  31. "tokenizer.ggml.tokens": t.Vocabulary.Tokens,
  32. "tokenizer.ggml.scores": t.Vocabulary.Scores,
  33. "tokenizer.ggml.token_type": t.Vocabulary.Types,
  34. }
  35. if len(t.Merges) > 0 {
  36. kv["tokenizer.ggml.merges"] = t.Merges
  37. }
  38. if t.Template != "" {
  39. kv["tokenizer.chat_template"] = t.Template
  40. }
  41. for _, sv := range t.SpecialVocabulary {
  42. kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
  43. kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
  44. }
  45. return kv
  46. }
  47. func (p AdapterParameters) KV() llm.KV {
  48. var alpha float32
  49. if p.LoraParameters.Alpha == 0 {
  50. alpha = float32(p.Alpha)
  51. } else {
  52. alpha = p.LoraParameters.Alpha
  53. }
  54. kv := llm.KV{
  55. "adapter.lora.alpha": alpha,
  56. "adapter.type": "lora",
  57. "general.file_type": uint32(1),
  58. "general.type": "adapter",
  59. "general.version": "v0.2",
  60. }
  61. return kv
  62. }
  63. func (ModelParameters) specialTokenTypes() []string {
  64. return []string{
  65. "bos", "eos", "unk", "sep", "pad", "cls", "mask",
  66. }
  67. }
  68. func (ModelParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
  69. return llm.WriteGGUF(ws, kv, ts)
  70. }
  71. func (AdapterParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
  72. return llm.WriteGGUF(ws, kv, ts)
  73. }
  74. type ModelConverter interface {
  75. // KV maps parameters to LLM key-values
  76. KV(*Tokenizer) llm.KV
  77. // Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
  78. Tensors([]Tensor) []llm.Tensor
  79. // Replacements returns a list of string pairs to replace in tensor names.
  80. // See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
  81. Replacements() []string
  82. // specialTokenTypes returns any special token types the model uses
  83. specialTokenTypes() []string
  84. // writeFile writes the model to the provided io.WriteSeeker
  85. writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
  86. }
  87. type moreParser interface {
  88. parseMore(fs.FS) error
  89. }
  90. type AdapterConverter interface {
  91. // KV maps parameters to LLM key-values
  92. KV(llm.KV) llm.KV
  93. // Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
  94. Tensors([]Tensor) []llm.Tensor
  95. // Replacements returns a list of string pairs to replace in tensor names.
  96. // See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
  97. Replacements() []string
  98. writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
  99. }
  100. func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV llm.KV) error {
  101. bts, err := fs.ReadFile(fsys, "adapter_config.json")
  102. if err != nil {
  103. return err
  104. }
  105. var p AdapterParameters
  106. if err := json.Unmarshal(bts, &p); err != nil {
  107. return err
  108. }
  109. arch, ok := baseKV["general.architecture"]
  110. if !ok {
  111. return errors.New("architecture not set for the base model")
  112. }
  113. var conv AdapterConverter
  114. switch arch {
  115. case "llama":
  116. conv = &llamaAdapter{}
  117. case "gemma2":
  118. conv = &gemma2Adapter{}
  119. default:
  120. return errors.New("unsupported architecture")
  121. }
  122. ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
  123. if err != nil {
  124. return err
  125. }
  126. if err := json.Unmarshal(bts, conv); err != nil {
  127. return err
  128. }
  129. return conv.writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
  130. }
  131. // Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
  132. // and files it finds in the input path.
  133. // Supported input model formats include safetensors.
  134. // Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
  135. func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
  136. bts, err := fs.ReadFile(fsys, "config.json")
  137. if err != nil {
  138. return err
  139. }
  140. var p ModelParameters
  141. if err := json.Unmarshal(bts, &p); err != nil {
  142. return err
  143. }
  144. if len(p.Architectures) < 1 {
  145. return errors.New("unknown architecture")
  146. }
  147. var conv ModelConverter
  148. switch p.Architectures[0] {
  149. case "LlamaForCausalLM", "MistralForCausalLM":
  150. conv = &llamaModel{}
  151. case "MixtralForCausalLM":
  152. conv = &mixtralModel{}
  153. case "GemmaForCausalLM":
  154. conv = &gemmaModel{}
  155. case "Gemma2ForCausalLM":
  156. conv = &gemma2Model{}
  157. case "Phi3ForCausalLM":
  158. conv = &phi3Model{}
  159. case "Qwen2ForCausalLM":
  160. conv = &qwen2Model{}
  161. case "BertModel":
  162. conv = &bertModel{}
  163. case "CohereForCausalLM":
  164. conv = &commandrModel{}
  165. default:
  166. return errors.New("unsupported architecture")
  167. }
  168. if err := json.Unmarshal(bts, conv); err != nil {
  169. return err
  170. }
  171. if t, ok := conv.(moreParser); ok {
  172. if err := t.parseMore(fsys); err != nil {
  173. return err
  174. }
  175. }
  176. t, err := parseTokenizer(fsys, conv.specialTokenTypes())
  177. if err != nil {
  178. return err
  179. }
  180. vocabSize := int(p.VocabSize)
  181. switch {
  182. case vocabSize > len(t.Vocabulary.Tokens):
  183. slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
  184. for i := range vocabSize - len(t.Vocabulary.Tokens) {
  185. t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
  186. t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
  187. t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
  188. }
  189. case vocabSize < len(t.Vocabulary.Tokens):
  190. return fmt.Errorf("vocabulary is larger than expected '%d' instead of '%d'", len(t.Vocabulary.Tokens), vocabSize)
  191. default:
  192. slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
  193. }
  194. ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
  195. if err != nil {
  196. return err
  197. }
  198. return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
  199. }