Pārlūkot izejas kodu

image processing

Bruce MacDonald 1 mēnesi atpakaļ
vecāks
revīzija
6f34126dcc

+ 1 - 1
model/models/pixtral/imageproc.go → model/models/mistral3/imageproc.go

@@ -1,4 +1,4 @@
-package pixtral
+package mistral3
 
 import (
 	"fmt"

+ 1 - 1
model/models/pixtral/imageproc_test.go → model/models/mistral3/imageproc_test.go

@@ -1,4 +1,4 @@
-package pixtral
+package mistral3
 
 import (
 	"bytes"

+ 75 - 21
model/models/mistral3/model.go

@@ -1,9 +1,14 @@
 package mistral3
 
 import (
+	"image"
+	_ "image/jpeg"
+	_ "image/png"
+
 	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/model/imageproc"
 	"github.com/ollama/ollama/model/input"
 )
 
@@ -11,14 +16,46 @@ type Model struct {
 	model.Base
 	*TextModel
 
+	ImageProcessor
+
 	// TODO: Add VisionModel field
 	// *VisionModel `gguf:"v,vision"`
 
 	// TODO: Add MultiModalProjector field for combining vision and text features
 	// *MultiModalProjector `gguf:"mm"`
+}
+
+// Adding ImageProcessor struct
+type ImageProcessor struct {
+	imageSize   int
+	patchSize   int
+	numChannels int
+	longestEdge int
+}
+
+// Function to create a new ImageProcessor
+func newImageProcessor(c ml.Config) ImageProcessor {
+	return ImageProcessor{
+		imageSize:   int(c.Uint("vision.image_size", 1024)),
+		patchSize:   int(c.Uint("vision.patch_size", 16)),
+		numChannels: int(c.Uint("vision.num_channels", 3)),
+		longestEdge: int(c.Uint("vision.longest_edge", 1024)),
+	}
+}
+
+// Method to process images for the model
+func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
+	// Get output size based on longest edge and patch size
+	outputSize := getResizeOutputImageSize(img, p.longestEdge, image.Point{p.patchSize, p.patchSize})
+
+	// Resize the image
+	newImage := imageproc.Composite(img)
+	newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
+
+	// Normalize image data
+	data := imageproc.Normalize(newImage, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
 
-	// TODO: Add ImageProcessor field
-	// ImageProcessor
+	return data, nil
 }
 
 // TODO: Implement MultimodalProcessor interface
@@ -32,12 +69,12 @@ func New(c ml.Config) (model.Model, error) {
 
 	m := &Model{
 		TextModel: textModel,
+		// Initialize the ImageProcessor
+		ImageProcessor: newImageProcessor(c),
+
 		// TODO: Initialize VisionModel if present
 		// VisionModel: newVisionModel(c),
 
-		// TODO: Initialize ImageProcessor
-		// ImageProcessor: newImageProcessor(c),
-
 		// TODO: Initialize MultiModalProjector
 		// MultiModalProjector: &MultiModalProjector{...},
 	}
@@ -47,21 +84,38 @@ func New(c ml.Config) (model.Model, error) {
 	return m, nil
 }
 
-// TODO: Implement EncodeMultimodal method for processing images
-// func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
-//     // Check if vision model is available
-//     // Decode image
-//     // Process the image
-//     // Pass through vision model
-//     // Project vision outputs to text embedding space
-//     // Return vision embeddings
-// }
-
-// TODO: Implement PostTokenize method to handle vision tokens
-// func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
-//     // Add special tokens around image data
-//     // Insert placeholders for image tokens
-// }
+// Implement EncodeMultimodal method for processing images
+func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
+	// Check if vision model exists - return error for now
+	return nil, model.ErrNoVisionModel
+
+	// This will be implemented when adding the vision model:
+	/*
+		image, _, err := image.Decode(bytes.NewReader(multimodalData))
+		if err != nil {
+			return nil, err
+		}
+
+		f32s, err := m.ImageProcessor.ProcessImage(image)
+		if err != nil {
+			return nil, err
+		}
+
+		pixelValues, err := ctx.Input().FromFloatSlice(f32s,
+			m.ImageProcessor.imageSize,
+			m.ImageProcessor.imageSize,
+			m.ImageProcessor.numChannels,
+		)
+		if err != nil {
+			return nil, err
+		}
+
+		// Will need VisionModel to process this
+		// visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
+		// visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs)
+		// return visionOutputs, nil
+	*/
+}
 
 func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
@@ -79,7 +133,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	// TODO: Add handling of multimodal inputs
+	// TODO: Add handling of multimodal inputs when vision model is added
 	// Set image embeddings into hidden state if present in opts.Multimodal
 
 	return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil