model.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. _ "image/jpeg"
  6. _ "image/png"
  7. "log/slog"
  8. "os"
  9. "reflect"
  10. "strconv"
  11. "strings"
  12. _ "golang.org/x/image/bmp"
  13. _ "golang.org/x/image/tiff"
  14. _ "golang.org/x/image/webp"
  15. fs "github.com/ollama/ollama/fs/ggml"
  16. "github.com/ollama/ollama/kvcache"
  17. "github.com/ollama/ollama/ml"
  18. _ "github.com/ollama/ollama/ml/backend"
  19. )
  20. // Input represents one token in the input stream
  21. type Input struct {
  22. // Token is a single element of text.
  23. Token int32
  24. // Multimodal is opaque data representing a non-text
  25. // element such as an image (or part of one if the image
  26. // can be processed in pieces). It may be either together
  27. // with Token or on its own.
  28. Multimodal any
  29. // MultimodalHash is a unique representation of the data
  30. // stored in Multimodal, used for caching and comparing
  31. // equality.
  32. MultimodalHash uint64
  33. }
  34. // MultimodalIndex is a multimodal element (such as an image)
  35. // together with an index into the slice of Inputs with the
  36. // corresponding token. Note that the index is not the same
  37. // as the position - to find that use the index with the
  38. // Positions slice.
  39. type MultimodalIndex struct {
  40. Index int
  41. Multimodal any
  42. }
  43. // Options contains the inputs for a model forward pass
  44. type Options struct {
  45. Inputs []int32
  46. Multimodal []MultimodalIndex
  47. Positions []int32
  48. Sequences []int
  49. Outputs []int32
  50. }
  51. type config struct {
  52. Cache kvcache.Cache
  53. }
  54. // Base implements the common fields and methods for all models
  55. type Base struct {
  56. b ml.Backend
  57. config
  58. }
  59. // Backend returns the underlying backend that will run the model
  60. func (m *Base) Backend() ml.Backend {
  61. return m.b
  62. }
  63. func (m *Base) Config() config {
  64. return m.config
  65. }
  66. // Model implements a specific model architecture, defining the forward pass and any model-specific configuration
  67. type Model interface {
  68. Forward(ml.Context, Options) (ml.Tensor, error)
  69. Backend() ml.Backend
  70. Config() config
  71. }
  72. // MultimodalProcessor must be implemented by multimodal models.
  73. type MultimodalProcessor interface {
  74. // EncodeMultimodal processes a single input (such as an image) and
  75. // generates an output (typically an embedding) that can be used by the model.
  76. //
  77. // The return value is most typically an ml.Tensor, however, different
  78. // type are possible, such as an object containing a tensor plus
  79. // additional metadata, a slice of tensors or even just the original input.
  80. //
  81. // The result may be cached by the runner.
  82. EncodeMultimodal(ml.Context, []byte) (any, error)
  83. // PostTokenize is called after tokenization to allow the model to edit the
  84. // input stream to correctly arrange multimodal elements.
  85. //
  86. // The input is a slice of tokens with the results of EncodeMultimodal interleaved
  87. // in the order that the user provided them. Each element of the slice will be
  88. // either a single token or single multimodal object.
  89. //
  90. // The model must ensure that inputs are stored according to how they will be
  91. // processed and stored in the cache. For example, Llava-style models should insert
  92. // placeholder tokens equal to the feature size of the corresponding image with
  93. // the image itself attached to and split across these tokens. When Forward is called
  94. // a partial subset of these tokens may be submitted according to the batch size.
  95. //
  96. // This function is also responsible for updating MultimodalHash for any Multimodal
  97. // that is modified to ensure that there is a unique hash value that accurately
  98. // represents the contents.
  99. PostTokenize(ml.Context, []Input) ([]Input, error)
  100. }
  101. var models = make(map[string]func(ml.Config) (Model, error))
  102. // Register registers a model constructor for the given architecture
  103. func Register(name string, f func(ml.Config) (Model, error)) {
  104. if _, ok := models[name]; ok {
  105. panic("model: model already registered")
  106. }
  107. models[name] = f
  108. }
  109. // New initializes a new model instance with the provided configuration based on the metadata in the model file
  110. func New(modelPath string, params ml.BackendParams) (Model, error) {
  111. r, err := os.Open(modelPath)
  112. if err != nil {
  113. return nil, err
  114. }
  115. defer r.Close()
  116. b, err := ml.NewBackend(r, params)
  117. if err != nil {
  118. return nil, err
  119. }
  120. arch := b.Config().Architecture()
  121. f, ok := models[arch]
  122. if !ok {
  123. return nil, fmt.Errorf("unsupported model architecture %q", arch)
  124. }
  125. m, err := f(b.Config())
  126. if err != nil {
  127. return nil, err
  128. }
  129. base := Base{b: b, config: m.Config()}
  130. v := reflect.ValueOf(m)
  131. v.Elem().Set(populateFields(base, v.Elem()))
  132. return m, nil
  133. }
  134. func NewTextProcessor(s string) (TextProcessor, error) {
  135. r, err := os.Open(s)
  136. if err != nil {
  137. return nil, err
  138. }
  139. defer r.Close()
  140. meta, _, err := fs.Decode(r, -1)
  141. if err != nil {
  142. return nil, err
  143. }
  144. return getTextProcessor(meta.KV())
  145. }
  146. func getTextProcessor(kv fs.KV) (TextProcessor, error) {
  147. arch := kv.Architecture()
  148. f, ok := models[arch]
  149. if !ok {
  150. return nil, fmt.Errorf("unsupported model architecture %q", arch)
  151. }
  152. m, err := f(kv)
  153. if err != nil {
  154. return nil, err
  155. }
  156. tp, ok := m.(TextProcessor)
  157. if !ok {
  158. return nil, fmt.Errorf("%v is not a TextProcessor", m)
  159. }
  160. return tp, nil
  161. }
  162. func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
  163. t := v.Type()
  164. if t.Kind() == reflect.Struct {
  165. allNil := true
  166. for i := range t.NumField() {
  167. tt := t.Field(i).Type
  168. vv := v.Field(i)
  169. if !vv.CanSet() {
  170. continue
  171. }
  172. // make a copy
  173. tagsCopy := tags
  174. if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
  175. tagsCopy = append(tagsCopy, ParseTags(tag))
  176. }
  177. if tt == reflect.TypeOf((*Base)(nil)).Elem() {
  178. vv.Set(reflect.ValueOf(base))
  179. } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
  180. var fn func([]Tag) [][]string
  181. fn = func(tags []Tag) (values [][]string) {
  182. if len(tags) < 1 {
  183. return nil
  184. }
  185. values = [][]string{{tags[0].Name}}
  186. for _, alt := range tags[0].Alternate {
  187. values = append(values, []string{alt})
  188. }
  189. for i, value := range values {
  190. for _, rest := range fn(tags[1:]) {
  191. value = append(value, rest...)
  192. }
  193. values[i] = value
  194. }
  195. return values
  196. }
  197. names := fn(tagsCopy)
  198. for _, name := range names {
  199. if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
  200. slog.Debug("found tensor", "", tensor)
  201. vv.Set(reflect.ValueOf(tensor))
  202. break
  203. }
  204. }
  205. } else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
  206. setPointer(base, vv, tagsCopy)
  207. } else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
  208. for i := range vv.Len() {
  209. vvv := vv.Index(i)
  210. if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
  211. setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
  212. } else {
  213. vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
  214. }
  215. }
  216. }
  217. if !canNil(tt) || !vv.IsNil() {
  218. allNil = false
  219. }
  220. }
  221. if allNil {
  222. return reflect.Zero(t)
  223. }
  224. }
  225. return v
  226. }
  227. func setPointer(base Base, v reflect.Value, tags []Tag) {
  228. vv := v
  229. if v.Kind() == reflect.Interface {
  230. if v.IsNil() {
  231. return
  232. }
  233. vv = vv.Elem()
  234. }
  235. vv = vv.Elem()
  236. if v.IsNil() {
  237. vv = reflect.New(v.Type().Elem()).Elem()
  238. }
  239. if f := populateFields(base, vv, tags...); f.CanAddr() {
  240. v.Set(f.Addr())
  241. }
  242. }
  243. type Tag struct {
  244. Name string
  245. Alternate []string
  246. }
  247. func ParseTags(s string) (tag Tag) {
  248. parts := strings.Split(s, ",")
  249. if len(parts) > 0 {
  250. tag.Name = parts[0]
  251. for _, part := range parts[1:] {
  252. if value, ok := strings.CutPrefix(part, "alt:"); ok {
  253. tag.Alternate = append(tag.Alternate, value)
  254. }
  255. }
  256. }
  257. return
  258. }
  259. func canNil(t reflect.Type) bool {
  260. return t.Kind() == reflect.Chan ||
  261. t.Kind() == reflect.Func ||
  262. t.Kind() == reflect.Interface ||
  263. t.Kind() == reflect.Map ||
  264. t.Kind() == reflect.Pointer ||
  265. t.Kind() == reflect.Slice
  266. }
  267. func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
  268. if len(opts.Positions) != len(opts.Sequences) {
  269. return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
  270. }
  271. if len(opts.Positions) < 1 {
  272. return nil, errors.New("batch size cannot be less than 1")
  273. }
  274. cache := m.Config().Cache
  275. if cache != nil {
  276. err := cache.StartForward(ctx, opts.Positions, opts.Sequences)
  277. if err != nil {
  278. return nil, err
  279. }
  280. }
  281. t, err := m.Forward(ctx, opts)
  282. if err != nil {
  283. return nil, err
  284. }
  285. ctx.Forward(t).Compute(t)
  286. return t, nil
  287. }