convert.go 9.1 KB


  1. package convert
  2. import (
  3. "bytes"
  4. "cmp"
  5. "encoding/binary"
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "log/slog"
  10. "os"
  11. "path/filepath"
  12. "regexp"
  13. "slices"
  14. "github.com/mitchellh/mapstructure"
  15. "google.golang.org/protobuf/proto"
  16. "github.com/ollama/ollama/convert/sentencepiece"
  17. "github.com/ollama/ollama/llm"
  18. )
  19. type Params struct {
  20. Architectures []string `json:"architectures"`
  21. VocabSize int `json:"vocab_size"`
  22. HiddenSize int `json:"hidden_size"` // n_embd
  23. HiddenLayers int `json:"num_hidden_layers"` // n_layer
  24. ContextSize int `json:"max_position_embeddings"`
  25. IntermediateSize int `json:"intermediate_size"`
  26. AttentionHeads int `json:"num_attention_heads"` // n_head
  27. KeyValHeads int `json:"num_key_value_heads"`
  28. NormEPS float64 `json:"rms_norm_eps"`
  29. RopeFreqBase float64 `json:"rope_theta"`
  30. BoSTokenID int `json:"bos_token_id"`
  31. EoSTokenID int `json:"eos_token_id"`
  32. }
  33. type MetaData struct {
  34. Type string `mapstructure:"dtype"`
  35. Shape []int `mapstructure:"shape"`
  36. Offsets []int `mapstructure:"data_offsets"`
  37. }
  38. func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
  39. f, err := os.Open(fn)
  40. if err != nil {
  41. return []llm.Tensor{}, 0, err
  42. }
  43. defer f.Close()
  44. var jsonSize uint64
  45. binary.Read(f, binary.LittleEndian, &jsonSize)
  46. buf := make([]byte, jsonSize)
  47. _, err = io.ReadFull(f, buf)
  48. if err != nil {
  49. return []llm.Tensor{}, 0, err
  50. }
  51. d := json.NewDecoder(bytes.NewBuffer(buf))
  52. d.UseNumber()
  53. var parsed map[string]interface{}
  54. if err = d.Decode(&parsed); err != nil {
  55. return []llm.Tensor{}, 0, err
  56. }
  57. var keys []string
  58. for k := range parsed {
  59. keys = append(keys, k)
  60. }
  61. slices.Sort(keys)
  62. slog.Info("converting layers")
  63. var tensors []llm.Tensor
  64. for _, k := range keys {
  65. vals := parsed[k].(map[string]interface{})
  66. var data MetaData
  67. if err = mapstructure.Decode(vals, &data); err != nil {
  68. return []llm.Tensor{}, 0, err
  69. }
  70. var size uint64
  71. var kind uint32
  72. switch len(data.Shape) {
  73. case 0:
  74. // metadata
  75. continue
  76. case 1:
  77. // convert to float32
  78. kind = 0
  79. size = uint64(data.Shape[0] * 4)
  80. case 2:
  81. // convert to float16
  82. kind = 1
  83. size = uint64(data.Shape[0] * data.Shape[1] * 2)
  84. }
  85. ggufName, err := GetTensorName(k)
  86. if err != nil {
  87. slog.Error("%v", err)
  88. return []llm.Tensor{}, 0, err
  89. }
  90. shape := []uint64{0, 0, 0, 0}
  91. for i := range data.Shape {
  92. shape[i] = uint64(data.Shape[i])
  93. }
  94. t := llm.Tensor{
  95. Name: ggufName,
  96. Kind: kind,
  97. Offset: offset,
  98. Shape: shape[:],
  99. FileName: fn,
  100. OffsetPadding: 8 + jsonSize,
  101. FileOffsets: []uint64{uint64(data.Offsets[0]), uint64(data.Offsets[1])},
  102. }
  103. slog.Debug(fmt.Sprintf("%v", t))
  104. tensors = append(tensors, t)
  105. offset += size
  106. }
  107. return tensors, offset, nil
  108. }
  109. func GetSafeTensors(dirpath string) ([]llm.Tensor, error) {
  110. var tensors []llm.Tensor
  111. files, err := filepath.Glob(filepath.Join(dirpath, "/model-*.safetensors"))
  112. if err != nil {
  113. return []llm.Tensor{}, err
  114. }
  115. var offset uint64
  116. for _, f := range files {
  117. var t []llm.Tensor
  118. var err error
  119. t, offset, err = ReadSafeTensors(f, offset)
  120. if err != nil {
  121. slog.Error("%v", err)
  122. return []llm.Tensor{}, err
  123. }
  124. tensors = append(tensors, t...)
  125. }
  126. return tensors, nil
  127. }
  128. func GetParams(dirpath string) (*Params, error) {
  129. f, err := os.Open(filepath.Join(dirpath, "config.json"))
  130. if err != nil {
  131. return nil, err
  132. }
  133. defer f.Close()
  134. var params Params
  135. d := json.NewDecoder(f)
  136. err = d.Decode(&params)
  137. if err != nil {
  138. return nil, err
  139. }
  140. return &params, nil
  141. }
  142. // Details on gguf's tokenizer can be found at:
  143. // https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#tokenizer
  144. type Vocab struct {
  145. Tokens []string
  146. Scores []float32
  147. Types []int32
  148. }
  149. func LoadTokens(dirpath string) (*Vocab, error) {
  150. slog.Info(fmt.Sprintf("reading vocab from %s", filepath.Join(dirpath, "tokenizer.model")))
  151. in, err := os.ReadFile(filepath.Join(dirpath, "tokenizer.model"))
  152. if err != nil {
  153. return nil, err
  154. }
  155. // To regenerate sentencepiece from the protobufs use:
  156. // protoc -I=./ --go_out=./ sentencepiece_model.proto
  157. modelProto := &sentencepiece.ModelProto{}
  158. if err := proto.Unmarshal(in, modelProto); err != nil {
  159. return nil, err
  160. }
  161. v := &Vocab{
  162. Tokens: make([]string, 0),
  163. Scores: make([]float32, 0),
  164. Types: make([]int32, 0),
  165. }
  166. pieces := modelProto.GetPieces()
  167. for _, p := range pieces {
  168. v.Tokens = append(v.Tokens, p.GetPiece())
  169. v.Scores = append(v.Scores, p.GetScore())
  170. t := p.GetType()
  171. v.Types = append(v.Types, int32(t))
  172. }
  173. slog.Info(fmt.Sprintf("vocab size: %d", len(v.Tokens)))
  174. // add any additional tokens
  175. addIn, err := os.ReadFile(filepath.Join(dirpath, "added_tokens.json"))
  176. if os.IsNotExist(err) {
  177. return v, nil
  178. } else if err != nil {
  179. return nil, err
  180. }
  181. slog.Info("reading user defined tokens")
  182. var extraTokenData map[string]int
  183. if err := json.Unmarshal(addIn, &extraTokenData); err != nil {
  184. return nil, err
  185. }
  186. type token struct {
  187. key string
  188. pos int
  189. }
  190. extraTokens := make([]token, 0)
  191. for k, id := range extraTokenData {
  192. extraTokens = append(extraTokens, token{k, id})
  193. }
  194. slices.SortFunc(extraTokens, func(a, b token) int {
  195. return cmp.Compare(a.pos, b.pos)
  196. })
  197. numToks := len(v.Tokens)
  198. for cnt, t := range extraTokens {
  199. // the token id should match the specific index for the total number of tokens
  200. if t.pos != cnt+numToks {
  201. return nil, fmt.Errorf("token ID '%d' for '%s' doesn't match total token size", t.pos, t.key)
  202. }
  203. v.Tokens = append(v.Tokens, t.key)
  204. v.Scores = append(v.Scores, -1000.0)
  205. v.Types = append(v.Types, int32(llm.GGUFTokenUserDefined))
  206. }
  207. slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens)))
  208. return v, nil
  209. }
  210. func GetTensorName(n string) (string, error) {
  211. tMap := map[string]string{
  212. "model.embed_tokens.weight": "token_embd.weight",
  213. "model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
  214. "model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
  215. "model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
  216. "model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
  217. "model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
  218. "model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
  219. "model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
  220. "model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
  221. "model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
  222. "lm_head.weight": "output.weight",
  223. "model.norm.weight": "output_norm.weight",
  224. }
  225. v, ok := tMap[n]
  226. if ok {
  227. return v, nil
  228. }
  229. // quick hack to rename the layers to gguf format
  230. for k, v := range tMap {
  231. re := regexp.MustCompile(k)
  232. newName := re.ReplaceAllString(n, v)
  233. if newName != n {
  234. return newName, nil
  235. }
  236. }
  237. return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
  238. }
  239. func WriteGGUF(name string, tensors []llm.Tensor, params *Params, vocab *Vocab) (string, error) {
  240. c := llm.ContainerGGUF{
  241. ByteOrder: binary.LittleEndian,
  242. }
  243. m := llm.NewGGUFModel(&c)
  244. m.Tensors = tensors
  245. m.KV["general.architecture"] = "llama"
  246. m.KV["general.name"] = name
  247. m.KV["llama.context_length"] = uint32(params.ContextSize)
  248. m.KV["llama.embedding_length"] = uint32(params.HiddenSize)
  249. m.KV["llama.block_count"] = uint32(params.HiddenLayers)
  250. m.KV["llama.feed_forward_length"] = uint32(params.IntermediateSize)
  251. m.KV["llama.rope.dimension_count"] = uint32(128)
  252. m.KV["llama.attention.head_count"] = uint32(params.AttentionHeads)
  253. m.KV["llama.attention.head_count_kv"] = uint32(params.KeyValHeads)
  254. m.KV["llama.attention.layer_norm_rms_epsilon"] = float32(params.NormEPS)
  255. m.KV["llama.rope.freq_base"] = float32(params.RopeFreqBase)
  256. m.KV["general.file_type"] = uint32(1)
  257. m.KV["tokenizer.ggml.model"] = "llama"
  258. m.KV["tokenizer.ggml.tokens"] = vocab.Tokens
  259. m.KV["tokenizer.ggml.scores"] = vocab.Scores
  260. m.KV["tokenizer.ggml.token_type"] = vocab.Types
  261. m.KV["tokenizer.ggml.bos_token_id"] = uint32(params.BoSTokenID)
  262. m.KV["tokenizer.ggml.eos_token_id"] = uint32(params.EoSTokenID)
  263. m.KV["tokenizer.ggml.unknown_token_id"] = uint32(0)
  264. m.KV["tokenizer.ggml.add_bos_token"] = true
  265. m.KV["tokenizer.ggml.add_eos_token"] = false
  266. // llamacpp sets the chat template, however we don't need to set it since we pass it in through a layer
  267. // m.KV["tokenizer.chat_template"] = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" // XXX removeme
  268. c.V3.NumTensor = uint64(len(tensors))
  269. c.V3.NumKV = uint64(len(m.KV))
  270. f, err := os.CreateTemp("", "ollama-gguf")
  271. if err != nil {
  272. return "", err
  273. }
  274. defer f.Close()
  275. err = m.Encode(f)
  276. if err != nil {
  277. return "", err
  278. }
  279. return f.Name(), nil
  280. }