convert_test.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. package convert
  2. import (
  3. "bytes"
  4. "crypto/sha256"
  5. "encoding/binary"
  6. "encoding/hex"
  7. "encoding/json"
  8. "flag"
  9. "fmt"
  10. "io"
  11. "io/fs"
  12. "log/slog"
  13. "math"
  14. "os"
  15. "path/filepath"
  16. "slices"
  17. "testing"
  18. "golang.org/x/exp/maps"
  19. "github.com/ollama/ollama/llm"
  20. )
  21. func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) {
  22. t.Helper()
  23. f, err := os.CreateTemp(t.TempDir(), "f16")
  24. if err != nil {
  25. t.Fatal(err)
  26. }
  27. defer f.Close()
  28. if err := ConvertModel(fsys, f); err != nil {
  29. t.Fatal(err)
  30. }
  31. r, err := os.Open(f.Name())
  32. if err != nil {
  33. t.Fatal(err)
  34. }
  35. t.Cleanup(func() { r.Close() })
  36. m, _, err := llm.DecodeGGML(r, math.MaxInt)
  37. if err != nil {
  38. t.Fatal(err)
  39. }
  40. if _, err := r.Seek(0, io.SeekStart); err != nil {
  41. t.Fatal(err)
  42. }
  43. return r, m.KV(), m.Tensors()
  44. }
  45. func generateResultsJSON(t *testing.T, f *os.File, kv llm.KV, tensors llm.Tensors) map[string]string {
  46. actual := make(map[string]string)
  47. for k, v := range kv {
  48. if s, ok := v.(json.Marshaler); !ok {
  49. actual[k] = fmt.Sprintf("%v", v)
  50. } else {
  51. bts, err := json.Marshal(s)
  52. if err != nil {
  53. t.Fatal(err)
  54. }
  55. actual[k] = fmt.Sprintf("%x", sha256.Sum256(bts))
  56. }
  57. }
  58. for _, tensor := range tensors.Items {
  59. sha256sum := sha256.New()
  60. sr := io.NewSectionReader(f, int64(tensors.Offset+tensor.Offset), int64(tensor.Size()))
  61. if _, err := io.Copy(sha256sum, sr); err != nil {
  62. t.Fatal(err)
  63. }
  64. actual[tensor.Name] = hex.EncodeToString(sha256sum.Sum(nil))
  65. }
  66. return actual
  67. }
  68. func TestMain(m *testing.M) {
  69. var level slog.Level
  70. flag.TextVar(&level, "level", slog.LevelInfo, "log level")
  71. flag.Parse()
  72. slog.SetLogLoggerLevel(level)
  73. os.Exit(m.Run())
  74. }
  75. func TestConvertFull(t *testing.T) {
  76. cases := []string{
  77. "Meta-Llama-3-8B-Instruct",
  78. "Meta-Llama-3.1-8B-Instruct",
  79. "Mistral-7B-Instruct-v0.2",
  80. "Mixtral-8x7B-Instruct-v0.1",
  81. "gemma-2b-it",
  82. // microsoft/Phi-3-mini-128-instruct@d548c233192db00165d842bf8edff054bb3212f8
  83. "Phi-3-mini-128k-instruct",
  84. "all-MiniLM-L6-v2",
  85. "gemma-2-9b-it",
  86. }
  87. for i := range cases {
  88. tt := cases[i]
  89. t.Run(tt, func(t *testing.T) {
  90. t.Parallel()
  91. p := filepath.Join("testdata", tt)
  92. if testing.Short() {
  93. t.Skip("skipping in short mode")
  94. } else if _, err := os.Stat(p); err != nil {
  95. t.Skipf("%s not found", p)
  96. }
  97. f, kv, tensors := convertFull(t, os.DirFS(p))
  98. actual := generateResultsJSON(t, f, kv, tensors)
  99. expectFile, err := os.Open(filepath.Join("testdata", fmt.Sprintf("%s.json", tt)))
  100. if err != nil {
  101. t.Fatal(err)
  102. }
  103. var expect map[string]string
  104. if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
  105. t.Fatal(err)
  106. }
  107. keys := maps.Keys(expect)
  108. slices.Sort(keys)
  109. for _, k := range keys {
  110. if v, ok := actual[k]; !ok {
  111. t.Errorf("missing %s", k)
  112. } else if v != expect[k] {
  113. t.Errorf("unexpected %s: want %s, got %s", k, expect[k], v)
  114. }
  115. }
  116. })
  117. }
  118. }
  119. func TestConvertAdapter(t *testing.T) {
  120. type AdapterCase struct {
  121. Name string
  122. BaseKV map[string]any
  123. Expected map[string]string
  124. }
  125. cases := []AdapterCase{
  126. {
  127. Name: "discollama",
  128. BaseKV: map[string]any{
  129. "general.architecture": "llama",
  130. "llama.attention.head_count": uint32(32),
  131. "llama.attention.head_count_kv": uint32(8),
  132. },
  133. Expected: map[string]string{
  134. "general.architecture": "llama",
  135. "general.file_type": "1",
  136. "general.parameter_count": "106496",
  137. "general.type": "adapter",
  138. "general.version": "v0.2",
  139. "adapter.lora.alpha": "16",
  140. "adapter.type": "lora",
  141. "llama.attention.head_count": "32",
  142. "llama.attention.head_count_kv": "8",
  143. "blk.31.attn_q.weight.lora_a": "0eb3318b02cd313429bcc7621b539fdbb10240fea190c56c9e5f93fcd37a4e50",
  144. "blk.31.attn_q.weight.lora_b": "0eb3318b02cd313429bcc7621b539fdbb10240fea190c56c9e5f93fcd37a4e50",
  145. "blk.31.attn_v.weight.lora_a": "0eb3318b02cd313429bcc7621b539fdbb10240fea190c56c9e5f93fcd37a4e50",
  146. "blk.31.attn_v.weight.lora_b": "071dcafe89df065d6e1c935ecb8fdf6479b3c202eb912e7da938597673ff5857",
  147. },
  148. },
  149. }
  150. for _, c := range cases {
  151. t.Run(c.Name, func(t *testing.T) {
  152. t.Parallel()
  153. f, err := os.CreateTemp(t.TempDir(), "f16")
  154. if err != nil {
  155. t.Fatal(err)
  156. }
  157. defer f.Close()
  158. tempDir := t.TempDir()
  159. generateLoraTestData(t, tempDir)
  160. if err = ConvertAdapter(os.DirFS(tempDir), f, c.BaseKV); err != nil {
  161. t.Fatal(err)
  162. }
  163. r, err := os.Open(f.Name())
  164. if err != nil {
  165. t.Fatal(err)
  166. }
  167. defer r.Close()
  168. m, _, err := llm.DecodeGGML(r, math.MaxInt)
  169. if err != nil {
  170. t.Fatal(err)
  171. }
  172. if _, err := r.Seek(0, io.SeekStart); err != nil {
  173. t.Fatal(err)
  174. }
  175. actual := generateResultsJSON(t, r, m.KV(), m.Tensors())
  176. keys := maps.Keys(c.Expected)
  177. slices.Sort(keys)
  178. for _, k := range keys {
  179. if v, ok := actual[k]; !ok {
  180. t.Errorf("missing %s", k)
  181. } else if v != c.Expected[k] {
  182. t.Errorf("unexpected %s: want %s, got %s", k, c.Expected[k], v)
  183. }
  184. }
  185. })
  186. }
  187. }
  188. func generateLoraTestData(t *testing.T, tempDir string) {
  189. type tensorData struct {
  190. Offsets []int `json:"data_offsets"`
  191. Type string `json:"dtype"`
  192. Shape []int `json:"shape"`
  193. }
  194. offset := 4096 * 8 * 4
  195. td := map[string]*tensorData{"__metadata__": nil}
  196. td["model.layers.31.self_attn.q_proj.lora_a"] = &tensorData{
  197. Offsets: []int{0, offset},
  198. Type: "F32",
  199. Shape: []int{4096, 8},
  200. }
  201. td["model.layers.31.self_attn.q_proj.lora_b"] = &tensorData{
  202. Offsets: []int{offset, offset * 2},
  203. Type: "F32",
  204. Shape: []int{8, 4096},
  205. }
  206. td["model.layers.31.self_attn.v_proj.lora_a"] = &tensorData{
  207. Offsets: []int{offset * 2, offset * 3},
  208. Type: "F32",
  209. Shape: []int{4096, 8},
  210. }
  211. td["model.layers.31.self_attn.v_proj.lora_b"] = &tensorData{
  212. Offsets: []int{offset * 3, offset*3 + 8*1024*4},
  213. Type: "F32",
  214. Shape: []int{8, 1024},
  215. }
  216. data, err := json.Marshal(td)
  217. if err != nil {
  218. t.Fatal(err)
  219. }
  220. var buf bytes.Buffer
  221. l := int64(len(data))
  222. err = binary.Write(&buf, binary.LittleEndian, l)
  223. if err != nil {
  224. t.Fatal(err)
  225. }
  226. _, err = buf.Write(data)
  227. if err != nil {
  228. t.Fatal(err)
  229. }
  230. // write some data for the tensors
  231. ones := make([]float32, 4096*8)
  232. for i := range ones {
  233. ones[i] = float32(1)
  234. }
  235. for range 3 {
  236. err = binary.Write(&buf, binary.LittleEndian, ones)
  237. if err != nil {
  238. t.Fatal(err)
  239. }
  240. }
  241. ones = make([]float32, 1024*8)
  242. for i := range ones {
  243. ones[i] = float32(1)
  244. }
  245. err = binary.Write(&buf, binary.LittleEndian, ones)
  246. if err != nil {
  247. t.Fatal(err)
  248. }
  249. fdata, err := os.Create(filepath.Join(tempDir, "adapters.safetensors"))
  250. if err != nil {
  251. t.Fatal(err)
  252. }
  253. defer fdata.Close()
  254. _, err = fdata.Write(buf.Bytes())
  255. if err != nil {
  256. t.Fatal(err)
  257. }
  258. configData := `
  259. {
  260. "adapter_path": "adapters-test",
  261. "batch_size": 8,
  262. "config": "config-tiny.json",
  263. "data": "../discollama-completion",
  264. "grad_checkpoint": null,
  265. "iters": 1000,
  266. "learning_rate": 1e-05,
  267. "lora_layers": 1,
  268. "lora_parameters": {
  269. "rank": 8,
  270. "alpha": 16,
  271. "dropout": 0.0,
  272. "scale": 2.0
  273. },
  274. "lr_schedule": null,
  275. "max_seq_length": 2048,
  276. "model": "/Users/pdevine/git/Meta-Llama-3-8B-Instruct",
  277. "resume_adapter_file": null,
  278. "save_every": 100,
  279. "seed": 0,
  280. "steps_per_eval": 200,
  281. "steps_per_report": 10,
  282. "test": false,
  283. "test_batches": 500,
  284. "train": true,
  285. "use_dora": false,
  286. "val_batches": 25
  287. }
  288. `
  289. f, err := os.Create(filepath.Join(tempDir, "adapter_config.json"))
  290. if err != nil {
  291. t.Fatal(err)
  292. }
  293. defer f.Close()
  294. _, err = f.WriteString(configData)
  295. if err != nil {
  296. t.Fatal(err)
  297. }
  298. }