model_vision.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. package mllama
  2. import (
  3. "math"
  4. "slices"
  5. "github.com/ollama/ollama/ml"
  6. "github.com/ollama/ollama/ml/nn"
  7. )
  8. var batchSize int64 = 1
  9. type VisionSelfAttention struct {
  10. Query *nn.Linear `gguf:"attn_q"`
  11. Key *nn.Linear `gguf:"attn_k"`
  12. Value *nn.Linear `gguf:"attn_v"`
  13. Output *nn.Linear `gguf:"attn_out"`
  14. Gate ml.Tensor `gguf:"attn_gate"`
  15. }
  16. func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
  17. headDim := opts.hiddenSize / opts.numHeads
  18. query := sa.Query.Forward(ctx, hiddenState)
  19. query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
  20. query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  21. key := sa.Key.Forward(ctx, hiddenState)
  22. key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
  23. key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  24. value := sa.Value.Forward(ctx, hiddenState)
  25. value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
  26. value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  27. scores := key.Mulmat(ctx, query)
  28. scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
  29. scores = scores.Softmax(ctx)
  30. attention := value.Mulmat(ctx, scores)
  31. attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
  32. attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  33. attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
  34. hiddenState = sa.Output.Forward(ctx, attention)
  35. if sa.Gate != nil {
  36. hiddenState = hiddenState.Mul(ctx, sa.Gate)
  37. }
  38. return hiddenState
  39. }
  40. type VisionMLP struct {
  41. Down *nn.Linear `gguf:"ffn_down"`
  42. Up *nn.Linear `gguf:"ffn_up"`
  43. Gate ml.Tensor `gguf:"ffn_gate"`
  44. }
  45. func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
  46. hiddenState = mlp.Down.Forward(ctx, hiddenState).GELU(ctx)
  47. hiddenState = mlp.Up.Forward(ctx, hiddenState)
  48. if mlp.Gate != nil {
  49. hiddenState = hiddenState.Mul(ctx, mlp.Gate)
  50. }
  51. return hiddenState
  52. }
  53. type VisionEncoderLayer struct {
  54. AttentionNorm *nn.LayerNorm `gguf:"ln1"`
  55. SelfAttention *VisionSelfAttention
  56. MLPNorm *nn.LayerNorm `gguf:"ln2"`
  57. MLP *VisionMLP
  58. }
  59. func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
  60. residual := hiddenState
  61. // self attention
  62. hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
  63. hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
  64. hiddenState = hiddenState.Add(ctx, residual)
  65. residual = hiddenState
  66. // feed forward
  67. hiddenState = e.MLPNorm.Forward(ctx, hiddenState, opts.eps)
  68. hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
  69. return hiddenState.Add(ctx, residual)
  70. }
  71. type VisionEncoder struct {
  72. Layers []VisionEncoderLayer
  73. }
  74. func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []uint32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) {
  75. var intermediateHiddenStates []ml.Tensor
  76. for i, layer := range e.Layers {
  77. if slices.Contains(intermediateLayersIndices, uint32(i)) {
  78. intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int64{1}, hiddenState.Shape()...)...))
  79. }
  80. hiddenState = layer.Forward(ctx, hiddenState, opts)
  81. }
  82. return hiddenState, intermediateHiddenStates
  83. }
  84. type PrecomputedAspectRatioEmbedding struct {
  85. Embedding *nn.Embedding
  86. Gate ml.Tensor `gguf:"gate"`
  87. }
  88. func (e *PrecomputedAspectRatioEmbedding) Forward(ctx ml.Context, hiddenState ml.Tensor, aspectRatioIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
  89. embeddings := e.Embedding.Forward(ctx, aspectRatioIDs)
  90. embeddings = embeddings.Reshape(ctx, opts.hiddenSize, 1, opts.numTiles)
  91. if e.Gate != nil {
  92. embeddings = embeddings.Mul(ctx, e.Gate)
  93. }
  94. return hiddenState.Add(ctx, embeddings)
  95. }
  96. type PrecomputedPositionEmbedding struct {
  97. PositionEmbedding *nn.Embedding `gguf:"position_embd"`
  98. PositionEmbeddingGate ml.Tensor `gguf:"position_embd.gate"`
  99. TilePositionEmbedding *nn.Embedding `gguf:"tile_position_embd"`
  100. TilePositionEmbeddingGate ml.Tensor `gguf:"tile_position_embd.gate"`
  101. }
  102. func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, positionIDs, aspectRatioIDs ml.Tensor, numPositions int64, opts *VisionModelOptions) ml.Tensor {
  103. positionEmbedding := e.PositionEmbedding.Forward(ctx, positionIDs)
  104. if e.PositionEmbeddingGate != nil {
  105. positionEmbedding = positionEmbedding.Mul(ctx, e.PositionEmbeddingGate)
  106. }
  107. hiddenState = hiddenState.Add(ctx, positionEmbedding)
  108. tilePositionEmbedding := e.TilePositionEmbedding.Forward(ctx, aspectRatioIDs)
  109. tilePositionEmbedding = tilePositionEmbedding.Reshape(ctx, opts.hiddenSize, numPositions, opts.numTiles)
  110. if e.TilePositionEmbeddingGate != nil {
  111. tilePositionEmbedding = tilePositionEmbedding.Mul(ctx, e.TilePositionEmbeddingGate)
  112. }
  113. return hiddenState.Add(ctx, tilePositionEmbedding)
  114. }
  115. type VisionModelOptions struct {
  116. hiddenSize, numHeads, numTiles int64
  117. imageSize, patchSize int
  118. eps float32
  119. intermediateLayersIndices []uint32
  120. }
  121. type VisionModel struct {
  122. PatchEmbeddings *nn.Conv2D `gguf:"patch_embd"`
  123. PreTilePositionEmbedding *PrecomputedAspectRatioEmbedding `gguf:"pre_tile_position_embd"`
  124. PostTilePositionEmbedding *PrecomputedAspectRatioEmbedding `gguf:"post_tile_position_embd"`
  125. PositionEmbedding *PrecomputedPositionEmbedding
  126. PreLayerNorm *nn.LayerNorm `gguf:"pre_ln"`
  127. PostLayerNorm *nn.LayerNorm `gguf:"post_ln"`
  128. ClassEmbedding ml.Tensor `gguf:"class_embd"`
  129. Transformer *VisionEncoder `gguf:"blk"`
  130. GlobalTransformer *VisionEncoder `gguf:"global.blk"`
  131. *VisionModelOptions
  132. }
  133. func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRatioIDs ml.Tensor) ml.Tensor {
  134. numPatches := int64((m.imageSize / m.patchSize) * (m.imageSize / m.patchSize))
  135. numPositions := numPatches
  136. if m.ClassEmbedding != nil {
  137. numPositions++
  138. }
  139. hiddenState := m.PatchEmbeddings.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
  140. hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize, m.numTiles)
  141. hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
  142. hiddenState = m.PreTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
  143. hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, int(m.numTiles)-1)...).Concat(ctx, hiddenState, 1)
  144. hiddenState = m.PositionEmbedding.Forward(ctx, hiddenState, positionIDs, aspectRatioIDs, numPositions, m.VisionModelOptions)
  145. hiddenState = m.PreLayerNorm.Forward(ctx, hiddenState, m.eps)
  146. numPaddingPatches := 8 - (hiddenState.Dim(1)%8)%8
  147. hiddenState = hiddenState.Pad(ctx, 0, numPaddingPatches, 0, 0)
  148. hiddenState = hiddenState.Reshape(ctx, hiddenState.Dim(0), hiddenState.Dim(1)*hiddenState.Dim(2), batchSize)
  149. hiddenState, intermediateHiddenStates := m.Transformer.Forward(ctx, hiddenState, m.intermediateLayersIndices, m.VisionModelOptions)
  150. hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
  151. hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
  152. hiddenState = m.PostTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
  153. hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, m.numTiles*(numPositions+numPaddingPatches), batchSize)
  154. hiddenState, _ = m.GlobalTransformer.Forward(ctx, hiddenState, nil, m.VisionModelOptions)
  155. hiddenStates := intermediateHiddenStates[0].Stack(ctx, 0, intermediateHiddenStates[1:]...)
  156. hiddenStates = hiddenStates.Reshape(ctx, int64(len(intermediateHiddenStates))*m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
  157. hiddenStates = hiddenStates.Unpad(ctx, 0, numPaddingPatches, 0, 0)
  158. hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
  159. hiddenState = hiddenState.Unpad(ctx, 0, numPaddingPatches, 0, 0)
  160. return hiddenState.Concat(ctx, hiddenStates, 0)
  161. }
  162. func newVisionModel(c ml.Config) *VisionModel {
  163. return &VisionModel{
  164. Transformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count"))},
  165. GlobalTransformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.global.block_count"))},
  166. VisionModelOptions: &VisionModelOptions{
  167. hiddenSize: int64(c.Uint("vision.embedding_length")),
  168. numHeads: int64(c.Uint("vision.attention.head_count")),
  169. numTiles: int64(c.Uint("vision.max_num_tiles")),
  170. imageSize: int(c.Uint("vision.image_size")),
  171. patchSize: int(c.Uint("vision.patch_size")),
  172. eps: c.Float("vision.attention.layer_norm_epsilon"),
  173. intermediateLayersIndices: c.Uints("vision.intermediate_layers_indices"),
  174. },
  175. }
  176. }