model.go 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. package gemma3
  2. import (
  3. "fmt"
  4. "github.com/ollama/ollama/kvcache"
  5. "github.com/ollama/ollama/ml"
  6. "github.com/ollama/ollama/model"
  7. "github.com/ollama/ollama/model/input"
  8. )
  9. type Model struct {
  10. model.Base
  11. model.SentencePieceModel
  12. //*VisionModel `gguf:"v,vision"`
  13. *TextModel
  14. //Projector *nn.Linear `gguf:"mm.0"`
  15. ImageProcessor
  16. }
  17. func New(c ml.Config) (model.Model, error) {
  18. // Verify unified config
  19. if c.Uint("vision.block_count") == 0 {
  20. return nil, fmt.Errorf("non-unified vision model not supported")
  21. }
  22. m := Model{
  23. SentencePieceModel: model.NewSentencePieceModel(
  24. c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
  25. &model.Vocabulary{
  26. Values: c.Strings("tokenizer.ggml.tokens"),
  27. Scores: c.Floats("tokenizer.ggml.scores"),
  28. Types: c.Uints("tokenizer.ggml.token_type"),
  29. BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
  30. AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
  31. EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
  32. AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
  33. },
  34. ),
  35. ImageProcessor: newImageProcessor(c),
  36. //VisionModel: newVisionModel(c),
  37. TextModel: newTextModel(c),
  38. }
  39. slidingWindowLen := int32(c.Uint("text.attention.sliding_window"))
  40. m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
  41. return &m, nil
  42. }
  43. func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
  44. inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
  45. if err != nil {
  46. return nil, err
  47. }
  48. positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
  49. if err != nil {
  50. return nil, err
  51. }
  52. outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
  53. if err != nil {
  54. return nil, err
  55. }
  56. return m.TextModel.Forward(ctx, inputs, positions, outputs, m.Cache), nil
  57. }
  58. func init() {
  59. model.Register("gemma3", New)
  60. }