imageproc_test.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. package pixtral
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "image"
  6. "image/png"
  7. "math"
  8. "os"
  9. "testing"
  10. "github.com/google/go-cmp/cmp"
  11. )
  12. func TestGetNumImageTokens(t *testing.T) {
  13. type numImageTokensCase struct {
  14. ImageSize image.Point
  15. PatchSize image.Point
  16. Expected image.Point
  17. }
  18. cases := []numImageTokensCase{
  19. {
  20. ImageSize: image.Point{1024, 764},
  21. PatchSize: image.Point{16, 16},
  22. Expected: image.Point{64, 48},
  23. },
  24. {
  25. ImageSize: image.Point{800, 600},
  26. PatchSize: image.Point{16, 16},
  27. Expected: image.Point{50, 38},
  28. },
  29. {
  30. ImageSize: image.Point{640, 480},
  31. PatchSize: image.Point{16, 16},
  32. Expected: image.Point{40, 30},
  33. },
  34. {
  35. ImageSize: image.Point{320, 200},
  36. PatchSize: image.Point{16, 16},
  37. Expected: image.Point{20, 13},
  38. },
  39. {
  40. ImageSize: image.Point{1320, 200},
  41. PatchSize: image.Point{16, 16},
  42. Expected: image.Point{83, 13},
  43. },
  44. {
  45. ImageSize: image.Point{2000, 200},
  46. PatchSize: image.Point{16, 16},
  47. Expected: image.Point{125, 13},
  48. },
  49. {
  50. ImageSize: image.Point{10000, 200},
  51. PatchSize: image.Point{16, 16},
  52. Expected: image.Point{625, 13},
  53. },
  54. {
  55. ImageSize: image.Point{1131, 577},
  56. PatchSize: image.Point{16, 16},
  57. Expected: image.Point{71, 37},
  58. },
  59. {
  60. ImageSize: image.Point{16, 16},
  61. PatchSize: image.Point{16, 16},
  62. Expected: image.Point{1, 1},
  63. },
  64. }
  65. for _, c := range cases {
  66. actual := getNumImageTokens(c.ImageSize, c.PatchSize)
  67. if diff := cmp.Diff(actual, c.Expected); diff != "" {
  68. t.Errorf("mismatch (-got +want):\n%s", diff)
  69. }
  70. }
  71. }
  72. func TestGetResizeOutputImageSize(t *testing.T) {
  73. type resizeCase struct {
  74. Image image.Image
  75. LongestEdge int
  76. PatchSize image.Point
  77. Expected image.Point
  78. }
  79. cases := []resizeCase{
  80. {
  81. Image: image.NewRGBA(image.Rect(0, 0, 1024, 768)),
  82. LongestEdge: 1024,
  83. PatchSize: image.Point{16, 16},
  84. Expected: image.Point{1024, 768},
  85. },
  86. {
  87. Image: image.NewRGBA(image.Rect(0, 0, 1162, 690)),
  88. LongestEdge: 1024,
  89. PatchSize: image.Point{16, 16},
  90. Expected: image.Point{1024, 624},
  91. },
  92. {
  93. Image: image.NewRGBA(image.Rect(0, 0, 300, 200)),
  94. LongestEdge: 1024,
  95. PatchSize: image.Point{16, 16},
  96. Expected: image.Point{304, 208},
  97. },
  98. {
  99. Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)),
  100. LongestEdge: 1024,
  101. PatchSize: image.Point{16, 16},
  102. Expected: image.Point{1024, 288},
  103. },
  104. }
  105. for _, c := range cases {
  106. actual := getResizeOutputImageSize(c.Image, c.LongestEdge, c.PatchSize)
  107. if diff := cmp.Diff(actual, c.Expected); diff != "" {
  108. t.Errorf("mismatch (-got +want):\n%s", diff)
  109. }
  110. }
  111. }
  112. func TestResize(t *testing.T) {
  113. type resizeCase struct {
  114. Image image.Image
  115. LongestEdge int
  116. PatchSize image.Point
  117. Expected image.Image
  118. }
  119. cases := []resizeCase{
  120. {
  121. Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)),
  122. LongestEdge: 1024,
  123. PatchSize: image.Point{16, 16},
  124. Expected: image.NewRGBA(image.Rect(0, 0, 1024, 288)),
  125. },
  126. {
  127. Image: image.NewRGBA(image.Rect(0, 0, 10, 10)),
  128. LongestEdge: 1024,
  129. PatchSize: image.Point{16, 16},
  130. Expected: image.NewRGBA(image.Rect(0, 0, 16, 16)),
  131. },
  132. }
  133. for _, c := range cases {
  134. actual := resizeImage(c.Image, "png", c.LongestEdge, c.PatchSize)
  135. if actual.Bounds() != c.Expected.Bounds() {
  136. t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual.Bounds(), c.Expected.Bounds())
  137. }
  138. }
  139. }
  140. func TestPreprocess(t *testing.T) {
  141. type preprocessCase struct {
  142. TestImage image.Image
  143. ExpectedLen int
  144. }
  145. cases := []preprocessCase{
  146. {
  147. TestImage: image.NewRGBA(image.Rect(0, 0, 10, 10)),
  148. ExpectedLen: 16 * 16 * 3 * 1,
  149. },
  150. {
  151. TestImage: image.NewRGBA(image.Rect(0, 0, 2000, 2000)),
  152. ExpectedLen: 1024 * 1024 * 3 * 1,
  153. },
  154. }
  155. for _, c := range cases {
  156. var buf bytes.Buffer
  157. err := png.Encode(&buf, c.TestImage)
  158. if err != nil {
  159. t.Fatal(err)
  160. }
  161. imgData, _, err := Preprocess(&buf)
  162. if err != nil {
  163. t.Fatalf("error processing: %q", err)
  164. }
  165. switch len(imgData) {
  166. case 0:
  167. t.Errorf("no image data returned")
  168. case c.ExpectedLen:
  169. // ok
  170. default:
  171. t.Errorf("unexpected image data length: %d, expected: %d", len(imgData), c.ExpectedLen)
  172. }
  173. }
  174. }
  175. func TestPreprocessImages(t *testing.T) {
  176. for _, testFile := range []string{"flight.png", "sportsball.png"} {
  177. f, err := os.Open(testFile)
  178. if err != nil {
  179. t.Skipf("skipping test, no test image found at %s", testFile)
  180. }
  181. defer f.Close()
  182. imgData, _, err := Preprocess(f)
  183. if err != nil {
  184. t.Fatalf("error processing: %q", err)
  185. }
  186. byteData := make([]byte, len(imgData)*4) // float32 is 4 bytes
  187. for i, f := range imgData {
  188. binary.LittleEndian.PutUint32(byteData[i*4:], math.Float32bits(f))
  189. }
  190. outputPath := "processed_" + testFile + ".bin"
  191. err = os.WriteFile(outputPath, byteData, 0o644)
  192. if err != nil {
  193. t.Fatalf("error writing processed image: %q", err)
  194. }
  195. }
  196. }