convert.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. package convert
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/fs"
  8. "log/slog"
  9. "strings"
  10. "github.com/ollama/ollama/api"
  11. "github.com/ollama/ollama/llm"
  12. )
  13. type ModelParameters struct {
  14. Architectures []string `json:"architectures"`
  15. VocabSize uint32 `json:"vocab_size"`
  16. }
  17. type AdapterParameters struct {
  18. Alpha uint32 `json:"lora_alpha"`
  19. LoraLayers uint32 `json:"lora_layers"`
  20. LoraParameters struct {
  21. Rank uint32 `json:"rank"`
  22. Alpha float32 `json:"alpha"`
  23. Scale float32 `json:"scale"`
  24. } `json:"lora_parameters"`
  25. }
  26. func (ModelParameters) KV(t *Tokenizer) llm.KV {
  27. kv := llm.KV{
  28. "general.file_type": uint32(1),
  29. "general.quantization_version": uint32(2),
  30. "tokenizer.ggml.pre": t.Pre,
  31. "tokenizer.ggml.model": t.Vocabulary.Model,
  32. "tokenizer.ggml.tokens": t.Vocabulary.Tokens,
  33. "tokenizer.ggml.scores": t.Vocabulary.Scores,
  34. "tokenizer.ggml.token_type": t.Vocabulary.Types,
  35. }
  36. if len(t.Merges) > 0 {
  37. kv["tokenizer.ggml.merges"] = t.Merges
  38. }
  39. if t.Template != "" {
  40. kv["tokenizer.chat_template"] = t.Template
  41. }
  42. for _, sv := range t.SpecialVocabulary {
  43. kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
  44. kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
  45. }
  46. return kv
  47. }
  48. func (p AdapterParameters) KV() llm.KV {
  49. var alpha float32
  50. if p.LoraParameters.Alpha == 0 {
  51. alpha = float32(p.Alpha)
  52. } else {
  53. alpha = p.LoraParameters.Alpha
  54. }
  55. kv := llm.KV{
  56. "adapter.lora.alpha": alpha,
  57. "adapter.type": "lora",
  58. "general.file_type": uint32(1),
  59. "general.type": "adapter",
  60. "general.version": "v0.2",
  61. }
  62. return kv
  63. }
  64. func (ModelParameters) specialTokenTypes() []string {
  65. return []string{
  66. "bos", "eos", "unk", "sep", "pad", "cls", "mask",
  67. }
  68. }
  69. func (ModelParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor, fn func(api.ProgressResponse)) error {
  70. return llm.WriteGGUF(ws, kv, ts, fn)
  71. }
  72. func (AdapterParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor, fn func(api.ProgressResponse)) error {
  73. return llm.WriteGGUF(ws, kv, ts, fn)
  74. }
  75. type ModelConverter interface {
  76. // KV maps parameters to LLM key-values
  77. KV(*Tokenizer) llm.KV
  78. // Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
  79. Tensors([]Tensor) []llm.Tensor
  80. // Replacements returns a list of string pairs to replace in tensor names.
  81. // See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
  82. Replacements() []string
  83. // specialTokenTypes returns any special token types the model uses
  84. specialTokenTypes() []string
  85. // writeFile writes the model to the provided io.WriteSeeker
  86. writeFile(io.WriteSeeker, llm.KV, []llm.Tensor, func(api.ProgressResponse)) error
  87. }
  88. type moreParser interface {
  89. parseMore(fs.FS) error
  90. }
  91. type AdapterConverter interface {
  92. // KV maps parameters to LLM key-values
  93. KV(llm.KV) llm.KV
  94. // Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
  95. Tensors([]Tensor) []llm.Tensor
  96. // Replacements returns a list of string pairs to replace in tensor names.
  97. // See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
  98. Replacements() []string
  99. writeFile(io.WriteSeeker, llm.KV, []llm.Tensor, func(api.ProgressResponse)) error
  100. }
  101. func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV llm.KV, fn func(api.ProgressResponse)) error {
  102. bts, err := fs.ReadFile(fsys, "adapter_config.json")
  103. if err != nil {
  104. return err
  105. }
  106. var p AdapterParameters
  107. if err := json.Unmarshal(bts, &p); err != nil {
  108. return err
  109. }
  110. arch, ok := baseKV["general.architecture"]
  111. if !ok {
  112. return errors.New("architecture not set for the base model")
  113. }
  114. var conv AdapterConverter
  115. switch arch {
  116. case "llama":
  117. conv = &llamaAdapter{}
  118. case "gemma2":
  119. conv = &gemma2Adapter{}
  120. default:
  121. return errors.New("unsupported architecture")
  122. }
  123. ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
  124. if err != nil {
  125. return err
  126. }
  127. if err := json.Unmarshal(bts, conv); err != nil {
  128. return err
  129. }
  130. fn(api.ProgressResponse{
  131. Status: fmt.Sprintf("converting adapter 0%%"),
  132. })
  133. return conv.writeFile(ws, conv.KV(baseKV), conv.Tensors(ts), fn)
  134. }
  135. // Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
  136. // and files it finds in the input path.
  137. // Supported input model formats include safetensors.
  138. // Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
  139. func ConvertModel(fsys fs.FS, ws io.WriteSeeker, fn func(api.ProgressResponse)) error {
  140. bts, err := fs.ReadFile(fsys, "config.json")
  141. if err != nil {
  142. return err
  143. }
  144. var p ModelParameters
  145. if err := json.Unmarshal(bts, &p); err != nil {
  146. return err
  147. }
  148. if len(p.Architectures) < 1 {
  149. return errors.New("unknown architecture")
  150. }
  151. var conv ModelConverter
  152. switch p.Architectures[0] {
  153. case "LlamaForCausalLM", "MistralForCausalLM":
  154. conv = &llamaModel{}
  155. case "MixtralForCausalLM":
  156. conv = &mixtralModel{}
  157. case "GemmaForCausalLM":
  158. conv = &gemmaModel{}
  159. case "Gemma2ForCausalLM":
  160. conv = &gemma2Model{}
  161. case "Phi3ForCausalLM":
  162. conv = &phi3Model{}
  163. case "BertModel":
  164. conv = &bertModel{}
  165. default:
  166. return errors.New("unsupported architecture")
  167. }
  168. if err := json.Unmarshal(bts, conv); err != nil {
  169. return err
  170. }
  171. if t, ok := conv.(moreParser); ok {
  172. if err := t.parseMore(fsys); err != nil {
  173. return err
  174. }
  175. }
  176. t, err := parseTokenizer(fsys, conv.specialTokenTypes())
  177. if err != nil {
  178. return err
  179. }
  180. if vocabSize := int(p.VocabSize); vocabSize > len(t.Vocabulary.Tokens) {
  181. slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", p.VocabSize, "actual", len(t.Vocabulary.Tokens))
  182. for i := range vocabSize - len(t.Vocabulary.Tokens) {
  183. t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
  184. t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
  185. t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
  186. }
  187. } else {
  188. slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
  189. }
  190. ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
  191. if err != nil {
  192. return err
  193. }
  194. fn(api.ProgressResponse{
  195. Status: fmt.Sprintf("converting model 0%%"),
  196. })
  197. return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts), fn)
  198. }