backend.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. package ml
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "os"
  7. "strconv"
  8. "strings"
  9. )
  10. type Config interface {
  11. Architecture() string
  12. String(string, ...string) string
  13. Uint(string, ...uint32) uint32
  14. Float(string, ...float32) float32
  15. Strings(string, ...[]string) []string
  16. Uints(string, ...[]uint32) []uint32
  17. }
  18. type Backend interface {
  19. Config() Config
  20. Get(name string) Tensor
  21. NewContext() Context
  22. SystemInfo() string
  23. }
  24. // BackendParams controls how the backend loads and executes models
  25. type BackendParams struct {
  26. // NumThreads sets the number of threads to use if running on the CPU
  27. NumThreads int
  28. // MainGPU is the index of the primary GPU to use
  29. MainGPU int
  30. // NumGPULayers is the number of layers to offload to GPUs
  31. NumGPULayers int
  32. // TensorSplit is the fraction of the model to offload to each GPU
  33. TensorSplit []float32
  34. }
  35. var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
  36. func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
  37. if _, ok := backends[name]; ok {
  38. panic("backend: backend already registered")
  39. }
  40. backends[name] = f
  41. }
  42. func NewBackend(f *os.File, params BackendParams) (Backend, error) {
  43. if backend, ok := backends["ggml"]; ok {
  44. return backend(f, params)
  45. }
  46. return nil, fmt.Errorf("unsupported backend")
  47. }
  48. type Context interface {
  49. Zeros(dtype DType, shape ...int) Tensor
  50. FromFloatSlice(s []float32, shape ...int) (Tensor, error)
  51. FromIntSlice(s []int32, shape ...int) (Tensor, error)
  52. Forward(Tensor)
  53. Compute(...Tensor)
  54. MaxTensors() int
  55. Close()
  56. }
  57. type Tensor interface {
  58. Dim(n int) int
  59. Stride(n int) int
  60. Shape() []int
  61. DType() DType
  62. Bytes() []byte
  63. Floats() []float32
  64. Add(ctx Context, t2 Tensor) Tensor
  65. Mul(ctx Context, t2 Tensor) Tensor
  66. Mulmat(ctx Context, t2 Tensor) Tensor
  67. MulmatFullPrec(ctx Context, t2 Tensor) Tensor
  68. Softmax(ctx Context) Tensor
  69. LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
  70. RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
  71. Scale(ctx Context, s float64) Tensor
  72. Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
  73. RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
  74. Tanh(ctx Context) Tensor
  75. GELU(ctx Context) Tensor
  76. SILU(ctx Context) Tensor
  77. Reshape(ctx Context, shape ...int) Tensor
  78. View(ctx Context, offset int, shape ...int) Tensor
  79. Permute(ctx Context, shape ...int) Tensor
  80. Contiguous(ctx Context) Tensor
  81. Pad(ctx Context, shape ...int) Tensor
  82. Unpad(ctx Context, shape ...int) Tensor
  83. Stack(ctx Context, dim int, s ...Tensor) Tensor
  84. Concat(ctx Context, t2 Tensor, dim int) Tensor
  85. Rows(ctx Context, t2 Tensor) Tensor
  86. Copy(ctx Context, t2 Tensor) Tensor
  87. }
  88. // ScaledDotProductAttention implements a fused attention
  89. // operation equivalent to following code on a tensor named
  90. // query:
  91. //
  92. // kq := key.MulmatFullPrec(ctx, query)
  93. //
  94. // kq = kq.Scale(ctx, scale)
  95. //
  96. // if mask != nil {
  97. // kq = kq.Add(ctx, mask)
  98. // }
  99. //
  100. // kq = kq.Softmax(ctx)
  101. //
  102. // kqv := value.Mulmat(ctx, kq)
  103. // return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
  104. type ScaledDotProductAttention interface {
  105. ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
  106. }
  107. type number interface {
  108. ~int | ~int8 | ~int16 | ~int32 | ~int64 |
  109. ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
  110. ~float32 | ~float64 |
  111. ~complex64 | ~complex128
  112. }
  113. func mul[T number](s ...T) T {
  114. p := T(1)
  115. for _, v := range s {
  116. p *= v
  117. }
  118. return p
  119. }
  120. type DumpOptions struct {
  121. // Items is the number of elements to print at the beginning and end of each dimension.
  122. Items int
  123. // Precision is the number of decimal places to print. Applies to float32 and float64.
  124. Precision int
  125. }
  126. func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
  127. if len(opts) < 1 {
  128. opts = append(opts, DumpOptions{
  129. Items: 3,
  130. Precision: 4,
  131. })
  132. }
  133. switch t.DType() {
  134. case DTypeF32:
  135. return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
  136. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  137. })
  138. case DTypeF16:
  139. f32 := ctx.Zeros(DTypeF32, t.Shape()...)
  140. f32 = t.Copy(ctx, f32)
  141. return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
  142. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  143. })
  144. case DTypeI32:
  145. return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
  146. return strconv.FormatInt(int64(i), 10)
  147. })
  148. default:
  149. return "<unsupported>"
  150. }
  151. }
  152. func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
  153. if t.Bytes() == nil {
  154. ctx.Forward(t)
  155. ctx.Compute(t)
  156. }
  157. s := make(S, mul(t.Shape()...))
  158. if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
  159. panic(err)
  160. }
  161. shape := t.Shape()
  162. var sb strings.Builder
  163. var f func([]int, int)
  164. f = func(dims []int, stride int) {
  165. prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
  166. fmt.Fprint(&sb, "[")
  167. defer func() { fmt.Fprint(&sb, "]") }()
  168. for i := 0; i < dims[0]; i++ {
  169. if i >= items && i < dims[0]-items {
  170. fmt.Fprint(&sb, "..., ")
  171. // skip to next printable element
  172. skip := dims[0] - 2*items
  173. if len(dims) > 1 {
  174. stride += mul(append(dims[1:], skip)...)
  175. fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
  176. }
  177. i += skip - 1
  178. } else if len(dims) > 1 {
  179. f(dims[1:], stride)
  180. stride += mul(dims[1:]...)
  181. if i < dims[0]-1 {
  182. fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
  183. }
  184. } else {
  185. fmt.Fprint(&sb, fn(s[stride+i]))
  186. if i < dims[0]-1 {
  187. fmt.Fprint(&sb, ", ")
  188. }
  189. }
  190. }
  191. }
  192. f(shape, 0)
  193. return sb.String()
  194. }
  195. type DType int
  196. const (
  197. DTypeOther DType = iota
  198. DTypeF32
  199. DTypeF16
  200. DTypeI32
  201. )