model.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. package gemma3
  2. import (
  3. "bytes"
  4. "image"
  5. "math"
  6. "slices"
  7. "github.com/ollama/ollama/kvcache"
  8. "github.com/ollama/ollama/ml"
  9. "github.com/ollama/ollama/ml/nn"
  10. "github.com/ollama/ollama/model"
  11. "github.com/ollama/ollama/model/input"
  12. )
  13. type Model struct {
  14. model.Base
  15. model.SentencePieceModel
  16. *VisionModel `gguf:"v,vision"`
  17. *TextModel
  18. *MultiModalProjector `gguf:"mm"`
  19. ImageProcessor
  20. }
  21. var _ model.MultimodalProcessor = (*Model)(nil)
  22. type MultiModalProjector struct {
  23. SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
  24. InputProjection *nn.Linear `gguf:"mm_input_projection"`
  25. tokensPerImage int
  26. }
  27. func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
  28. l := visionOutputs.Dim(0)
  29. visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
  30. patchesPerImage := imageSize / patchSize
  31. visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
  32. kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
  33. visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
  34. visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
  35. visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
  36. visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
  37. // TODO: inputProjection must be transposed since they're incompatible with visionOutputs
  38. visionOutputs = p.InputProjection.Weight.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mulmat(ctx, visionOutputs)
  39. return visionOutputs
  40. }
  41. func New(c ml.Config) (model.Model, error) {
  42. m := Model{
  43. SentencePieceModel: model.NewSentencePieceModel(
  44. c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
  45. &model.Vocabulary{
  46. Values: c.Strings("tokenizer.ggml.tokens"),
  47. Scores: c.Floats("tokenizer.ggml.scores"),
  48. Types: c.Uints("tokenizer.ggml.token_type"),
  49. BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
  50. AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
  51. EOS: int32(1),
  52. AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
  53. EOT: int32(106),
  54. AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false),
  55. },
  56. ),
  57. ImageProcessor: newImageProcessor(c),
  58. VisionModel: newVisionModel(c),
  59. TextModel: newTextModel(c),
  60. MultiModalProjector: &MultiModalProjector{
  61. tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
  62. },
  63. }
  64. slidingWindowLen := int32(c.Uint("attention.sliding_window"))
  65. m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
  66. return &m, nil
  67. }
  68. func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
  69. if len(m.VisionModel.Layers) == 0 {
  70. return nil, model.ErrNoVisionModel
  71. }
  72. image, _, err := image.Decode(bytes.NewReader(multimodalData))
  73. if err != nil {
  74. return nil, err
  75. }
  76. f32s, err := m.ImageProcessor.ProcessImage(image)
  77. if err != nil {
  78. return nil, err
  79. }
  80. pixelValues, err := ctx.Input().FromFloatSlice(f32s,
  81. m.ImageProcessor.imageSize,
  82. m.ImageProcessor.imageSize,
  83. m.ImageProcessor.numChannels,
  84. )
  85. if err != nil {
  86. return nil, err
  87. }
  88. visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
  89. visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
  90. return visionOutputs, nil
  91. }
  92. func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
  93. var result []input.Input
  94. for _, inp := range inputs {
  95. if inp.Multimodal == nil {
  96. result = append(result, inp)
  97. } else {
  98. inputMultimodal := inp.Multimodal.(ml.Tensor)
  99. result = append(result,
  100. input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
  101. input.Input{Token: 255999}, // "<start_of_image>""
  102. input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
  103. )
  104. // add image token placeholders
  105. result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
  106. result = append(result,
  107. input.Input{Token: 256000}, // <end_of_image>
  108. input.Input{Token: 108}, // "\n\n"
  109. )
  110. }
  111. }
  112. return result, nil
  113. }
  114. func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
  115. inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
  116. if err != nil {
  117. return nil, err
  118. }
  119. positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
  120. if err != nil {
  121. return nil, err
  122. }
  123. outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
  124. if err != nil {
  125. return nil, err
  126. }
  127. return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
  128. }
  129. func init() {
  130. model.Register("gemma3", New)
  131. }