model.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. package mllama
  2. import (
  3. "fmt"
  4. "github.com/ollama/ollama/kvcache"
  5. "github.com/ollama/ollama/ml"
  6. "github.com/ollama/ollama/ml/nn"
  7. "github.com/ollama/ollama/model"
  8. )
  9. type Model struct {
  10. model.Base
  11. model.BytePairEncoding
  12. *VisionModel `gguf:"v,vision"`
  13. *TextModel
  14. Projector *nn.Linear `gguf:"mm.0"`
  15. ImageProcessor
  16. }
  17. const (
  18. crossAttentionLayer = iota
  19. selfAttentionLayer
  20. )
  21. func New(c ml.Config) (model.Model, error) {
  22. // Verify unified config
  23. if c.Uint("vision.block_count") == 0 {
  24. return nil, fmt.Errorf("non-unified vision model not supported")
  25. }
  26. m := Model{
  27. BytePairEncoding: model.NewBytePairEncoding(
  28. c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
  29. &model.Vocabulary{
  30. Values: c.Strings("tokenizer.ggml.tokens"),
  31. Types: c.Uints("tokenizer.ggml.token_type"),
  32. Merges: c.Strings("tokenizer.ggml.merges"),
  33. BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
  34. AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
  35. EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
  36. AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
  37. },
  38. ),
  39. ImageProcessor: newImageProcessor(c),
  40. VisionModel: newVisionModel(c),
  41. TextModel: newTextModel(c),
  42. }
  43. encoderCache := kvcache.NewEncoderCache()
  44. encoderCache.SetConfig(ml.CacheConfig{})
  45. m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
  46. return &m, nil
  47. }
  48. func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
  49. var crossAttentionStates ml.Tensor
  50. if opts.Images != nil {
  51. f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
  52. if err != nil {
  53. return nil, err
  54. }
  55. pixelValues, err := ctx.FromFloatSlice(f32s,
  56. m.ImageProcessor.imageSize,
  57. m.ImageProcessor.imageSize,
  58. m.ImageProcessor.numChannels,
  59. m.ImageProcessor.maxNumTiles,
  60. )
  61. if err != nil {
  62. return nil, err
  63. }
  64. aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
  65. if err != nil {
  66. return nil, err
  67. }
  68. positions := make([]int32, 1601)
  69. for i := range positions {
  70. positions[i] = int32(i)
  71. }
  72. positionIDs, err := ctx.FromIntSlice(positions, len(positions))
  73. if err != nil {
  74. return nil, err
  75. }
  76. crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
  77. crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
  78. }
  79. inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
  80. if err != nil {
  81. return nil, err
  82. }
  83. positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
  84. if err != nil {
  85. return nil, err
  86. }
  87. outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
  88. if err != nil {
  89. return nil, err
  90. }
  91. // TODO: attention mask, cross attention mask
  92. return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
  93. }
  94. func init() {
  95. model.Register("mllama", New)
  96. }