backend.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. package ml
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "os"
  7. "slices"
  8. "strconv"
  9. "strings"
  10. )
  11. type Config interface {
  12. Architecture() string
  13. String(string, ...string) string
  14. Uint(string, ...uint32) uint32
  15. Float(string, ...float32) float32
  16. Bool(string, ...bool) bool
  17. Strings(string, ...[]string) []string
  18. Uints(string, ...[]uint32) []uint32
  19. Floats(string, ...[]float32) []float32
  20. }
  21. type Backend interface {
  22. Config() Config
  23. Get(name string) Tensor
  24. NewContext() Context
  25. NewContextSize(size int) Context
  26. }
  27. // BackendCacheConfig should be implemented by backends that need special output
  28. // from the cache to meet specific requirements. It is frequently implemented in
  29. // conjunction with ScaledDotProductAttention.
  30. type BackendCacheConfig interface {
  31. CacheConfig() CacheConfig
  32. }
  33. // CacheConfig controls optimizations (mostly backend-specific) that may transform
  34. // the output the cache to work better with specific kernels.
  35. type CacheConfig struct {
  36. // CachePadding specifies the multiple for the number of tokens of cache history
  37. // that will be returned from cache Get for k, v and mask. The capacity of the
  38. // cache itself will also be increased to a multiple of this size if needed.
  39. CachePadding int
  40. // PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
  41. // and return the permuted version via Get. This uses the cache copy operation
  42. // to avoid a Contiguous call on the permuted tensor.
  43. PermutedV bool
  44. // MaskDType specifies the data type for generating the mask. If unset it will
  45. // default to DTypeF32.
  46. MaskDType DType
  47. // MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
  48. // Any position that does not correspond to an actual token will be filled with -Inf.
  49. MaskBatchPadding int
  50. }
  51. // BackendParams controls how the backend loads and executes models
  52. type BackendParams struct {
  53. // NumThreads sets the number of threads to use if running on the CPU
  54. NumThreads int
  55. // MainGPU is the index of the primary GPU to use
  56. MainGPU int
  57. // NumGPULayers is the number of layers to offload to GPUs
  58. NumGPULayers int
  59. // TensorSplit is the fraction of the model to offload to each GPU
  60. TensorSplit []float32
  61. // FlashAttention indicates that we should use a fused flash attention kernel
  62. FlashAttention bool
  63. }
  64. var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
  65. func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
  66. if _, ok := backends[name]; ok {
  67. panic("backend: backend already registered")
  68. }
  69. backends[name] = f
  70. }
  71. func NewBackend(f *os.File, params BackendParams) (Backend, error) {
  72. if backend, ok := backends["ggml"]; ok {
  73. return backend(f, params)
  74. }
  75. return nil, fmt.Errorf("unsupported backend")
  76. }
  77. type Context interface {
  78. Empty(dtype DType, shape ...int) Tensor
  79. Zeros(dtype DType, shape ...int) Tensor
  80. FromFloatSlice(s []float32, shape ...int) (Tensor, error)
  81. FromIntSlice(s []int32, shape ...int) (Tensor, error)
  82. Forward(...Tensor) Context
  83. Compute(...Tensor)
  84. MaxGraphNodes() int
  85. Close()
  86. // Input returns a context appropriate for creating input tensors
  87. Input() Context
  88. // Output returns a context appropriate for creating output tensors
  89. Output() Context
  90. // Layer returns a context appropriate for creating intermediate tensors
  91. Layer(int) Context
  92. }
  93. type Tensor interface {
  94. Dim(n int) int
  95. Stride(n int) int
  96. Shape() []int
  97. DType() DType
  98. Bytes() []byte
  99. Floats() []float32
  100. Add(ctx Context, t2 Tensor) Tensor
  101. Mul(ctx Context, t2 Tensor) Tensor
  102. Mulmat(ctx Context, t2 Tensor) Tensor
  103. MulmatFullPrec(ctx Context, t2 Tensor) Tensor
  104. Softmax(ctx Context) Tensor
  105. LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
  106. RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
  107. Scale(ctx Context, s float64) Tensor
  108. Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
  109. RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
  110. Tanh(ctx Context) Tensor
  111. GELU(ctx Context) Tensor
  112. SILU(ctx Context) Tensor
  113. Reshape(ctx Context, shape ...int) Tensor
  114. View(ctx Context, offset int, shape ...int) Tensor
  115. Permute(ctx Context, shape ...int) Tensor
  116. Contiguous(ctx Context) Tensor
  117. Pad(ctx Context, shape ...int) Tensor
  118. Unpad(ctx Context, shape ...int) Tensor
  119. Stack(ctx Context, dim int, s ...Tensor) Tensor
  120. Concat(ctx Context, t2 Tensor, dim int) Tensor
  121. Rows(ctx Context, t2 Tensor) Tensor
  122. Copy(ctx Context, t2 Tensor) Tensor
  123. }
  124. // ScaledDotProductAttention implements a fused attention
  125. // operation equivalent to following code on a tensor named
  126. // query:
  127. //
  128. // query = query.Permute(ctx, 0, 2, 1, 3)
  129. // key = key.Permute(ctx, 0, 2, 1, 3)
  130. // value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  131. //
  132. // kq := key.MulmatFullPrec(ctx, query)
  133. //
  134. // kq = kq.Scale(ctx, scale)
  135. //
  136. // if mask != nil {
  137. // kq = kq.Add(ctx, mask)
  138. // }
  139. //
  140. // kq = kq.Softmax(ctx)
  141. //
  142. // kqv := value.Mulmat(ctx, kq)
  143. // return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  144. type ScaledDotProductAttention interface {
  145. ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
  146. }
  147. type number interface {
  148. ~int | ~int8 | ~int16 | ~int32 | ~int64 |
  149. ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
  150. ~float32 | ~float64 |
  151. ~complex64 | ~complex128
  152. }
  153. func mul[T number](s ...T) T {
  154. p := T(1)
  155. for _, v := range s {
  156. p *= v
  157. }
  158. return p
  159. }
  160. type DumpOptions struct {
  161. // Items is the number of elements to print at the beginning and end of each dimension.
  162. Items int
  163. // Precision is the number of decimal places to print. Applies to float32 and float64.
  164. Precision int
  165. }
  166. func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
  167. if len(opts) < 1 {
  168. opts = append(opts, DumpOptions{
  169. Items: 3,
  170. Precision: 4,
  171. })
  172. }
  173. switch t.DType() {
  174. case DTypeF32:
  175. return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
  176. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  177. })
  178. case DTypeF16, DTypeQ80, DTypeQ40:
  179. f32 := ctx.Empty(DTypeF32, t.Shape()...)
  180. f32 = t.Copy(ctx, f32)
  181. return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
  182. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  183. })
  184. case DTypeI32:
  185. return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
  186. return strconv.FormatInt(int64(i), 10)
  187. })
  188. default:
  189. return "<unsupported>"
  190. }
  191. }
  192. func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
  193. if t.Bytes() == nil {
  194. ctx.Forward(t).Compute(t)
  195. }
  196. s := make(S, mul(t.Shape()...))
  197. if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
  198. panic(err)
  199. }
  200. shape := t.Shape()
  201. slices.Reverse(shape)
  202. var sb strings.Builder
  203. var f func([]int, int)
  204. f = func(dims []int, stride int) {
  205. prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
  206. sb.WriteString("[")
  207. defer func() { sb.WriteString("]") }()
  208. for i := 0; i < dims[0]; i++ {
  209. if i >= items && i < dims[0]-items {
  210. sb.WriteString("..., ")
  211. // skip to next printable element
  212. skip := dims[0] - 2*items
  213. if len(dims) > 1 {
  214. stride += mul(append(dims[1:], skip)...)
  215. fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
  216. }
  217. i += skip - 1
  218. } else if len(dims) > 1 {
  219. f(dims[1:], stride)
  220. stride += mul(dims[1:]...)
  221. if i < dims[0]-1 {
  222. fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
  223. }
  224. } else {
  225. text := fn(s[stride+i])
  226. if len(text) > 0 && text[0] != '-' {
  227. sb.WriteString(" ")
  228. }
  229. sb.WriteString(text)
  230. if i < dims[0]-1 {
  231. sb.WriteString(", ")
  232. }
  233. }
  234. }
  235. }
  236. f(shape, 0)
  237. return sb.String()
  238. }
  239. type DType int
  240. const (
  241. DTypeOther DType = iota
  242. DTypeF32
  243. DTypeF16
  244. DTypeQ80
  245. DTypeQ40
  246. DTypeI32
  247. )