model.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "image"
  6. _ "image/jpeg"
  7. _ "image/png"
  8. "log/slog"
  9. "os"
  10. "reflect"
  11. "strconv"
  12. "strings"
  13. _ "golang.org/x/image/bmp"
  14. _ "golang.org/x/image/tiff"
  15. _ "golang.org/x/image/webp"
  16. fs "github.com/ollama/ollama/fs/ggml"
  17. "github.com/ollama/ollama/kvcache"
  18. "github.com/ollama/ollama/ml"
  19. _ "github.com/ollama/ollama/ml/backend"
  20. )
  21. // Options contains the inputs for a model forward pass
  22. type Options struct {
  23. Inputs []int32
  24. Positions []int32
  25. Sequences []int
  26. Outputs []int32
  27. Images []image.Image
  28. }
  29. type config struct {
  30. Cache kvcache.Cache
  31. }
  32. // Base implements the common fields and methods for all models
  33. type Base struct {
  34. b ml.Backend
  35. config
  36. }
  37. // Backend returns the underlying backend that will run the model
  38. func (m *Base) Backend() ml.Backend {
  39. return m.b
  40. }
  41. func (m *Base) Config() config {
  42. return m.config
  43. }
  44. // Model implements a specific model architecture, defining the forward pass and any model-specific configuration
  45. type Model interface {
  46. Forward(ml.Context, Options) (ml.Tensor, error)
  47. Backend() ml.Backend
  48. Config() config
  49. }
  50. var models = make(map[string]func(ml.Config) (Model, error))
  51. // Register registers a model constructor for the given architecture
  52. func Register(name string, f func(ml.Config) (Model, error)) {
  53. if _, ok := models[name]; ok {
  54. panic("model: model already registered")
  55. }
  56. models[name] = f
  57. }
  58. // New initializes a new model instance with the provided configuration based on the metadata in the model file
  59. func New(modelPath string, params ml.BackendParams) (Model, error) {
  60. r, err := os.Open(modelPath)
  61. if err != nil {
  62. return nil, err
  63. }
  64. defer r.Close()
  65. b, err := ml.NewBackend(r, params)
  66. if err != nil {
  67. return nil, err
  68. }
  69. arch := b.Config().Architecture()
  70. f, ok := models[arch]
  71. if !ok {
  72. return nil, fmt.Errorf("unsupported model architecture %q", arch)
  73. }
  74. m, err := f(b.Config())
  75. if err != nil {
  76. return nil, err
  77. }
  78. base := Base{b: b, config: m.Config()}
  79. v := reflect.ValueOf(m)
  80. v.Elem().Set(populateFields(base, v.Elem()))
  81. return m, nil
  82. }
  83. func NewTextProcessor(s string) (TextProcessor, error) {
  84. r, err := os.Open(s)
  85. if err != nil {
  86. return nil, err
  87. }
  88. defer r.Close()
  89. meta, _, err := fs.Decode(r, -1)
  90. if err != nil {
  91. return nil, err
  92. }
  93. return getTextProcessor(meta.KV())
  94. }
  95. func getTextProcessor(kv fs.KV) (TextProcessor, error) {
  96. arch := kv.Architecture()
  97. f, ok := models[arch]
  98. if !ok {
  99. return nil, fmt.Errorf("unsupported model architecture %q", arch)
  100. }
  101. m, err := f(kv)
  102. if err != nil {
  103. return nil, err
  104. }
  105. tp, ok := m.(TextProcessor)
  106. if !ok {
  107. return nil, fmt.Errorf("%v is not a TextProcessor", m)
  108. }
  109. return tp, nil
  110. }
  111. func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
  112. t := v.Type()
  113. if t.Kind() == reflect.Struct {
  114. allNil := true
  115. for i := range t.NumField() {
  116. tt := t.Field(i).Type
  117. vv := v.Field(i)
  118. if !vv.CanSet() {
  119. continue
  120. }
  121. // make a copy
  122. tagsCopy := tags
  123. if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
  124. tagsCopy = append(tagsCopy, ParseTags(tag))
  125. }
  126. if tt == reflect.TypeOf((*Base)(nil)).Elem() {
  127. vv.Set(reflect.ValueOf(base))
  128. } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
  129. var fn func([]Tag) [][]string
  130. fn = func(tags []Tag) (values [][]string) {
  131. if len(tags) < 1 {
  132. return nil
  133. }
  134. values = [][]string{{tags[0].Name}}
  135. for _, alt := range tags[0].Alternate {
  136. values = append(values, []string{alt})
  137. }
  138. for i, value := range values {
  139. for _, rest := range fn(tags[1:]) {
  140. value = append(value, rest...)
  141. }
  142. values[i] = value
  143. }
  144. return values
  145. }
  146. names := fn(tagsCopy)
  147. for _, name := range names {
  148. if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
  149. slog.Debug("found tensor", "", tensor)
  150. vv.Set(reflect.ValueOf(tensor))
  151. break
  152. }
  153. }
  154. } else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
  155. setPointer(base, vv, tagsCopy)
  156. } else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
  157. for i := range vv.Len() {
  158. vvv := vv.Index(i)
  159. if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
  160. setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
  161. } else {
  162. vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
  163. }
  164. }
  165. }
  166. if !canNil(tt) || !vv.IsNil() {
  167. allNil = false
  168. }
  169. }
  170. if allNil {
  171. return reflect.Zero(t)
  172. }
  173. }
  174. return v
  175. }
  176. func setPointer(base Base, v reflect.Value, tags []Tag) {
  177. vv := v
  178. if v.Kind() == reflect.Interface {
  179. if v.IsNil() {
  180. return
  181. }
  182. vv = vv.Elem()
  183. }
  184. vv = vv.Elem()
  185. if v.IsNil() {
  186. vv = reflect.New(v.Type().Elem()).Elem()
  187. }
  188. if f := populateFields(base, vv, tags...); f.CanAddr() {
  189. v.Set(f.Addr())
  190. }
  191. }
  192. type Tag struct {
  193. Name string
  194. Alternate []string
  195. }
  196. func ParseTags(s string) (tag Tag) {
  197. parts := strings.Split(s, ",")
  198. if len(parts) > 0 {
  199. tag.Name = parts[0]
  200. for _, part := range parts[1:] {
  201. if value, ok := strings.CutPrefix(part, "alt:"); ok {
  202. tag.Alternate = append(tag.Alternate, value)
  203. }
  204. }
  205. }
  206. return
  207. }
  208. func canNil(t reflect.Type) bool {
  209. return t.Kind() == reflect.Chan ||
  210. t.Kind() == reflect.Func ||
  211. t.Kind() == reflect.Interface ||
  212. t.Kind() == reflect.Map ||
  213. t.Kind() == reflect.Pointer ||
  214. t.Kind() == reflect.Slice
  215. }
  216. func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
  217. if len(opts.Positions) != len(opts.Sequences) {
  218. return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
  219. }
  220. if len(opts.Positions) < 1 {
  221. return nil, errors.New("batch size cannot be less than 1")
  222. }
  223. cache := m.Config().Cache
  224. if cache != nil {
  225. err := cache.StartForward(ctx, opts.Positions, opts.Sequences)
  226. if err != nil {
  227. return nil, err
  228. }
  229. }
  230. t, err := m.Forward(ctx, opts)
  231. if err != nil {
  232. return nil, err
  233. }
  234. ctx.Forward(t).Compute(t)
  235. return t, nil
  236. }