backend.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. package ml
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/binary"
  6. "fmt"
  7. "os"
  8. "slices"
  9. "strconv"
  10. "strings"
  11. )
  12. type Config interface {
  13. Architecture() string
  14. String(string, ...string) string
  15. Uint(string, ...uint32) uint32
  16. Float(string, ...float32) float32
  17. Bool(string, ...bool) bool
  18. Strings(string, ...[]string) []string
  19. Uints(string, ...[]uint32) []uint32
  20. Floats(string, ...[]float32) []float32
  21. }
  22. type Backend interface {
  23. Config() Config
  24. Get(name string) Tensor
  25. NewContext() Context
  26. NewContextSize(size int) Context
  27. }
  28. // BackendCacheConfig should be implemented by backends that need special output
  29. // from the cache to meet specific requirements. It is frequently implemented in
  30. // conjunction with ScaledDotProductAttention.
  31. type BackendCacheConfig interface {
  32. CacheConfig() CacheConfig
  33. }
  34. // CacheConfig controls optimizations (mostly backend-specific) that may transform
  35. // the output the cache to work better with specific kernels.
  36. type CacheConfig struct {
  37. // CachePadding specifies the multiple for the number of tokens of cache history
  38. // that will be returned from cache Get for k, v and mask. The capacity of the
  39. // cache itself will also be increased to a multiple of this size if needed.
  40. CachePadding int
  41. // PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
  42. // and return the permuted version via Get. This uses the cache copy operation
  43. // to avoid a Contiguous call on the permuted tensor.
  44. PermutedV bool
  45. // MaskDType specifies the data type for generating the mask. If unset it will
  46. // default to DTypeF32.
  47. MaskDType DType
  48. // MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
  49. // Any position that does not correspond to an actual token will be filled with -Inf.
  50. MaskBatchPadding int
  51. }
  52. // BackendParams controls how the backend loads and executes models
  53. type BackendParams struct {
  54. // Progress is a callback function that allows reporting percentage completion
  55. // of model loading
  56. Progress func(float32)
  57. // NumThreads sets the number of threads to use if running on the CPU
  58. NumThreads int
  59. // MainGPU is the index of the primary GPU to use
  60. MainGPU int
  61. // NumGPULayers is the number of layers to offload to GPUs
  62. NumGPULayers int
  63. // TensorSplit is the fraction of the model to offload to each GPU
  64. TensorSplit []float32
  65. // FlashAttention indicates that we should use a fused flash attention kernel
  66. FlashAttention bool
  67. }
  68. var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
  69. func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
  70. if _, ok := backends[name]; ok {
  71. panic("backend: backend already registered")
  72. }
  73. backends[name] = f
  74. }
  75. func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
  76. if backend, ok := backends["ggml"]; ok {
  77. return backend(ctx, f, params)
  78. }
  79. return nil, fmt.Errorf("unsupported backend")
  80. }
  81. type Context interface {
  82. Empty(dtype DType, shape ...int) Tensor
  83. Zeros(dtype DType, shape ...int) Tensor
  84. FromFloatSlice(s []float32, shape ...int) (Tensor, error)
  85. FromIntSlice(s []int32, shape ...int) (Tensor, error)
  86. Forward(...Tensor) Context
  87. Compute(...Tensor)
  88. MaxGraphNodes() int
  89. Close()
  90. // Input returns a context appropriate for creating input tensors
  91. Input() Context
  92. // Output returns a context appropriate for creating output tensors
  93. Output() Context
  94. // Layer returns a context appropriate for creating intermediate tensors
  95. Layer(int) Context
  96. }
  97. type Tensor interface {
  98. Dim(n int) int
  99. Stride(n int) int
  100. Shape() []int
  101. DType() DType
  102. Bytes() []byte
  103. Floats() []float32
  104. Add(ctx Context, t2 Tensor) Tensor
  105. Mul(ctx Context, t2 Tensor) Tensor
  106. Mulmat(ctx Context, t2 Tensor) Tensor
  107. MulmatFullPrec(ctx Context, t2 Tensor) Tensor
  108. Softmax(ctx Context) Tensor
  109. LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
  110. RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
  111. Scale(ctx Context, s float64) Tensor
  112. AvgPool2D(ctx Context, k, s int, p float32) Tensor
  113. Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
  114. RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
  115. Tanh(ctx Context) Tensor
  116. GELU(ctx Context) Tensor
  117. SILU(ctx Context) Tensor
  118. Reshape(ctx Context, shape ...int) Tensor
  119. View(ctx Context, offset int, shape ...int) Tensor
  120. Permute(ctx Context, shape ...int) Tensor
  121. Contiguous(ctx Context) Tensor
  122. Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
  123. Pad(ctx Context, shape ...int) Tensor
  124. Unpad(ctx Context, shape ...int) Tensor
  125. Stack(ctx Context, dim int, s ...Tensor) Tensor
  126. Concat(ctx Context, t2 Tensor, dim int) Tensor
  127. Rows(ctx Context, t2 Tensor) Tensor
  128. Copy(ctx Context, t2 Tensor) Tensor
  129. }
  130. // ScaledDotProductAttention implements a fused attention
  131. // operation equivalent to following code on a tensor named
  132. // query:
  133. //
  134. // query = query.Permute(ctx, 0, 2, 1, 3)
  135. // key = key.Permute(ctx, 0, 2, 1, 3)
  136. // value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  137. //
  138. // kq := key.MulmatFullPrec(ctx, query)
  139. //
  140. // kq = kq.Scale(ctx, scale)
  141. //
  142. // if mask != nil {
  143. // kq = kq.Add(ctx, mask)
  144. // }
  145. //
  146. // kq = kq.Softmax(ctx)
  147. //
  148. // kqv := value.Mulmat(ctx, kq)
  149. // return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  150. type ScaledDotProductAttention interface {
  151. ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
  152. }
  153. type number interface {
  154. ~int | ~int8 | ~int16 | ~int32 | ~int64 |
  155. ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
  156. ~float32 | ~float64 |
  157. ~complex64 | ~complex128
  158. }
  159. func mul[T number](s ...T) T {
  160. p := T(1)
  161. for _, v := range s {
  162. p *= v
  163. }
  164. return p
  165. }
  166. type DumpOptions struct {
  167. // Items is the number of elements to print at the beginning and end of each dimension.
  168. Items int
  169. // Precision is the number of decimal places to print. Applies to float32 and float64.
  170. Precision int
  171. }
  172. func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
  173. if len(opts) < 1 {
  174. opts = append(opts, DumpOptions{
  175. Items: 3,
  176. Precision: 4,
  177. })
  178. }
  179. switch t.DType() {
  180. case DTypeF32:
  181. return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
  182. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  183. })
  184. case DTypeF16, DTypeQ80, DTypeQ40:
  185. f32 := ctx.Empty(DTypeF32, t.Shape()...)
  186. f32 = t.Copy(ctx, f32)
  187. return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
  188. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  189. })
  190. case DTypeI32:
  191. return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
  192. return strconv.FormatInt(int64(i), 10)
  193. })
  194. default:
  195. return "<unsupported>"
  196. }
  197. }
  198. func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
  199. if t.Bytes() == nil {
  200. ctx.Forward(t).Compute(t)
  201. }
  202. s := make(S, mul(t.Shape()...))
  203. if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
  204. panic(err)
  205. }
  206. shape := t.Shape()
  207. slices.Reverse(shape)
  208. var sb strings.Builder
  209. var f func([]int, int)
  210. f = func(dims []int, stride int) {
  211. prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
  212. sb.WriteString("[")
  213. defer func() { sb.WriteString("]") }()
  214. for i := 0; i < dims[0]; i++ {
  215. if i >= items && i < dims[0]-items {
  216. sb.WriteString("..., ")
  217. // skip to next printable element
  218. skip := dims[0] - 2*items
  219. if len(dims) > 1 {
  220. stride += mul(append(dims[1:], skip)...)
  221. fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
  222. }
  223. i += skip - 1
  224. } else if len(dims) > 1 {
  225. f(dims[1:], stride)
  226. stride += mul(dims[1:]...)
  227. if i < dims[0]-1 {
  228. fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
  229. }
  230. } else {
  231. text := fn(s[stride+i])
  232. if len(text) > 0 && text[0] != '-' {
  233. sb.WriteString(" ")
  234. }
  235. sb.WriteString(text)
  236. if i < dims[0]-1 {
  237. sb.WriteString(", ")
  238. }
  239. }
  240. }
  241. }
  242. f(shape, 0)
  243. return sb.String()
  244. }
  245. type DType int
  246. const (
  247. DTypeOther DType = iota
  248. DTypeF32
  249. DTypeF16
  250. DTypeQ80
  251. DTypeQ40
  252. DTypeI32
  253. )