backend.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. package ml
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "os"
  7. "strconv"
  8. "strings"
  9. )
  10. type Config interface {
  11. Architecture() string
  12. String(string, ...string) string
  13. Uint(string, ...uint32) uint32
  14. Float(string, ...float32) float32
  15. Bool(string, ...bool) bool
  16. Strings(string, ...[]string) []string
  17. Uints(string, ...[]uint32) []uint32
  18. }
  19. type Backend interface {
  20. Config() Config
  21. Get(name string) Tensor
  22. NewContext() Context
  23. }
  24. // BackendCacheConfig should be implemented by backends that need special output
  25. // from the cache to meet specific requirements. It is frequently implemented in
  26. // conjunction with ScaledDotProductAttention.
  27. type BackendCacheConfig interface {
  28. CacheConfig() CacheConfig
  29. }
  30. // CacheConfig controls optimizations (mostly backend-specific) that may transform
  31. // the output the cache to work better with specific kernels.
  32. type CacheConfig struct {
  33. // CachePadding specifies the multiple for the number of tokens of cache history
  34. // that will be returned from cache Get for k, v and mask. The capacity of the
  35. // cache itself will also be increased to a multiple of this size if needed.
  36. CachePadding int
  37. // PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
  38. // and return the permuted version via Get. This uses the cache copy operation
  39. // to avoid a Contiguous call on the permuted tensor.
  40. PermutedV bool
  41. // MaskDType specifies the data type for generating the mask. If unset it will
  42. // default to DTypeF32.
  43. MaskDType DType
  44. // MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
  45. // Any position that does not correspond to an actual token will be filled with -Inf.
  46. MaskBatchPadding int
  47. }
  48. // BackendParams controls how the backend loads and executes models
  49. type BackendParams struct {
  50. // NumThreads sets the number of threads to use if running on the CPU
  51. NumThreads int
  52. // MainGPU is the index of the primary GPU to use
  53. MainGPU int
  54. // NumGPULayers is the number of layers to offload to GPUs
  55. NumGPULayers int
  56. // TensorSplit is the fraction of the model to offload to each GPU
  57. TensorSplit []float32
  58. // FlashAttention indicates that we should use a fused flash attention kernel
  59. FlashAttention bool
  60. }
  61. var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
  62. func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
  63. if _, ok := backends[name]; ok {
  64. panic("backend: backend already registered")
  65. }
  66. backends[name] = f
  67. }
  68. func NewBackend(f *os.File, params BackendParams) (Backend, error) {
  69. if backend, ok := backends["ggml"]; ok {
  70. return backend(f, params)
  71. }
  72. return nil, fmt.Errorf("unsupported backend")
  73. }
  74. type Context interface {
  75. Empty(dtype DType, shape ...int) Tensor
  76. Zeros(dtype DType, shape ...int) Tensor
  77. FromFloatSlice(s []float32, shape ...int) (Tensor, error)
  78. FromIntSlice(s []int32, shape ...int) (Tensor, error)
  79. Forward(...Tensor) Context
  80. Compute(...Tensor)
  81. MaxTensors() int
  82. Close()
  83. }
  84. type Tensor interface {
  85. Dim(n int) int
  86. Stride(n int) int
  87. Shape() []int
  88. DType() DType
  89. Bytes() []byte
  90. Floats() []float32
  91. Add(ctx Context, t2 Tensor) Tensor
  92. Mul(ctx Context, t2 Tensor) Tensor
  93. Mulmat(ctx Context, t2 Tensor) Tensor
  94. MulmatFullPrec(ctx Context, t2 Tensor) Tensor
  95. Softmax(ctx Context) Tensor
  96. LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
  97. RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
  98. Scale(ctx Context, s float64) Tensor
  99. Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
  100. RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
  101. Tanh(ctx Context) Tensor
  102. GELU(ctx Context) Tensor
  103. SILU(ctx Context) Tensor
  104. Reshape(ctx Context, shape ...int) Tensor
  105. View(ctx Context, offset int, shape ...int) Tensor
  106. Permute(ctx Context, shape ...int) Tensor
  107. Contiguous(ctx Context) Tensor
  108. Pad(ctx Context, shape ...int) Tensor
  109. Unpad(ctx Context, shape ...int) Tensor
  110. Stack(ctx Context, dim int, s ...Tensor) Tensor
  111. Concat(ctx Context, t2 Tensor, dim int) Tensor
  112. Rows(ctx Context, t2 Tensor) Tensor
  113. Copy(ctx Context, t2 Tensor) Tensor
  114. }
  115. // ScaledDotProductAttention implements a fused attention
  116. // operation equivalent to following code on a tensor named
  117. // query:
  118. //
  119. // query = query.Permute(ctx, 0, 2, 1, 3)
  120. // key = key.Permute(ctx, 0, 2, 1, 3)
  121. // value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  122. //
  123. // kq := key.MulmatFullPrec(ctx, query)
  124. //
  125. // kq = kq.Scale(ctx, scale)
  126. //
  127. // if mask != nil {
  128. // kq = kq.Add(ctx, mask)
  129. // }
  130. //
  131. // kq = kq.Softmax(ctx)
  132. //
  133. // kqv := value.Mulmat(ctx, kq)
  134. // return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  135. type ScaledDotProductAttention interface {
  136. ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
  137. }
  138. type number interface {
  139. ~int | ~int8 | ~int16 | ~int32 | ~int64 |
  140. ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
  141. ~float32 | ~float64 |
  142. ~complex64 | ~complex128
  143. }
  144. func mul[T number](s ...T) T {
  145. p := T(1)
  146. for _, v := range s {
  147. p *= v
  148. }
  149. return p
  150. }
  151. type DumpOptions struct {
  152. // Items is the number of elements to print at the beginning and end of each dimension.
  153. Items int
  154. // Precision is the number of decimal places to print. Applies to float32 and float64.
  155. Precision int
  156. }
  157. func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
  158. if len(opts) < 1 {
  159. opts = append(opts, DumpOptions{
  160. Items: 3,
  161. Precision: 4,
  162. })
  163. }
  164. switch t.DType() {
  165. case DTypeF32:
  166. return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
  167. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  168. })
  169. case DTypeF16:
  170. f32 := ctx.Empty(DTypeF32, t.Shape()...)
  171. f32 = t.Copy(ctx, f32)
  172. return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
  173. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  174. })
  175. case DTypeI32:
  176. return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
  177. return strconv.FormatInt(int64(i), 10)
  178. })
  179. default:
  180. return "<unsupported>"
  181. }
  182. }
  183. func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
  184. if t.Bytes() == nil {
  185. ctx.Forward(t).Compute(t)
  186. }
  187. s := make(S, mul(t.Shape()...))
  188. if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
  189. panic(err)
  190. }
  191. shape := t.Shape()
  192. var sb strings.Builder
  193. var f func([]int, int)
  194. f = func(dims []int, stride int) {
  195. prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
  196. fmt.Fprint(&sb, "[")
  197. defer func() { fmt.Fprint(&sb, "]") }()
  198. for i := 0; i < dims[0]; i++ {
  199. if i >= items && i < dims[0]-items {
  200. fmt.Fprint(&sb, "..., ")
  201. // skip to next printable element
  202. skip := dims[0] - 2*items
  203. if len(dims) > 1 {
  204. stride += mul(append(dims[1:], skip)...)
  205. fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
  206. }
  207. i += skip - 1
  208. } else if len(dims) > 1 {
  209. f(dims[1:], stride)
  210. stride += mul(dims[1:]...)
  211. if i < dims[0]-1 {
  212. fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
  213. }
  214. } else {
  215. fmt.Fprint(&sb, fn(s[stride+i]))
  216. if i < dims[0]-1 {
  217. fmt.Fprint(&sb, ", ")
  218. }
  219. }
  220. }
  221. }
  222. f(shape, 0)
  223. return sb.String()
  224. }
  225. type DType int
  226. const (
  227. DTypeOther DType = iota
  228. DTypeF32
  229. DTypeF16
  230. DTypeI32
  231. )