model.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. package model
  2. import (
  3. "fmt"
  4. "image"
  5. _ "image/jpeg"
  6. _ "image/png"
  7. "log/slog"
  8. "os"
  9. "reflect"
  10. "strconv"
  11. "strings"
  12. _ "golang.org/x/image/bmp"
  13. _ "golang.org/x/image/tiff"
  14. _ "golang.org/x/image/webp"
  15. "github.com/ollama/ollama/cache"
  16. "github.com/ollama/ollama/ml"
  17. _ "github.com/ollama/ollama/ml/backend"
  18. )
  19. type Cache struct {
  20. cache.Cache
  21. cache.Options
  22. }
  23. func (c Cache) Sub(i int) Cache {
  24. if c.Cache != nil {
  25. return Cache{
  26. Cache: c.Cache.Sub(i),
  27. Options: c.Options,
  28. }
  29. }
  30. return c
  31. }
  32. func (c Cache) Put(ctx ml.Context, key, value ml.Tensor, opts cache.Options) (ml.Tensor, ml.Tensor) {
  33. if c.Cache != nil {
  34. return c.Cache.Put(ctx, key, value, opts)
  35. }
  36. return key, value
  37. }
  38. type Options struct {
  39. inputs []int32
  40. Offset int
  41. Images []image.Image
  42. Cache
  43. }
  44. func (opts Options) Inputs() []int32 {
  45. return opts.inputs[opts.Offset:]
  46. }
  47. func (opts Options) Positions() []int32 {
  48. positions := make([]int32, len(opts.inputs)-opts.Offset)
  49. for i := range positions {
  50. positions[i] = int32(opts.Offset + i)
  51. }
  52. return positions
  53. }
  54. type OptionsFunc func(Model, *Options)
  55. func WithInputIDs(ids []int32) OptionsFunc {
  56. return func(m Model, opts *Options) {
  57. opts.inputs = ids
  58. }
  59. }
  60. func WithOffset(offset int) OptionsFunc {
  61. return func(m Model, opts *Options) {
  62. opts.Offset = offset
  63. opts.Cache.Position = offset
  64. }
  65. }
  66. func WithImage(img image.Image) OptionsFunc {
  67. return func(m Model, opts *Options) {
  68. opts.Images = append(opts.Images, img)
  69. }
  70. }
  71. func WithCache(c cache.Cache) OptionsFunc {
  72. return func(m Model, opts *Options) {
  73. opts.Cache = Cache{
  74. Cache: c,
  75. Options: cache.Options{
  76. Position: opts.Offset,
  77. },
  78. }
  79. }
  80. }
  81. type Base struct {
  82. b ml.Backend
  83. }
  84. func (m *Base) Backend() ml.Backend {
  85. return m.b
  86. }
  87. type Model interface {
  88. Forward(ml.Context, Options) (ml.Tensor, error)
  89. Backend() ml.Backend
  90. }
  91. var models = make(map[string]func(ml.Config) (Model, error))
  92. func Register(name string, f func(ml.Config) (Model, error)) {
  93. if _, ok := models[name]; ok {
  94. panic("model: model already registered")
  95. }
  96. models[name] = f
  97. }
  98. func New(s string) (Model, error) {
  99. r, err := os.Open(s)
  100. if err != nil {
  101. return nil, err
  102. }
  103. defer r.Close()
  104. b, err := ml.NewBackend(r)
  105. if err != nil {
  106. return nil, err
  107. }
  108. arch := b.Config().Architecture()
  109. f, ok := models[arch]
  110. if !ok {
  111. return nil, fmt.Errorf("unsupported model architecture %q", arch)
  112. }
  113. m, err := f(b.Config())
  114. if err != nil {
  115. return nil, err
  116. }
  117. v := reflect.ValueOf(m)
  118. v.Elem().Set(populateFields(b, v))
  119. return m, nil
  120. }
  121. func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
  122. t := v.Type()
  123. if t.Kind() == reflect.Pointer {
  124. t, v = t.Elem(), v.Elem()
  125. }
  126. if t.Kind() == reflect.Struct {
  127. allNil := true
  128. for i := range t.NumField() {
  129. tt := t.Field(i).Type
  130. vv := v.Field(i)
  131. if !vv.CanSet() {
  132. continue
  133. }
  134. // make a copy
  135. tagsCopy := tags
  136. if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
  137. tagsCopy = append(tagsCopy, ParseTags(tag))
  138. }
  139. if tt == reflect.TypeOf((*Base)(nil)).Elem() {
  140. vv.Set(reflect.ValueOf(Base{b: b}))
  141. } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
  142. var fn func([]Tag) [][]string
  143. fn = func(tags []Tag) (values [][]string) {
  144. if len(tags) < 1 {
  145. return nil
  146. }
  147. values = [][]string{{tags[0].Name}}
  148. for _, alt := range tags[0].Alternate {
  149. values = append(values, []string{alt})
  150. }
  151. for i, value := range values {
  152. for _, rest := range fn(tags[1:]) {
  153. value = append(value, rest...)
  154. }
  155. values[i] = value
  156. }
  157. return values
  158. }
  159. names := fn(tagsCopy)
  160. for _, name := range names {
  161. if tensor := b.Get(strings.Join(name, ".")); tensor != nil {
  162. slog.Debug("found tensor", "", tensor)
  163. vv.Set(reflect.ValueOf(tensor))
  164. break
  165. }
  166. }
  167. } else if tt.Kind() == reflect.Pointer {
  168. vvv := vv.Elem()
  169. if vv.IsNil() {
  170. vvv = reflect.New(tt.Elem())
  171. }
  172. if f := populateFields(b, vvv, tagsCopy...); f.CanAddr() {
  173. vv.Set(f.Addr())
  174. }
  175. } else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
  176. for i := range vv.Len() {
  177. vv.Index(i).Set(populateFields(b, vv.Index(i), append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
  178. }
  179. }
  180. if !canNil(tt) || !vv.IsNil() {
  181. allNil = false
  182. }
  183. }
  184. if allNil {
  185. return reflect.Zero(t)
  186. }
  187. }
  188. return v
  189. }
  190. type Tag struct {
  191. Name string
  192. Alternate []string
  193. }
  194. func ParseTags(s string) (tag Tag) {
  195. parts := strings.Split(s, ",")
  196. if len(parts) > 0 {
  197. tag.Name = parts[0]
  198. for _, part := range parts[1:] {
  199. if value, ok := strings.CutPrefix(part, "alt:"); ok {
  200. tag.Alternate = append(tag.Alternate, value)
  201. }
  202. }
  203. }
  204. return
  205. }
  206. func canNil(t reflect.Type) bool {
  207. return t.Kind() == reflect.Chan ||
  208. t.Kind() == reflect.Func ||
  209. t.Kind() == reflect.Interface ||
  210. t.Kind() == reflect.Map ||
  211. t.Kind() == reflect.Pointer ||
  212. t.Kind() == reflect.Slice
  213. }
  214. func Forward(m Model, optsFuncs ...OptionsFunc) (ml.Tensor, error) {
  215. var opts Options
  216. for _, optsFunc := range optsFuncs {
  217. optsFunc(m, &opts)
  218. }
  219. ctx := m.Backend().NewContext()
  220. t, err := m.Forward(ctx, opts)
  221. if err != nil {
  222. return nil, err
  223. }
  224. defer ctx.Close()
  225. return ctx.Compute(t), nil
  226. }