parser_test.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832
  1. package parser
  2. import (
  3. "bytes"
  4. "crypto/sha256"
  5. "encoding/binary"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "os"
  10. "strings"
  11. "testing"
  12. "unicode/utf16"
  13. "github.com/google/go-cmp/cmp"
  14. "github.com/stretchr/testify/assert"
  15. "github.com/stretchr/testify/require"
  16. "golang.org/x/text/encoding"
  17. "golang.org/x/text/encoding/unicode"
  18. "github.com/ollama/ollama/api"
  19. "github.com/ollama/ollama/fs/ggml"
  20. )
  21. func TestParseFileFile(t *testing.T) {
  22. input := `
  23. FROM model1
  24. ADAPTER adapter1
  25. LICENSE MIT
  26. PARAMETER param1 value1
  27. PARAMETER param2 value2
  28. TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
  29. {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
  30. {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
  31. {{ .Response }}<|eot_id|>"""
  32. `
  33. reader := strings.NewReader(input)
  34. modelfile, err := ParseFile(reader)
  35. require.NoError(t, err)
  36. expectedCommands := []Command{
  37. {Name: "model", Args: "model1"},
  38. {Name: "adapter", Args: "adapter1"},
  39. {Name: "license", Args: "MIT"},
  40. {Name: "param1", Args: "value1"},
  41. {Name: "param2", Args: "value2"},
  42. {Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"},
  43. }
  44. assert.Equal(t, expectedCommands, modelfile.Commands)
  45. }
  46. func TestParseFileTrimSpace(t *testing.T) {
  47. input := `
  48. FROM " model 1"
  49. ADAPTER adapter3
  50. LICENSE "MIT "
  51. PARAMETER param1 value1
  52. PARAMETER param2 value2
  53. TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|>
  54. {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
  55. {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
  56. {{ .Response }}<|eot_id|> """
  57. `
  58. reader := strings.NewReader(input)
  59. modelfile, err := ParseFile(reader)
  60. require.NoError(t, err)
  61. expectedCommands := []Command{
  62. {Name: "model", Args: " model 1"},
  63. {Name: "adapter", Args: "adapter3"},
  64. {Name: "license", Args: "MIT "},
  65. {Name: "param1", Args: "value1"},
  66. {Name: "param2", Args: "value2"},
  67. {Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "},
  68. }
  69. assert.Equal(t, expectedCommands, modelfile.Commands)
  70. }
  71. func TestParseFileFrom(t *testing.T) {
  72. cases := []struct {
  73. input string
  74. expected []Command
  75. err error
  76. }{
  77. {
  78. "FROM \"FOO BAR \"",
  79. []Command{{Name: "model", Args: "FOO BAR "}},
  80. nil,
  81. },
  82. {
  83. "FROM \"FOO BAR\"\nPARAMETER param1 value1",
  84. []Command{{Name: "model", Args: "FOO BAR"}, {Name: "param1", Args: "value1"}},
  85. nil,
  86. },
  87. {
  88. "FROM FOOO BAR ",
  89. []Command{{Name: "model", Args: "FOOO BAR"}},
  90. nil,
  91. },
  92. {
  93. "FROM /what/is/the path ",
  94. []Command{{Name: "model", Args: "/what/is/the path"}},
  95. nil,
  96. },
  97. {
  98. "FROM foo",
  99. []Command{{Name: "model", Args: "foo"}},
  100. nil,
  101. },
  102. {
  103. "FROM /path/to/model",
  104. []Command{{Name: "model", Args: "/path/to/model"}},
  105. nil,
  106. },
  107. {
  108. "FROM /path/to/model/fp16.bin",
  109. []Command{{Name: "model", Args: "/path/to/model/fp16.bin"}},
  110. nil,
  111. },
  112. {
  113. "FROM llama3:latest",
  114. []Command{{Name: "model", Args: "llama3:latest"}},
  115. nil,
  116. },
  117. {
  118. "FROM llama3:7b-instruct-q4_K_M",
  119. []Command{{Name: "model", Args: "llama3:7b-instruct-q4_K_M"}},
  120. nil,
  121. },
  122. {
  123. "", nil, errMissingFrom,
  124. },
  125. {
  126. "PARAMETER param1 value1",
  127. nil,
  128. errMissingFrom,
  129. },
  130. {
  131. "PARAMETER param1 value1\nFROM foo",
  132. []Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
  133. nil,
  134. },
  135. {
  136. "PARAMETER what the \nFROM lemons make lemonade ",
  137. []Command{{Name: "what", Args: "the"}, {Name: "model", Args: "lemons make lemonade"}},
  138. nil,
  139. },
  140. }
  141. for _, c := range cases {
  142. t.Run("", func(t *testing.T) {
  143. modelfile, err := ParseFile(strings.NewReader(c.input))
  144. require.ErrorIs(t, err, c.err)
  145. if modelfile != nil {
  146. assert.Equal(t, c.expected, modelfile.Commands)
  147. }
  148. })
  149. }
  150. }
  151. func TestParseFileParametersMissingValue(t *testing.T) {
  152. input := `
  153. FROM foo
  154. PARAMETER param1
  155. `
  156. reader := strings.NewReader(input)
  157. _, err := ParseFile(reader)
  158. require.ErrorIs(t, err, io.ErrUnexpectedEOF)
  159. }
  160. func TestParseFileBadCommand(t *testing.T) {
  161. input := `
  162. FROM foo
  163. BADCOMMAND param1 value1
  164. `
  165. parserError := &ParserError{
  166. LineNumber: 3,
  167. Msg: errInvalidCommand.Error(),
  168. }
  169. _, err := ParseFile(strings.NewReader(input))
  170. if !errors.As(err, &parserError) {
  171. t.Errorf("unexpected error: expected: %s, actual: %s", parserError.Error(), err.Error())
  172. }
  173. }
  174. func TestParseFileMessages(t *testing.T) {
  175. cases := []struct {
  176. input string
  177. expected []Command
  178. err error
  179. }{
  180. {
  181. `
  182. FROM foo
  183. MESSAGE system You are a file parser. Always parse things.
  184. `,
  185. []Command{
  186. {Name: "model", Args: "foo"},
  187. {Name: "message", Args: "system: You are a file parser. Always parse things."},
  188. },
  189. nil,
  190. },
  191. {
  192. `
  193. FROM foo
  194. MESSAGE system You are a file parser. Always parse things.`,
  195. []Command{
  196. {Name: "model", Args: "foo"},
  197. {Name: "message", Args: "system: You are a file parser. Always parse things."},
  198. },
  199. nil,
  200. },
  201. {
  202. `
  203. FROM foo
  204. MESSAGE system You are a file parser. Always parse things.
  205. MESSAGE user Hey there!
  206. MESSAGE assistant Hello, I want to parse all the things!
  207. `,
  208. []Command{
  209. {Name: "model", Args: "foo"},
  210. {Name: "message", Args: "system: You are a file parser. Always parse things."},
  211. {Name: "message", Args: "user: Hey there!"},
  212. {Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
  213. },
  214. nil,
  215. },
  216. {
  217. `
  218. FROM foo
  219. MESSAGE system """
  220. You are a multiline file parser. Always parse things.
  221. """
  222. `,
  223. []Command{
  224. {Name: "model", Args: "foo"},
  225. {Name: "message", Args: "system: \nYou are a multiline file parser. Always parse things.\n"},
  226. },
  227. nil,
  228. },
  229. {
  230. `
  231. FROM foo
  232. MESSAGE somerandomrole I'm ok with you adding any role message now!
  233. `,
  234. []Command{
  235. {Name: "model", Args: "foo"},
  236. {Name: "message", Args: "somerandomrole: I'm ok with you adding any role message now!"},
  237. },
  238. nil,
  239. },
  240. {
  241. `
  242. FROM foo
  243. MESSAGE system
  244. `,
  245. nil,
  246. io.ErrUnexpectedEOF,
  247. },
  248. {
  249. `
  250. FROM foo
  251. MESSAGE system`,
  252. nil,
  253. io.ErrUnexpectedEOF,
  254. },
  255. }
  256. for _, tt := range cases {
  257. t.Run("", func(t *testing.T) {
  258. modelfile, err := ParseFile(strings.NewReader(tt.input))
  259. if modelfile != nil {
  260. assert.Equal(t, tt.expected, modelfile.Commands)
  261. }
  262. if tt.err == nil {
  263. if err != nil {
  264. t.Fatalf("expected no error, but got %v", err)
  265. }
  266. return
  267. }
  268. switch tt.err.(type) {
  269. case *ParserError:
  270. var pErr *ParserError
  271. if errors.As(err, &pErr) {
  272. // got the correct type of error
  273. return
  274. }
  275. }
  276. if errors.Is(err, tt.err) {
  277. return
  278. }
  279. t.Fatalf("unexpected error: expected: %v, actual: %v", tt.err, err)
  280. })
  281. }
  282. }
  283. func TestParseFileQuoted(t *testing.T) {
  284. cases := []struct {
  285. multiline string
  286. expected []Command
  287. err error
  288. }{
  289. {
  290. `
  291. FROM foo
  292. SYSTEM """
  293. This is a
  294. multiline system.
  295. """
  296. `,
  297. []Command{
  298. {Name: "model", Args: "foo"},
  299. {Name: "system", Args: "\nThis is a\nmultiline system.\n"},
  300. },
  301. nil,
  302. },
  303. {
  304. `
  305. FROM foo
  306. SYSTEM """
  307. This is a
  308. multiline system."""
  309. `,
  310. []Command{
  311. {Name: "model", Args: "foo"},
  312. {Name: "system", Args: "\nThis is a\nmultiline system."},
  313. },
  314. nil,
  315. },
  316. {
  317. `
  318. FROM foo
  319. SYSTEM """This is a
  320. multiline system."""
  321. `,
  322. []Command{
  323. {Name: "model", Args: "foo"},
  324. {Name: "system", Args: "This is a\nmultiline system."},
  325. },
  326. nil,
  327. },
  328. {
  329. `
  330. FROM foo
  331. SYSTEM """This is a multiline system."""
  332. `,
  333. []Command{
  334. {Name: "model", Args: "foo"},
  335. {Name: "system", Args: "This is a multiline system."},
  336. },
  337. nil,
  338. },
  339. {
  340. `
  341. FROM foo
  342. SYSTEM """This is a multiline system.""
  343. `,
  344. nil,
  345. io.ErrUnexpectedEOF,
  346. },
  347. {
  348. `
  349. FROM foo
  350. SYSTEM "
  351. `,
  352. nil,
  353. io.ErrUnexpectedEOF,
  354. },
  355. {
  356. `
  357. FROM foo
  358. SYSTEM """
  359. This is a multiline system with "quotes".
  360. """
  361. `,
  362. []Command{
  363. {Name: "model", Args: "foo"},
  364. {Name: "system", Args: "\nThis is a multiline system with \"quotes\".\n"},
  365. },
  366. nil,
  367. },
  368. {
  369. `
  370. FROM foo
  371. SYSTEM """"""
  372. `,
  373. []Command{
  374. {Name: "model", Args: "foo"},
  375. {Name: "system", Args: ""},
  376. },
  377. nil,
  378. },
  379. {
  380. `
  381. FROM foo
  382. SYSTEM ""
  383. `,
  384. []Command{
  385. {Name: "model", Args: "foo"},
  386. {Name: "system", Args: ""},
  387. },
  388. nil,
  389. },
  390. {
  391. `
  392. FROM foo
  393. SYSTEM "'"
  394. `,
  395. []Command{
  396. {Name: "model", Args: "foo"},
  397. {Name: "system", Args: "'"},
  398. },
  399. nil,
  400. },
  401. {
  402. `
  403. FROM foo
  404. SYSTEM """''"'""'""'"'''''""'""'"""
  405. `,
  406. []Command{
  407. {Name: "model", Args: "foo"},
  408. {Name: "system", Args: `''"'""'""'"'''''""'""'`},
  409. },
  410. nil,
  411. },
  412. {
  413. `
  414. FROM foo
  415. TEMPLATE """
  416. {{ .Prompt }}
  417. """`,
  418. []Command{
  419. {Name: "model", Args: "foo"},
  420. {Name: "template", Args: "\n{{ .Prompt }}\n"},
  421. },
  422. nil,
  423. },
  424. }
  425. for _, c := range cases {
  426. t.Run("", func(t *testing.T) {
  427. modelfile, err := ParseFile(strings.NewReader(c.multiline))
  428. require.ErrorIs(t, err, c.err)
  429. if modelfile != nil {
  430. assert.Equal(t, c.expected, modelfile.Commands)
  431. }
  432. })
  433. }
  434. }
  435. func TestParseFileParameters(t *testing.T) {
  436. cases := map[string]struct {
  437. name, value string
  438. }{
  439. "numa true": {"numa", "true"},
  440. "num_ctx 1": {"num_ctx", "1"},
  441. "num_batch 1": {"num_batch", "1"},
  442. "num_gqa 1": {"num_gqa", "1"},
  443. "num_gpu 1": {"num_gpu", "1"},
  444. "main_gpu 1": {"main_gpu", "1"},
  445. "low_vram true": {"low_vram", "true"},
  446. "logits_all true": {"logits_all", "true"},
  447. "vocab_only true": {"vocab_only", "true"},
  448. "use_mmap true": {"use_mmap", "true"},
  449. "use_mlock true": {"use_mlock", "true"},
  450. "num_thread 1": {"num_thread", "1"},
  451. "num_keep 1": {"num_keep", "1"},
  452. "seed 1": {"seed", "1"},
  453. "num_predict 1": {"num_predict", "1"},
  454. "top_k 1": {"top_k", "1"},
  455. "top_p 1.0": {"top_p", "1.0"},
  456. "min_p 0.05": {"min_p", "0.05"},
  457. "typical_p 1.0": {"typical_p", "1.0"},
  458. "repeat_last_n 1": {"repeat_last_n", "1"},
  459. "temperature 1.0": {"temperature", "1.0"},
  460. "repeat_penalty 1.0": {"repeat_penalty", "1.0"},
  461. "presence_penalty 1.0": {"presence_penalty", "1.0"},
  462. "frequency_penalty 1.0": {"frequency_penalty", "1.0"},
  463. "mirostat 1": {"mirostat", "1"},
  464. "mirostat_tau 1.0": {"mirostat_tau", "1.0"},
  465. "mirostat_eta 1.0": {"mirostat_eta", "1.0"},
  466. "penalize_newline true": {"penalize_newline", "true"},
  467. "stop ### User:": {"stop", "### User:"},
  468. "stop ### User: ": {"stop", "### User:"},
  469. "stop \"### User:\"": {"stop", "### User:"},
  470. "stop \"### User: \"": {"stop", "### User: "},
  471. "stop \"\"\"### User:\"\"\"": {"stop", "### User:"},
  472. "stop \"\"\"### User:\n\"\"\"": {"stop", "### User:\n"},
  473. "stop <|endoftext|>": {"stop", "<|endoftext|>"},
  474. "stop <|eot_id|>": {"stop", "<|eot_id|>"},
  475. "stop </s>": {"stop", "</s>"},
  476. }
  477. for k, v := range cases {
  478. t.Run(k, func(t *testing.T) {
  479. var b bytes.Buffer
  480. fmt.Fprintln(&b, "FROM foo")
  481. fmt.Fprintln(&b, "PARAMETER", k)
  482. modelfile, err := ParseFile(&b)
  483. require.NoError(t, err)
  484. assert.Equal(t, []Command{
  485. {Name: "model", Args: "foo"},
  486. {Name: v.name, Args: v.value},
  487. }, modelfile.Commands)
  488. })
  489. }
  490. }
  491. func TestParseFileComments(t *testing.T) {
  492. cases := []struct {
  493. input string
  494. expected []Command
  495. }{
  496. {
  497. `
  498. # comment
  499. FROM foo
  500. `,
  501. []Command{
  502. {Name: "model", Args: "foo"},
  503. },
  504. },
  505. }
  506. for _, c := range cases {
  507. t.Run("", func(t *testing.T) {
  508. modelfile, err := ParseFile(strings.NewReader(c.input))
  509. require.NoError(t, err)
  510. assert.Equal(t, c.expected, modelfile.Commands)
  511. })
  512. }
  513. }
  514. func TestParseFileFormatParseFile(t *testing.T) {
  515. cases := []string{
  516. `
  517. FROM foo
  518. ADAPTER adapter1
  519. LICENSE MIT
  520. PARAMETER param1 value1
  521. PARAMETER param2 value2
  522. TEMPLATE template1
  523. MESSAGE system You are a file parser. Always parse things.
  524. MESSAGE user Hey there!
  525. MESSAGE assistant Hello, I want to parse all the things!
  526. `,
  527. `
  528. FROM foo
  529. ADAPTER adapter1
  530. LICENSE MIT
  531. PARAMETER param1 value1
  532. PARAMETER param2 value2
  533. TEMPLATE template1
  534. MESSAGE system """
  535. You are a store greeter. Always respond with "Hello!".
  536. """
  537. MESSAGE user Hey there!
  538. MESSAGE assistant Hello, I want to parse all the things!
  539. `,
  540. `
  541. FROM foo
  542. ADAPTER adapter1
  543. LICENSE """
  544. Very long and boring legal text.
  545. Blah blah blah.
  546. "Oh look, a quote!"
  547. """
  548. PARAMETER param1 value1
  549. PARAMETER param2 value2
  550. TEMPLATE template1
  551. MESSAGE system """
  552. You are a store greeter. Always respond with "Hello!".
  553. """
  554. MESSAGE user Hey there!
  555. MESSAGE assistant Hello, I want to parse all the things!
  556. `,
  557. `
  558. FROM foo
  559. SYSTEM ""
  560. `,
  561. }
  562. for _, c := range cases {
  563. t.Run("", func(t *testing.T) {
  564. modelfile, err := ParseFile(strings.NewReader(c))
  565. require.NoError(t, err)
  566. modelfile2, err := ParseFile(strings.NewReader(modelfile.String()))
  567. require.NoError(t, err)
  568. assert.Equal(t, modelfile, modelfile2)
  569. })
  570. }
  571. }
  572. func TestParseFileUTF16ParseFile(t *testing.T) {
  573. data := `FROM bob
  574. PARAMETER param1 1
  575. PARAMETER param2 4096
  576. SYSTEM You are a utf16 file.
  577. `
  578. expected := []Command{
  579. {Name: "model", Args: "bob"},
  580. {Name: "param1", Args: "1"},
  581. {Name: "param2", Args: "4096"},
  582. {Name: "system", Args: "You are a utf16 file."},
  583. }
  584. t.Run("le", func(t *testing.T) {
  585. var b bytes.Buffer
  586. require.NoError(t, binary.Write(&b, binary.LittleEndian, []byte{0xff, 0xfe}))
  587. require.NoError(t, binary.Write(&b, binary.LittleEndian, utf16.Encode([]rune(data))))
  588. actual, err := ParseFile(&b)
  589. require.NoError(t, err)
  590. assert.Equal(t, expected, actual.Commands)
  591. })
  592. t.Run("be", func(t *testing.T) {
  593. var b bytes.Buffer
  594. require.NoError(t, binary.Write(&b, binary.BigEndian, []byte{0xfe, 0xff}))
  595. require.NoError(t, binary.Write(&b, binary.BigEndian, utf16.Encode([]rune(data))))
  596. actual, err := ParseFile(&b)
  597. require.NoError(t, err)
  598. assert.Equal(t, expected, actual.Commands)
  599. })
  600. }
  601. func TestParseMultiByte(t *testing.T) {
  602. input := `FROM test
  603. SYSTEM 你好👋`
  604. expect := []Command{
  605. {Name: "model", Args: "test"},
  606. {Name: "system", Args: "你好👋"},
  607. }
  608. encodings := []encoding.Encoding{
  609. unicode.UTF8,
  610. unicode.UTF16(unicode.LittleEndian, unicode.UseBOM),
  611. unicode.UTF16(unicode.BigEndian, unicode.UseBOM),
  612. }
  613. for _, encoding := range encodings {
  614. t.Run(fmt.Sprintf("%s", encoding), func(t *testing.T) {
  615. s, err := encoding.NewEncoder().String(input)
  616. require.NoError(t, err)
  617. actual, err := ParseFile(strings.NewReader(s))
  618. require.NoError(t, err)
  619. assert.Equal(t, expect, actual.Commands)
  620. })
  621. }
  622. }
  623. func TestCreateRequest(t *testing.T) {
  624. cases := []struct {
  625. input string
  626. expected *api.CreateRequest
  627. }{
  628. {
  629. `FROM test`,
  630. &api.CreateRequest{From: "test"},
  631. },
  632. {
  633. `FROM test
  634. TEMPLATE some template
  635. `,
  636. &api.CreateRequest{
  637. From: "test",
  638. Template: "some template",
  639. },
  640. },
  641. {
  642. `FROM test
  643. LICENSE single license
  644. PARAMETER temperature 0.5
  645. MESSAGE user Hello
  646. `,
  647. &api.CreateRequest{
  648. From: "test",
  649. License: []string{"single license"},
  650. Parameters: map[string]any{"temperature": float32(0.5)},
  651. Messages: []api.Message{
  652. {Role: "user", Content: "Hello"},
  653. },
  654. },
  655. },
  656. {
  657. `FROM test
  658. PARAMETER temperature 0.5
  659. PARAMETER top_k 1
  660. SYSTEM You are a bot.
  661. LICENSE license1
  662. LICENSE license2
  663. MESSAGE user Hello there!
  664. MESSAGE assistant Hi! How are you?
  665. `,
  666. &api.CreateRequest{
  667. From: "test",
  668. License: []string{"license1", "license2"},
  669. System: "You are a bot.",
  670. Parameters: map[string]any{"temperature": float32(0.5), "top_k": int64(1)},
  671. Messages: []api.Message{
  672. {Role: "user", Content: "Hello there!"},
  673. {Role: "assistant", Content: "Hi! How are you?"},
  674. },
  675. },
  676. },
  677. }
  678. for _, c := range cases {
  679. s, err := unicode.UTF8.NewEncoder().String(c.input)
  680. if err != nil {
  681. t.Fatal(err)
  682. }
  683. p, err := ParseFile(strings.NewReader(s))
  684. if err != nil {
  685. t.Error(err)
  686. }
  687. actual, err := p.CreateRequest("")
  688. if err != nil {
  689. t.Error(err)
  690. }
  691. if diff := cmp.Diff(actual, c.expected); diff != "" {
  692. t.Errorf("mismatch (-got +want):\n%s", diff)
  693. }
  694. }
  695. }
  696. func getSHA256Digest(t *testing.T, r io.Reader) (string, int64) {
  697. t.Helper()
  698. h := sha256.New()
  699. n, err := io.Copy(h, r)
  700. if err != nil {
  701. t.Fatal(err)
  702. }
  703. return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
  704. }
  705. func createBinFile(t *testing.T, kv map[string]any, ti []ggml.Tensor) (string, string) {
  706. t.Helper()
  707. f, err := os.CreateTemp(t.TempDir(), "testbin.*.gguf")
  708. if err != nil {
  709. t.Fatal(err)
  710. }
  711. defer f.Close()
  712. if err := ggml.WriteGGUF(f, kv, ti); err != nil {
  713. t.Fatal(err)
  714. }
  715. // Calculate sha256 of file
  716. if _, err := f.Seek(0, 0); err != nil {
  717. t.Fatal(err)
  718. }
  719. digest, _ := getSHA256Digest(t, f)
  720. return f.Name(), digest
  721. }
  722. func TestCreateRequestFiles(t *testing.T) {
  723. n1, d1 := createBinFile(t, nil, nil)
  724. n2, d2 := createBinFile(t, map[string]any{"foo": "bar"}, nil)
  725. cases := []struct {
  726. input string
  727. expected *api.CreateRequest
  728. }{
  729. {
  730. fmt.Sprintf("FROM %s", n1),
  731. &api.CreateRequest{Files: map[string]string{n1: d1}},
  732. },
  733. {
  734. fmt.Sprintf("FROM %s\nFROM %s", n1, n2),
  735. &api.CreateRequest{Files: map[string]string{n1: d1, n2: d2}},
  736. },
  737. }
  738. for _, c := range cases {
  739. s, err := unicode.UTF8.NewEncoder().String(c.input)
  740. if err != nil {
  741. t.Fatal(err)
  742. }
  743. p, err := ParseFile(strings.NewReader(s))
  744. if err != nil {
  745. t.Error(err)
  746. }
  747. actual, err := p.CreateRequest("")
  748. if err != nil {
  749. t.Error(err)
  750. }
  751. if diff := cmp.Diff(actual, c.expected); diff != "" {
  752. t.Errorf("mismatch (-got +want):\n%s", diff)
  753. }
  754. }
  755. }