parser_test.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833
  1. package parser
  2. import (
  3. "bytes"
  4. "crypto/sha256"
  5. "encoding/binary"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "os"
  10. "strings"
  11. "testing"
  12. "unicode/utf16"
  13. "github.com/google/go-cmp/cmp"
  14. "github.com/stretchr/testify/assert"
  15. "github.com/stretchr/testify/require"
  16. "golang.org/x/text/encoding"
  17. "golang.org/x/text/encoding/unicode"
  18. "github.com/ollama/ollama/api"
  19. "github.com/ollama/ollama/llm"
  20. )
  21. func TestParseFileFile(t *testing.T) {
  22. input := `
  23. FROM model1
  24. ADAPTER adapter1
  25. LICENSE MIT
  26. PARAMETER param1 value1
  27. PARAMETER param2 value2
  28. TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
  29. {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
  30. {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
  31. {{ .Response }}<|eot_id|>"""
  32. `
  33. reader := strings.NewReader(input)
  34. modelfile, err := ParseFile(reader)
  35. require.NoError(t, err)
  36. expectedCommands := []Command{
  37. {Name: "model", Args: "model1"},
  38. {Name: "adapter", Args: "adapter1"},
  39. {Name: "license", Args: "MIT"},
  40. {Name: "param1", Args: "value1"},
  41. {Name: "param2", Args: "value2"},
  42. {Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"},
  43. }
  44. assert.Equal(t, expectedCommands, modelfile.Commands)
  45. }
  46. func TestParseFileTrimSpace(t *testing.T) {
  47. input := `
  48. FROM " model 1"
  49. ADAPTER adapter3
  50. LICENSE "MIT "
  51. PARAMETER param1 value1
  52. PARAMETER param2 value2
  53. TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|>
  54. {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
  55. {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
  56. {{ .Response }}<|eot_id|> """
  57. `
  58. reader := strings.NewReader(input)
  59. modelfile, err := ParseFile(reader)
  60. require.NoError(t, err)
  61. expectedCommands := []Command{
  62. {Name: "model", Args: " model 1"},
  63. {Name: "adapter", Args: "adapter3"},
  64. {Name: "license", Args: "MIT "},
  65. {Name: "param1", Args: "value1"},
  66. {Name: "param2", Args: "value2"},
  67. {Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "},
  68. }
  69. assert.Equal(t, expectedCommands, modelfile.Commands)
  70. }
  71. func TestParseFileFrom(t *testing.T) {
  72. cases := []struct {
  73. input string
  74. expected []Command
  75. err error
  76. }{
  77. {
  78. "FROM \"FOO BAR \"",
  79. []Command{{Name: "model", Args: "FOO BAR "}},
  80. nil,
  81. },
  82. {
  83. "FROM \"FOO BAR\"\nPARAMETER param1 value1",
  84. []Command{{Name: "model", Args: "FOO BAR"}, {Name: "param1", Args: "value1"}},
  85. nil,
  86. },
  87. {
  88. "FROM FOOO BAR ",
  89. []Command{{Name: "model", Args: "FOOO BAR"}},
  90. nil,
  91. },
  92. {
  93. "FROM /what/is/the path ",
  94. []Command{{Name: "model", Args: "/what/is/the path"}},
  95. nil,
  96. },
  97. {
  98. "FROM foo",
  99. []Command{{Name: "model", Args: "foo"}},
  100. nil,
  101. },
  102. {
  103. "FROM /path/to/model",
  104. []Command{{Name: "model", Args: "/path/to/model"}},
  105. nil,
  106. },
  107. {
  108. "FROM /path/to/model/fp16.bin",
  109. []Command{{Name: "model", Args: "/path/to/model/fp16.bin"}},
  110. nil,
  111. },
  112. {
  113. "FROM llama3:latest",
  114. []Command{{Name: "model", Args: "llama3:latest"}},
  115. nil,
  116. },
  117. {
  118. "FROM llama3:7b-instruct-q4_K_M",
  119. []Command{{Name: "model", Args: "llama3:7b-instruct-q4_K_M"}},
  120. nil,
  121. },
  122. {
  123. "", nil, errMissingFrom,
  124. },
  125. {
  126. "PARAMETER param1 value1",
  127. nil,
  128. errMissingFrom,
  129. },
  130. {
  131. "PARAMETER param1 value1\nFROM foo",
  132. []Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
  133. nil,
  134. },
  135. {
  136. "PARAMETER what the \nFROM lemons make lemonade ",
  137. []Command{{Name: "what", Args: "the"}, {Name: "model", Args: "lemons make lemonade"}},
  138. nil,
  139. },
  140. }
  141. for _, c := range cases {
  142. t.Run("", func(t *testing.T) {
  143. modelfile, err := ParseFile(strings.NewReader(c.input))
  144. require.ErrorIs(t, err, c.err)
  145. if modelfile != nil {
  146. assert.Equal(t, c.expected, modelfile.Commands)
  147. }
  148. })
  149. }
  150. }
  151. func TestParseFileParametersMissingValue(t *testing.T) {
  152. input := `
  153. FROM foo
  154. PARAMETER param1
  155. `
  156. reader := strings.NewReader(input)
  157. _, err := ParseFile(reader)
  158. require.ErrorIs(t, err, io.ErrUnexpectedEOF)
  159. }
  160. func TestParseFileBadCommand(t *testing.T) {
  161. input := `
  162. FROM foo
  163. BADCOMMAND param1 value1
  164. `
  165. parserError := &ParserError{
  166. LineNumber: 3,
  167. Msg: errInvalidCommand.Error(),
  168. }
  169. _, err := ParseFile(strings.NewReader(input))
  170. if !errors.As(err, &parserError) {
  171. t.Errorf("unexpected error: expected: %s, actual: %s", parserError.Error(), err.Error())
  172. }
  173. }
  174. func TestParseFileMessages(t *testing.T) {
  175. cases := []struct {
  176. input string
  177. expected []Command
  178. err error
  179. }{
  180. {
  181. `
  182. FROM foo
  183. MESSAGE system You are a file parser. Always parse things.
  184. `,
  185. []Command{
  186. {Name: "model", Args: "foo"},
  187. {Name: "message", Args: "system: You are a file parser. Always parse things."},
  188. },
  189. nil,
  190. },
  191. {
  192. `
  193. FROM foo
  194. MESSAGE system You are a file parser. Always parse things.`,
  195. []Command{
  196. {Name: "model", Args: "foo"},
  197. {Name: "message", Args: "system: You are a file parser. Always parse things."},
  198. },
  199. nil,
  200. },
  201. {
  202. `
  203. FROM foo
  204. MESSAGE system You are a file parser. Always parse things.
  205. MESSAGE user Hey there!
  206. MESSAGE assistant Hello, I want to parse all the things!
  207. `,
  208. []Command{
  209. {Name: "model", Args: "foo"},
  210. {Name: "message", Args: "system: You are a file parser. Always parse things."},
  211. {Name: "message", Args: "user: Hey there!"},
  212. {Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
  213. },
  214. nil,
  215. },
  216. {
  217. `
  218. FROM foo
  219. MESSAGE system """
  220. You are a multiline file parser. Always parse things.
  221. """
  222. `,
  223. []Command{
  224. {Name: "model", Args: "foo"},
  225. {Name: "message", Args: "system: \nYou are a multiline file parser. Always parse things.\n"},
  226. },
  227. nil,
  228. },
  229. {
  230. `
  231. FROM foo
  232. MESSAGE badguy I'm a bad guy!
  233. `,
  234. nil,
  235. &ParserError{
  236. LineNumber: 3,
  237. Msg: errInvalidMessageRole.Error(),
  238. },
  239. },
  240. {
  241. `
  242. FROM foo
  243. MESSAGE system
  244. `,
  245. nil,
  246. io.ErrUnexpectedEOF,
  247. },
  248. {
  249. `
  250. FROM foo
  251. MESSAGE system`,
  252. nil,
  253. io.ErrUnexpectedEOF,
  254. },
  255. }
  256. for _, tt := range cases {
  257. t.Run("", func(t *testing.T) {
  258. modelfile, err := ParseFile(strings.NewReader(tt.input))
  259. if modelfile != nil {
  260. assert.Equal(t, tt.expected, modelfile.Commands)
  261. }
  262. if tt.err == nil {
  263. if err != nil {
  264. t.Fatalf("expected no error, but got %v", err)
  265. }
  266. return
  267. }
  268. switch tt.err.(type) {
  269. case *ParserError:
  270. var pErr *ParserError
  271. if errors.As(err, &pErr) {
  272. // got the correct type of error
  273. return
  274. }
  275. }
  276. if errors.Is(err, tt.err) {
  277. return
  278. }
  279. t.Fatalf("unexpected error: expected: %v, actual: %v", tt.err, err)
  280. })
  281. }
  282. }
  283. func TestParseFileQuoted(t *testing.T) {
  284. cases := []struct {
  285. multiline string
  286. expected []Command
  287. err error
  288. }{
  289. {
  290. `
  291. FROM foo
  292. SYSTEM """
  293. This is a
  294. multiline system.
  295. """
  296. `,
  297. []Command{
  298. {Name: "model", Args: "foo"},
  299. {Name: "system", Args: "\nThis is a\nmultiline system.\n"},
  300. },
  301. nil,
  302. },
  303. {
  304. `
  305. FROM foo
  306. SYSTEM """
  307. This is a
  308. multiline system."""
  309. `,
  310. []Command{
  311. {Name: "model", Args: "foo"},
  312. {Name: "system", Args: "\nThis is a\nmultiline system."},
  313. },
  314. nil,
  315. },
  316. {
  317. `
  318. FROM foo
  319. SYSTEM """This is a
  320. multiline system."""
  321. `,
  322. []Command{
  323. {Name: "model", Args: "foo"},
  324. {Name: "system", Args: "This is a\nmultiline system."},
  325. },
  326. nil,
  327. },
  328. {
  329. `
  330. FROM foo
  331. SYSTEM """This is a multiline system."""
  332. `,
  333. []Command{
  334. {Name: "model", Args: "foo"},
  335. {Name: "system", Args: "This is a multiline system."},
  336. },
  337. nil,
  338. },
  339. {
  340. `
  341. FROM foo
  342. SYSTEM """This is a multiline system.""
  343. `,
  344. nil,
  345. io.ErrUnexpectedEOF,
  346. },
  347. {
  348. `
  349. FROM foo
  350. SYSTEM "
  351. `,
  352. nil,
  353. io.ErrUnexpectedEOF,
  354. },
  355. {
  356. `
  357. FROM foo
  358. SYSTEM """
  359. This is a multiline system with "quotes".
  360. """
  361. `,
  362. []Command{
  363. {Name: "model", Args: "foo"},
  364. {Name: "system", Args: "\nThis is a multiline system with \"quotes\".\n"},
  365. },
  366. nil,
  367. },
  368. {
  369. `
  370. FROM foo
  371. SYSTEM """"""
  372. `,
  373. []Command{
  374. {Name: "model", Args: "foo"},
  375. {Name: "system", Args: ""},
  376. },
  377. nil,
  378. },
  379. {
  380. `
  381. FROM foo
  382. SYSTEM ""
  383. `,
  384. []Command{
  385. {Name: "model", Args: "foo"},
  386. {Name: "system", Args: ""},
  387. },
  388. nil,
  389. },
  390. {
  391. `
  392. FROM foo
  393. SYSTEM "'"
  394. `,
  395. []Command{
  396. {Name: "model", Args: "foo"},
  397. {Name: "system", Args: "'"},
  398. },
  399. nil,
  400. },
  401. {
  402. `
  403. FROM foo
  404. SYSTEM """''"'""'""'"'''''""'""'"""
  405. `,
  406. []Command{
  407. {Name: "model", Args: "foo"},
  408. {Name: "system", Args: `''"'""'""'"'''''""'""'`},
  409. },
  410. nil,
  411. },
  412. {
  413. `
  414. FROM foo
  415. TEMPLATE """
  416. {{ .Prompt }}
  417. """`,
  418. []Command{
  419. {Name: "model", Args: "foo"},
  420. {Name: "template", Args: "\n{{ .Prompt }}\n"},
  421. },
  422. nil,
  423. },
  424. }
  425. for _, c := range cases {
  426. t.Run("", func(t *testing.T) {
  427. modelfile, err := ParseFile(strings.NewReader(c.multiline))
  428. require.ErrorIs(t, err, c.err)
  429. if modelfile != nil {
  430. assert.Equal(t, c.expected, modelfile.Commands)
  431. }
  432. })
  433. }
  434. }
  435. func TestParseFileParameters(t *testing.T) {
  436. cases := map[string]struct {
  437. name, value string
  438. }{
  439. "numa true": {"numa", "true"},
  440. "num_ctx 1": {"num_ctx", "1"},
  441. "num_batch 1": {"num_batch", "1"},
  442. "num_gqa 1": {"num_gqa", "1"},
  443. "num_gpu 1": {"num_gpu", "1"},
  444. "main_gpu 1": {"main_gpu", "1"},
  445. "low_vram true": {"low_vram", "true"},
  446. "logits_all true": {"logits_all", "true"},
  447. "vocab_only true": {"vocab_only", "true"},
  448. "use_mmap true": {"use_mmap", "true"},
  449. "use_mlock true": {"use_mlock", "true"},
  450. "num_thread 1": {"num_thread", "1"},
  451. "num_keep 1": {"num_keep", "1"},
  452. "seed 1": {"seed", "1"},
  453. "num_predict 1": {"num_predict", "1"},
  454. "top_k 1": {"top_k", "1"},
  455. "top_p 1.0": {"top_p", "1.0"},
  456. "min_p 0.05": {"min_p", "0.05"},
  457. "tfs_z 1.0": {"tfs_z", "1.0"},
  458. "typical_p 1.0": {"typical_p", "1.0"},
  459. "repeat_last_n 1": {"repeat_last_n", "1"},
  460. "temperature 1.0": {"temperature", "1.0"},
  461. "repeat_penalty 1.0": {"repeat_penalty", "1.0"},
  462. "presence_penalty 1.0": {"presence_penalty", "1.0"},
  463. "frequency_penalty 1.0": {"frequency_penalty", "1.0"},
  464. "mirostat 1": {"mirostat", "1"},
  465. "mirostat_tau 1.0": {"mirostat_tau", "1.0"},
  466. "mirostat_eta 1.0": {"mirostat_eta", "1.0"},
  467. "penalize_newline true": {"penalize_newline", "true"},
  468. "stop ### User:": {"stop", "### User:"},
  469. "stop ### User: ": {"stop", "### User:"},
  470. "stop \"### User:\"": {"stop", "### User:"},
  471. "stop \"### User: \"": {"stop", "### User: "},
  472. "stop \"\"\"### User:\"\"\"": {"stop", "### User:"},
  473. "stop \"\"\"### User:\n\"\"\"": {"stop", "### User:\n"},
  474. "stop <|endoftext|>": {"stop", "<|endoftext|>"},
  475. "stop <|eot_id|>": {"stop", "<|eot_id|>"},
  476. "stop </s>": {"stop", "</s>"},
  477. }
  478. for k, v := range cases {
  479. t.Run(k, func(t *testing.T) {
  480. var b bytes.Buffer
  481. fmt.Fprintln(&b, "FROM foo")
  482. fmt.Fprintln(&b, "PARAMETER", k)
  483. modelfile, err := ParseFile(&b)
  484. require.NoError(t, err)
  485. assert.Equal(t, []Command{
  486. {Name: "model", Args: "foo"},
  487. {Name: v.name, Args: v.value},
  488. }, modelfile.Commands)
  489. })
  490. }
  491. }
  492. func TestParseFileComments(t *testing.T) {
  493. cases := []struct {
  494. input string
  495. expected []Command
  496. }{
  497. {
  498. `
  499. # comment
  500. FROM foo
  501. `,
  502. []Command{
  503. {Name: "model", Args: "foo"},
  504. },
  505. },
  506. }
  507. for _, c := range cases {
  508. t.Run("", func(t *testing.T) {
  509. modelfile, err := ParseFile(strings.NewReader(c.input))
  510. require.NoError(t, err)
  511. assert.Equal(t, c.expected, modelfile.Commands)
  512. })
  513. }
  514. }
  515. func TestParseFileFormatParseFile(t *testing.T) {
  516. cases := []string{
  517. `
  518. FROM foo
  519. ADAPTER adapter1
  520. LICENSE MIT
  521. PARAMETER param1 value1
  522. PARAMETER param2 value2
  523. TEMPLATE template1
  524. MESSAGE system You are a file parser. Always parse things.
  525. MESSAGE user Hey there!
  526. MESSAGE assistant Hello, I want to parse all the things!
  527. `,
  528. `
  529. FROM foo
  530. ADAPTER adapter1
  531. LICENSE MIT
  532. PARAMETER param1 value1
  533. PARAMETER param2 value2
  534. TEMPLATE template1
  535. MESSAGE system """
  536. You are a store greeter. Always respond with "Hello!".
  537. """
  538. MESSAGE user Hey there!
  539. MESSAGE assistant Hello, I want to parse all the things!
  540. `,
  541. `
  542. FROM foo
  543. ADAPTER adapter1
  544. LICENSE """
  545. Very long and boring legal text.
  546. Blah blah blah.
  547. "Oh look, a quote!"
  548. """
  549. PARAMETER param1 value1
  550. PARAMETER param2 value2
  551. TEMPLATE template1
  552. MESSAGE system """
  553. You are a store greeter. Always respond with "Hello!".
  554. """
  555. MESSAGE user Hey there!
  556. MESSAGE assistant Hello, I want to parse all the things!
  557. `,
  558. `
  559. FROM foo
  560. SYSTEM ""
  561. `,
  562. }
  563. for _, c := range cases {
  564. t.Run("", func(t *testing.T) {
  565. modelfile, err := ParseFile(strings.NewReader(c))
  566. require.NoError(t, err)
  567. modelfile2, err := ParseFile(strings.NewReader(modelfile.String()))
  568. require.NoError(t, err)
  569. assert.Equal(t, modelfile, modelfile2)
  570. })
  571. }
  572. }
  573. func TestParseFileUTF16ParseFile(t *testing.T) {
  574. data := `FROM bob
  575. PARAMETER param1 1
  576. PARAMETER param2 4096
  577. SYSTEM You are a utf16 file.
  578. `
  579. expected := []Command{
  580. {Name: "model", Args: "bob"},
  581. {Name: "param1", Args: "1"},
  582. {Name: "param2", Args: "4096"},
  583. {Name: "system", Args: "You are a utf16 file."},
  584. }
  585. t.Run("le", func(t *testing.T) {
  586. var b bytes.Buffer
  587. require.NoError(t, binary.Write(&b, binary.LittleEndian, []byte{0xff, 0xfe}))
  588. require.NoError(t, binary.Write(&b, binary.LittleEndian, utf16.Encode([]rune(data))))
  589. actual, err := ParseFile(&b)
  590. require.NoError(t, err)
  591. assert.Equal(t, expected, actual.Commands)
  592. })
  593. t.Run("be", func(t *testing.T) {
  594. var b bytes.Buffer
  595. require.NoError(t, binary.Write(&b, binary.BigEndian, []byte{0xfe, 0xff}))
  596. require.NoError(t, binary.Write(&b, binary.BigEndian, utf16.Encode([]rune(data))))
  597. actual, err := ParseFile(&b)
  598. require.NoError(t, err)
  599. assert.Equal(t, expected, actual.Commands)
  600. })
  601. }
  602. func TestParseMultiByte(t *testing.T) {
  603. input := `FROM test
  604. SYSTEM 你好👋`
  605. expect := []Command{
  606. {Name: "model", Args: "test"},
  607. {Name: "system", Args: "你好👋"},
  608. }
  609. encodings := []encoding.Encoding{
  610. unicode.UTF8,
  611. unicode.UTF16(unicode.LittleEndian, unicode.UseBOM),
  612. unicode.UTF16(unicode.BigEndian, unicode.UseBOM),
  613. }
  614. for _, encoding := range encodings {
  615. t.Run(fmt.Sprintf("%s", encoding), func(t *testing.T) {
  616. s, err := encoding.NewEncoder().String(input)
  617. require.NoError(t, err)
  618. actual, err := ParseFile(strings.NewReader(s))
  619. require.NoError(t, err)
  620. assert.Equal(t, expect, actual.Commands)
  621. })
  622. }
  623. }
  624. func TestCreateRequest(t *testing.T) {
  625. cases := []struct {
  626. input string
  627. expected *api.CreateRequest
  628. }{
  629. {
  630. `FROM test`,
  631. &api.CreateRequest{From: "test"},
  632. },
  633. {
  634. `FROM test
  635. TEMPLATE some template
  636. `,
  637. &api.CreateRequest{
  638. From: "test",
  639. Template: "some template",
  640. },
  641. },
  642. {
  643. `FROM test
  644. LICENSE single license
  645. PARAMETER temperature 0.5
  646. MESSAGE user Hello
  647. `,
  648. &api.CreateRequest{
  649. From: "test",
  650. License: []string{"single license"},
  651. Parameters: map[string]any{"temperature": float32(0.5)},
  652. Messages: []api.Message{
  653. {Role: "user", Content: "Hello"},
  654. },
  655. },
  656. },
  657. {
  658. `FROM test
  659. PARAMETER temperature 0.5
  660. PARAMETER top_k 1
  661. SYSTEM You are a bot.
  662. LICENSE license1
  663. LICENSE license2
  664. MESSAGE user Hello there!
  665. MESSAGE assistant Hi! How are you?
  666. `,
  667. &api.CreateRequest{
  668. From: "test",
  669. License: []string{"license1", "license2"},
  670. System: "You are a bot.",
  671. Parameters: map[string]any{"temperature": float32(0.5), "top_k": int64(1)},
  672. Messages: []api.Message{
  673. {Role: "user", Content: "Hello there!"},
  674. {Role: "assistant", Content: "Hi! How are you?"},
  675. },
  676. },
  677. },
  678. }
  679. for _, c := range cases {
  680. s, err := unicode.UTF8.NewEncoder().String(c.input)
  681. if err != nil {
  682. t.Fatal(err)
  683. }
  684. p, err := ParseFile(strings.NewReader(s))
  685. if err != nil {
  686. t.Error(err)
  687. }
  688. actual, err := p.CreateRequest("")
  689. if err != nil {
  690. t.Error(err)
  691. }
  692. if diff := cmp.Diff(actual, c.expected); diff != "" {
  693. t.Errorf("mismatch (-got +want):\n%s", diff)
  694. }
  695. }
  696. }
  697. func getSHA256Digest(t *testing.T, r io.Reader) (string, int64) {
  698. t.Helper()
  699. h := sha256.New()
  700. n, err := io.Copy(h, r)
  701. if err != nil {
  702. t.Fatal(err)
  703. }
  704. return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
  705. }
  706. func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) (string, string) {
  707. t.Helper()
  708. f, err := os.CreateTemp(t.TempDir(), "testbin.*.gguf")
  709. if err != nil {
  710. t.Fatal(err)
  711. }
  712. defer f.Close()
  713. if err := llm.WriteGGUF(f, kv, ti); err != nil {
  714. t.Fatal(err)
  715. }
  716. // Calculate sha256 of file
  717. if _, err := f.Seek(0, 0); err != nil {
  718. t.Fatal(err)
  719. }
  720. digest, _ := getSHA256Digest(t, f)
  721. return f.Name(), digest
  722. }
  723. func TestCreateRequestFiles(t *testing.T) {
  724. n1, d1 := createBinFile(t, nil, nil)
  725. n2, d2 := createBinFile(t, map[string]any{"foo": "bar"}, nil)
  726. cases := []struct {
  727. input string
  728. expected *api.CreateRequest
  729. }{
  730. {
  731. fmt.Sprintf("FROM %s", n1),
  732. &api.CreateRequest{Files: map[string]string{n1: d1}},
  733. },
  734. {
  735. fmt.Sprintf("FROM %s\nFROM %s", n1, n2),
  736. &api.CreateRequest{Files: map[string]string{n1: d1, n2: d2}},
  737. },
  738. }
  739. for _, c := range cases {
  740. s, err := unicode.UTF8.NewEncoder().String(c.input)
  741. if err != nil {
  742. t.Fatal(err)
  743. }
  744. p, err := ParseFile(strings.NewReader(s))
  745. if err != nil {
  746. t.Error(err)
  747. }
  748. actual, err := p.CreateRequest("")
  749. if err != nil {
  750. t.Error(err)
  751. }
  752. if diff := cmp.Diff(actual, c.expected); diff != "" {
  753. t.Errorf("mismatch (-got +want):\n%s", diff)
  754. }
  755. }
  756. }