multimodal_proj.go 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. package mistral3
  2. import (
  3. "github.com/ollama/ollama/ml"
  4. "github.com/ollama/ollama/ml/nn"
  5. )
  6. type MultiModalProjector struct {
  7. Norm *nn.RMSNorm `gguf:"norm"`
  8. Projection *nn.Linear `gguf:"projection"`
  9. spatialMergeSize int
  10. imageTokenIndex int
  11. hasBias bool
  12. }
  13. func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
  14. // Apply normalization
  15. visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps)
  16. // If the spatial merge size is > 1, average pool the patches
  17. if p.spatialMergeSize > 1 {
  18. // Implementation depends on how the model handles spatial merging
  19. // For simplicity, we'll use a spatial pooling approach
  20. visionOutputs = visionOutputs.AvgPool2D(ctx, p.spatialMergeSize, p.spatialMergeSize, 0)
  21. }
  22. // Project to text embedding dimension
  23. return p.Projection.Forward(ctx, visionOutputs)
  24. }
  25. func newMultiModalProjector(c ml.Config) *MultiModalProjector {
  26. return &MultiModalProjector{
  27. spatialMergeSize: int(c.Uint("spatial_merge_size", 2)),
  28. imageTokenIndex: int(c.Uint("image_token_index", 10)),
  29. hasBias: c.Bool("mm.projector_bias", false),
  30. }
  31. }