convert_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. package convert
  2. import (
  3. "bytes"
  4. "crypto/sha256"
  5. "encoding/binary"
  6. "encoding/hex"
  7. "encoding/json"
  8. "flag"
  9. "fmt"
  10. "io"
  11. "io/fs"
  12. "log/slog"
  13. "math"
  14. "os"
  15. "path/filepath"
  16. "slices"
  17. "strings"
  18. "testing"
  19. "golang.org/x/exp/maps"
  20. "github.com/ollama/ollama/llm"
  21. )
  22. type tensorData struct {
  23. Offsets []int `json:"data_offsets"`
  24. Type string `json:"dtype"`
  25. Shape []int `json:"shape"`
  26. }
  27. var generate string
  28. func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, *llm.Tensors) {
  29. t.Helper()
  30. f, err := os.CreateTemp(t.TempDir(), "f16")
  31. if err != nil {
  32. t.Fatal(err)
  33. }
  34. defer f.Close()
  35. if err := ConvertModel(fsys, f); err != nil {
  36. t.Fatal(err)
  37. }
  38. r, err := os.Open(f.Name())
  39. if err != nil {
  40. t.Fatal(err)
  41. }
  42. t.Cleanup(func() { r.Close() })
  43. m, _, err := llm.DecodeGGML(r, math.MaxInt)
  44. if err != nil {
  45. t.Fatal(err)
  46. }
  47. if _, err := r.Seek(0, io.SeekStart); err != nil {
  48. t.Fatal(err)
  49. }
  50. return r, m.KV(), m.Tensors()
  51. }
  52. func generateResultsJSON(t *testing.T, f *os.File, kv llm.KV, tensors *llm.Tensors) map[string]string {
  53. actual := make(map[string]string)
  54. for k, v := range kv {
  55. if s, ok := v.(json.Marshaler); !ok {
  56. actual[k] = fmt.Sprintf("%v", v)
  57. } else {
  58. bts, err := json.Marshal(s)
  59. if err != nil {
  60. t.Fatal(err)
  61. }
  62. actual[k] = fmt.Sprintf("%x", sha256.Sum256(bts))
  63. }
  64. }
  65. for _, tensor := range tensors.Items {
  66. sha256sum := sha256.New()
  67. sr := io.NewSectionReader(f, int64(tensors.Offset+tensor.Offset), int64(tensor.Size()))
  68. if _, err := io.Copy(sha256sum, sr); err != nil {
  69. t.Fatal(err)
  70. }
  71. actual[tensor.Name] = hex.EncodeToString(sha256sum.Sum(nil))
  72. }
  73. return actual
  74. }
  75. func TestMain(m *testing.M) {
  76. var level slog.Level
  77. flag.TextVar(&level, "level", slog.LevelInfo, "log level")
  78. flag.StringVar(&generate, "generate", "", "generate model data")
  79. flag.Parse()
  80. slog.SetLogLoggerLevel(level)
  81. os.Exit(m.Run())
  82. }
  83. func TestConvertModel(t *testing.T) {
  84. cases := []string{
  85. "Meta-Llama-3-8B-Instruct",
  86. "Meta-Llama-3.1-8B-Instruct",
  87. "Mistral-7B-Instruct-v0.2",
  88. "Mixtral-8x7B-Instruct-v0.1",
  89. "gemma-2b-it",
  90. "gemma-2-2b-it",
  91. // microsoft/Phi-3-mini-128-instruct@d548c233192db00165d842bf8edff054bb3212f8
  92. "Phi-3-mini-128k-instruct",
  93. "all-MiniLM-L6-v2",
  94. "gemma-2-9b-it",
  95. "Qwen2.5-0.5B-Instruct",
  96. "c4ai-command-r-v01",
  97. "c4ai-command-r7b-12-2024",
  98. }
  99. for i := range cases {
  100. tt := cases[i]
  101. t.Run(tt, func(t *testing.T) {
  102. t.Parallel()
  103. p := filepath.Join("testdata", tt)
  104. if testing.Short() {
  105. t.Skip("skipping in short mode")
  106. } else if _, err := os.Stat(p); err != nil {
  107. t.Skipf("%s not found", p)
  108. }
  109. f, kv, tensors := convertFull(t, os.DirFS(p))
  110. actual := generateResultsJSON(t, f, kv, tensors)
  111. if generate != "" && generate == tt {
  112. outFile := filepath.Join("testdata", fmt.Sprintf("%s.json", tt))
  113. data, err := json.MarshalIndent(actual, "", " ")
  114. if err != nil {
  115. t.Fatal(err)
  116. }
  117. if err := os.WriteFile(outFile, data, 0o644); err != nil {
  118. t.Fatal(err)
  119. }
  120. t.Logf("Generated expected results for %s", tt)
  121. return
  122. }
  123. expectFile, err := os.Open(filepath.Join("testdata", fmt.Sprintf("%s.json", tt)))
  124. if err != nil {
  125. t.Fatal(err)
  126. }
  127. var expect map[string]string
  128. if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
  129. t.Fatal(err)
  130. }
  131. keys := maps.Keys(expect)
  132. slices.Sort(keys)
  133. for _, k := range keys {
  134. if v, ok := actual[k]; !ok {
  135. t.Errorf("missing %s", k)
  136. } else if v != expect[k] {
  137. t.Errorf("unexpected %s: want %s, got %s", k, expect[k], v)
  138. }
  139. }
  140. })
  141. }
  142. }
  143. func TestConvertInvalidTensorNames(t *testing.T) {
  144. f, err := os.CreateTemp(t.TempDir(), "testmodel")
  145. if err != nil {
  146. t.Fatal(err)
  147. }
  148. defer f.Close()
  149. tempDir := t.TempDir()
  150. td := map[string]*tensorData{}
  151. offset := 4096
  152. td["model.layers.0.self_attn.q_proj.weight"] = &tensorData{
  153. Offsets: []int{0, offset},
  154. Type: "F32",
  155. Shape: []int{4096, 4096},
  156. }
  157. td["blk.0.attn_q.weight"] = &tensorData{
  158. Offsets: []int{offset, offset * 2},
  159. Type: "F32",
  160. Shape: []int{4096, 4096},
  161. }
  162. generateSafetensorTestData(t, tempDir, td)
  163. err = ConvertModel(os.DirFS(tempDir), f)
  164. if err == nil || !strings.HasPrefix(err.Error(), "duplicate tensor name") {
  165. t.Errorf("expected error but didn't get one")
  166. }
  167. }
  168. func TestConvertInvalidDatatype(t *testing.T) {
  169. f, err := os.CreateTemp(t.TempDir(), "testmodel")
  170. if err != nil {
  171. t.Fatal(err)
  172. }
  173. defer f.Close()
  174. tempDir := t.TempDir()
  175. td := map[string]*tensorData{}
  176. offset := 4096 * 14336
  177. td["model.layers.0.mlp.down_proj.weight"] = &tensorData{
  178. Offsets: []int{0, offset},
  179. Type: "I8",
  180. Shape: []int{4096, 14336},
  181. }
  182. td["model.layers.0.mlp.down_proj.weight_format"] = &tensorData{
  183. Offsets: []int{offset, offset},
  184. Type: "U8",
  185. Shape: []int{},
  186. }
  187. generateSafetensorTestData(t, tempDir, td)
  188. err = ConvertModel(os.DirFS(tempDir), f)
  189. if err == nil || err.Error() != "unsupported safetensors model" {
  190. t.Errorf("expected error but didn't get one")
  191. }
  192. }
  193. func generateSafetensorTestData(t *testing.T, tempDir string, tensorData map[string]*tensorData) {
  194. data, err := json.Marshal(tensorData)
  195. if err != nil {
  196. t.Fatal(err)
  197. }
  198. var buf bytes.Buffer
  199. l := int64(len(data))
  200. err = binary.Write(&buf, binary.LittleEndian, l)
  201. if err != nil {
  202. t.Fatal(err)
  203. }
  204. _, err = buf.Write(data)
  205. if err != nil {
  206. t.Fatal(err)
  207. }
  208. fdata, err := os.Create(filepath.Join(tempDir, "model-00001-of-00001.safetensors"))
  209. if err != nil {
  210. t.Fatal(err)
  211. }
  212. defer fdata.Close()
  213. _, err = fdata.Write(buf.Bytes())
  214. if err != nil {
  215. t.Fatal(err)
  216. }
  217. configData := `
  218. {
  219. "architectures": [
  220. "LlamaForCausalLM"
  221. ]
  222. }
  223. `
  224. f, err := os.Create(filepath.Join(tempDir, "config.json"))
  225. if err != nil {
  226. t.Fatal(err)
  227. }
  228. defer f.Close()
  229. _, err = f.WriteString(configData)
  230. if err != nil {
  231. t.Fatal(err)
  232. }
  233. tokenizerData := `
  234. {
  235. }
  236. `
  237. f, err = os.Create(filepath.Join(tempDir, "tokenizer.json"))
  238. if err != nil {
  239. t.Fatal(err)
  240. }
  241. defer f.Close()
  242. _, err = f.WriteString(tokenizerData)
  243. if err != nil {
  244. t.Fatal(err)
  245. }
  246. }
  247. func TestConvertAdapter(t *testing.T) {
  248. type AdapterCase struct {
  249. Name string
  250. BaseKV map[string]any
  251. Expected map[string]string
  252. }
  253. cases := []AdapterCase{
  254. {
  255. Name: "discollama",
  256. BaseKV: map[string]any{
  257. "general.architecture": "llama",
  258. "llama.attention.head_count": uint32(32),
  259. "llama.attention.head_count_kv": uint32(8),
  260. },
  261. Expected: map[string]string{
  262. "general.architecture": "llama",
  263. "general.file_type": "1",
  264. "general.parameter_count": "106496",
  265. "general.type": "adapter",
  266. "general.version": "v0.2",
  267. "adapter.lora.alpha": "16",
  268. "adapter.type": "lora",
  269. "llama.attention.head_count": "32",
  270. "llama.attention.head_count_kv": "8",
  271. "blk.31.attn_q.weight.lora_a": "0eb3318b02cd313429bcc7621b539fdbb10240fea190c56c9e5f93fcd37a4e50",
  272. "blk.31.attn_q.weight.lora_b": "0eb3318b02cd313429bcc7621b539fdbb10240fea190c56c9e5f93fcd37a4e50",
  273. "blk.31.attn_v.weight.lora_a": "0eb3318b02cd313429bcc7621b539fdbb10240fea190c56c9e5f93fcd37a4e50",
  274. "blk.31.attn_v.weight.lora_b": "071dcafe89df065d6e1c935ecb8fdf6479b3c202eb912e7da938597673ff5857",
  275. },
  276. },
  277. }
  278. for _, c := range cases {
  279. t.Run(c.Name, func(t *testing.T) {
  280. t.Parallel()
  281. f, err := os.CreateTemp(t.TempDir(), "f16")
  282. if err != nil {
  283. t.Fatal(err)
  284. }
  285. defer f.Close()
  286. tempDir := t.TempDir()
  287. generateLoraTestData(t, tempDir)
  288. if err = ConvertAdapter(os.DirFS(tempDir), f, c.BaseKV); err != nil {
  289. t.Fatal(err)
  290. }
  291. r, err := os.Open(f.Name())
  292. if err != nil {
  293. t.Fatal(err)
  294. }
  295. defer r.Close()
  296. m, _, err := llm.DecodeGGML(r, math.MaxInt)
  297. if err != nil {
  298. t.Fatal(err)
  299. }
  300. if _, err := r.Seek(0, io.SeekStart); err != nil {
  301. t.Fatal(err)
  302. }
  303. actual := generateResultsJSON(t, r, m.KV(), m.Tensors())
  304. keys := maps.Keys(c.Expected)
  305. slices.Sort(keys)
  306. for _, k := range keys {
  307. if v, ok := actual[k]; !ok {
  308. t.Errorf("missing %s", k)
  309. } else if v != c.Expected[k] {
  310. t.Errorf("unexpected %s: want %s, got %s", k, c.Expected[k], v)
  311. }
  312. }
  313. })
  314. }
  315. }
  316. func generateLoraTestData(t *testing.T, tempDir string) {
  317. offset := 4096 * 8 * 4
  318. td := map[string]*tensorData{"__metadata__": nil}
  319. td["model.layers.31.self_attn.q_proj.lora_a"] = &tensorData{
  320. Offsets: []int{0, offset},
  321. Type: "F32",
  322. Shape: []int{4096, 8},
  323. }
  324. td["model.layers.31.self_attn.q_proj.lora_b"] = &tensorData{
  325. Offsets: []int{offset, offset * 2},
  326. Type: "F32",
  327. Shape: []int{8, 4096},
  328. }
  329. td["model.layers.31.self_attn.v_proj.lora_a"] = &tensorData{
  330. Offsets: []int{offset * 2, offset * 3},
  331. Type: "F32",
  332. Shape: []int{4096, 8},
  333. }
  334. td["model.layers.31.self_attn.v_proj.lora_b"] = &tensorData{
  335. Offsets: []int{offset * 3, offset*3 + 8*1024*4},
  336. Type: "F32",
  337. Shape: []int{8, 1024},
  338. }
  339. data, err := json.Marshal(td)
  340. if err != nil {
  341. t.Fatal(err)
  342. }
  343. var buf bytes.Buffer
  344. l := int64(len(data))
  345. err = binary.Write(&buf, binary.LittleEndian, l)
  346. if err != nil {
  347. t.Fatal(err)
  348. }
  349. _, err = buf.Write(data)
  350. if err != nil {
  351. t.Fatal(err)
  352. }
  353. // write some data for the tensors
  354. ones := make([]float32, 4096*8)
  355. for i := range ones {
  356. ones[i] = float32(1)
  357. }
  358. for range 3 {
  359. err = binary.Write(&buf, binary.LittleEndian, ones)
  360. if err != nil {
  361. t.Fatal(err)
  362. }
  363. }
  364. ones = make([]float32, 1024*8)
  365. for i := range ones {
  366. ones[i] = float32(1)
  367. }
  368. err = binary.Write(&buf, binary.LittleEndian, ones)
  369. if err != nil {
  370. t.Fatal(err)
  371. }
  372. fdata, err := os.Create(filepath.Join(tempDir, "adapters.safetensors"))
  373. if err != nil {
  374. t.Fatal(err)
  375. }
  376. defer fdata.Close()
  377. _, err = fdata.Write(buf.Bytes())
  378. if err != nil {
  379. t.Fatal(err)
  380. }
  381. configData := `
  382. {
  383. "adapter_path": "adapters-test",
  384. "batch_size": 8,
  385. "config": "config-tiny.json",
  386. "data": "../discollama-completion",
  387. "grad_checkpoint": null,
  388. "iters": 1000,
  389. "learning_rate": 1e-05,
  390. "lora_layers": 1,
  391. "lora_parameters": {
  392. "rank": 8,
  393. "alpha": 16,
  394. "dropout": 0.0,
  395. "scale": 2.0
  396. },
  397. "lr_schedule": null,
  398. "max_seq_length": 2048,
  399. "model": "/Users/pdevine/git/Meta-Llama-3-8B-Instruct",
  400. "resume_adapter_file": null,
  401. "save_every": 100,
  402. "seed": 0,
  403. "steps_per_eval": 200,
  404. "steps_per_report": 10,
  405. "test": false,
  406. "test_batches": 500,
  407. "train": true,
  408. "use_dora": false,
  409. "val_batches": 25
  410. }
  411. `
  412. f, err := os.Create(filepath.Join(tempDir, "adapter_config.json"))
  413. if err != nil {
  414. t.Fatal(err)
  415. }
  416. defer f.Close()
  417. _, err = f.WriteString(configData)
  418. if err != nil {
  419. t.Fatal(err)
  420. }
  421. }