|
@@ -46,13 +46,11 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
|
|
return nil, model.ErrNoVisionModel
|
|
|
}
|
|
|
|
|
|
- // Decode image
|
|
|
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
- // Process image
|
|
|
f32s, err := m.ImageProcessor.ProcessImage(image)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
@@ -100,38 +98,18 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|
|
return result, nil
|
|
|
}
|
|
|
|
|
|
-func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|
|
- inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
|
|
+func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|
|
+ positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
- positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
|
|
+ outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
- outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
-
|
|
|
- // Handle multimodal inputs
|
|
|
- // var except []int
|
|
|
- // hiddenState := m.TextModel.TokenEmbedding.Forward(ctx, inputs)
|
|
|
-
|
|
|
- // for _, image := range opts.Multimodal {
|
|
|
- // visionOutputs := image.Multimodal.(ml.Tensor)
|
|
|
-
|
|
|
- // // Copy vision outputs into the hidden state
|
|
|
- // ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
|
|
-
|
|
|
- // for i := range visionOutputs.Dim(1) {
|
|
|
- // except = append(except, image.Index+i)
|
|
|
- // }
|
|
|
- // }
|
|
|
-
|
|
|
- return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
|
|
|
+ return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
|
|
}
|
|
|
|
|
|
func init() {
|