backend.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. package ml
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "os"
  7. "strconv"
  8. "strings"
  9. )
  10. type Config interface {
  11. Architecture() string
  12. String(string, ...string) string
  13. Uint(string, ...uint32) uint32
  14. Float(string, ...float32) float32
  15. Bool(string, ...bool) bool
  16. Strings(string, ...[]string) []string
  17. Uints(string, ...[]uint32) []uint32
  18. }
  19. type Backend interface {
  20. Config() Config
  21. Get(name string) Tensor
  22. NewContext() Context
  23. SystemInfo() string
  24. }
  25. // BackendCacheConfig should be implemented by backends that need special output
  26. // from the cache to meet specific requirements. It is frequently implemented in
  27. // conjunction with ScaledDotProductAttention.
  28. type BackendCacheConfig interface {
  29. CacheConfig() CacheConfig
  30. }
  31. // CacheConfig controls optimizations (mostly backend-specific) that may transform
  32. // the output the cache to work better with specific kernels.
  33. type CacheConfig struct {
  34. // CachePadding specifies the multiple for the number of tokens of cache history
  35. // that will be returned from cache Get for k, v and mask. The capacity of the
  36. // cache itself will also be increased to a multiple of this size if needed.
  37. CachePadding int
  38. // PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
  39. // and return the permuted version via Get. This uses the cache copy operation
  40. // to avoid a Contiguous call on the permuted tensor.
  41. PermutedV bool
  42. }
  43. // BackendParams controls how the backend loads and executes models
  44. type BackendParams struct {
  45. // NumThreads sets the number of threads to use if running on the CPU
  46. NumThreads int
  47. // MainGPU is the index of the primary GPU to use
  48. MainGPU int
  49. // NumGPULayers is the number of layers to offload to GPUs
  50. NumGPULayers int
  51. // TensorSplit is the fraction of the model to offload to each GPU
  52. TensorSplit []float32
  53. }
  54. var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
  55. func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
  56. if _, ok := backends[name]; ok {
  57. panic("backend: backend already registered")
  58. }
  59. backends[name] = f
  60. }
  61. func NewBackend(f *os.File, params BackendParams) (Backend, error) {
  62. if backend, ok := backends["ggml"]; ok {
  63. return backend(f, params)
  64. }
  65. return nil, fmt.Errorf("unsupported backend")
  66. }
  67. type Context interface {
  68. Zeros(dtype DType, shape ...int) Tensor
  69. FromFloatSlice(s []float32, shape ...int) (Tensor, error)
  70. FromIntSlice(s []int32, shape ...int) (Tensor, error)
  71. Forward(...Tensor) Context
  72. Compute(...Tensor)
  73. MaxTensors() int
  74. Close()
  75. }
  76. type Tensor interface {
  77. Dim(n int) int
  78. Stride(n int) int
  79. Shape() []int
  80. DType() DType
  81. Bytes() []byte
  82. Floats() []float32
  83. Add(ctx Context, t2 Tensor) Tensor
  84. Mul(ctx Context, t2 Tensor) Tensor
  85. Mulmat(ctx Context, t2 Tensor) Tensor
  86. MulmatFullPrec(ctx Context, t2 Tensor) Tensor
  87. Softmax(ctx Context) Tensor
  88. LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
  89. RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
  90. Scale(ctx Context, s float64) Tensor
  91. Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
  92. RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
  93. Tanh(ctx Context) Tensor
  94. GELU(ctx Context) Tensor
  95. SILU(ctx Context) Tensor
  96. Reshape(ctx Context, shape ...int) Tensor
  97. View(ctx Context, offset int, shape ...int) Tensor
  98. Permute(ctx Context, shape ...int) Tensor
  99. Contiguous(ctx Context) Tensor
  100. Pad(ctx Context, shape ...int) Tensor
  101. Unpad(ctx Context, shape ...int) Tensor
  102. Stack(ctx Context, dim int, s ...Tensor) Tensor
  103. Concat(ctx Context, t2 Tensor, dim int) Tensor
  104. Rows(ctx Context, t2 Tensor) Tensor
  105. Copy(ctx Context, t2 Tensor) Tensor
  106. }
  107. // ScaledDotProductAttention implements a fused attention
  108. // operation equivalent to following code on a tensor named
  109. // query:
  110. //
  111. // query = query.Permute(ctx, 0, 2, 1, 3)
  112. // key = key.Permute(ctx, 0, 2, 1, 3)
  113. // value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
  114. //
  115. // kq := key.MulmatFullPrec(ctx, query)
  116. //
  117. // kq = kq.Scale(ctx, scale)
  118. //
  119. // if mask != nil {
  120. // kq = kq.Add(ctx, mask)
  121. // }
  122. //
  123. // kq = kq.Softmax(ctx)
  124. //
  125. // kqv := value.Mulmat(ctx, kq)
  126. // return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  127. type ScaledDotProductAttention interface {
  128. ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
  129. }
  130. type number interface {
  131. ~int | ~int8 | ~int16 | ~int32 | ~int64 |
  132. ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
  133. ~float32 | ~float64 |
  134. ~complex64 | ~complex128
  135. }
  136. func mul[T number](s ...T) T {
  137. p := T(1)
  138. for _, v := range s {
  139. p *= v
  140. }
  141. return p
  142. }
  143. type DumpOptions struct {
  144. // Items is the number of elements to print at the beginning and end of each dimension.
  145. Items int
  146. // Precision is the number of decimal places to print. Applies to float32 and float64.
  147. Precision int
  148. }
  149. func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
  150. if len(opts) < 1 {
  151. opts = append(opts, DumpOptions{
  152. Items: 3,
  153. Precision: 4,
  154. })
  155. }
  156. switch t.DType() {
  157. case DTypeF32:
  158. return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
  159. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  160. })
  161. case DTypeF16:
  162. f32 := ctx.Zeros(DTypeF32, t.Shape()...)
  163. f32 = t.Copy(ctx, f32)
  164. return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
  165. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  166. })
  167. case DTypeI32:
  168. return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
  169. return strconv.FormatInt(int64(i), 10)
  170. })
  171. default:
  172. return "<unsupported>"
  173. }
  174. }
  175. func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
  176. if t.Bytes() == nil {
  177. ctx.Forward(t).Compute(t)
  178. }
  179. s := make(S, mul(t.Shape()...))
  180. if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
  181. panic(err)
  182. }
  183. shape := t.Shape()
  184. var sb strings.Builder
  185. var f func([]int, int)
  186. f = func(dims []int, stride int) {
  187. prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
  188. fmt.Fprint(&sb, "[")
  189. defer func() { fmt.Fprint(&sb, "]") }()
  190. for i := 0; i < dims[0]; i++ {
  191. if i >= items && i < dims[0]-items {
  192. fmt.Fprint(&sb, "..., ")
  193. // skip to next printable element
  194. skip := dims[0] - 2*items
  195. if len(dims) > 1 {
  196. stride += mul(append(dims[1:], skip)...)
  197. fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
  198. }
  199. i += skip - 1
  200. } else if len(dims) > 1 {
  201. f(dims[1:], stride)
  202. stride += mul(dims[1:]...)
  203. if i < dims[0]-1 {
  204. fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
  205. }
  206. } else {
  207. fmt.Fprint(&sb, fn(s[stride+i]))
  208. if i < dims[0]-1 {
  209. fmt.Fprint(&sb, ", ")
  210. }
  211. }
  212. }
  213. }
  214. f(shape, 0)
  215. return sb.String()
  216. }
  217. type DType int
  218. const (
  219. DTypeOther DType = iota
  220. DTypeF32
  221. DTypeF16
  222. DTypeI32
  223. )