model.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. package model
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. _ "image/jpeg"
  7. _ "image/png"
  8. "log/slog"
  9. "os"
  10. "reflect"
  11. "strconv"
  12. "strings"
  13. _ "golang.org/x/image/bmp"
  14. _ "golang.org/x/image/tiff"
  15. _ "golang.org/x/image/webp"
  16. fs "github.com/ollama/ollama/fs/ggml"
  17. "github.com/ollama/ollama/kvcache"
  18. "github.com/ollama/ollama/ml"
  19. _ "github.com/ollama/ollama/ml/backend"
  20. "github.com/ollama/ollama/model/input"
  21. )
  22. var ErrNoVisionModel = errors.New("this model is missing data required for image input")
  23. // Model implements a specific model architecture, defining the forward pass and any model-specific configuration
  24. type Model interface {
  25. Forward(ml.Context, input.Batch) (ml.Tensor, error)
  26. Backend() ml.Backend
  27. Config() config
  28. }
  29. // MultimodalProcessor must be implemented by multimodal models.
  30. type MultimodalProcessor interface {
  31. // EncodeMultimodal processes a single input (such as an image) and
  32. // generates an output (typically an embedding) that can be used by the model.
  33. //
  34. // The return value is most typically an ml.Tensor, however, different
  35. // type are possible, such as an object containing a tensor plus
  36. // additional metadata, a slice of tensors or even just the original input.
  37. //
  38. // The result may be cached by the runner.
  39. EncodeMultimodal(ml.Context, []byte) (any, error)
  40. // PostTokenize is called after tokenization to allow the model to edit the
  41. // input stream to correctly arrange multimodal elements.
  42. //
  43. // The input is a slice of tokens with the results of EncodeMultimodal interleaved
  44. // in the order that the user provided them. Each element of the slice will be
  45. // either a single token or single multimodal object.
  46. //
  47. // The model must ensure that inputs are stored according to how they will be
  48. // processed and stored in the cache. For example, Llava-style models should insert
  49. // placeholder tokens equal to the feature size of the corresponding image with
  50. // the image itself attached to and split across these tokens. When Forward is called
  51. // a partial subset of these tokens may be submitted according to the batch size.
  52. //
  53. // This function is also responsible for updating MultimodalHash for any Multimodal
  54. // that is modified to ensure that there is a unique hash value that accurately
  55. // represents the contents.
  56. PostTokenize([]input.Input) ([]input.Input, error)
  57. }
  58. // Base implements the common fields and methods for all models
  59. type Base struct {
  60. b ml.Backend
  61. config
  62. }
  63. type config struct {
  64. Cache kvcache.Cache
  65. }
  66. // Backend returns the underlying backend that will run the model
  67. func (m *Base) Backend() ml.Backend {
  68. return m.b
  69. }
  70. func (m *Base) Config() config {
  71. return m.config
  72. }
  73. var models = make(map[string]func(ml.Config) (Model, error))
  74. // Register registers a model constructor for the given architecture
  75. func Register(name string, f func(ml.Config) (Model, error)) {
  76. if _, ok := models[name]; ok {
  77. panic("model: model already registered")
  78. }
  79. models[name] = f
  80. }
  81. // New initializes a new model instance with the provided configuration based on the metadata in the model file
  82. func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
  83. r, err := os.Open(modelPath)
  84. if err != nil {
  85. return nil, err
  86. }
  87. defer r.Close()
  88. b, err := ml.NewBackend(ctx, r, params)
  89. if err != nil {
  90. return nil, err
  91. }
  92. arch := b.Config().Architecture()
  93. f, ok := models[arch]
  94. if !ok {
  95. return nil, fmt.Errorf("unsupported model architecture %q", arch)
  96. }
  97. m, err := f(b.Config())
  98. if err != nil {
  99. return nil, err
  100. }
  101. base := Base{b: b, config: m.Config()}
  102. v := reflect.ValueOf(m)
  103. v.Elem().Set(populateFields(base, v.Elem()))
  104. return m, nil
  105. }
  106. func NewTextProcessor(s string) (TextProcessor, error) {
  107. r, err := os.Open(s)
  108. if err != nil {
  109. return nil, err
  110. }
  111. defer r.Close()
  112. meta, _, err := fs.Decode(r, -1)
  113. if err != nil {
  114. return nil, err
  115. }
  116. return getTextProcessor(meta.KV())
  117. }
  118. func getTextProcessor(kv fs.KV) (TextProcessor, error) {
  119. arch := kv.Architecture()
  120. f, ok := models[arch]
  121. if !ok {
  122. return nil, fmt.Errorf("unsupported model architecture %q", arch)
  123. }
  124. m, err := f(kv)
  125. if err != nil {
  126. return nil, err
  127. }
  128. tp, ok := m.(TextProcessor)
  129. if !ok {
  130. return nil, fmt.Errorf("%v is not a TextProcessor", m)
  131. }
  132. return tp, nil
  133. }
  134. func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
  135. t := v.Type()
  136. if t.Kind() == reflect.Struct {
  137. allNil := true
  138. for i := range t.NumField() {
  139. tt := t.Field(i).Type
  140. vv := v.Field(i)
  141. if !vv.CanSet() {
  142. continue
  143. }
  144. // make a copy
  145. tagsCopy := tags
  146. if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
  147. tagsCopy = append(tagsCopy, ParseTags(tag))
  148. }
  149. if tt == reflect.TypeOf((*Base)(nil)).Elem() {
  150. vv.Set(reflect.ValueOf(base))
  151. } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
  152. var fn func([]Tag) [][]string
  153. fn = func(tags []Tag) (values [][]string) {
  154. if len(tags) < 1 {
  155. return nil
  156. }
  157. values = [][]string{{tags[0].Name}}
  158. for _, alt := range tags[0].Alternate {
  159. values = append(values, []string{alt})
  160. }
  161. for i, value := range values {
  162. for _, rest := range fn(tags[1:]) {
  163. value = append(value, rest...)
  164. }
  165. values[i] = value
  166. }
  167. return values
  168. }
  169. names := fn(tagsCopy)
  170. for _, name := range names {
  171. if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
  172. slog.Debug("found tensor", "", tensor)
  173. vv.Set(reflect.ValueOf(tensor))
  174. break
  175. }
  176. }
  177. } else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
  178. setPointer(base, vv, tagsCopy)
  179. } else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
  180. for i := range vv.Len() {
  181. vvv := vv.Index(i)
  182. if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
  183. setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
  184. } else {
  185. vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
  186. }
  187. }
  188. }
  189. if !canNil(tt) || !vv.IsNil() {
  190. allNil = false
  191. }
  192. }
  193. if allNil {
  194. return reflect.Zero(t)
  195. }
  196. }
  197. return v
  198. }
  199. func setPointer(base Base, v reflect.Value, tags []Tag) {
  200. vv := v
  201. if v.Kind() == reflect.Interface {
  202. if v.IsNil() {
  203. return
  204. }
  205. vv = vv.Elem()
  206. }
  207. vv = vv.Elem()
  208. if v.IsNil() {
  209. vv = reflect.New(v.Type().Elem()).Elem()
  210. }
  211. if f := populateFields(base, vv, tags...); f.CanAddr() {
  212. v.Set(f.Addr())
  213. }
  214. }
  215. type Tag struct {
  216. Name string
  217. Alternate []string
  218. }
  219. func ParseTags(s string) (tag Tag) {
  220. parts := strings.Split(s, ",")
  221. if len(parts) > 0 {
  222. tag.Name = parts[0]
  223. for _, part := range parts[1:] {
  224. if value, ok := strings.CutPrefix(part, "alt:"); ok {
  225. tag.Alternate = append(tag.Alternate, value)
  226. }
  227. }
  228. }
  229. return
  230. }
  231. func canNil(t reflect.Type) bool {
  232. return t.Kind() == reflect.Chan ||
  233. t.Kind() == reflect.Func ||
  234. t.Kind() == reflect.Interface ||
  235. t.Kind() == reflect.Map ||
  236. t.Kind() == reflect.Pointer ||
  237. t.Kind() == reflect.Slice
  238. }
  239. func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
  240. if len(batch.Positions) != len(batch.Sequences) {
  241. return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
  242. }
  243. if len(batch.Positions) < 1 {
  244. return nil, errors.New("batch size cannot be less than 1")
  245. }
  246. var err error
  247. batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
  248. if err != nil {
  249. return nil, err
  250. }
  251. cache := m.Config().Cache
  252. if cache != nil {
  253. err := cache.StartForward(ctx, batch)
  254. if err != nil {
  255. return nil, err
  256. }
  257. }
  258. t, err := m.Forward(ctx, batch)
  259. if err != nil {
  260. return nil, err
  261. }
  262. ctx.Forward(t).Compute(t)
  263. return t, nil
  264. }