backend.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. package ml
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "os"
  7. "slices"
  8. "strconv"
  9. "strings"
  10. )
  11. type Config interface {
  12. Architecture() string
  13. String(string, ...string) string
  14. Uint(string, ...uint32) uint32
  15. Float(string, ...float32) float32
  16. Bool(string, ...bool) bool
  17. Strings(string, ...[]string) []string
  18. Uints(string, ...[]uint32) []uint32
  19. Floats(string, ...[]float32) []float32
  20. }
  21. type Backend interface {
  22. Config() Config
  23. Get(name string) Tensor
  24. NewContext() Context
  25. NewContextSize(size int) Context
  26. }
  27. // BackendCacheConfig should be implemented by backends that need special output
  28. // from the cache to meet specific requirements. It is frequently implemented in
  29. // conjunction with ScaledDotProductAttention.
  30. type BackendCacheConfig interface {
  31. CacheConfig() CacheConfig
  32. }
  33. // CacheConfig controls optimizations (mostly backend-specific) that may transform
  34. // the output the cache to work better with specific kernels.
  35. type CacheConfig struct {
  36. // CachePadding specifies the multiple for the number of tokens of cache history
  37. // that will be returned from cache Get for k, v and mask. The capacity of the
  38. // cache itself will also be increased to a multiple of this size if needed.
  39. CachePadding int
  40. // PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
  41. // and return the permuted version via Get. This uses the cache copy operation
  42. // to avoid a Contiguous call on the permuted tensor.
  43. PermutedV bool
  44. // MaskDType specifies the data type for generating the mask. If unset it will
  45. // default to DTypeF32.
  46. MaskDType DType
  47. // MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
  48. // Any position that does not correspond to an actual token will be filled with -Inf.
  49. MaskBatchPadding int
  50. }
  51. // BackendParams controls how the backend loads and executes models
  52. type BackendParams struct {
  53. // Progress is a callback function that allows reporting percentage completion
  54. // of model loading
  55. Progress func(float32)
  56. // NumThreads sets the number of threads to use if running on the CPU
  57. NumThreads int
  58. // MainGPU is the index of the primary GPU to use
  59. MainGPU int
  60. // NumGPULayers is the number of layers to offload to GPUs
  61. NumGPULayers int
  62. // TensorSplit is the fraction of the model to offload to each GPU
  63. TensorSplit []float32
  64. // FlashAttention indicates that we should use a fused flash attention kernel
  65. FlashAttention bool
  66. }
  67. var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
  68. func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
  69. if _, ok := backends[name]; ok {
  70. panic("backend: backend already registered")
  71. }
  72. backends[name] = f
  73. }
  74. func NewBackend(f *os.File, params BackendParams) (Backend, error) {
  75. if backend, ok := backends["ggml"]; ok {
  76. return backend(f, params)
  77. }
  78. return nil, fmt.Errorf("unsupported backend")
  79. }
  80. type Context interface {
  81. Empty(dtype DType, shape ...int) Tensor
  82. Zeros(dtype DType, shape ...int) Tensor
  83. FromFloatSlice(s []float32, shape ...int) (Tensor, error)
  84. FromIntSlice(s []int32, shape ...int) (Tensor, error)
  85. Forward(...Tensor) Context
  86. Compute(...Tensor)
  87. MaxGraphNodes() int
  88. Close()
  89. // Input returns a context appropriate for creating input tensors
  90. Input() Context
  91. // Output returns a context appropriate for creating output tensors
  92. Output() Context
  93. // Layer returns a context appropriate for creating intermediate tensors
  94. Layer(int) Context
  95. }
  96. type Tensor interface {
  97. Dim(n int) int
  98. Stride(n int) int
  99. Shape() []int
  100. DType() DType
  101. Bytes() []byte
  102. Floats() []float32
  103. Add(ctx Context, t2 Tensor) Tensor
  104. Mul(ctx Context, t2 Tensor) Tensor
  105. Mulmat(ctx Context, t2 Tensor) Tensor
  106. MulmatFullPrec(ctx Context, t2 Tensor) Tensor
  107. Softmax(ctx Context) Tensor
  108. LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
  109. RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
  110. Scale(ctx Context, s float64) Tensor
  111. AvgPool2D(ctx Context, k, s int, p float32) Tensor
  112. Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
  113. RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
  114. Tanh(ctx Context) Tensor
  115. GELU(ctx Context) Tensor
  116. SILU(ctx Context) Tensor
  117. Reshape(ctx Context, shape ...int) Tensor
  118. View(ctx Context, offset int, shape ...int) Tensor
  119. Permute(ctx Context, shape ...int) Tensor
  120. Contiguous(ctx Context) Tensor
  121. Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
  122. Pad(ctx Context, shape ...int) Tensor
  123. Unpad(ctx Context, shape ...int) Tensor
  124. Stack(ctx Context, dim int, s ...Tensor) Tensor
  125. Concat(ctx Context, t2 Tensor, dim int) Tensor
  126. Rows(ctx Context, t2 Tensor) Tensor
  127. Copy(ctx Context, t2 Tensor) Tensor
  128. }
  129. // ScaledDotProductAttention implements a fused attention
  130. // operation equivalent to following code on a tensor named
  131. // query:
  132. //
  133. // query = query.Permute(ctx, 0, 2, 1, 3)
  134. // key = key.Permute(ctx, 0, 2, 1, 3)
  135. // value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  136. //
  137. // kq := key.MulmatFullPrec(ctx, query)
  138. //
  139. // kq = kq.Scale(ctx, scale)
  140. //
  141. // if mask != nil {
  142. // kq = kq.Add(ctx, mask)
  143. // }
  144. //
  145. // kq = kq.Softmax(ctx)
  146. //
  147. // kqv := value.Mulmat(ctx, kq)
  148. // return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  149. type ScaledDotProductAttention interface {
  150. ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
  151. }
  152. type number interface {
  153. ~int | ~int8 | ~int16 | ~int32 | ~int64 |
  154. ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
  155. ~float32 | ~float64 |
  156. ~complex64 | ~complex128
  157. }
  158. func mul[T number](s ...T) T {
  159. p := T(1)
  160. for _, v := range s {
  161. p *= v
  162. }
  163. return p
  164. }
  165. type DumpOptions struct {
  166. // Items is the number of elements to print at the beginning and end of each dimension.
  167. Items int
  168. // Precision is the number of decimal places to print. Applies to float32 and float64.
  169. Precision int
  170. }
  171. func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
  172. if len(opts) < 1 {
  173. opts = append(opts, DumpOptions{
  174. Items: 3,
  175. Precision: 4,
  176. })
  177. }
  178. switch t.DType() {
  179. case DTypeF32:
  180. return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
  181. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  182. })
  183. case DTypeF16, DTypeQ80, DTypeQ40:
  184. f32 := ctx.Empty(DTypeF32, t.Shape()...)
  185. f32 = t.Copy(ctx, f32)
  186. return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
  187. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  188. })
  189. case DTypeI32:
  190. return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
  191. return strconv.FormatInt(int64(i), 10)
  192. })
  193. default:
  194. return "<unsupported>"
  195. }
  196. }
  197. func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
  198. if t.Bytes() == nil {
  199. ctx.Forward(t).Compute(t)
  200. }
  201. s := make(S, mul(t.Shape()...))
  202. if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
  203. panic(err)
  204. }
  205. shape := t.Shape()
  206. slices.Reverse(shape)
  207. var sb strings.Builder
  208. var f func([]int, int)
  209. f = func(dims []int, stride int) {
  210. prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
  211. sb.WriteString("[")
  212. defer func() { sb.WriteString("]") }()
  213. for i := 0; i < dims[0]; i++ {
  214. if i >= items && i < dims[0]-items {
  215. sb.WriteString("..., ")
  216. // skip to next printable element
  217. skip := dims[0] - 2*items
  218. if len(dims) > 1 {
  219. stride += mul(append(dims[1:], skip)...)
  220. fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
  221. }
  222. i += skip - 1
  223. } else if len(dims) > 1 {
  224. f(dims[1:], stride)
  225. stride += mul(dims[1:]...)
  226. if i < dims[0]-1 {
  227. fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
  228. }
  229. } else {
  230. text := fn(s[stride+i])
  231. if len(text) > 0 && text[0] != '-' {
  232. sb.WriteString(" ")
  233. }
  234. sb.WriteString(text)
  235. if i < dims[0]-1 {
  236. sb.WriteString(", ")
  237. }
  238. }
  239. }
  240. }
  241. f(shape, 0)
  242. return sb.String()
  243. }
  244. type DType int
  245. const (
  246. DTypeOther DType = iota
  247. DTypeF32
  248. DTypeF16
  249. DTypeQ80
  250. DTypeQ40
  251. DTypeI32
  252. )