model_vision.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. package gemma3
  2. import (
  3. "math"
  4. "github.com/ollama/ollama/ml"
  5. "github.com/ollama/ollama/ml/nn"
  6. )
  7. var batchSize int = 1
  8. type VisionSelfAttention struct {
  9. Query *nn.Linear `gguf:"attn_q"`
  10. Key *nn.Linear `gguf:"attn_k"`
  11. Value *nn.Linear `gguf:"attn_v"`
  12. Output *nn.Linear `gguf:"attn_output"`
  13. }
  14. func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
  15. headDim := opts.hiddenSize / opts.numHeads
  16. query := sa.Query.Forward(ctx, hiddenState)
  17. key := sa.Key.Forward(ctx, hiddenState)
  18. value := sa.Value.Forward(ctx, hiddenState)
  19. query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
  20. key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
  21. value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
  22. attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
  23. attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
  24. hiddenState = sa.Output.Forward(ctx, attention)
  25. return hiddenState
  26. }
  27. type VisionMLP struct {
  28. FC1 *nn.Linear `gguf:"fc1"`
  29. FC2 *nn.Linear `gguf:"fc2"`
  30. }
  31. func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
  32. hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
  33. hiddenState = mlp.FC2.Forward(ctx, hiddenState)
  34. return hiddenState
  35. }
  36. type VisionEncoderLayer struct {
  37. LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
  38. SelfAttention *VisionSelfAttention
  39. LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
  40. MLP *VisionMLP `gguf:"mlp"`
  41. }
  42. func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
  43. residual := hiddenState
  44. // self attention
  45. hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
  46. hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
  47. hiddenState = hiddenState.Add(ctx, residual)
  48. residual = hiddenState
  49. // feed forward
  50. hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
  51. hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
  52. return hiddenState.Add(ctx, residual)
  53. }
  54. type VisionModelOptions struct {
  55. hiddenSize, numHeads int
  56. imageSize, patchSize int
  57. eps float32
  58. }
  59. type VisionModel struct {
  60. PatchEmbedding *nn.Conv2D `gguf:"patch_embedding"`
  61. PositionEmbedding *nn.Embedding `gguf:"position_embedding"`
  62. PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
  63. Layers []VisionEncoderLayer `gguf:"blk"`
  64. *VisionModelOptions
  65. }
  66. func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
  67. numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
  68. hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
  69. hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
  70. hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
  71. positions := make([]int32, numPatches)
  72. for i := range positions {
  73. positions[i] = int32(i)
  74. }
  75. positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
  76. if err != nil {
  77. panic(err)
  78. }
  79. hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
  80. for _, layer := range m.Layers {
  81. hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
  82. }
  83. hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
  84. return hiddenState
  85. }
  86. func newVisionModel(c ml.Config) *VisionModel {
  87. return &VisionModel{
  88. Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
  89. VisionModelOptions: &VisionModelOptions{
  90. hiddenSize: int(c.Uint("vision.embedding_length")),
  91. numHeads: int(c.Uint("vision.attention.head_count")),
  92. imageSize: int(c.Uint("vision.image_size")),
  93. patchSize: int(c.Uint("vision.patch_size")),
  94. eps: c.Float("vision.attention.layer_norm_epsilon"),
  95. },
  96. }
  97. }