process_text_test.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. package mllama
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "os"
  6. "path/filepath"
  7. "strconv"
  8. "testing"
  9. "github.com/google/go-cmp/cmp"
  10. "github.com/google/go-cmp/cmp/cmpopts"
  11. "github.com/ollama/ollama/model"
  12. )
  13. func TestProcessText(t *testing.T) {
  14. ours, err := model.New(filepath.Join("testdata", "model.bin"))
  15. if errors.Is(err, os.ErrNotExist) {
  16. t.Skip("no model.bin")
  17. } else if err != nil {
  18. t.Fatal(err)
  19. }
  20. t.Run("decode", func(t *testing.T) {
  21. f, err := os.Open(filepath.Join("testdata", "theirs.json"))
  22. if errors.Is(err, os.ErrNotExist) {
  23. t.Skip("no theirs.json")
  24. } else if err != nil {
  25. t.Fatal(err)
  26. }
  27. defer f.Close()
  28. var theirs [][]byte
  29. if err := json.NewDecoder(f).Decode(&theirs); err != nil {
  30. t.Fatal(err)
  31. }
  32. for id := range theirs {
  33. ids := []int32{int32(id)}
  34. s, err := ours.(model.TextProcessor).Decode(ids)
  35. if err != nil {
  36. t.Fatal(err)
  37. }
  38. if diff := cmp.Diff(string(theirs[id]), s); diff != "" {
  39. t.Errorf("%d no match (-theirs +ours):\n%s", id, diff)
  40. }
  41. }
  42. })
  43. t.Run("encode", func(t *testing.T) {
  44. f, err := os.Open(filepath.Join("..", "testdata", "inputs.json"))
  45. if errors.Is(err, os.ErrNotExist) {
  46. t.Skip("no inputs.json")
  47. } else if err != nil {
  48. t.Fatal(err)
  49. }
  50. defer f.Close()
  51. var inputs []struct {
  52. Values []byte `json:"base64"`
  53. IDs []int32 `json:"ids"`
  54. }
  55. if err := json.NewDecoder(f).Decode(&inputs); err != nil {
  56. t.Fatal(err)
  57. }
  58. for i, input := range inputs {
  59. if i == 45 {
  60. t.Skip("skip 45")
  61. }
  62. t.Run(strconv.Itoa(i), func(t *testing.T) {
  63. ids, err := ours.(model.TextProcessor).Encode(string(input.Values))
  64. if err != nil {
  65. t.Fatal(err)
  66. }
  67. if diff := cmp.Diff(input.IDs, ids, cmpopts.EquateEmpty()); diff != "" {
  68. t.Errorf("%s: no match (-theirs +ours):\n%s", input.Values, diff)
  69. }
  70. })
  71. }
  72. })
  73. }