backend.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. package ml
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "os"
  7. "slices"
  8. "strconv"
  9. "strings"
  10. )
  11. type Config interface {
  12. Architecture() string
  13. String(string, ...string) string
  14. Uint(string, ...uint32) uint32
  15. Float(string, ...float32) float32
  16. Bool(string, ...bool) bool
  17. Strings(string, ...[]string) []string
  18. Uints(string, ...[]uint32) []uint32
  19. }
  20. type Backend interface {
  21. Config() Config
  22. Get(name string) Tensor
  23. NewContext() Context
  24. NewContextSize(size int) Context
  25. }
  26. // BackendCacheConfig should be implemented by backends that need special output
  27. // from the cache to meet specific requirements. It is frequently implemented in
  28. // conjunction with ScaledDotProductAttention.
  29. type BackendCacheConfig interface {
  30. CacheConfig() CacheConfig
  31. }
  32. // CacheConfig controls optimizations (mostly backend-specific) that may transform
  33. // the output the cache to work better with specific kernels.
  34. type CacheConfig struct {
  35. // CachePadding specifies the multiple for the number of tokens of cache history
  36. // that will be returned from cache Get for k, v and mask. The capacity of the
  37. // cache itself will also be increased to a multiple of this size if needed.
  38. CachePadding int
  39. // PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
  40. // and return the permuted version via Get. This uses the cache copy operation
  41. // to avoid a Contiguous call on the permuted tensor.
  42. PermutedV bool
  43. // MaskDType specifies the data type for generating the mask. If unset it will
  44. // default to DTypeF32.
  45. MaskDType DType
  46. // MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
  47. // Any position that does not correspond to an actual token will be filled with -Inf.
  48. MaskBatchPadding int
  49. }
  50. // BackendParams controls how the backend loads and executes models
  51. type BackendParams struct {
  52. // NumThreads sets the number of threads to use if running on the CPU
  53. NumThreads int
  54. // MainGPU is the index of the primary GPU to use
  55. MainGPU int
  56. // NumGPULayers is the number of layers to offload to GPUs
  57. NumGPULayers int
  58. // TensorSplit is the fraction of the model to offload to each GPU
  59. TensorSplit []float32
  60. // FlashAttention indicates that we should use a fused flash attention kernel
  61. FlashAttention bool
  62. }
  63. var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
  64. func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
  65. if _, ok := backends[name]; ok {
  66. panic("backend: backend already registered")
  67. }
  68. backends[name] = f
  69. }
  70. func NewBackend(f *os.File, params BackendParams) (Backend, error) {
  71. if backend, ok := backends["ggml"]; ok {
  72. return backend(f, params)
  73. }
  74. return nil, fmt.Errorf("unsupported backend")
  75. }
  76. type Context interface {
  77. Empty(dtype DType, shape ...int) Tensor
  78. Zeros(dtype DType, shape ...int) Tensor
  79. FromFloatSlice(s []float32, shape ...int) (Tensor, error)
  80. FromIntSlice(s []int32, shape ...int) (Tensor, error)
  81. Forward(...Tensor) Context
  82. Compute(...Tensor)
  83. MaxGraphNodes() int
  84. Close()
  85. // Input returns a context appropriate for creating input tensors
  86. Input() Context
  87. // Output returns a context appropriate for creating output tensors
  88. Output() Context
  89. // Layer returns a context appropriate for creating intermediate tensors
  90. Layer(int) Context
  91. }
  92. type Tensor interface {
  93. Dim(n int) int
  94. Stride(n int) int
  95. Shape() []int
  96. DType() DType
  97. Bytes() []byte
  98. Floats() []float32
  99. Add(ctx Context, t2 Tensor) Tensor
  100. Mul(ctx Context, t2 Tensor) Tensor
  101. Mulmat(ctx Context, t2 Tensor) Tensor
  102. MulmatFullPrec(ctx Context, t2 Tensor) Tensor
  103. Softmax(ctx Context) Tensor
  104. LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
  105. RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
  106. Scale(ctx Context, s float64) Tensor
  107. Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
  108. RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
  109. Tanh(ctx Context) Tensor
  110. GELU(ctx Context) Tensor
  111. SILU(ctx Context) Tensor
  112. Reshape(ctx Context, shape ...int) Tensor
  113. View(ctx Context, offset int, shape ...int) Tensor
  114. Permute(ctx Context, shape ...int) Tensor
  115. Contiguous(ctx Context) Tensor
  116. Pad(ctx Context, shape ...int) Tensor
  117. Unpad(ctx Context, shape ...int) Tensor
  118. Stack(ctx Context, dim int, s ...Tensor) Tensor
  119. Concat(ctx Context, t2 Tensor, dim int) Tensor
  120. Rows(ctx Context, t2 Tensor) Tensor
  121. Copy(ctx Context, t2 Tensor) Tensor
  122. }
  123. // ScaledDotProductAttention implements a fused attention
  124. // operation equivalent to following code on a tensor named
  125. // query:
  126. //
  127. // query = query.Permute(ctx, 0, 2, 1, 3)
  128. // key = key.Permute(ctx, 0, 2, 1, 3)
  129. // value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  130. //
  131. // kq := key.MulmatFullPrec(ctx, query)
  132. //
  133. // kq = kq.Scale(ctx, scale)
  134. //
  135. // if mask != nil {
  136. // kq = kq.Add(ctx, mask)
  137. // }
  138. //
  139. // kq = kq.Softmax(ctx)
  140. //
  141. // kqv := value.Mulmat(ctx, kq)
  142. // return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  143. type ScaledDotProductAttention interface {
  144. ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
  145. }
  146. type number interface {
  147. ~int | ~int8 | ~int16 | ~int32 | ~int64 |
  148. ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
  149. ~float32 | ~float64 |
  150. ~complex64 | ~complex128
  151. }
  152. func mul[T number](s ...T) T {
  153. p := T(1)
  154. for _, v := range s {
  155. p *= v
  156. }
  157. return p
  158. }
  159. type DumpOptions struct {
  160. // Items is the number of elements to print at the beginning and end of each dimension.
  161. Items int
  162. // Precision is the number of decimal places to print. Applies to float32 and float64.
  163. Precision int
  164. }
  165. func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
  166. if len(opts) < 1 {
  167. opts = append(opts, DumpOptions{
  168. Items: 3,
  169. Precision: 4,
  170. })
  171. }
  172. switch t.DType() {
  173. case DTypeF32:
  174. return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
  175. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  176. })
  177. case DTypeF16, DTypeQ80, DTypeQ40:
  178. f32 := ctx.Empty(DTypeF32, t.Shape()...)
  179. f32 = t.Copy(ctx, f32)
  180. return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
  181. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  182. })
  183. case DTypeI32:
  184. return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
  185. return strconv.FormatInt(int64(i), 10)
  186. })
  187. default:
  188. return "<unsupported>"
  189. }
  190. }
  191. func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
  192. if t.Bytes() == nil {
  193. ctx.Forward(t).Compute(t)
  194. }
  195. s := make(S, mul(t.Shape()...))
  196. if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
  197. panic(err)
  198. }
  199. shape := t.Shape()
  200. slices.Reverse(shape)
  201. var sb strings.Builder
  202. var f func([]int, int)
  203. f = func(dims []int, stride int) {
  204. prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
  205. sb.WriteString("[")
  206. defer func() { sb.WriteString("]") }()
  207. for i := 0; i < dims[0]; i++ {
  208. if i >= items && i < dims[0]-items {
  209. sb.WriteString("..., ")
  210. // skip to next printable element
  211. skip := dims[0] - 2*items
  212. if len(dims) > 1 {
  213. stride += mul(append(dims[1:], skip)...)
  214. fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
  215. }
  216. i += skip - 1
  217. } else if len(dims) > 1 {
  218. f(dims[1:], stride)
  219. stride += mul(dims[1:]...)
  220. if i < dims[0]-1 {
  221. fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
  222. }
  223. } else {
  224. text := fn(s[stride+i])
  225. if len(text) > 0 && text[0] != '-' {
  226. sb.WriteString(" ")
  227. }
  228. sb.WriteString(text)
  229. if i < dims[0]-1 {
  230. sb.WriteString(", ")
  231. }
  232. }
  233. }
  234. }
  235. f(shape, 0)
  236. return sb.String()
  237. }
  238. type DType int
  239. const (
  240. DTypeOther DType = iota
  241. DTypeF32
  242. DTypeF16
  243. DTypeQ80
  244. DTypeQ40
  245. DTypeI32
  246. )