process_text_test.go 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. package mllama
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "os"
  6. "path/filepath"
  7. "strconv"
  8. "testing"
  9. "github.com/google/go-cmp/cmp"
  10. "github.com/ollama/ollama/model"
  11. )
  12. func TestProcessText(t *testing.T) {
  13. ours, err := model.New(filepath.Join("testdata", "model.bin"))
  14. if errors.Is(err, os.ErrNotExist) {
  15. t.Skip("no model.bin")
  16. } else if err != nil {
  17. t.Fatal(err)
  18. }
  19. t.Run("decode", func(t *testing.T) {
  20. f, err := os.Open(filepath.Join("testdata", "theirs.json"))
  21. if errors.Is(err, os.ErrNotExist) {
  22. t.Skip("no theirs.json")
  23. } else if err != nil {
  24. t.Fatal(err)
  25. }
  26. defer f.Close()
  27. var theirs [][]byte
  28. if err := json.NewDecoder(f).Decode(&theirs); err != nil {
  29. t.Fatal(err)
  30. }
  31. for id := range theirs {
  32. ids := []int32{int32(id)}
  33. s, err := ours.(model.TextProcessor).Decode(ids)
  34. if err != nil {
  35. t.Fatal(err)
  36. }
  37. if diff := cmp.Diff(string(theirs[id]), s); diff != "" {
  38. t.Errorf("%d no match (-theirs +ours):\n%s", id, diff)
  39. }
  40. }
  41. })
  42. t.Run("encode", func(t *testing.T) {
  43. f, err := os.Open(filepath.Join("..", "testdata", "inputs.json"))
  44. if errors.Is(err, os.ErrNotExist) {
  45. t.Skip("no inputs.json")
  46. } else if err != nil {
  47. t.Fatal(err)
  48. }
  49. defer f.Close()
  50. var inputs []struct {
  51. Values []byte `json:"base64"`
  52. IDs []int32 `json:"ids"`
  53. }
  54. if err := json.NewDecoder(f).Decode(&inputs); err != nil {
  55. t.Fatal(err)
  56. }
  57. for i, input := range inputs {
  58. t.Run(strconv.Itoa(i), func(t *testing.T) {
  59. ids, err := ours.(model.TextProcessor).Encode(string(input.Values))
  60. if err != nil {
  61. t.Fatal(err)
  62. }
  63. if diff := cmp.Diff(input.IDs, ids); diff != "" {
  64. t.Errorf("%s: no match (-theirs +ours):\n%s", input.Values, diff)
  65. }
  66. })
  67. }
  68. })
  69. }