convert_test.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. package convert
  2. import (
  3. "bytes"
  4. "crypto/sha256"
  5. "encoding/binary"
  6. "encoding/hex"
  7. "encoding/json"
  8. "flag"
  9. "fmt"
  10. "io"
  11. "io/fs"
  12. "log/slog"
  13. "math"
  14. "os"
  15. "path/filepath"
  16. "slices"
  17. "testing"
  18. "golang.org/x/exp/maps"
  19. "github.com/ollama/ollama/api"
  20. "github.com/ollama/ollama/llm"
  21. )
  22. func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) {
  23. t.Helper()
  24. f, err := os.CreateTemp(t.TempDir(), "f16")
  25. if err != nil {
  26. t.Fatal(err)
  27. }
  28. defer f.Close()
  29. if err := ConvertModel(fsys, f, func(api.ProgressResponse) {}); err != nil {
  30. t.Fatal(err)
  31. }
  32. r, err := os.Open(f.Name())
  33. if err != nil {
  34. t.Fatal(err)
  35. }
  36. t.Cleanup(func() { r.Close() })
  37. m, _, err := llm.DecodeGGML(r, math.MaxInt)
  38. if err != nil {
  39. t.Fatal(err)
  40. }
  41. if _, err := r.Seek(0, io.SeekStart); err != nil {
  42. t.Fatal(err)
  43. }
  44. return r, m.KV(), m.Tensors()
  45. }
  46. func generateResultsJSON(t *testing.T, f *os.File, kv llm.KV, tensors llm.Tensors) map[string]string {
  47. actual := make(map[string]string)
  48. for k, v := range kv {
  49. if s, ok := v.(json.Marshaler); !ok {
  50. actual[k] = fmt.Sprintf("%v", v)
  51. } else {
  52. bts, err := json.Marshal(s)
  53. if err != nil {
  54. t.Fatal(err)
  55. }
  56. actual[k] = fmt.Sprintf("%x", sha256.Sum256(bts))
  57. }
  58. }
  59. for _, tensor := range tensors.Items {
  60. sha256sum := sha256.New()
  61. sr := io.NewSectionReader(f, int64(tensors.Offset+tensor.Offset), int64(tensor.Size()))
  62. if _, err := io.Copy(sha256sum, sr); err != nil {
  63. t.Fatal(err)
  64. }
  65. actual[tensor.Name] = hex.EncodeToString(sha256sum.Sum(nil))
  66. }
  67. return actual
  68. }
  69. func TestMain(m *testing.M) {
  70. var level slog.Level
  71. flag.TextVar(&level, "level", slog.LevelInfo, "log level")
  72. flag.Parse()
  73. slog.SetLogLoggerLevel(level)
  74. os.Exit(m.Run())
  75. }
  76. func TestConvertFull(t *testing.T) {
  77. cases := []string{
  78. "Meta-Llama-3-8B-Instruct",
  79. "Meta-Llama-3.1-8B-Instruct",
  80. "Mistral-7B-Instruct-v0.2",
  81. "Mixtral-8x7B-Instruct-v0.1",
  82. "gemma-2b-it",
  83. // microsoft/Phi-3-mini-128-instruct@d548c233192db00165d842bf8edff054bb3212f8
  84. "Phi-3-mini-128k-instruct",
  85. "all-MiniLM-L6-v2",
  86. "gemma-2-9b-it",
  87. }
  88. for i := range cases {
  89. tt := cases[i]
  90. t.Run(tt, func(t *testing.T) {
  91. t.Parallel()
  92. p := filepath.Join("testdata", tt)
  93. if testing.Short() {
  94. t.Skip("skipping in short mode")
  95. } else if _, err := os.Stat(p); err != nil {
  96. t.Skipf("%s not found", p)
  97. }
  98. f, kv, tensors := convertFull(t, os.DirFS(p))
  99. actual := generateResultsJSON(t, f, kv, tensors)
  100. expectFile, err := os.Open(filepath.Join("testdata", fmt.Sprintf("%s.json", tt)))
  101. if err != nil {
  102. t.Fatal(err)
  103. }
  104. var expect map[string]string
  105. if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
  106. t.Fatal(err)
  107. }
  108. keys := maps.Keys(expect)
  109. slices.Sort(keys)
  110. for _, k := range keys {
  111. if v, ok := actual[k]; !ok {
  112. t.Errorf("missing %s", k)
  113. } else if v != expect[k] {
  114. t.Errorf("unexpected %s: want %s, got %s", k, expect[k], v)
  115. }
  116. }
  117. })
  118. }
  119. }
  120. func TestConvertInvalidDatatype(t *testing.T) {
  121. f, err := os.CreateTemp(t.TempDir(), "testmodel")
  122. if err != nil {
  123. t.Fatal(err)
  124. }
  125. defer f.Close()
  126. tempDir := t.TempDir()
  127. generateSafetensorTestData(t, tempDir)
  128. err = ConvertModel(os.DirFS(tempDir), f, func(api.ProgressResponse) {})
  129. if err == nil || err.Error() != "unsupported safetensors model" {
  130. t.Errorf("expected error but didn't get one")
  131. }
  132. }
  133. func generateSafetensorTestData(t *testing.T, tempDir string) {
  134. type tensorData struct {
  135. Offsets []int `json:"data_offsets"`
  136. Type string `json:"dtype"`
  137. Shape []int `json:"shape"`
  138. }
  139. offset := 4096 * 14336
  140. td := map[string]*tensorData{}
  141. td["model.layers.0.mlp.down_proj.weight"] = &tensorData{
  142. Offsets: []int{0, offset},
  143. Type: "I8",
  144. Shape: []int{4096, 14336},
  145. }
  146. td["model.layers.0.mlp.down_proj.weight_format"] = &tensorData{
  147. Offsets: []int{offset, offset},
  148. Type: "U8",
  149. Shape: []int{},
  150. }
  151. data, err := json.Marshal(td)
  152. if err != nil {
  153. t.Fatal(err)
  154. }
  155. var buf bytes.Buffer
  156. l := int64(len(data))
  157. err = binary.Write(&buf, binary.LittleEndian, l)
  158. if err != nil {
  159. t.Fatal(err)
  160. }
  161. _, err = buf.Write(data)
  162. if err != nil {
  163. t.Fatal(err)
  164. }
  165. fdata, err := os.Create(filepath.Join(tempDir, "model-00001-of-00001.safetensors"))
  166. if err != nil {
  167. t.Fatal(err)
  168. }
  169. defer fdata.Close()
  170. _, err = fdata.Write(buf.Bytes())
  171. if err != nil {
  172. t.Fatal(err)
  173. }
  174. configData := `
  175. {
  176. "architectures": [
  177. "LlamaForCausalLM"
  178. ]
  179. }
  180. `
  181. f, err := os.Create(filepath.Join(tempDir, "config.json"))
  182. if err != nil {
  183. t.Fatal(err)
  184. }
  185. defer f.Close()
  186. _, err = f.WriteString(configData)
  187. if err != nil {
  188. t.Fatal(err)
  189. }
  190. tokenizerData := `
  191. {
  192. }
  193. `
  194. f, err = os.Create(filepath.Join(tempDir, "tokenizer.json"))
  195. if err != nil {
  196. t.Fatal(err)
  197. }
  198. defer f.Close()
  199. _, err = f.WriteString(tokenizerData)
  200. if err != nil {
  201. t.Fatal(err)
  202. }
  203. }
  204. func TestConvertAdapter(t *testing.T) {
  205. type AdapterCase struct {
  206. Name string
  207. BaseKV map[string]any
  208. Expected map[string]string
  209. }
  210. cases := []AdapterCase{
  211. {
  212. Name: "discollama",
  213. BaseKV: map[string]any{
  214. "general.architecture": "llama",
  215. "llama.attention.head_count": uint32(32),
  216. "llama.attention.head_count_kv": uint32(8),
  217. },
  218. Expected: map[string]string{
  219. "general.architecture": "llama",
  220. "general.file_type": "1",
  221. "general.parameter_count": "106496",
  222. "general.type": "adapter",
  223. "general.version": "v0.2",
  224. "adapter.lora.alpha": "16",
  225. "adapter.type": "lora",
  226. "llama.attention.head_count": "32",
  227. "llama.attention.head_count_kv": "8",
  228. "blk.31.attn_q.weight.lora_a": "0eb3318b02cd313429bcc7621b539fdbb10240fea190c56c9e5f93fcd37a4e50",
  229. "blk.31.attn_q.weight.lora_b": "0eb3318b02cd313429bcc7621b539fdbb10240fea190c56c9e5f93fcd37a4e50",
  230. "blk.31.attn_v.weight.lora_a": "0eb3318b02cd313429bcc7621b539fdbb10240fea190c56c9e5f93fcd37a4e50",
  231. "blk.31.attn_v.weight.lora_b": "071dcafe89df065d6e1c935ecb8fdf6479b3c202eb912e7da938597673ff5857",
  232. },
  233. },
  234. }
  235. for _, c := range cases {
  236. t.Run(c.Name, func(t *testing.T) {
  237. t.Parallel()
  238. f, err := os.CreateTemp(t.TempDir(), "f16")
  239. if err != nil {
  240. t.Fatal(err)
  241. }
  242. defer f.Close()
  243. tempDir := t.TempDir()
  244. generateLoraTestData(t, tempDir)
  245. if err = ConvertAdapter(os.DirFS(tempDir), f, c.BaseKV, func(api.ProgressResponse) {}); err != nil {
  246. t.Fatal(err)
  247. }
  248. r, err := os.Open(f.Name())
  249. if err != nil {
  250. t.Fatal(err)
  251. }
  252. defer r.Close()
  253. m, _, err := llm.DecodeGGML(r, math.MaxInt)
  254. if err != nil {
  255. t.Fatal(err)
  256. }
  257. if _, err := r.Seek(0, io.SeekStart); err != nil {
  258. t.Fatal(err)
  259. }
  260. actual := generateResultsJSON(t, r, m.KV(), m.Tensors())
  261. keys := maps.Keys(c.Expected)
  262. slices.Sort(keys)
  263. for _, k := range keys {
  264. if v, ok := actual[k]; !ok {
  265. t.Errorf("missing %s", k)
  266. } else if v != c.Expected[k] {
  267. t.Errorf("unexpected %s: want %s, got %s", k, c.Expected[k], v)
  268. }
  269. }
  270. })
  271. }
  272. }
  273. func generateLoraTestData(t *testing.T, tempDir string) {
  274. type tensorData struct {
  275. Offsets []int `json:"data_offsets"`
  276. Type string `json:"dtype"`
  277. Shape []int `json:"shape"`
  278. }
  279. offset := 4096 * 8 * 4
  280. td := map[string]*tensorData{"__metadata__": nil}
  281. td["model.layers.31.self_attn.q_proj.lora_a"] = &tensorData{
  282. Offsets: []int{0, offset},
  283. Type: "F32",
  284. Shape: []int{4096, 8},
  285. }
  286. td["model.layers.31.self_attn.q_proj.lora_b"] = &tensorData{
  287. Offsets: []int{offset, offset * 2},
  288. Type: "F32",
  289. Shape: []int{8, 4096},
  290. }
  291. td["model.layers.31.self_attn.v_proj.lora_a"] = &tensorData{
  292. Offsets: []int{offset * 2, offset * 3},
  293. Type: "F32",
  294. Shape: []int{4096, 8},
  295. }
  296. td["model.layers.31.self_attn.v_proj.lora_b"] = &tensorData{
  297. Offsets: []int{offset * 3, offset*3 + 8*1024*4},
  298. Type: "F32",
  299. Shape: []int{8, 1024},
  300. }
  301. data, err := json.Marshal(td)
  302. if err != nil {
  303. t.Fatal(err)
  304. }
  305. var buf bytes.Buffer
  306. l := int64(len(data))
  307. err = binary.Write(&buf, binary.LittleEndian, l)
  308. if err != nil {
  309. t.Fatal(err)
  310. }
  311. _, err = buf.Write(data)
  312. if err != nil {
  313. t.Fatal(err)
  314. }
  315. // write some data for the tensors
  316. ones := make([]float32, 4096*8)
  317. for i := range ones {
  318. ones[i] = float32(1)
  319. }
  320. for range 3 {
  321. err = binary.Write(&buf, binary.LittleEndian, ones)
  322. if err != nil {
  323. t.Fatal(err)
  324. }
  325. }
  326. ones = make([]float32, 1024*8)
  327. for i := range ones {
  328. ones[i] = float32(1)
  329. }
  330. err = binary.Write(&buf, binary.LittleEndian, ones)
  331. if err != nil {
  332. t.Fatal(err)
  333. }
  334. fdata, err := os.Create(filepath.Join(tempDir, "adapters.safetensors"))
  335. if err != nil {
  336. t.Fatal(err)
  337. }
  338. defer fdata.Close()
  339. _, err = fdata.Write(buf.Bytes())
  340. if err != nil {
  341. t.Fatal(err)
  342. }
  343. configData := `
  344. {
  345. "adapter_path": "adapters-test",
  346. "batch_size": 8,
  347. "config": "config-tiny.json",
  348. "data": "../discollama-completion",
  349. "grad_checkpoint": null,
  350. "iters": 1000,
  351. "learning_rate": 1e-05,
  352. "lora_layers": 1,
  353. "lora_parameters": {
  354. "rank": 8,
  355. "alpha": 16,
  356. "dropout": 0.0,
  357. "scale": 2.0
  358. },
  359. "lr_schedule": null,
  360. "max_seq_length": 2048,
  361. "model": "/Users/pdevine/git/Meta-Llama-3-8B-Instruct",
  362. "resume_adapter_file": null,
  363. "save_every": 100,
  364. "seed": 0,
  365. "steps_per_eval": 200,
  366. "steps_per_report": 10,
  367. "test": false,
  368. "test_batches": 500,
  369. "train": true,
  370. "use_dora": false,
  371. "val_batches": 25
  372. }
  373. `
  374. f, err := os.Create(filepath.Join(tempDir, "adapter_config.json"))
  375. if err != nil {
  376. t.Fatal(err)
  377. }
  378. defer f.Close()
  379. _, err = f.WriteString(configData)
  380. if err != nil {
  381. t.Fatal(err)
  382. }
  383. }