backend.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. package ml
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "os"
  7. "strconv"
  8. "strings"
  9. )
  10. type Config interface {
  11. Architecture() string
  12. String(string, ...string) string
  13. Uint(string, ...uint32) uint32
  14. Float(string, ...float32) float32
  15. Bool(string, ...bool) bool
  16. Strings(string, ...[]string) []string
  17. Uints(string, ...[]uint32) []uint32
  18. }
  19. type Backend interface {
  20. Config() Config
  21. Get(name string) Tensor
  22. NewContext() Context
  23. SystemInfo() string
  24. }
  25. // BackendParams controls how the backend loads and executes models
  26. type BackendParams struct {
  27. // NumThreads sets the number of threads to use if running on the CPU
  28. NumThreads int
  29. // MainGPU is the index of the primary GPU to use
  30. MainGPU int
  31. // NumGPULayers is the number of layers to offload to GPUs
  32. NumGPULayers int
  33. // TensorSplit is the fraction of the model to offload to each GPU
  34. TensorSplit []float32
  35. }
  36. var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
  37. func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
  38. if _, ok := backends[name]; ok {
  39. panic("backend: backend already registered")
  40. }
  41. backends[name] = f
  42. }
  43. func NewBackend(f *os.File, params BackendParams) (Backend, error) {
  44. if backend, ok := backends["ggml"]; ok {
  45. return backend(f, params)
  46. }
  47. return nil, fmt.Errorf("unsupported backend")
  48. }
  49. type Context interface {
  50. Zeros(dtype DType, shape ...int) Tensor
  51. FromFloatSlice(s []float32, shape ...int) (Tensor, error)
  52. FromIntSlice(s []int32, shape ...int) (Tensor, error)
  53. Forward(Tensor)
  54. Compute(...Tensor)
  55. MaxTensors() int
  56. Close()
  57. }
  58. type Tensor interface {
  59. Dim(n int) int
  60. Stride(n int) int
  61. Shape() []int
  62. DType() DType
  63. Bytes() []byte
  64. Floats() []float32
  65. Add(ctx Context, t2 Tensor) Tensor
  66. Mul(ctx Context, t2 Tensor) Tensor
  67. Mulmat(ctx Context, t2 Tensor) Tensor
  68. MulmatFullPrec(ctx Context, t2 Tensor) Tensor
  69. Softmax(ctx Context) Tensor
  70. LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
  71. RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
  72. Scale(ctx Context, s float64) Tensor
  73. Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
  74. RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
  75. Tanh(ctx Context) Tensor
  76. GELU(ctx Context) Tensor
  77. SILU(ctx Context) Tensor
  78. Reshape(ctx Context, shape ...int) Tensor
  79. View(ctx Context, offset int, shape ...int) Tensor
  80. Permute(ctx Context, shape ...int) Tensor
  81. Contiguous(ctx Context) Tensor
  82. Pad(ctx Context, shape ...int) Tensor
  83. Unpad(ctx Context, shape ...int) Tensor
  84. Stack(ctx Context, dim int, s ...Tensor) Tensor
  85. Concat(ctx Context, t2 Tensor, dim int) Tensor
  86. Rows(ctx Context, t2 Tensor) Tensor
  87. Copy(ctx Context, t2 Tensor) Tensor
  88. }
  89. // ScaledDotProductAttention implements a fused attention
  90. // operation equivalent to following code on a tensor named
  91. // query:
  92. //
  93. // kq := key.MulmatFullPrec(ctx, query)
  94. //
  95. // kq = kq.Scale(ctx, scale)
  96. //
  97. // if mask != nil {
  98. // kq = kq.Add(ctx, mask)
  99. // }
  100. //
  101. // kq = kq.Softmax(ctx)
  102. //
  103. // kqv := value.Mulmat(ctx, kq)
  104. // return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  105. type ScaledDotProductAttention interface {
  106. ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
  107. }
  108. type number interface {
  109. ~int | ~int8 | ~int16 | ~int32 | ~int64 |
  110. ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
  111. ~float32 | ~float64 |
  112. ~complex64 | ~complex128
  113. }
  114. func mul[T number](s ...T) T {
  115. p := T(1)
  116. for _, v := range s {
  117. p *= v
  118. }
  119. return p
  120. }
  121. type DumpOptions struct {
  122. // Items is the number of elements to print at the beginning and end of each dimension.
  123. Items int
  124. // Precision is the number of decimal places to print. Applies to float32 and float64.
  125. Precision int
  126. }
  127. func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
  128. if len(opts) < 1 {
  129. opts = append(opts, DumpOptions{
  130. Items: 3,
  131. Precision: 4,
  132. })
  133. }
  134. switch t.DType() {
  135. case DTypeF32:
  136. return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
  137. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  138. })
  139. case DTypeF16:
  140. f32 := ctx.Zeros(DTypeF32, t.Shape()...)
  141. f32 = t.Copy(ctx, f32)
  142. return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
  143. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  144. })
  145. case DTypeI32:
  146. return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
  147. return strconv.FormatInt(int64(i), 10)
  148. })
  149. default:
  150. return "<unsupported>"
  151. }
  152. }
  153. func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
  154. if t.Bytes() == nil {
  155. ctx.Forward(t)
  156. ctx.Compute(t)
  157. }
  158. s := make(S, mul(t.Shape()...))
  159. if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
  160. panic(err)
  161. }
  162. shape := t.Shape()
  163. var sb strings.Builder
  164. var f func([]int, int)
  165. f = func(dims []int, stride int) {
  166. prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
  167. fmt.Fprint(&sb, "[")
  168. defer func() { fmt.Fprint(&sb, "]") }()
  169. for i := 0; i < dims[0]; i++ {
  170. if i >= items && i < dims[0]-items {
  171. fmt.Fprint(&sb, "..., ")
  172. // skip to next printable element
  173. skip := dims[0] - 2*items
  174. if len(dims) > 1 {
  175. stride += mul(append(dims[1:], skip)...)
  176. fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
  177. }
  178. i += skip - 1
  179. } else if len(dims) > 1 {
  180. f(dims[1:], stride)
  181. stride += mul(dims[1:]...)
  182. if i < dims[0]-1 {
  183. fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
  184. }
  185. } else {
  186. fmt.Fprint(&sb, fn(s[stride+i]))
  187. if i < dims[0]-1 {
  188. fmt.Fprint(&sb, ", ")
  189. }
  190. }
  191. }
  192. }
  193. f(shape, 0)
  194. return sb.String()
  195. }
  196. type DType int
  197. const (
  198. DTypeOther DType = iota
  199. DTypeF32
  200. DTypeF16
  201. DTypeI32
  202. )