model.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. package mllama
  2. import (
  3. "github.com/ollama/ollama/ml"
  4. "github.com/ollama/ollama/ml/nn"
  5. "github.com/ollama/ollama/model"
  6. )
  7. type Model struct {
  8. model.Base
  9. *VisionModel `gguf:"v,vision"`
  10. *TextModel
  11. Projector *nn.Linear `gguf:"mm.0"`
  12. ImageProcessor
  13. TextProcessor
  14. }
  15. func New(c ml.Config) (model.Model, error) {
  16. return &Model{
  17. ImageProcessor: newImageProcessor(c),
  18. VisionModel: newVisionModel(c),
  19. TextProcessor: newTextProcessor(c),
  20. TextModel: newTextModel(c),
  21. }, nil
  22. }
  23. func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
  24. var crossAttentionStates ml.Tensor
  25. if opts.Images != nil {
  26. f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
  27. if err != nil {
  28. return nil, err
  29. }
  30. pixelValues, err := ctx.FromFloatSlice(f32s,
  31. m.ImageProcessor.imageSize,
  32. m.ImageProcessor.imageSize,
  33. m.ImageProcessor.numChannels,
  34. m.ImageProcessor.maxNumTiles,
  35. )
  36. if err != nil {
  37. return nil, err
  38. }
  39. aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
  40. if err != nil {
  41. return nil, err
  42. }
  43. positions := make([]int32, 1601)
  44. for i := range positions {
  45. positions[i] = int32(i)
  46. }
  47. positionIDs, err := ctx.FromIntSlice(positions, len(positions))
  48. if err != nil {
  49. return nil, err
  50. }
  51. crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
  52. crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
  53. }
  54. inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs()))
  55. if err != nil {
  56. return nil, err
  57. }
  58. positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
  59. if err != nil {
  60. return nil, err
  61. }
  62. // TODO: attention mask, cross attention mask
  63. hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, opts.Cache)
  64. outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1)
  65. if err != nil {
  66. return nil, err
  67. }
  68. return hiddenState.Rows(ctx, outputs), nil
  69. }
  70. func init() {
  71. model.Register("mllama", New)
  72. }