convert_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. package convert
  2. import (
  3. "bytes"
  4. "crypto/sha256"
  5. "encoding/binary"
  6. "encoding/hex"
  7. "encoding/json"
  8. "flag"
  9. "fmt"
  10. "io"
  11. "io/fs"
  12. "log/slog"
  13. "math"
  14. "os"
  15. "path/filepath"
  16. "slices"
  17. "strings"
  18. "testing"
  19. "golang.org/x/exp/maps"
  20. "github.com/ollama/ollama/fs/ggml"
  21. )
  22. type tensorData struct {
  23. Offsets []int `json:"data_offsets"`
  24. Type string `json:"dtype"`
  25. Shape []int `json:"shape"`
  26. }
  27. func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
  28. t.Helper()
  29. f, err := os.CreateTemp(t.TempDir(), "f16")
  30. if err != nil {
  31. t.Fatal(err)
  32. }
  33. defer f.Close()
  34. if err := ConvertModel(fsys, f); err != nil {
  35. t.Fatal(err)
  36. }
  37. r, err := os.Open(f.Name())
  38. if err != nil {
  39. t.Fatal(err)
  40. }
  41. t.Cleanup(func() { r.Close() })
  42. m, _, err := ggml.Decode(r, math.MaxInt)
  43. if err != nil {
  44. t.Fatal(err)
  45. }
  46. if _, err := r.Seek(0, io.SeekStart); err != nil {
  47. t.Fatal(err)
  48. }
  49. return r, m.KV(), m.Tensors()
  50. }
  51. func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
  52. actual := make(map[string]string)
  53. for k, v := range kv {
  54. if s, ok := v.(json.Marshaler); !ok {
  55. actual[k] = fmt.Sprintf("%v", v)
  56. } else {
  57. bts, err := json.Marshal(s)
  58. if err != nil {
  59. t.Fatal(err)
  60. }
  61. actual[k] = fmt.Sprintf("%x", sha256.Sum256(bts))
  62. }
  63. }
  64. for _, tensor := range tensors.Items() {
  65. sha256sum := sha256.New()
  66. sr := io.NewSectionReader(f, int64(tensors.Offset+tensor.Offset), int64(tensor.Size()))
  67. if _, err := io.Copy(sha256sum, sr); err != nil {
  68. t.Fatal(err)
  69. }
  70. actual[tensor.Name] = hex.EncodeToString(sha256sum.Sum(nil))
  71. }
  72. return actual
  73. }
  74. func TestMain(m *testing.M) {
  75. var level slog.Level
  76. flag.TextVar(&level, "level", slog.LevelInfo, "log level")
  77. flag.Parse()
  78. slog.SetLogLoggerLevel(level)
  79. os.Exit(m.Run())
  80. }
  81. func TestConvertModel(t *testing.T) {
  82. cases := []string{
  83. "Meta-Llama-3-8B-Instruct",
  84. "Meta-Llama-3.1-8B-Instruct",
  85. "Mistral-7B-Instruct-v0.2",
  86. "Mixtral-8x7B-Instruct-v0.1",
  87. "gemma-2b-it",
  88. "gemma-2-2b-it",
  89. // microsoft/Phi-3-mini-128-instruct@d548c233192db00165d842bf8edff054bb3212f8
  90. "Phi-3-mini-128k-instruct",
  91. "all-MiniLM-L6-v2",
  92. "gemma-2-9b-it",
  93. "Qwen2.5-0.5B-Instruct",
  94. "c4ai-command-r-v01",
  95. }
  96. for i := range cases {
  97. tt := cases[i]
  98. t.Run(tt, func(t *testing.T) {
  99. t.Parallel()
  100. p := filepath.Join("testdata", tt)
  101. if testing.Short() {
  102. t.Skip("skipping in short mode")
  103. } else if _, err := os.Stat(p); err != nil {
  104. t.Skipf("%s not found", p)
  105. }
  106. f, kv, tensors := convertFull(t, os.DirFS(p))
  107. actual := generateResultsJSON(t, f, kv, tensors)
  108. expectFile, err := os.Open(filepath.Join("testdata", fmt.Sprintf("%s.json", tt)))
  109. if err != nil {
  110. t.Fatal(err)
  111. }
  112. var expect map[string]string
  113. if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
  114. t.Fatal(err)
  115. }
  116. keys := maps.Keys(expect)
  117. slices.Sort(keys)
  118. for _, k := range keys {
  119. if v, ok := actual[k]; !ok {
  120. t.Errorf("missing %s", k)
  121. } else if v != expect[k] {
  122. t.Errorf("unexpected %s: want %s, got %s", k, expect[k], v)
  123. }
  124. }
  125. })
  126. }
  127. }
  128. func TestConvertInvalidTensorNames(t *testing.T) {
  129. f, err := os.CreateTemp(t.TempDir(), "testmodel")
  130. if err != nil {
  131. t.Fatal(err)
  132. }
  133. defer f.Close()
  134. tempDir := t.TempDir()
  135. td := map[string]*tensorData{}
  136. offset := 4096
  137. td["model.layers.0.self_attn.q_proj.weight"] = &tensorData{
  138. Offsets: []int{0, offset},
  139. Type: "F32",
  140. Shape: []int{4096, 4096},
  141. }
  142. td["blk.0.attn_q.weight"] = &tensorData{
  143. Offsets: []int{offset, offset * 2},
  144. Type: "F32",
  145. Shape: []int{4096, 4096},
  146. }
  147. generateSafetensorTestData(t, tempDir, td)
  148. err = ConvertModel(os.DirFS(tempDir), f)
  149. if err == nil || !strings.HasPrefix(err.Error(), "duplicate tensor name") {
  150. t.Errorf("expected error but didn't get one")
  151. }
  152. }
  153. func TestConvertInvalidDatatype(t *testing.T) {
  154. f, err := os.CreateTemp(t.TempDir(), "testmodel")
  155. if err != nil {
  156. t.Fatal(err)
  157. }
  158. defer f.Close()
  159. tempDir := t.TempDir()
  160. td := map[string]*tensorData{}
  161. offset := 4096 * 14336
  162. td["model.layers.0.mlp.down_proj.weight"] = &tensorData{
  163. Offsets: []int{0, offset},
  164. Type: "I8",
  165. Shape: []int{4096, 14336},
  166. }
  167. td["model.layers.0.mlp.down_proj.weight_format"] = &tensorData{
  168. Offsets: []int{offset, offset},
  169. Type: "U8",
  170. Shape: []int{},
  171. }
  172. generateSafetensorTestData(t, tempDir, td)
  173. err = ConvertModel(os.DirFS(tempDir), f)
  174. if err == nil || err.Error() != "unsupported safetensors model" {
  175. t.Errorf("expected error but didn't get one")
  176. }
  177. }
  178. func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[string]*tensorData) {
  179. data, err := json.Marshal(tensorData)
  180. if err != nil {
  181. t.Fatal(err)
  182. }
  183. var buf bytes.Buffer
  184. l := int64(len(data))
  185. err = binary.Write(&buf, binary.LittleEndian, l)
  186. if err != nil {
  187. t.Fatal(err)
  188. }
  189. _, err = buf.Write(data)
  190. if err != nil {
  191. t.Fatal(err)
  192. }
  193. fdata, err := os.Create(filepath.Join(tempDir, "model-00001-of-00001.safetensors"))
  194. if err != nil {
  195. t.Fatal(err)
  196. }
  197. defer fdata.Close()
  198. _, err = fdata.Write(buf.Bytes())
  199. if err != nil {
  200. t.Fatal(err)
  201. }
  202. configData := `
  203. {
  204. "architectures": [
  205. "LlamaForCausalLM"
  206. ]
  207. }
  208. `
  209. f, err := os.Create(filepath.Join(tempDir, "config.json"))
  210. if err != nil {
  211. t.Fatal(err)
  212. }
  213. defer f.Close()
  214. _, err = f.WriteString(configData)
  215. if err != nil {
  216. t.Fatal(err)
  217. }
  218. tokenizerData := `
  219. {
  220. }
  221. `
  222. f, err = os.Create(filepath.Join(tempDir, "tokenizer.json"))
  223. if err != nil {
  224. t.Fatal(err)
  225. }
  226. defer f.Close()
  227. _, err = f.WriteString(tokenizerData)
  228. if err != nil {
  229. t.Fatal(err)
  230. }
  231. }
  232. func TestConvertAdapter(t *testing.T) {
  233. type AdapterCase struct {
  234. Name string
  235. BaseKV map[string]any
  236. Expected map[string]string
  237. }
  238. cases := []AdapterCase{
  239. {
  240. Name: "discollama",
  241. BaseKV: map[string]any{
  242. "general.architecture": "llama",
  243. "llama.attention.head_count": uint32(32),
  244. "llama.attention.head_count_kv": uint32(8),
  245. },
  246. Expected: map[string]string{
  247. "general.architecture": "llama",
  248. "general.file_type": "1",
  249. "general.parameter_count": "106496",
  250. "general.type": "adapter",
  251. "general.version": "v0.2",
  252. "adapter.lora.alpha": "16",
  253. "adapter.type": "lora",
  254. "llama.attention.head_count": "32",
  255. "llama.attention.head_count_kv": "8",
  256. "blk.31.attn_q.weight.lora_a": "0eb3318b02cd313429bcc7621b539fdbb10240fea190c56c9e5f93fcd37a4e50",
  257. "blk.31.attn_q.weight.lora_b": "0eb3318b02cd313429bcc7621b539fdbb10240fea190c56c9e5f93fcd37a4e50",
  258. "blk.31.attn_v.weight.lora_a": "0eb3318b02cd313429bcc7621b539fdbb10240fea190c56c9e5f93fcd37a4e50",
  259. "blk.31.attn_v.weight.lora_b": "071dcafe89df065d6e1c935ecb8fdf6479b3c202eb912e7da938597673ff5857",
  260. },
  261. },
  262. }
  263. for _, c := range cases {
  264. t.Run(c.Name, func(t *testing.T) {
  265. t.Parallel()
  266. f, err := os.CreateTemp(t.TempDir(), "f16")
  267. if err != nil {
  268. t.Fatal(err)
  269. }
  270. defer f.Close()
  271. tempDir := t.TempDir()
  272. generateLoraTestData(t, tempDir)
  273. if err = ConvertAdapter(os.DirFS(tempDir), f, c.BaseKV); err != nil {
  274. t.Fatal(err)
  275. }
  276. r, err := os.Open(f.Name())
  277. if err != nil {
  278. t.Fatal(err)
  279. }
  280. defer r.Close()
  281. m, _, err := ggml.Decode(r, math.MaxInt)
  282. if err != nil {
  283. t.Fatal(err)
  284. }
  285. if _, err := r.Seek(0, io.SeekStart); err != nil {
  286. t.Fatal(err)
  287. }
  288. actual := generateResultsJSON(t, r, m.KV(), m.Tensors())
  289. keys := maps.Keys(c.Expected)
  290. slices.Sort(keys)
  291. for _, k := range keys {
  292. if v, ok := actual[k]; !ok {
  293. t.Errorf("missing %s", k)
  294. } else if v != c.Expected[k] {
  295. t.Errorf("unexpected %s: want %s, got %s", k, c.Expected[k], v)
  296. }
  297. }
  298. })
  299. }
  300. }
  301. func generateLoraTestData(t *testing.T, tempDir string) {
  302. offset := 4096 * 8 * 4
  303. td := map[string]*tensorData{"__metadata__": nil}
  304. td["model.layers.31.self_attn.q_proj.lora_a"] = &tensorData{
  305. Offsets: []int{0, offset},
  306. Type: "F32",
  307. Shape: []int{4096, 8},
  308. }
  309. td["model.layers.31.self_attn.q_proj.lora_b"] = &tensorData{
  310. Offsets: []int{offset, offset * 2},
  311. Type: "F32",
  312. Shape: []int{8, 4096},
  313. }
  314. td["model.layers.31.self_attn.v_proj.lora_a"] = &tensorData{
  315. Offsets: []int{offset * 2, offset * 3},
  316. Type: "F32",
  317. Shape: []int{4096, 8},
  318. }
  319. td["model.layers.31.self_attn.v_proj.lora_b"] = &tensorData{
  320. Offsets: []int{offset * 3, offset*3 + 8*1024*4},
  321. Type: "F32",
  322. Shape: []int{8, 1024},
  323. }
  324. data, err := json.Marshal(td)
  325. if err != nil {
  326. t.Fatal(err)
  327. }
  328. var buf bytes.Buffer
  329. l := int64(len(data))
  330. err = binary.Write(&buf, binary.LittleEndian, l)
  331. if err != nil {
  332. t.Fatal(err)
  333. }
  334. _, err = buf.Write(data)
  335. if err != nil {
  336. t.Fatal(err)
  337. }
  338. // write some data for the tensors
  339. ones := make([]float32, 4096*8)
  340. for i := range ones {
  341. ones[i] = float32(1)
  342. }
  343. for range 3 {
  344. err = binary.Write(&buf, binary.LittleEndian, ones)
  345. if err != nil {
  346. t.Fatal(err)
  347. }
  348. }
  349. ones = make([]float32, 1024*8)
  350. for i := range ones {
  351. ones[i] = float32(1)
  352. }
  353. err = binary.Write(&buf, binary.LittleEndian, ones)
  354. if err != nil {
  355. t.Fatal(err)
  356. }
  357. fdata, err := os.Create(filepath.Join(tempDir, "adapters.safetensors"))
  358. if err != nil {
  359. t.Fatal(err)
  360. }
  361. defer fdata.Close()
  362. _, err = fdata.Write(buf.Bytes())
  363. if err != nil {
  364. t.Fatal(err)
  365. }
  366. configData := `
  367. {
  368. "adapter_path": "adapters-test",
  369. "batch_size": 8,
  370. "config": "config-tiny.json",
  371. "data": "../discollama-completion",
  372. "grad_checkpoint": null,
  373. "iters": 1000,
  374. "learning_rate": 1e-05,
  375. "lora_layers": 1,
  376. "lora_parameters": {
  377. "rank": 8,
  378. "alpha": 16,
  379. "dropout": 0.0,
  380. "scale": 2.0
  381. },
  382. "lr_schedule": null,
  383. "max_seq_length": 2048,
  384. "model": "/Users/pdevine/git/Meta-Llama-3-8B-Instruct",
  385. "resume_adapter_file": null,
  386. "save_every": 100,
  387. "seed": 0,
  388. "steps_per_eval": 200,
  389. "steps_per_report": 10,
  390. "test": false,
  391. "test_batches": 500,
  392. "train": true,
  393. "use_dora": false,
  394. "val_batches": 25
  395. }
  396. `
  397. f, err := os.Create(filepath.Join(tempDir, "adapter_config.json"))
  398. if err != nil {
  399. t.Fatal(err)
  400. }
  401. defer f.Close()
  402. _, err = f.WriteString(configData)
  403. if err != nil {
  404. t.Fatal(err)
  405. }
  406. }