model_vision.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. package gemma3
  2. import (
  3. "math"
  4. "slices"
  5. "github.com/ollama/ollama/ml"
  6. "github.com/ollama/ollama/ml/nn"
  7. )
  8. var batchSize int = 1
  9. type VisionSelfAttention struct {
  10. Query *nn.Linear `gguf:"attn_q"`
  11. Key *nn.Linear `gguf:"attn_k"`
  12. Value *nn.Linear `gguf:"attn_v"`
  13. Output *nn.Linear `gguf:"attn_output"`
  14. }
  15. func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
  16. headDim := opts.hiddenSize / opts.numHeads
  17. query := sa.Query.Forward(ctx, hiddenState)
  18. key := sa.Key.Forward(ctx, hiddenState)
  19. value := sa.Value.Forward(ctx, hiddenState)
  20. query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3)
  21. key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3)
  22. value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize).Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  23. scores := key.Mulmat(ctx, query)
  24. scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
  25. scores = scores.Softmax(ctx)
  26. attention := value.Mulmat(ctx, scores)
  27. attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
  28. attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  29. attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
  30. hiddenState = sa.Output.Forward(ctx, attention)
  31. return hiddenState
  32. }
  33. type VisionMLP struct {
  34. FC1 *nn.Linear `gguf:"fc1"`
  35. FC2 *nn.Linear `gguf:"fc2"`
  36. }
  37. func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
  38. hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
  39. hiddenState = mlp.FC2.Forward(ctx, hiddenState)
  40. return hiddenState
  41. }
  42. type VisionEncoderLayer struct {
  43. LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
  44. SelfAttention *VisionSelfAttention
  45. LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
  46. MLP *VisionMLP `gguf:"mlp"`
  47. }
  48. func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
  49. residual := hiddenState
  50. // self attention
  51. hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
  52. hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
  53. hiddenState = hiddenState.Add(ctx, residual)
  54. residual = hiddenState
  55. // feed forward
  56. hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
  57. hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
  58. return hiddenState.Add(ctx, residual)
  59. }
  60. type VisionEncoder struct {
  61. Layers []VisionEncoderLayer
  62. }
  63. func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []uint32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) {
  64. var intermediateHiddenStates []ml.Tensor
  65. for i, layer := range e.Layers {
  66. if slices.Contains(intermediateLayersIndices, uint32(i)) {
  67. intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...))
  68. }
  69. hiddenState = layer.Forward(ctx, hiddenState, opts)
  70. }
  71. return hiddenState, intermediateHiddenStates
  72. }
  73. type PrecomputedAspectRatioEmbedding struct {
  74. Embedding *nn.Embedding
  75. Gate ml.Tensor `gguf:"gate"`
  76. }
  77. func (e *PrecomputedAspectRatioEmbedding) Forward(ctx ml.Context, hiddenState ml.Tensor, aspectRatioIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
  78. embeddings := e.Embedding.Forward(ctx, aspectRatioIDs)
  79. embeddings = embeddings.Reshape(ctx, opts.hiddenSize, 1, opts.numTiles)
  80. if e.Gate != nil {
  81. embeddings = embeddings.Mul(ctx, e.Gate)
  82. }
  83. return hiddenState.Add(ctx, embeddings)
  84. }
  85. type PrecomputedPositionEmbedding struct {
  86. PositionEmbedding *nn.Embedding `gguf:"position_embd"`
  87. PositionEmbeddingGate ml.Tensor `gguf:"position_embd.gate"`
  88. }
  89. func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, numPositions int, opts *VisionModelOptions) ml.Tensor {
  90. positionEmbedding := e.PositionEmbedding.Forward(ctx, positionIDs)
  91. if e.PositionEmbeddingGate != nil {
  92. positionEmbedding = positionEmbedding.Mul(ctx, e.PositionEmbeddingGate)
  93. }
  94. return hiddenState.Add(ctx, positionEmbedding)
  95. }
  96. type VisionModelOptions struct {
  97. hiddenSize, numHeads, numTiles int
  98. imageSize, patchSize int
  99. eps float32
  100. }
  101. type VisionModel struct {
  102. PatchEmbedding *nn.Conv2D `gguf:"patch_embedding"`
  103. PositionEmbedding *nn.Embedding `gguf:"position_embedding"`
  104. PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
  105. Encoder *VisionEncoder `gguf:"blk"`
  106. *VisionModelOptions
  107. }
  108. func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs ml.Tensor) ml.Tensor {
  109. numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
  110. hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
  111. hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
  112. hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
  113. positions := m.PositionEmbedding.Forward(ctx, positionIDs)
  114. hiddenState = hiddenState.Add(ctx, positions)
  115. for _, layer := range m.Encoder.Layers {
  116. hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
  117. }
  118. hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
  119. return hiddenState
  120. }
  121. func newVisionModel(c ml.Config) *VisionModel {
  122. return &VisionModel{
  123. Encoder: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count"))},
  124. VisionModelOptions: &VisionModelOptions{
  125. hiddenSize: int(c.Uint("vision.embedding_length")),
  126. numHeads: int(c.Uint("vision.attention.head_count")),
  127. imageSize: int(c.Uint("vision.image_size")),
  128. patchSize: int(c.Uint("vision.patch_size")),
  129. eps: c.Float("vision.attention.layer_norm_epsilon"),
  130. },
  131. }
  132. }