backend.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. package ml
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "os"
  7. "strconv"
  8. "strings"
  9. )
  10. type Config interface {
  11. Architecture() string
  12. String(string, ...string) string
  13. Uint(string, ...uint32) uint32
  14. Float(string, ...float32) float32
  15. Bool(string, ...bool) bool
  16. Strings(string, ...[]string) []string
  17. Uints(string, ...[]uint32) []uint32
  18. }
  19. type Backend interface {
  20. Config() Config
  21. Get(name string) Tensor
  22. NewContext() Context
  23. NewContextSize(size int) Context
  24. }
  25. // BackendCacheConfig should be implemented by backends that need special output
  26. // from the cache to meet specific requirements. It is frequently implemented in
  27. // conjunction with ScaledDotProductAttention.
  28. type BackendCacheConfig interface {
  29. CacheConfig() CacheConfig
  30. }
  31. // CacheConfig controls optimizations (mostly backend-specific) that may transform
  32. // the output the cache to work better with specific kernels.
  33. type CacheConfig struct {
  34. // CachePadding specifies the multiple for the number of tokens of cache history
  35. // that will be returned from cache Get for k, v and mask. The capacity of the
  36. // cache itself will also be increased to a multiple of this size if needed.
  37. CachePadding int
  38. // PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
  39. // and return the permuted version via Get. This uses the cache copy operation
  40. // to avoid a Contiguous call on the permuted tensor.
  41. PermutedV bool
  42. // MaskDType specifies the data type for generating the mask. If unset it will
  43. // default to DTypeF32.
  44. MaskDType DType
  45. // MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
  46. // Any position that does not correspond to an actual token will be filled with -Inf.
  47. MaskBatchPadding int
  48. }
  49. // BackendParams controls how the backend loads and executes models
  50. type BackendParams struct {
  51. // NumThreads sets the number of threads to use if running on the CPU
  52. NumThreads int
  53. // MainGPU is the index of the primary GPU to use
  54. MainGPU int
  55. // NumGPULayers is the number of layers to offload to GPUs
  56. NumGPULayers int
  57. // TensorSplit is the fraction of the model to offload to each GPU
  58. TensorSplit []float32
  59. // FlashAttention indicates that we should use a fused flash attention kernel
  60. FlashAttention bool
  61. }
  62. var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
  63. func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
  64. if _, ok := backends[name]; ok {
  65. panic("backend: backend already registered")
  66. }
  67. backends[name] = f
  68. }
  69. func NewBackend(f *os.File, params BackendParams) (Backend, error) {
  70. if backend, ok := backends["ggml"]; ok {
  71. return backend(f, params)
  72. }
  73. return nil, fmt.Errorf("unsupported backend")
  74. }
  75. type Context interface {
  76. Empty(dtype DType, shape ...int) Tensor
  77. Zeros(dtype DType, shape ...int) Tensor
  78. FromFloatSlice(s []float32, shape ...int) (Tensor, error)
  79. FromIntSlice(s []int32, shape ...int) (Tensor, error)
  80. Forward(...Tensor) Context
  81. Compute(...Tensor)
  82. MaxGraphNodes() int
  83. Close()
  84. // Input returns a context appropriate for creating input tensors
  85. Input() Context
  86. // Output returns a context appropriate for creating output tensors
  87. Output() Context
  88. // Layer returns a context appropriate for creating intermediate tensors
  89. Layer(int) Context
  90. }
  91. type Tensor interface {
  92. Dim(n int) int
  93. Stride(n int) int
  94. Shape() []int
  95. DType() DType
  96. Bytes() []byte
  97. Floats() []float32
  98. Add(ctx Context, t2 Tensor) Tensor
  99. Mul(ctx Context, t2 Tensor) Tensor
  100. Mulmat(ctx Context, t2 Tensor) Tensor
  101. MulmatFullPrec(ctx Context, t2 Tensor) Tensor
  102. Softmax(ctx Context) Tensor
  103. LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
  104. RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
  105. Scale(ctx Context, s float64) Tensor
  106. Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
  107. RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
  108. Tanh(ctx Context) Tensor
  109. GELU(ctx Context) Tensor
  110. SILU(ctx Context) Tensor
  111. Reshape(ctx Context, shape ...int) Tensor
  112. View(ctx Context, offset int, shape ...int) Tensor
  113. Permute(ctx Context, shape ...int) Tensor
  114. Contiguous(ctx Context) Tensor
  115. Pad(ctx Context, shape ...int) Tensor
  116. Unpad(ctx Context, shape ...int) Tensor
  117. Stack(ctx Context, dim int, s ...Tensor) Tensor
  118. Concat(ctx Context, t2 Tensor, dim int) Tensor
  119. Rows(ctx Context, t2 Tensor) Tensor
  120. Copy(ctx Context, t2 Tensor) Tensor
  121. }
  122. // ScaledDotProductAttention implements a fused attention
  123. // operation equivalent to following code on a tensor named
  124. // query:
  125. //
  126. // query = query.Permute(ctx, 0, 2, 1, 3)
  127. // key = key.Permute(ctx, 0, 2, 1, 3)
  128. // value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  129. //
  130. // kq := key.MulmatFullPrec(ctx, query)
  131. //
  132. // kq = kq.Scale(ctx, scale)
  133. //
  134. // if mask != nil {
  135. // kq = kq.Add(ctx, mask)
  136. // }
  137. //
  138. // kq = kq.Softmax(ctx)
  139. //
  140. // kqv := value.Mulmat(ctx, kq)
  141. // return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  142. type ScaledDotProductAttention interface {
  143. ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
  144. }
  145. type number interface {
  146. ~int | ~int8 | ~int16 | ~int32 | ~int64 |
  147. ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
  148. ~float32 | ~float64 |
  149. ~complex64 | ~complex128
  150. }
  151. func mul[T number](s ...T) T {
  152. p := T(1)
  153. for _, v := range s {
  154. p *= v
  155. }
  156. return p
  157. }
  158. type DumpOptions struct {
  159. // Items is the number of elements to print at the beginning and end of each dimension.
  160. Items int
  161. // Precision is the number of decimal places to print. Applies to float32 and float64.
  162. Precision int
  163. }
  164. func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
  165. if len(opts) < 1 {
  166. opts = append(opts, DumpOptions{
  167. Items: 3,
  168. Precision: 4,
  169. })
  170. }
  171. switch t.DType() {
  172. case DTypeF32:
  173. return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
  174. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  175. })
  176. case DTypeF16, DTypeQ80, DTypeQ40:
  177. f32 := ctx.Empty(DTypeF32, t.Shape()...)
  178. f32 = t.Copy(ctx, f32)
  179. return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
  180. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  181. })
  182. case DTypeI32:
  183. return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
  184. return strconv.FormatInt(int64(i), 10)
  185. })
  186. default:
  187. return "<unsupported>"
  188. }
  189. }
  190. func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
  191. if t.Bytes() == nil {
  192. ctx.Forward(t).Compute(t)
  193. }
  194. s := make(S, mul(t.Shape()...))
  195. if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
  196. panic(err)
  197. }
  198. shape := t.Shape()
  199. var sb strings.Builder
  200. var f func([]int, int)
  201. f = func(dims []int, stride int) {
  202. prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
  203. fmt.Fprint(&sb, "[")
  204. defer func() { fmt.Fprint(&sb, "]") }()
  205. for i := 0; i < dims[0]; i++ {
  206. if i >= items && i < dims[0]-items {
  207. fmt.Fprint(&sb, "..., ")
  208. // skip to next printable element
  209. skip := dims[0] - 2*items
  210. if len(dims) > 1 {
  211. stride += mul(append(dims[1:], skip)...)
  212. fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
  213. }
  214. i += skip - 1
  215. } else if len(dims) > 1 {
  216. f(dims[1:], stride)
  217. stride += mul(dims[1:]...)
  218. if i < dims[0]-1 {
  219. fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
  220. }
  221. } else {
  222. fmt.Fprint(&sb, fn(s[stride+i]))
  223. if i < dims[0]-1 {
  224. fmt.Fprint(&sb, ", ")
  225. }
  226. }
  227. }
  228. }
  229. f(shape, 0)
  230. return sb.String()
  231. }
  232. type DType int
  233. const (
  234. DTypeOther DType = iota
  235. DTypeF32
  236. DTypeF16
  237. DTypeQ80
  238. DTypeQ40
  239. DTypeI32
  240. )