model.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. package model
  2. import (
  3. "fmt"
  4. "image"
  5. _ "image/jpeg"
  6. _ "image/png"
  7. "log/slog"
  8. "os"
  9. "reflect"
  10. "strconv"
  11. "strings"
  12. _ "golang.org/x/image/bmp"
  13. _ "golang.org/x/image/tiff"
  14. _ "golang.org/x/image/webp"
  15. "github.com/ollama/ollama/cache"
  16. "github.com/ollama/ollama/ml"
  17. _ "github.com/ollama/ollama/ml/backend"
  18. )
  19. type Options struct {
  20. inputs []int32
  21. positions []int32
  22. outputs []int32
  23. sequences []int
  24. Images []image.Image
  25. cache.Cache
  26. }
  27. func (opts Options) Inputs() []int32 {
  28. return opts.inputs
  29. }
  30. func (opts Options) Positions() []int32 {
  31. return opts.positions
  32. }
  33. func (opts Options) Outputs() []int32 {
  34. return opts.outputs
  35. }
  36. type OptionsFunc func(Model, *Options)
  37. func WithInputIDs(ids []int32) OptionsFunc {
  38. return func(m Model, opts *Options) {
  39. opts.inputs = ids
  40. }
  41. }
  42. func WithPositions(pos []int32) OptionsFunc {
  43. return func(m Model, opts *Options) {
  44. opts.positions = pos
  45. }
  46. }
  47. func WithOutputs(outputs []int32) OptionsFunc {
  48. return func(m Model, opts *Options) {
  49. opts.outputs = outputs
  50. }
  51. }
  52. func WithSequences(seqs []int) OptionsFunc {
  53. return func(m Model, opts *Options) {
  54. opts.sequences = seqs
  55. }
  56. }
  57. func WithImage(img image.Image) OptionsFunc {
  58. return func(m Model, opts *Options) {
  59. opts.Images = append(opts.Images, img)
  60. }
  61. }
  62. func WithCache(c cache.Cache) OptionsFunc {
  63. return func(m Model, opts *Options) {
  64. opts.Cache = c
  65. }
  66. }
  67. type Base struct {
  68. b ml.Backend
  69. }
  70. func (m *Base) Backend() ml.Backend {
  71. return m.b
  72. }
  73. type Model interface {
  74. Forward(ml.Context, Options) (ml.Tensor, error)
  75. Backend() ml.Backend
  76. }
  77. var models = make(map[string]func(ml.Config) (Model, error))
  78. func Register(name string, f func(ml.Config) (Model, error)) {
  79. if _, ok := models[name]; ok {
  80. panic("model: model already registered")
  81. }
  82. models[name] = f
  83. }
  84. func New(s string) (Model, error) {
  85. r, err := os.Open(s)
  86. if err != nil {
  87. return nil, err
  88. }
  89. defer r.Close()
  90. b, err := ml.NewBackend(r)
  91. if err != nil {
  92. return nil, err
  93. }
  94. arch := b.Config().Architecture()
  95. f, ok := models[arch]
  96. if !ok {
  97. return nil, fmt.Errorf("unsupported model architecture %q", arch)
  98. }
  99. m, err := f(b.Config())
  100. if err != nil {
  101. return nil, err
  102. }
  103. v := reflect.ValueOf(m)
  104. v.Elem().Set(populateFields(b, v))
  105. return m, nil
  106. }
  107. func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
  108. var iface bool
  109. if v.Kind() == reflect.Interface {
  110. iface = true
  111. v = v.Elem()
  112. }
  113. t := v.Type()
  114. if t.Kind() == reflect.Pointer {
  115. t, v = t.Elem(), v.Elem()
  116. }
  117. if t.Kind() == reflect.Struct {
  118. allNil := true
  119. for i := range t.NumField() {
  120. tt := t.Field(i).Type
  121. vv := v.Field(i)
  122. if !vv.CanSet() {
  123. continue
  124. }
  125. // make a copy
  126. tagsCopy := tags
  127. if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
  128. tagsCopy = append(tagsCopy, ParseTags(tag))
  129. }
  130. if tt == reflect.TypeOf((*Base)(nil)).Elem() {
  131. vv.Set(reflect.ValueOf(Base{b: b}))
  132. } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
  133. var fn func([]Tag) [][]string
  134. fn = func(tags []Tag) (values [][]string) {
  135. if len(tags) < 1 {
  136. return nil
  137. }
  138. values = [][]string{{tags[0].Name}}
  139. for _, alt := range tags[0].Alternate {
  140. values = append(values, []string{alt})
  141. }
  142. for i, value := range values {
  143. for _, rest := range fn(tags[1:]) {
  144. value = append(value, rest...)
  145. }
  146. values[i] = value
  147. }
  148. return values
  149. }
  150. names := fn(tagsCopy)
  151. for _, name := range names {
  152. if tensor := b.Get(strings.Join(name, ".")); tensor != nil {
  153. slog.Debug("found tensor", "", tensor)
  154. vv.Set(reflect.ValueOf(tensor))
  155. break
  156. }
  157. }
  158. } else if tt.Kind() == reflect.Pointer {
  159. vvv := vv.Elem()
  160. if vv.IsNil() {
  161. vvv = reflect.New(tt.Elem())
  162. }
  163. if f := populateFields(b, vvv, tagsCopy...); f.CanAddr() {
  164. vv.Set(f.Addr())
  165. }
  166. } else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
  167. for i := range vv.Len() {
  168. vv.Index(i).Set(populateFields(b, vv.Index(i), append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
  169. }
  170. }
  171. if !canNil(tt) || !vv.IsNil() {
  172. allNil = false
  173. }
  174. }
  175. if allNil {
  176. return reflect.Zero(t)
  177. }
  178. }
  179. if iface {
  180. return v.Addr()
  181. }
  182. return v
  183. }
  184. type Tag struct {
  185. Name string
  186. Alternate []string
  187. }
  188. func ParseTags(s string) (tag Tag) {
  189. parts := strings.Split(s, ",")
  190. if len(parts) > 0 {
  191. tag.Name = parts[0]
  192. for _, part := range parts[1:] {
  193. if value, ok := strings.CutPrefix(part, "alt:"); ok {
  194. tag.Alternate = append(tag.Alternate, value)
  195. }
  196. }
  197. }
  198. return
  199. }
  200. func canNil(t reflect.Type) bool {
  201. return t.Kind() == reflect.Chan ||
  202. t.Kind() == reflect.Func ||
  203. t.Kind() == reflect.Interface ||
  204. t.Kind() == reflect.Map ||
  205. t.Kind() == reflect.Pointer ||
  206. t.Kind() == reflect.Slice
  207. }
  208. func Forward(ctx ml.Context, m Model, optsFuncs ...OptionsFunc) (ml.Tensor, error) {
  209. var opts Options
  210. for _, optsFunc := range optsFuncs {
  211. optsFunc(m, &opts)
  212. }
  213. err := opts.Cache.StartForward(ctx, opts.sequences)
  214. if err != nil {
  215. return nil, err
  216. }
  217. t, err := m.Forward(ctx, opts)
  218. if err != nil {
  219. return nil, err
  220. }
  221. ctx.Forward(t)
  222. return ctx.Compute(t), nil
  223. }