Browse Source

image processing for llama3.2

Patrick Devine 7 months ago
parent
commit
f8ed545cbb
7 changed files with 579 additions and 7 deletions
  1. 1 0
      go.mod
  2. 2 0
      go.sum
  3. 4 2
      llm/server.go
  4. 238 0
      server/imageproc/images.go
  5. 305 0
      server/imageproc/images_test.go
  6. 28 4
      server/prompt.go
  7. 1 1
      server/prompt_test.go

+ 1 - 0
go.mod

@@ -22,6 +22,7 @@ require (
 	github.com/mattn/go-runewidth v0.0.14
 	github.com/nlpodyssey/gopickle v0.3.0
 	github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
+	golang.org/x/image v0.14.0
 )
 
 require (

+ 2 - 0
go.sum

@@ -230,6 +230,8 @@ golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+o
 golang.org/x/image v0.0.0-20200618115811-c13761719519/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
 golang.org/x/image v0.0.0-20201208152932-35266b937fa6/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
 golang.org/x/image v0.0.0-20210216034530-4410531fe030/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
+golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
+golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
 golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
 golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
 golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=

+ 4 - 2
llm/server.go

@@ -673,8 +673,10 @@ ws ::= ([ \t\n] ws)?
 const maxBufferSize = 512 * format.KiloByte
 
 type ImageData struct {
-	Data []byte `json:"data"`
-	ID   int    `json:"id"`
+	Data          []byte    `json:"data"`
+	ID            int       `json:"id"`
+	ImageData     []float32 `json:"image_data"`
+	AspectRatioID int       `json:"aspect_ratio_id"`
 }
 
 type completion struct {

+ 238 - 0
server/imageproc/images.go

@@ -0,0 +1,238 @@
+package imageproc
+
+import (
+	"bytes"
+	"fmt"
+	"image"
+	_ "image/jpeg"
+	_ "image/png"
+	"math"
+
+	"golang.org/x/image/draw"
+)
+
+func GetSupportedAspectRatios(maxTiles int) []image.Point {
+	ratios := []image.Point{}
+
+	for w := range maxTiles {
+		for h := range maxTiles {
+			if (w+1)*(h+1) <= maxTiles {
+				ratios = append(ratios, image.Point{w + 1, h + 1})
+			}
+		}
+	}
+
+	return ratios
+}
+
+func clip(a, a_min, a_max int) int {
+	if a < a_min {
+		return a_min
+	} else if a > a_max {
+		return a_max
+	}
+
+	return a
+}
+
+func min(a, b int) int {
+	if a < b {
+		return a
+	}
+	return b
+}
+
+func GetImageSizeFitToCanvas(imageSize, canvasSize image.Point, tileSize int) image.Point {
+	targetWidth := clip(imageSize.X, tileSize, canvasSize.X)
+	targetHeight := clip(imageSize.Y, tileSize, canvasSize.Y)
+
+	scaleWidth := float64(targetWidth) / float64(imageSize.X)
+	scaleHeight := float64(targetHeight) / float64(imageSize.Y)
+
+	var w, h int
+
+	if scaleWidth < scaleHeight {
+		w = targetWidth
+		h = min(int(math.Floor(float64(imageSize.Y)*scaleWidth)), targetHeight)
+	} else {
+		w = min(int(math.Floor(float64(imageSize.X)*scaleHeight)), targetWidth)
+		h = targetHeight
+	}
+
+	return image.Point{w, h}
+}
+
+func GetOptimalTiledCanvas(imageSize image.Point, maxImageTiles, tileSize int) image.Point {
+	possibleTileArrangements := GetSupportedAspectRatios(maxImageTiles)
+	possibleCanvasSizes := []image.Point{}
+	for _, pta := range possibleTileArrangements {
+		possibleCanvasSizes = append(possibleCanvasSizes, image.Point{pta.X * tileSize, pta.Y * tileSize})
+	}
+
+	scales := []float64{}
+
+	for _, pcs := range possibleCanvasSizes {
+		scaleHeight := float64(pcs.Y) / float64(imageSize.Y)
+		scaleWidth := float64(pcs.X) / float64(imageSize.X)
+
+		if scaleWidth > scaleHeight {
+			scales = append(scales, scaleHeight)
+		} else {
+			scales = append(scales, scaleWidth)
+		}
+	}
+
+	var minUpscale float64
+	var maxDownscale float64
+	var upscale bool
+
+	for _, s := range scales {
+		if s > 1.0 {
+			upscale = true
+			if minUpscale == 0 {
+				minUpscale = s
+			} else {
+				minUpscale = math.Min(minUpscale, s)
+			}
+		} else {
+			maxDownscale = math.Max(maxDownscale, s)
+		}
+	}
+
+	selectedScale := maxDownscale
+	if upscale {
+		selectedScale = minUpscale
+	}
+
+	selectedCanvas := possibleCanvasSizes[0]
+	for n, pcs := range possibleCanvasSizes {
+		if scales[n] == selectedScale {
+			// choose the largest possible canvas
+			if pcs.X*pcs.Y > selectedCanvas.X*selectedCanvas.Y {
+				selectedCanvas = pcs
+			}
+		}
+	}
+	return selectedCanvas
+}
+
+func SplitToTiles(img image.Image, numTilesSize image.Point) []image.Image {
+	b := img.Bounds()
+	width := b.Max.X - b.Min.X
+	height := b.Max.Y - b.Min.Y
+	tileHeight := height / numTilesSize.Y
+	tileWidth := width / numTilesSize.X
+
+	images := []image.Image{}
+
+	for h := range numTilesSize.Y {
+		for w := range numTilesSize.X {
+			rect := image.Rect(tileWidth*w, tileHeight*h, tileWidth*(w+1), tileHeight*(h+1))
+			images = append(images, img.(interface {
+				SubImage(image.Rectangle) image.Image
+			}).SubImage(rect))
+		}
+	}
+
+	return images
+}
+
+func ResizeImage(img image.Image, outputSize image.Point, maxImageTiles int) (image.Image, image.Point) {
+	b := img.Bounds()
+	tileSize := outputSize.Y
+
+	canvasSize := GetOptimalTiledCanvas(b.Max, maxImageTiles, tileSize)
+	aspectRatio := image.Point{canvasSize.X / tileSize, canvasSize.Y / tileSize}
+
+	newSize := GetImageSizeFitToCanvas(b.Max, canvasSize, tileSize)
+
+	dst := image.NewRGBA(image.Rect(0, 0, newSize.X, newSize.Y))
+	draw.ApproxBiLinear.Scale(dst, dst.Rect, img, b, draw.Over, nil)
+
+	return dst, aspectRatio
+}
+
+func PadImage(img image.Image, outputSize, aspectRatio image.Point) image.Image {
+	paddedSize := image.Point{
+		X: outputSize.X * aspectRatio.X,
+		Y: outputSize.Y * aspectRatio.Y,
+	}
+
+	dst := image.NewRGBA(image.Rect(0, 0, paddedSize.X, paddedSize.Y))
+	centerX := (paddedSize.X - img.Bounds().Max.X) / 2
+	centerY := (paddedSize.Y - img.Bounds().Max.Y) / 2
+	pos := image.Rect(centerX, centerY, centerX+img.Bounds().Max.X, centerY+img.Bounds().Max.Y)
+
+	draw.Draw(dst, pos, img, image.Point{0, 0}, draw.Over)
+
+	return dst
+}
+
+func PackImages(img image.Image, aspectRatio image.Point, mean, std [3]float32) []float32 {
+	subImages := SplitToTiles(img, aspectRatio)
+
+	var pixelVals []float32
+
+	for _, subImg := range subImages {
+		bounds := subImg.Bounds()
+		rVals := []float32{}
+		gVals := []float32{}
+		bVals := []float32{}
+		for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
+			for x := bounds.Min.X; x < bounds.Max.X; x++ {
+				c := subImg.At(x, y)
+				r, g, b, _ := c.RGBA()
+				rVal := float32(r>>8) / 255.0
+				gVal := float32(g>>8) / 255.0
+				bVal := float32(b>>8) / 255.0
+
+				rVal = (rVal - mean[0]) / std[0]
+				gVal = (gVal - mean[1]) / std[1]
+				bVal = (bVal - mean[2]) / std[2]
+
+				rVals = append(rVals, rVal)
+				gVals = append(gVals, gVal)
+				bVals = append(bVals, bVal)
+			}
+		}
+		pixelVals = append(pixelVals, rVals...)
+		pixelVals = append(pixelVals, gVals...)
+		pixelVals = append(pixelVals, bVals...)
+	}
+
+	return pixelVals
+}
+
+func Preprocess(imageData []byte) ([]float32, int, error) {
+	// todo: need guard in here for bad image data
+
+	// mllama values
+	outputSize := image.Point{560, 560}
+	maxTiles := 4
+
+	// clip values
+	mean := [3]float32{0.48145466, 0.4578275, 0.40821073}
+	std := [3]float32{0.26862954, 0.26130258, 0.27577711}
+
+	img, _, err := image.Decode(bytes.NewReader(imageData))
+	if err != nil {
+		return nil, 0, fmt.Errorf("failed to decode image: %w", err)
+	}
+
+	newImage, aspectRatio := ResizeImage(img, outputSize, maxTiles)
+	newImage = PadImage(newImage, outputSize, aspectRatio)
+
+	// todo: need to scale (dim) by 1/256
+
+	data := PackImages(newImage, aspectRatio, mean, std)
+	supportedRatios := GetSupportedAspectRatios(maxTiles)
+	var aspectRatioIndex int
+	for n, r := range supportedRatios {
+		if r == aspectRatio {
+			aspectRatioIndex = n+1
+			break
+		}
+	}
+
+	return data, aspectRatioIndex, nil
+}

+ 305 - 0
server/imageproc/images_test.go

@@ -0,0 +1,305 @@
+package imageproc
+
+import (
+	"image"
+	"reflect"
+	"testing"
+)
+
+func testEq(a, b any) bool {
+	va := reflect.ValueOf(a)
+	vb := reflect.ValueOf(b)
+
+	if va.Kind() != reflect.Slice || vb.Kind() != reflect.Slice {
+		return false
+	}
+
+	if va.Len() != vb.Len() {
+		return false
+	}
+
+	for i := range va.Len() {
+		if !reflect.DeepEqual(va.Index(i).Interface(), vb.Index(i).Interface()) {
+			return false
+		}
+	}
+	return true
+}
+
+func TestAspectRatios(t *testing.T) {
+	type AspectCase struct {
+		MaxTiles int
+		Expected []image.Point
+	}
+
+	cases := []AspectCase{
+		{
+			MaxTiles: 1,
+			Expected: []image.Point{{1, 1}},
+		},
+		{
+			MaxTiles: 2,
+			Expected: []image.Point{{1, 1}, {1, 2}, {2, 1}},
+		},
+		{
+			MaxTiles: 3,
+			Expected: []image.Point{{1, 1}, {1, 2}, {1, 3}, {2, 1}, {3, 1}},
+		},
+		{
+			MaxTiles: 4,
+			Expected: []image.Point{{1, 1}, {1, 2}, {1, 3}, {1, 4}, {2, 1}, {2, 2}, {3, 1}, {4, 1}},
+		},
+	}
+
+	for _, c := range cases {
+		actual := GetSupportedAspectRatios(c.MaxTiles)
+
+		if !testEq(actual, c.Expected) {
+			t.Errorf("incorrect aspect ratio: '%#v'. expected: '%#v'", actual, c.Expected)
+		}
+	}
+}
+
+func TestGetImageSizeFitToCanvas(t *testing.T) {
+	type ImageSizeCase struct {
+		ImageRect  image.Point
+		CanvasRect image.Point
+		TileSize   int
+		Expected   image.Point
+	}
+
+	cases := []ImageSizeCase{
+		{
+			ImageRect:  image.Point{400, 400},
+			CanvasRect: image.Point{640, 480},
+			TileSize:   200,
+			Expected:   image.Point{400, 400},
+		},
+		{
+			ImageRect:  image.Point{1024, 768},
+			CanvasRect: image.Point{640, 480},
+			TileSize:   200,
+			Expected:   image.Point{640, 480},
+		},
+		{
+			ImageRect:  image.Point{500, 500},
+			CanvasRect: image.Point{1000, 1000},
+			TileSize:   750,
+			Expected:   image.Point{750, 750},
+		},
+		{
+			ImageRect:  image.Point{500, 1000},
+			CanvasRect: image.Point{2000, 2000},
+			TileSize:   2000,
+			Expected:   image.Point{1000, 2000},
+		},
+		{
+			ImageRect:  image.Point{4000, 3000},
+			CanvasRect: image.Point{2000, 1000},
+			TileSize:   1000,
+			Expected:   image.Point{1333, 1000},
+		},
+		{
+			ImageRect:  image.Point{667, 1000},
+			CanvasRect: image.Point{1000, 1000},
+			TileSize:   560,
+			Expected:   image.Point{667, 1000},
+		},
+	}
+
+	for _, c := range cases {
+		actual := GetImageSizeFitToCanvas(c.ImageRect, c.CanvasRect, c.TileSize)
+
+		if actual != c.Expected {
+			t.Errorf("incorrect image rect: '%#v'. expected: '%#v'", actual, c.Expected)
+		}
+	}
+}
+
+func TestGetOptimalTiledCanvas(t *testing.T) {
+	type TiledCanvasSizeCase struct {
+		ImageSize     image.Point
+		MaxImageTiles int
+		TileSize      int
+		Expected      image.Point
+	}
+
+	cases := []TiledCanvasSizeCase{
+		{
+			ImageSize:     image.Point{1024, 768},
+			MaxImageTiles: 4,
+			TileSize:      1000,
+			Expected:      image.Point{4000, 1000},
+		},
+		{
+			ImageSize:     image.Point{1024, 768},
+			MaxImageTiles: 4,
+			TileSize:      560,
+			Expected:      image.Point{1120, 1120},
+		},
+	}
+
+	for _, c := range cases {
+		actual := GetOptimalTiledCanvas(c.ImageSize, c.MaxImageTiles, c.TileSize)
+
+		if actual != c.Expected {
+			t.Errorf("incorrect tiled canvas: '%#v'. expected: '%#v'", actual, c.Expected)
+		}
+	}
+}
+
+func TestSplitToTiles(t *testing.T) {
+	type SplitCase struct {
+		TestImage    image.Image
+		NumTilesSize image.Point
+		Expected     []image.Image
+	}
+
+	cases := []SplitCase{
+		{
+			TestImage:    image.NewRGBA(image.Rect(0, 0, 1024, 768)),
+			NumTilesSize: image.Point{1, 1},
+			Expected:     []image.Image{image.NewRGBA(image.Rect(0, 0, 1024, 768))},
+		},
+		{
+			TestImage:    image.NewRGBA(image.Rect(0, 0, 1000, 500)),
+			NumTilesSize: image.Point{2, 1},
+			Expected: []image.Image{
+				image.NewRGBA(image.Rect(0, 0, 500, 500)),
+				image.NewRGBA(image.Rect(500, 0, 1000, 500)),
+			},
+		},
+		{
+			TestImage:    image.NewRGBA(image.Rect(0, 0, 1000, 1000)),
+			NumTilesSize: image.Point{2, 2},
+			Expected: []image.Image{
+				image.NewRGBA(image.Rect(0, 0, 500, 500)),
+				image.NewRGBA(image.Rect(500, 0, 1000, 500)),
+				image.NewRGBA(image.Rect(0, 500, 500, 1000)),
+				image.NewRGBA(image.Rect(500, 500, 1000, 1000)),
+			},
+		},
+	}
+
+	for _, c := range cases {
+		actual := SplitToTiles(c.TestImage, c.NumTilesSize)
+
+		if len(actual) != len(c.Expected) {
+			t.Errorf("incorrect number of images '%d': expected: '%d'", len(actual), len(c.Expected))
+		}
+
+		for i := range actual {
+			if actual[i].Bounds() != c.Expected[i].Bounds() {
+				t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual[i].Bounds(), c.Expected[i].Bounds())
+			}
+		}
+	}
+}
+
+func TestResize(t *testing.T) {
+	type ResizeCase struct {
+		TestImage           image.Image
+		OutputSize          image.Point
+		MaxImageTiles       int
+		ExpectedImage       image.Image
+		ExpectedAspectRatio image.Point
+	}
+
+	cases := []ResizeCase{
+		{
+			TestImage:           image.NewRGBA(image.Rect(0, 0, 200, 200)),
+			OutputSize:          image.Point{100, 100},
+			MaxImageTiles:       1,
+			ExpectedImage:       image.NewRGBA(image.Rect(0, 0, 100, 100)),
+			ExpectedAspectRatio: image.Point{1, 1},
+		},
+		{
+			TestImage:           image.NewRGBA(image.Rect(0, 0, 200, 200)),
+			OutputSize:          image.Point{100, 100},
+			MaxImageTiles:       2,
+			ExpectedImage:       image.NewRGBA(image.Rect(0, 0, 100, 100)),
+			ExpectedAspectRatio: image.Point{1, 2},
+		},
+		{
+			TestImage:           image.NewRGBA(image.Rect(0, 0, 2560, 1920)),
+			OutputSize:          image.Point{560, 560},
+			MaxImageTiles:       4,
+			ExpectedImage:       image.NewRGBA(image.Rect(0, 0, 1120, 840)),
+			ExpectedAspectRatio: image.Point{2, 2},
+		},
+		{
+			TestImage:           image.NewRGBA(image.Rect(0, 0, 1024, 768)),
+			OutputSize:          image.Point{560, 560},
+			MaxImageTiles:       4,
+			ExpectedImage:       image.NewRGBA(image.Rect(0, 0, 1024, 768)),
+			ExpectedAspectRatio: image.Point{2, 2},
+		},
+	}
+
+	for _, c := range cases {
+		actualImage, actualAspectRatio := ResizeImage(c.TestImage, c.OutputSize, c.MaxImageTiles)
+
+		if actualImage.Bounds() != c.ExpectedImage.Bounds() {
+			t.Errorf("image size incorrect: '%#v': expected: '%#v'", actualImage.Bounds(), c.ExpectedImage.Bounds())
+		}
+
+		if actualAspectRatio != c.ExpectedAspectRatio {
+			t.Errorf("canvas size incorrect: '%#v': expected: '%#v'", actualAspectRatio, c.ExpectedAspectRatio)
+		}
+	}
+}
+
+func TestPad(t *testing.T) {
+	type PadCase struct {
+		TestImage   image.Image
+		OutputSize  image.Point
+		AspectRatio image.Point
+		Expected    image.Image
+	}
+
+	cases := []PadCase{
+		{
+			TestImage:   image.NewRGBA(image.Rect(0, 0, 1000, 667)),
+			OutputSize:  image.Point{560, 560},
+			AspectRatio: image.Point{2, 2},
+			Expected:    image.NewRGBA(image.Rect(0, 0, 1120, 1120)),
+		},
+	}
+
+	for _, c := range cases {
+		actual := PadImage(c.TestImage, c.OutputSize, c.AspectRatio)
+
+		if actual.Bounds() != c.Expected.Bounds() {
+			t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual.Bounds(), c.Expected.Bounds())
+		}
+	}
+}
+
+func TestPackImages(t *testing.T) {
+	type PackCase struct {
+		TestImage   image.Image
+		AspectRatio image.Point
+	}
+
+	mean := [3]float32{0.48145466, 0.4578275, 0.40821073}
+	std := [3]float32{0.26862954, 0.26130258, 0.27577711}
+
+	cases := []PackCase{
+		{
+			TestImage:   image.NewRGBA(image.Rect(0, 0, 1120, 1120)),
+			AspectRatio: image.Point{2, 2},
+		},
+		{
+			TestImage:   image.NewRGBA(image.Rect(0, 0, 560, 560)),
+			AspectRatio: image.Point{1, 1},
+		},
+		{
+			TestImage:   image.NewRGBA(image.Rect(0, 0, 1120, 560)),
+			AspectRatio: image.Point{1, 2},
+		},
+	}
+
+	for _, c := range cases {
+		PackImages(c.TestImage, c.AspectRatio, mean, std)
+	}
+}

+ 28 - 4
server/prompt.go

@@ -7,6 +7,7 @@ import (
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/llm"
+	"github.com/ollama/ollama/server/imageproc"
 	"github.com/ollama/ollama/template"
 )
 
@@ -61,14 +62,37 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 		return "", nil, err
 	}
 
+	preprocess := checkMllamaModelFamily(m)
+
 	for _, m := range msgs[n:] {
 		for _, i := range m.Images {
-			images = append(images, llm.ImageData{
-				ID:   len(images),
-				Data: i,
-			})
+			if preprocess {
+				data, aspectRatioID, err := imageproc.Preprocess(i)
+				if err != nil {
+					return "", nil, err
+				}
+				images = append(images, llm.ImageData{
+					ID:            len(images),
+					ImageData:     data,
+					AspectRatioID: aspectRatioID,
+				})
+			} else {
+				images = append(images, llm.ImageData{
+					ID:   len(images),
+					Data: i,
+				})
+			}
 		}
 	}
 
 	return b.String(), images, nil
 }
+
+func checkMllamaModelFamily(m *Model) bool {
+	for _, arch := range m.Config.ModelFamilies {
+		if arch == "mllama" {
+			return true
+		}
+	}
+	return false
+}

+ 1 - 1
server/prompt_test.go

@@ -203,7 +203,7 @@ func TestChatPrompt(t *testing.T) {
 				}
 
 				if !bytes.Equal(images[i].Data, tt.images[i]) {
-					t.Errorf("expected %q, got %q", tt.images[i], images[i])
+					t.Errorf("expected %q, got %q", tt.images[i], images[i].Data)
 				}
 			}
 		})