safetensors.go 7.3 KB


  1. package convert
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "log/slog"
  9. "os"
  10. "path/filepath"
  11. "regexp"
  12. "slices"
  13. "github.com/d4l3k/go-bfloat16"
  14. "github.com/mitchellh/mapstructure"
  15. "github.com/x448/float16"
  16. "github.com/ollama/ollama/llm"
  17. )
  18. type safetensorWriterTo struct {
  19. t *llm.Tensor
  20. params *Params
  21. bo ByteOrder
  22. filename string
  23. start, end, padding uint64
  24. handler func(w io.Writer, r safetensorWriterTo, f *os.File) error
  25. }
  26. type tensorMetaData struct {
  27. Type string `mapstructure:"dtype"`
  28. Shape []int `mapstructure:"shape"`
  29. Offsets []int `mapstructure:"data_offsets"`
  30. }
  31. type SafetensorFormat struct{}
  32. func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
  33. slog.Debug("getting tensor data")
  34. var tensors []llm.Tensor
  35. files, err := filepath.Glob(filepath.Join(dirpath, "/model-*.safetensors"))
  36. if err != nil {
  37. return nil, err
  38. }
  39. var offset uint64
  40. for _, f := range files {
  41. var t []llm.Tensor
  42. var err error
  43. t, offset, err = m.readTensors(f, offset, params)
  44. if err != nil {
  45. slog.Error("%v", err)
  46. return nil, err
  47. }
  48. tensors = append(tensors, t...)
  49. }
  50. slog.Debug(fmt.Sprintf("all tensors = %d", len(tensors)))
  51. return tensors, nil
  52. }
  53. func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, uint64, error) {
  54. f, err := os.Open(fn)
  55. if err != nil {
  56. return nil, 0, err
  57. }
  58. defer f.Close()
  59. var jsonSize uint64
  60. if err := binary.Read(f, binary.LittleEndian, &jsonSize); err != nil {
  61. return nil, 0, err
  62. }
  63. buf := make([]byte, jsonSize)
  64. _, err = io.ReadFull(f, buf)
  65. if err != nil {
  66. return nil, 0, err
  67. }
  68. d := json.NewDecoder(bytes.NewBuffer(buf))
  69. d.UseNumber()
  70. var parsed map[string]interface{}
  71. if err = d.Decode(&parsed); err != nil {
  72. return nil, 0, err
  73. }
  74. var keys []string
  75. for k := range parsed {
  76. keys = append(keys, k)
  77. }
  78. slices.Sort(keys)
  79. slog.Info("converting layers")
  80. var tensors []llm.Tensor
  81. for _, k := range keys {
  82. vals := parsed[k].(map[string]interface{})
  83. var data tensorMetaData
  84. if err = mapstructure.Decode(vals, &data); err != nil {
  85. slog.Error("couldn't decode properly")
  86. return nil, 0, err
  87. }
  88. var size uint64
  89. var kind uint32
  90. switch len(data.Shape) {
  91. case 0:
  92. // metadata
  93. continue
  94. case 1:
  95. // convert to float32
  96. kind = 0
  97. size = uint64(data.Shape[0] * 4)
  98. case 2:
  99. // convert to float16
  100. kind = 1
  101. size = uint64(data.Shape[0] * data.Shape[1] * 2)
  102. }
  103. ggufName, err := m.GetLayerName(k)
  104. if err != nil {
  105. slog.Error("%v", err)
  106. return nil, 0, err
  107. }
  108. shape := []uint64{0, 0, 0, 0}
  109. for i := range data.Shape {
  110. shape[i] = uint64(data.Shape[i])
  111. }
  112. t := llm.Tensor{
  113. Name: ggufName,
  114. Kind: kind,
  115. Offset: offset,
  116. Shape: shape[:],
  117. }
  118. t.WriterTo = safetensorWriterTo{
  119. t: &t,
  120. params: params,
  121. bo: params.ByteOrder,
  122. filename: fn,
  123. start: uint64(data.Offsets[0]),
  124. end: uint64(data.Offsets[1]),
  125. padding: 8 + jsonSize,
  126. }
  127. offset += size
  128. tensors = append(tensors, t)
  129. }
  130. slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors)))
  131. slog.Debug(fmt.Sprintf("offset = %d", offset))
  132. return tensors, offset, nil
  133. }
  134. func (m *SafetensorFormat) GetParams(dirpath string) (*Params, error) {
  135. f, err := os.Open(filepath.Join(dirpath, "config.json"))
  136. if err != nil {
  137. return nil, err
  138. }
  139. defer f.Close()
  140. var params Params
  141. d := json.NewDecoder(f)
  142. err = d.Decode(&params)
  143. if err != nil {
  144. return nil, err
  145. }
  146. params.ByteOrder = binary.LittleEndian
  147. return &params, nil
  148. }
  149. func (m *SafetensorFormat) GetLayerName(n string) (string, error) {
  150. directMap := map[string]string{
  151. "model.embed_tokens.weight": "token_embd.weight",
  152. "lm_head.weight": "output.weight",
  153. "model.norm.weight": "output_norm.weight",
  154. }
  155. tMap := map[string]string{
  156. "model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
  157. "model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
  158. "model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
  159. "model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
  160. "model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
  161. "model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
  162. "model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
  163. "model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
  164. "model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
  165. "model.layers.(\\d+).block_sparse_moe.gate.weight": "blk.$1.ffn_gate_inp.weight",
  166. "model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w1.weight": "blk.$1.ffn_gate.$2.weight",
  167. "model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w2.weight": "blk.$1.ffn_down.$2.weight",
  168. "model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w3.weight": "blk.$1.ffn_up.$2.weight",
  169. }
  170. v, ok := directMap[n]
  171. if ok {
  172. return v, nil
  173. }
  174. // quick hack to rename the layers to gguf format
  175. for k, v := range tMap {
  176. re := regexp.MustCompile(k)
  177. newName := re.ReplaceAllString(n, v)
  178. if newName != n {
  179. return newName, nil
  180. }
  181. }
  182. return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
  183. }
  184. func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
  185. f, err := os.Open(r.filename)
  186. if err != nil {
  187. return 0, err
  188. }
  189. defer f.Close()
  190. if _, err = f.Seek(int64(r.padding+r.start), 0); err != nil {
  191. return 0, err
  192. }
  193. // use the handler if one is present
  194. if r.handler != nil {
  195. return 0, r.handler(w, r, f)
  196. }
  197. remaining := r.end - r.start
  198. bufSize := uint64(10240)
  199. var finished bool
  200. for {
  201. data := make([]byte, min(bufSize, remaining))
  202. b, err := io.ReadFull(f, data)
  203. remaining -= uint64(b)
  204. if err == io.EOF || remaining <= 0 {
  205. finished = true
  206. } else if err != nil {
  207. return 0, err
  208. }
  209. // convert bfloat16 -> ieee float32
  210. tDataF32 := bfloat16.DecodeFloat32(data)
  211. switch r.t.Kind {
  212. case 0:
  213. if err := binary.Write(w, r.bo, tDataF32); err != nil {
  214. return 0, err
  215. }
  216. case 1:
  217. // convert float32 -> float16
  218. tempBuf := make([]uint16, len(data)/2)
  219. for cnt, v := range tDataF32 {
  220. tDataF16 := float16.Fromfloat32(v)
  221. tempBuf[cnt] = uint16(tDataF16)
  222. }
  223. if err := binary.Write(w, r.bo, tempBuf); err != nil {
  224. return 0, err
  225. }
  226. }
  227. if finished {
  228. break
  229. }
  230. }
  231. return 0, nil
  232. }
  233. func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) {
  234. switch len(params.Architectures) {
  235. case 0:
  236. return nil, fmt.Errorf("No architecture specified to convert")
  237. case 1:
  238. switch params.Architectures[0] {
  239. case "MistralForCausalLM":
  240. return &MistralModel{
  241. ModelData{
  242. Name: name,
  243. Path: dirPath,
  244. Params: params,
  245. Format: m,
  246. },
  247. }, nil
  248. case "MixtralForCausalLM":
  249. return &MixtralModel{
  250. ModelData{
  251. Name: name,
  252. Path: dirPath,
  253. Params: params,
  254. Format: m,
  255. },
  256. }, nil
  257. case "GemmaForCausalLM":
  258. return &GemmaModel{
  259. ModelData{
  260. Name: name,
  261. Path: dirPath,
  262. Params: params,
  263. Format: m,
  264. },
  265. }, nil
  266. default:
  267. return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
  268. }
  269. }
  270. return nil, fmt.Errorf("Unknown error")
  271. }