model.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. package gemma3
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "hash/fnv"
  6. "image"
  7. "github.com/ollama/ollama/kvcache"
  8. "github.com/ollama/ollama/ml"
  9. "github.com/ollama/ollama/ml/nn"
  10. "github.com/ollama/ollama/model"
  11. "github.com/ollama/ollama/model/input"
  12. )
  13. type Model struct {
  14. model.Base
  15. model.SentencePieceModel
  16. *VisionModel `gguf:"v,vision"`
  17. *TextModel
  18. *MultiModalProjector `gguf:"mm"`
  19. ImageProcessor
  20. }
  21. var _ model.MultimodalProcessor = (*Model)(nil)
  22. type MultiModalProjector struct {
  23. SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
  24. InputProjection *nn.Linear `gguf:"mm_input_projection"`
  25. }
  26. func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
  27. visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
  28. // TODO: inputProjection must be transposed since they're incompatible with visionOutputs
  29. visionOutputs = p.InputProjection.Weight.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mulmat(ctx, visionOutputs)
  30. return visionOutputs
  31. }
  32. func New(c ml.Config) (model.Model, error) {
  33. m := Model{
  34. SentencePieceModel: model.NewSentencePieceModel(
  35. c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
  36. &model.Vocabulary{
  37. Values: c.Strings("tokenizer.ggml.tokens"),
  38. Scores: c.Floats("tokenizer.ggml.scores"),
  39. Types: c.Uints("tokenizer.ggml.token_type"),
  40. BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
  41. AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
  42. EOS: int32(1),
  43. AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
  44. EOT: int32(106),
  45. AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false),
  46. },
  47. ),
  48. ImageProcessor: newImageProcessor(c),
  49. VisionModel: newVisionModel(c),
  50. TextModel: newTextModel(c),
  51. }
  52. slidingWindowLen := int32(c.Uint("attention.sliding_window"))
  53. m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
  54. return &m, nil
  55. }
  56. func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
  57. image, _, err := image.Decode(bytes.NewReader(multimodalData))
  58. if err != nil {
  59. return nil, err
  60. }
  61. f32s, err := m.ImageProcessor.ProcessImage(image)
  62. if err != nil {
  63. return nil, err
  64. }
  65. pixelValues, err := ctx.Input().FromFloatSlice(f32s,
  66. m.ImageProcessor.imageSize,
  67. m.ImageProcessor.imageSize,
  68. m.ImageProcessor.numChannels,
  69. )
  70. if err != nil {
  71. return nil, err
  72. }
  73. visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
  74. visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
  75. patchesPerImage := m.ImageProcessor.imageSize / m.ImageProcessor.patchSize
  76. kernelSize := patchesPerImage * patchesPerImage / 256
  77. visionOutputs = visionOutputs.AvgPool1D(ctx, kernelSize, kernelSize, 0)
  78. visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
  79. visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps)
  80. return visionOutputs, nil
  81. }
  82. type imageToken struct {
  83. embedding ml.Tensor
  84. index int
  85. }
  86. func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
  87. var result []input.Input
  88. fnvHash := fnv.New64a()
  89. for _, inp := range inputs {
  90. if inp.Multimodal == nil {
  91. result = append(result, inp)
  92. } else {
  93. imageInputs := []input.Input{
  94. {Token: 108}, // "\n\n"
  95. {Token: 255999}, // "<start_of_image>""
  96. }
  97. result = append(result, imageInputs...)
  98. // add image embeddings
  99. inputMultimodal := inp.Multimodal.(ml.Tensor)
  100. for i := range inputMultimodal.Dim(1) {
  101. fnvHash.Reset()
  102. binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
  103. fnvHash.Write([]byte{byte(i)})
  104. imageToken := imageToken{embedding: inputMultimodal, index: i}
  105. result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
  106. }
  107. // <end_of_image>
  108. result = append(result, input.Input{Token: 256000})
  109. }
  110. }
  111. return result, nil
  112. }
  113. func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
  114. inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
  115. if err != nil {
  116. return nil, err
  117. }
  118. positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
  119. if err != nil {
  120. return nil, err
  121. }
  122. outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
  123. if err != nil {
  124. return nil, err
  125. }
  126. return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
  127. }
  128. func init() {
  129. model.Register("gemma3", New)
  130. }