model.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. package mllama
  2. import (
  3. "sync"
  4. "github.com/ollama/ollama/cache"
  5. "github.com/ollama/ollama/ml"
  6. "github.com/ollama/ollama/ml/nn"
  7. "github.com/ollama/ollama/model"
  8. )
  9. type Model struct {
  10. model.Base
  11. *VisionModel `gguf:"v,vision"`
  12. *TextModel
  13. Projector *nn.Linear `gguf:"mm.0"`
  14. ImageProcessor
  15. TextProcessor
  16. start sync.Once
  17. tCache *cache.TensorCache
  18. }
  19. func New(c ml.Config) (model.Model, error) {
  20. return &Model{
  21. ImageProcessor: newImageProcessor(c),
  22. VisionModel: newVisionModel(c),
  23. TextProcessor: newTextProcessor(c),
  24. TextModel: newTextModel(c),
  25. }, nil
  26. }
  27. func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
  28. m.start.Do(func() {
  29. m.tCache = cache.NewTensorCache(m.Backend())
  30. })
  31. var crossAttentionStates ml.Tensor
  32. if opts.Images != nil {
  33. f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
  34. if err != nil {
  35. return nil, err
  36. }
  37. pixelValues, err := ctx.FromFloatSlice(f32s,
  38. m.ImageProcessor.imageSize,
  39. m.ImageProcessor.imageSize,
  40. m.ImageProcessor.numChannels,
  41. m.ImageProcessor.maxNumTiles,
  42. )
  43. if err != nil {
  44. return nil, err
  45. }
  46. aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
  47. if err != nil {
  48. return nil, err
  49. }
  50. positions := make([]int32, 1601)
  51. for i := range positions {
  52. positions[i] = int32(i)
  53. }
  54. positionIDs, err := ctx.FromIntSlice(positions, len(positions))
  55. if err != nil {
  56. return nil, err
  57. }
  58. crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
  59. crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
  60. }
  61. inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs()))
  62. if err != nil {
  63. return nil, err
  64. }
  65. positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
  66. if err != nil {
  67. return nil, err
  68. }
  69. // TODO: attention mask, cross attention mask
  70. hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, opts.Cache, m.tCache)
  71. outputs, err := ctx.FromIntSlice(opts.Outputs(), len(opts.Outputs()))
  72. if err != nil {
  73. return nil, err
  74. }
  75. return hiddenState.Rows(ctx, outputs), nil
  76. }
  77. func init() {
  78. model.Register("mllama", New)
  79. }