model.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. package mllama
  2. import (
  3. "github.com/ollama/ollama/ml"
  4. "github.com/ollama/ollama/ml/nn"
  5. "github.com/ollama/ollama/model"
  6. )
  7. type Model struct {
  8. model.Base
  9. model.BytePairEncoding
  10. *VisionModel `gguf:"v,vision"`
  11. *TextModel
  12. Projector *nn.Linear `gguf:"mm.0"`
  13. ImageProcessor
  14. }
  15. func New(c ml.Config) (model.Model, error) {
  16. return &Model{
  17. BytePairEncoding: model.NewBytePairEncoding(
  18. c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
  19. &model.Vocabulary{
  20. Values: c.Strings("tokenizer.ggml.tokens"),
  21. Types: c.Uints("tokenizer.ggml.token_type"),
  22. Merges: c.Strings("tokenizer.ggml.merges"),
  23. BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
  24. EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
  25. },
  26. ),
  27. ImageProcessor: newImageProcessor(c),
  28. VisionModel: newVisionModel(c),
  29. TextModel: newTextModel(c),
  30. }, nil
  31. }
  32. func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
  33. var crossAttentionStates ml.Tensor
  34. if opts.Images != nil {
  35. f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
  36. if err != nil {
  37. return nil, err
  38. }
  39. pixelValues, err := ctx.FromFloatSlice(f32s,
  40. m.ImageProcessor.imageSize,
  41. m.ImageProcessor.imageSize,
  42. m.ImageProcessor.numChannels,
  43. m.ImageProcessor.maxNumTiles,
  44. )
  45. if err != nil {
  46. return nil, err
  47. }
  48. aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
  49. if err != nil {
  50. return nil, err
  51. }
  52. positions := make([]int32, 1601)
  53. for i := range positions {
  54. positions[i] = int32(i)
  55. }
  56. positionIDs, err := ctx.FromIntSlice(positions, len(positions))
  57. if err != nil {
  58. return nil, err
  59. }
  60. crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
  61. crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
  62. }
  63. inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs()))
  64. if err != nil {
  65. return nil, err
  66. }
  67. positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
  68. if err != nil {
  69. return nil, err
  70. }
  71. // TODO: attention mask, cross attention mask
  72. hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, opts.Cache)
  73. outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1)
  74. if err != nil {
  75. return nil, err
  76. }
  77. return hiddenState.Rows(ctx, outputs), nil
  78. }
  79. func init() {
  80. model.Register("mllama", New)
  81. }