model.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. package mllama
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "hash/fnv"
  7. "image"
  8. "slices"
  9. "github.com/ollama/ollama/kvcache"
  10. "github.com/ollama/ollama/ml"
  11. "github.com/ollama/ollama/ml/nn"
  12. "github.com/ollama/ollama/model"
  13. "github.com/ollama/ollama/model/input"
  14. )
  15. type Model struct {
  16. model.Base
  17. model.BytePairEncoding
  18. *VisionModel `gguf:"v,vision"`
  19. *TextModel
  20. Projector *nn.Linear `gguf:"mm.0"`
  21. ImageProcessor
  22. }
  23. const (
  24. crossAttentionLayer = iota
  25. selfAttentionLayer
  26. )
  27. func New(c ml.Config) (model.Model, error) {
  28. // Verify unified config
  29. if c.Uint("vision.block_count") == 0 {
  30. return nil, fmt.Errorf("non-unified vision model not supported")
  31. }
  32. m := Model{
  33. BytePairEncoding: model.NewBytePairEncoding(
  34. c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
  35. &model.Vocabulary{
  36. Values: c.Strings("tokenizer.ggml.tokens"),
  37. Types: c.Uints("tokenizer.ggml.token_type"),
  38. Merges: c.Strings("tokenizer.ggml.merges"),
  39. BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
  40. AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
  41. EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
  42. AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
  43. },
  44. ),
  45. ImageProcessor: newImageProcessor(c),
  46. VisionModel: newVisionModel(c),
  47. TextModel: newTextModel(c),
  48. }
  49. encoderCache := kvcache.NewEncoderCache()
  50. encoderCache.SetConfig(ml.CacheConfig{})
  51. m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
  52. return &m, nil
  53. }
  54. func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
  55. if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
  56. return nil, model.ErrNoVisionModel
  57. }
  58. image, _, err := image.Decode(bytes.NewReader(multimodalData))
  59. if err != nil {
  60. return nil, err
  61. }
  62. f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(image)
  63. if err != nil {
  64. return nil, err
  65. }
  66. pixelValues, err := ctx.Input().FromFloatSlice(f32s,
  67. m.ImageProcessor.imageSize,
  68. m.ImageProcessor.imageSize,
  69. m.ImageProcessor.numChannels,
  70. m.ImageProcessor.maxNumTiles,
  71. )
  72. if err != nil {
  73. return nil, err
  74. }
  75. aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(aspectRatioID)}, 1)
  76. if err != nil {
  77. return nil, err
  78. }
  79. positions := make([]int32, 1601)
  80. for i := range positions {
  81. positions[i] = int32(i)
  82. }
  83. positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
  84. if err != nil {
  85. return nil, err
  86. }
  87. crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
  88. return m.Projector.Forward(ctx, crossAttentionStates), nil
  89. }
  90. func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
  91. var images []input.Input
  92. fnvHash := fnv.New64a()
  93. for i := range inputs {
  94. if inputs[i].Multimodal == nil {
  95. if len(images) > 0 {
  96. inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)}
  97. inputs[i].MultimodalHash = images[0].MultimodalHash
  98. for j := 1; j < len(images); j++ {
  99. inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor))
  100. fnvHash.Reset()
  101. binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
  102. binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
  103. inputs[i].MultimodalHash = fnvHash.Sum64()
  104. }
  105. images = nil
  106. }
  107. } else {
  108. images = append(images, inputs[i])
  109. inputs[i].Token = -1
  110. }
  111. }
  112. inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
  113. return inputs, nil
  114. }
  115. func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
  116. var crossAttentionStates ml.Tensor
  117. if len(opts.Multimodal) > 0 {
  118. images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
  119. if len(images) > 0 {
  120. crossAttentionStates = images[len(images)-1]
  121. }
  122. }
  123. inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
  124. if err != nil {
  125. return nil, err
  126. }
  127. positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
  128. if err != nil {
  129. return nil, err
  130. }
  131. outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
  132. if err != nil {
  133. return nil, err
  134. }
  135. // TODO: attention mask, cross attention mask
  136. return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
  137. }
  138. func init() {
  139. model.Register("mllama", New)
  140. }