model.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "image"
  6. _ "image/jpeg"
  7. _ "image/png"
  8. "log/slog"
  9. "os"
  10. "reflect"
  11. "strconv"
  12. "strings"
  13. _ "golang.org/x/image/bmp"
  14. _ "golang.org/x/image/tiff"
  15. _ "golang.org/x/image/webp"
  16. "github.com/ollama/ollama/kvcache"
  17. "github.com/ollama/ollama/ml"
  18. _ "github.com/ollama/ollama/ml/backend"
  19. )
  20. // Options contains the inputs for a model forward pass
  21. type Options struct {
  22. Inputs []int32
  23. Positions []int32
  24. Sequences []int
  25. Outputs []int32
  26. Images []image.Image
  27. }
  28. type config struct {
  29. Cache kvcache.Cache
  30. }
  31. // Base implements the common fields and methods for all models
  32. type Base struct {
  33. b ml.Backend
  34. config
  35. }
  36. // Backend returns the underlying backend that will run the model
  37. func (m *Base) Backend() ml.Backend {
  38. return m.b
  39. }
  40. func (m *Base) Config() config {
  41. return m.config
  42. }
  43. // Model implements a specific model architecture, defining the forward pass and any model-specific configuration
  44. type Model interface {
  45. Forward(ml.Context, Options) (ml.Tensor, error)
  46. Backend() ml.Backend
  47. Config() config
  48. }
  49. var models = make(map[string]func(ml.Config) (Model, error))
  50. // Register registers a model constructor for the given architecture
  51. func Register(name string, f func(ml.Config) (Model, error)) {
  52. if _, ok := models[name]; ok {
  53. panic("model: model already registered")
  54. }
  55. models[name] = f
  56. }
  57. // New initializes a new model instance with the provided configuration based on the metadata in the model file
  58. func New(modelPath string) (Model, error) {
  59. r, err := os.Open(modelPath)
  60. if err != nil {
  61. return nil, err
  62. }
  63. defer r.Close()
  64. b, err := ml.NewBackend(r)
  65. if err != nil {
  66. return nil, err
  67. }
  68. arch := b.Config().Architecture()
  69. f, ok := models[arch]
  70. if !ok {
  71. return nil, fmt.Errorf("unsupported model architecture %q", arch)
  72. }
  73. m, err := f(b.Config())
  74. if err != nil {
  75. return nil, err
  76. }
  77. base := Base{b: b, config: m.Config()}
  78. v := reflect.ValueOf(m)
  79. v.Elem().Set(populateFields(base, v.Elem()))
  80. return m, nil
  81. }
  82. func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
  83. t := v.Type()
  84. if t.Kind() == reflect.Struct {
  85. allNil := true
  86. for i := range t.NumField() {
  87. tt := t.Field(i).Type
  88. vv := v.Field(i)
  89. if !vv.CanSet() {
  90. continue
  91. }
  92. // make a copy
  93. tagsCopy := tags
  94. if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
  95. tagsCopy = append(tagsCopy, ParseTags(tag))
  96. }
  97. if tt == reflect.TypeOf((*Base)(nil)).Elem() {
  98. vv.Set(reflect.ValueOf(base))
  99. } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
  100. var fn func([]Tag) [][]string
  101. fn = func(tags []Tag) (values [][]string) {
  102. if len(tags) < 1 {
  103. return nil
  104. }
  105. values = [][]string{{tags[0].Name}}
  106. for _, alt := range tags[0].Alternate {
  107. values = append(values, []string{alt})
  108. }
  109. for i, value := range values {
  110. for _, rest := range fn(tags[1:]) {
  111. value = append(value, rest...)
  112. }
  113. values[i] = value
  114. }
  115. return values
  116. }
  117. names := fn(tagsCopy)
  118. for _, name := range names {
  119. if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
  120. slog.Debug("found tensor", "", tensor)
  121. vv.Set(reflect.ValueOf(tensor))
  122. break
  123. }
  124. }
  125. } else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
  126. setPointer(base, vv, tagsCopy)
  127. } else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
  128. for i := range vv.Len() {
  129. vvv := vv.Index(i)
  130. if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
  131. setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
  132. } else {
  133. vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
  134. }
  135. }
  136. }
  137. if !canNil(tt) || !vv.IsNil() {
  138. allNil = false
  139. }
  140. }
  141. if allNil {
  142. return reflect.Zero(t)
  143. }
  144. }
  145. return v
  146. }
  147. func setPointer(base Base, v reflect.Value, tags []Tag) {
  148. vv := v
  149. if v.Kind() == reflect.Interface {
  150. if v.IsNil() {
  151. return
  152. }
  153. vv = vv.Elem()
  154. }
  155. vv = vv.Elem()
  156. if v.IsNil() {
  157. vv = reflect.New(v.Type().Elem()).Elem()
  158. }
  159. if f := populateFields(base, vv, tags...); f.CanAddr() {
  160. v.Set(f.Addr())
  161. }
  162. }
  163. type Tag struct {
  164. Name string
  165. Alternate []string
  166. }
  167. func ParseTags(s string) (tag Tag) {
  168. parts := strings.Split(s, ",")
  169. if len(parts) > 0 {
  170. tag.Name = parts[0]
  171. for _, part := range parts[1:] {
  172. if value, ok := strings.CutPrefix(part, "alt:"); ok {
  173. tag.Alternate = append(tag.Alternate, value)
  174. }
  175. }
  176. }
  177. return
  178. }
  179. func canNil(t reflect.Type) bool {
  180. return t.Kind() == reflect.Chan ||
  181. t.Kind() == reflect.Func ||
  182. t.Kind() == reflect.Interface ||
  183. t.Kind() == reflect.Map ||
  184. t.Kind() == reflect.Pointer ||
  185. t.Kind() == reflect.Slice
  186. }
  187. func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
  188. if len(opts.Positions) != len(opts.Sequences) {
  189. return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
  190. }
  191. if len(opts.Positions) < 1 {
  192. return nil, errors.New("batch size cannot be less than 1")
  193. }
  194. cache := m.Config().Cache
  195. if cache != nil {
  196. err := cache.StartForward(ctx, opts.Positions, opts.Sequences)
  197. if err != nil {
  198. return nil, err
  199. }
  200. }
  201. t, err := m.Forward(ctx, opts)
  202. if err != nil {
  203. return nil, err
  204. }
  205. ctx.Forward(t)
  206. ctx.Compute(t)
  207. return t, nil
  208. }