imageproc.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. package mllama
  2. import (
  3. "fmt"
  4. "image"
  5. _ "image/jpeg"
  6. _ "image/png"
  7. "io"
  8. "math"
  9. "slices"
  10. "golang.org/x/image/draw"
  11. "github.com/ollama/ollama/model/imageproc"
  12. )
  13. func getSupportedAspectRatios(maxTiles int) []image.Point {
  14. ratios := []image.Point{}
  15. for w := range maxTiles {
  16. for h := range maxTiles {
  17. if (w+1)*(h+1) <= maxTiles {
  18. ratios = append(ratios, image.Point{w + 1, h + 1})
  19. }
  20. }
  21. }
  22. return ratios
  23. }
  24. func clip(a, a_min, a_max int) int {
  25. if a < a_min {
  26. return a_min
  27. } else if a > a_max {
  28. return a_max
  29. }
  30. return a
  31. }
  32. func getOptimalTiledCanvas(imageSize image.Point, maxImageTiles, tileSize int) image.Point {
  33. possibleTileArrangements := getSupportedAspectRatios(maxImageTiles)
  34. possibleCanvasSizes := []image.Point{}
  35. for _, pta := range possibleTileArrangements {
  36. possibleCanvasSizes = append(possibleCanvasSizes, image.Point{pta.X * tileSize, pta.Y * tileSize})
  37. }
  38. scales := []float64{}
  39. for _, pcs := range possibleCanvasSizes {
  40. scaleHeight := float64(pcs.Y) / float64(imageSize.Y)
  41. scaleWidth := float64(pcs.X) / float64(imageSize.X)
  42. if scaleWidth > scaleHeight {
  43. scales = append(scales, scaleHeight)
  44. } else {
  45. scales = append(scales, scaleWidth)
  46. }
  47. }
  48. var minUpscale float64
  49. var maxDownscale float64
  50. var upscale bool
  51. for _, s := range scales {
  52. if s > 1.0 {
  53. upscale = true
  54. if minUpscale == 0 {
  55. minUpscale = s
  56. } else {
  57. minUpscale = math.Min(minUpscale, s)
  58. }
  59. } else {
  60. maxDownscale = math.Max(maxDownscale, s)
  61. }
  62. }
  63. selectedScale := maxDownscale
  64. if upscale {
  65. selectedScale = minUpscale
  66. }
  67. var selectedCanvas image.Point
  68. for n, pcs := range possibleCanvasSizes {
  69. if scales[n] == selectedScale {
  70. // choose the smallest possible canvas
  71. if selectedCanvas.X == 0 && selectedCanvas.Y == 0 {
  72. selectedCanvas = pcs
  73. } else if pcs.X*pcs.Y < selectedCanvas.X*selectedCanvas.Y {
  74. selectedCanvas = pcs
  75. }
  76. }
  77. }
  78. return selectedCanvas
  79. }
  80. func getImageSizeFitToCanvas(imageSize, canvasSize image.Point, tileSize int) image.Point {
  81. targetWidth := clip(imageSize.X, tileSize, canvasSize.X)
  82. targetHeight := clip(imageSize.Y, tileSize, canvasSize.Y)
  83. scaleWidth := float64(targetWidth) / float64(imageSize.X)
  84. scaleHeight := float64(targetHeight) / float64(imageSize.Y)
  85. var w, h int
  86. if scaleWidth < scaleHeight {
  87. w = targetWidth
  88. h = min(int(math.Floor(float64(imageSize.Y)*scaleWidth)), targetHeight)
  89. } else {
  90. w = min(int(math.Floor(float64(imageSize.X)*scaleHeight)), targetWidth)
  91. h = targetHeight
  92. }
  93. return image.Point{w, h}
  94. }
  95. func resizeImage(img image.Image, format string, outputSize image.Point, maxImageTiles int) (image.Image, image.Point) {
  96. if format == "png" {
  97. img = imageproc.Composite(img)
  98. }
  99. b := img.Bounds()
  100. tileSize := outputSize.Y
  101. canvasSize := getOptimalTiledCanvas(b.Max, maxImageTiles, tileSize)
  102. aspectRatio := image.Point{canvasSize.X / tileSize, canvasSize.Y / tileSize}
  103. newSize := getImageSizeFitToCanvas(b.Max, canvasSize, tileSize)
  104. return imageproc.Resize(img, newSize, imageproc.ResizeBilinear), aspectRatio
  105. }
  106. func padImage(img image.Image, outputSize, aspectRatio image.Point) image.Image {
  107. paddedSize := image.Point{
  108. X: outputSize.X * aspectRatio.X,
  109. Y: outputSize.Y * aspectRatio.Y,
  110. }
  111. dst := image.NewRGBA(image.Rect(0, 0, paddedSize.X, paddedSize.Y))
  112. draw.Draw(dst, img.Bounds(), img, image.Point{0, 0}, draw.Over)
  113. return dst
  114. }
  115. func splitToTiles(img image.Image, numTilesSize image.Point) []image.Image {
  116. b := img.Bounds()
  117. width := b.Max.X - b.Min.X
  118. height := b.Max.Y - b.Min.Y
  119. tileHeight := height / numTilesSize.Y
  120. tileWidth := width / numTilesSize.X
  121. images := []image.Image{}
  122. for h := range numTilesSize.Y {
  123. for w := range numTilesSize.X {
  124. rect := image.Rect(tileWidth*w, tileHeight*h, tileWidth*(w+1), tileHeight*(h+1))
  125. images = append(images, img.(interface {
  126. SubImage(image.Rectangle) image.Image
  127. }).SubImage(rect))
  128. }
  129. }
  130. return images
  131. }
  132. func packImages(img image.Image, aspectRatio image.Point) []float32 {
  133. subImages := splitToTiles(img, aspectRatio)
  134. var pixelVals []float32
  135. rescale := true
  136. channelFirst := true
  137. for _, subImg := range subImages {
  138. vals := imageproc.Normalize(subImg, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, rescale, channelFirst)
  139. pixelVals = append(pixelVals, vals...)
  140. }
  141. return pixelVals
  142. }
  143. func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) {
  144. outputSize := image.Point{560, 560}
  145. maxTiles := 4
  146. img, format, err := image.Decode(imageData)
  147. if err != nil {
  148. return nil, nil, fmt.Errorf("failed to decode image: %w", err)
  149. }
  150. newImage, aspectRatio := resizeImage(img, format, outputSize, maxTiles)
  151. newImage = padImage(newImage, outputSize, aspectRatio)
  152. data := packImages(newImage, aspectRatio)
  153. aspectRatioIndex := slices.Index(getSupportedAspectRatios(maxTiles), aspectRatio) + 1
  154. opts := map[string]any{
  155. "aspectRatioIndex": aspectRatioIndex,
  156. }
  157. return data, opts, nil
  158. }