model.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. package gemma3
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "hash/fnv"
  6. "image"
  7. "slices"
  8. "github.com/ollama/ollama/kvcache"
  9. "github.com/ollama/ollama/ml"
  10. "github.com/ollama/ollama/ml/nn"
  11. "github.com/ollama/ollama/model"
  12. "github.com/ollama/ollama/model/input"
  13. )
  14. type Model struct {
  15. model.Base
  16. model.SentencePieceModel
  17. *VisionModel `gguf:"v,vision"`
  18. *TextModel
  19. *MultiModalProjector `gguf:"mm"`
  20. ImageProcessor
  21. }
  22. var _ model.MultimodalProcessor = (*Model)(nil)
  23. type MultiModalProjector struct {
  24. SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
  25. InputProjection *nn.Linear `gguf:"mm_input_projection"`
  26. }
  27. func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
  28. visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
  29. // TODO: inputProjection must be transposed since they're incompatible with visionOutputs
  30. visionOutputs = p.InputProjection.Weight.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mulmat(ctx, visionOutputs)
  31. return visionOutputs
  32. }
  33. func New(c ml.Config) (model.Model, error) {
  34. m := Model{
  35. SentencePieceModel: model.NewSentencePieceModel(
  36. c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
  37. &model.Vocabulary{
  38. Values: c.Strings("tokenizer.ggml.tokens"),
  39. Scores: c.Floats("tokenizer.ggml.scores"),
  40. Types: c.Uints("tokenizer.ggml.token_type"),
  41. BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
  42. AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
  43. EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
  44. AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
  45. },
  46. ),
  47. ImageProcessor: newImageProcessor(c),
  48. VisionModel: newVisionModel(c),
  49. TextModel: newTextModel(c),
  50. }
  51. slidingWindowLen := int32(c.Uint("text.attention.sliding_window"))
  52. m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
  53. return &m, nil
  54. }
  55. func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
  56. image, _, err := image.Decode(bytes.NewReader(multimodalData))
  57. if err != nil {
  58. return nil, err
  59. }
  60. f32s, err := m.ImageProcessor.ProcessImage(image)
  61. if err != nil {
  62. return nil, err
  63. }
  64. pixelValues, err := ctx.Input().FromFloatSlice(f32s,
  65. m.ImageProcessor.imageSize,
  66. m.ImageProcessor.imageSize,
  67. m.ImageProcessor.numChannels,
  68. )
  69. if err != nil {
  70. return nil, err
  71. }
  72. positionIDs, err := ctx.FromIntSlice([]int32{0}, 1)
  73. if err != nil {
  74. return nil, err
  75. }
  76. visionOutputs := m.VisionModel.Forward(ctx, pixelValues, positionIDs)
  77. visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
  78. patchesPerImage := m.ImageProcessor.imageSize / m.ImageProcessor.patchSize
  79. kernelSize := patchesPerImage * patchesPerImage / 256
  80. visionOutputs = visionOutputs.AvgPool1D(ctx, kernelSize, kernelSize, 0)
  81. visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
  82. visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps)
  83. return visionOutputs, nil
  84. }
  85. func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
  86. var images []input.Input
  87. fnvHash := fnv.New64a()
  88. for i := range inputs {
  89. if inputs[i].Multimodal == nil {
  90. if len(images) > 0 {
  91. inputs[i].Multimodal = images[0].Multimodal
  92. inputs[i].MultimodalHash = images[0].MultimodalHash
  93. for j := 1; j < len(images); j++ {
  94. inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
  95. fnvHash.Reset()
  96. binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
  97. binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
  98. inputs[i].MultimodalHash = fnvHash.Sum64()
  99. }
  100. images = nil
  101. }
  102. } else {
  103. images = append(images, inputs[i])
  104. inputs[i].Token = -1
  105. }
  106. }
  107. inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
  108. return inputs, nil
  109. }
  110. func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
  111. var embeddings ml.Tensor
  112. if opts.Multimodal != nil {
  113. embeddings = opts.Multimodal[0].Multimodal.(ml.Tensor)
  114. }
  115. inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
  116. if err != nil {
  117. return nil, err
  118. }
  119. positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
  120. if err != nil {
  121. return nil, err
  122. }
  123. outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
  124. if err != nil {
  125. return nil, err
  126. }
  127. return m.TextModel.Forward(ctx, inputs, positions, embeddings, outputs, m.Cache), nil
  128. }
  129. func init() {
  130. model.Register("gemma3", New)
  131. }