backend.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. package ml
  2. import (
  3. "bytes"
  4. "cmp"
  5. "encoding/binary"
  6. "fmt"
  7. "os"
  8. "strconv"
  9. "strings"
  10. )
  11. type Config interface {
  12. Architecture() string
  13. String(string, ...string) string
  14. Uint(string, ...uint32) uint32
  15. Float(string, ...float32) float32
  16. Strings(string, ...[]string) []string
  17. Uints(string, ...[]uint32) []uint32
  18. }
  19. type Backend interface {
  20. Config() Config
  21. Get(name string) Tensor
  22. NewContext() Context
  23. SystemInfo() string
  24. }
  25. var backends = make(map[string]func(*os.File) (Backend, error))
  26. func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
  27. if _, ok := backends[name]; ok {
  28. panic("backend: backend already registered")
  29. }
  30. backends[name] = f
  31. }
  32. func NewBackend(f *os.File) (Backend, error) {
  33. if backend, ok := backends[cmp.Or(os.Getenv("OLLAMA_BACKEND"), "ggml")]; ok {
  34. return backend(f)
  35. }
  36. return nil, fmt.Errorf("unsupported backend")
  37. }
  38. type Context interface {
  39. Zeros(dtype DType, shape ...int) Tensor
  40. FromFloatSlice(s []float32, shape ...int) (Tensor, error)
  41. FromIntSlice(s []int32, shape ...int) (Tensor, error)
  42. Forward(Tensor)
  43. Compute(...Tensor)
  44. MaxTensors() int
  45. Close()
  46. Timing() []OpTiming
  47. }
  48. // OpType is the type of operation performed during a forward pass.
  49. type OpType string
  50. const (
  51. View OpType = "View"
  52. Copy OpType = "Copy"
  53. Reshape OpType = "Reshape"
  54. Permute OpType = "Permute"
  55. Contiguous OpType = "Contiguous"
  56. Input OpType = "Input"
  57. ComputeOp OpType = "Compute"
  58. Transpose OpType = "Transpose"
  59. )
  60. // OpTiming stores the timing information for a single operation.
  61. type OpTiming struct {
  62. Type OpType
  63. Operation string
  64. Duration float64
  65. Order int
  66. }
  67. type Tensor interface {
  68. Dim(n int) int
  69. Stride(n int) int
  70. Shape() []int
  71. DType() DType
  72. Bytes() []byte
  73. Floats() []float32
  74. Add(ctx Context, t2 Tensor) Tensor
  75. Mul(ctx Context, t2 Tensor) Tensor
  76. Mulmat(ctx Context, t2 Tensor) Tensor
  77. MulmatFullPrec(ctx Context, t2 Tensor) Tensor
  78. Softmax(ctx Context) Tensor
  79. LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
  80. RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
  81. Scale(ctx Context, s float64) Tensor
  82. Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
  83. RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
  84. Tanh(ctx Context) Tensor
  85. GELU(ctx Context) Tensor
  86. SILU(ctx Context) Tensor
  87. Reshape(ctx Context, shape ...int) Tensor
  88. View(ctx Context, offset int, shape ...int) Tensor
  89. Permute(ctx Context, shape ...int) Tensor
  90. Contiguous(ctx Context) Tensor
  91. Pad(ctx Context, shape ...int) Tensor
  92. Unpad(ctx Context, shape ...int) Tensor
  93. Stack(ctx Context, dim int, s ...Tensor) Tensor
  94. Concat(ctx Context, t2 Tensor, dim int) Tensor
  95. Rows(ctx Context, t2 Tensor) Tensor
  96. Copy(ctx Context, t2 Tensor) Tensor
  97. }
  98. type number interface {
  99. ~int | ~int8 | ~int16 | ~int32 | ~int64 |
  100. ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
  101. ~float32 | ~float64 |
  102. ~complex64 | ~complex128
  103. }
  104. func mul[T number](s ...T) T {
  105. p := T(1)
  106. for _, v := range s {
  107. p *= v
  108. }
  109. return p
  110. }
  111. type DumpOptions struct {
  112. // Items is the number of elements to print at the beginning and end of each dimension.
  113. Items int
  114. // Precision is the number of decimal places to print. Applies to float32 and float64.
  115. Precision int
  116. }
  117. func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
  118. if len(opts) < 1 {
  119. opts = append(opts, DumpOptions{
  120. Items: 3,
  121. Precision: 4,
  122. })
  123. }
  124. switch t.DType() {
  125. case DTypeF32:
  126. return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
  127. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  128. })
  129. case DTypeF16:
  130. f32 := ctx.Zeros(DTypeF32, t.Shape()...)
  131. f32 = t.Copy(ctx, f32)
  132. return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
  133. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  134. })
  135. case DTypeI32:
  136. return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
  137. return strconv.FormatInt(int64(i), 10)
  138. })
  139. default:
  140. return "<unsupported>"
  141. }
  142. }
  143. func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
  144. if t.Bytes() == nil {
  145. ctx.Forward(t)
  146. ctx.Compute(t)
  147. }
  148. s := make(S, mul(t.Shape()...))
  149. if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
  150. panic(err)
  151. }
  152. shape := t.Shape()
  153. var sb strings.Builder
  154. var f func([]int, int)
  155. f = func(dims []int, stride int) {
  156. prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
  157. fmt.Fprint(&sb, "[")
  158. defer func() { fmt.Fprint(&sb, "]") }()
  159. for i := 0; i < dims[0]; i++ {
  160. if i >= items && i < dims[0]-items {
  161. fmt.Fprint(&sb, "..., ")
  162. // skip to next printable element
  163. skip := dims[0] - 2*items
  164. if len(dims) > 1 {
  165. stride += mul(append(dims[1:], skip)...)
  166. fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
  167. }
  168. i += skip - 1
  169. } else if len(dims) > 1 {
  170. f(dims[1:], stride)
  171. stride += mul(dims[1:]...)
  172. if i < dims[0]-1 {
  173. fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
  174. }
  175. } else {
  176. fmt.Fprint(&sb, fn(s[stride+i]))
  177. if i < dims[0]-1 {
  178. fmt.Fprint(&sb, ", ")
  179. }
  180. }
  181. }
  182. }
  183. f(shape, 0)
  184. return sb.String()
  185. }
  186. type DType int
  187. const (
  188. DTypeOther DType = iota
  189. DTypeF32
  190. DTypeF16
  191. DTypeI32
  192. )