jmorganca 1 month ago
parent
commit
caddb1e4cf

+ 1 - 1
model/models/gemma3/process_image.go

@@ -51,7 +51,7 @@ func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
 func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
 	outputSize := image.Point{p.imageSize, p.imageSize}
 	newImage := imageproc.Composite(img)
-	newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBicubic)
+	newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
 
 	data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
 	return data, nil

+ 4 - 26
model/models/mistral3/model.go

@@ -46,13 +46,11 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
 		return nil, model.ErrNoVisionModel
 	}
 
-	// Decode image
 	image, _, err := image.Decode(bytes.NewReader(multimodalData))
 	if err != nil {
 		return nil, err
 	}
 
-	// Process image
 	f32s, err := m.ImageProcessor.ProcessImage(image)
 	if err != nil {
 		return nil, err
@@ -100,38 +98,18 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
 	return result, nil
 }
 
-func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
-	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
+func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
+	positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
 	if err != nil {
 		return nil, err
 	}
 
-	positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
+	outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
 	if err != nil {
 		return nil, err
 	}
 
-	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
-	if err != nil {
-		return nil, err
-	}
-
-	// Handle multimodal inputs
-	// var except []int
-	// hiddenState := m.TextModel.TokenEmbedding.Forward(ctx, inputs)
-
-	// for _, image := range opts.Multimodal {
-	// 	visionOutputs := image.Multimodal.(ml.Tensor)
-
-	// 	// Copy vision outputs into the hidden state
-	// 	ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
-
-	// 	for i := range visionOutputs.Dim(1) {
-	// 		except = append(except, image.Index+i)
-	// 	}
-	// }
-
-	return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
+	return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
 }
 
 func init() {

+ 1 - 1
model/models/mistral3/model_text.go

@@ -116,7 +116,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
 	return hiddenState.Add(ctx, residual)
 }
 
-func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
+func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
 	// Process text inputs
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)