backend.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. package ml
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "os"
  7. "strconv"
  8. "strings"
  9. )
  10. type Config interface {
  11. Architecture() string
  12. String(string, ...string) string
  13. Uint(string, ...uint32) uint32
  14. Float(string, ...float32) float32
  15. Strings(string, ...[]string) []string
  16. Uints(string, ...[]uint32) []uint32
  17. Floats(string, ...[]float32) []float32
  18. }
  19. type Backend interface {
  20. Config() Config
  21. Get(name string) Tensor
  22. NewContext() Context
  23. SystemInfo() string
  24. }
  25. var backends = make(map[string]func(*os.File) (Backend, error))
  26. func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
  27. if _, ok := backends[name]; ok {
  28. panic("backend: backend already registered")
  29. }
  30. backends[name] = f
  31. }
  32. func NewBackend(f *os.File) (Backend, error) {
  33. if backend, ok := backends["ggml"]; ok {
  34. return backend(f)
  35. }
  36. return nil, fmt.Errorf("unsupported backend")
  37. }
  38. type Context interface {
  39. Zeros(dtype DType, shape ...int) Tensor
  40. FromFloatSlice(s []float32, shape ...int) (Tensor, error)
  41. FromIntSlice(s []int32, shape ...int) (Tensor, error)
  42. Forward(Tensor)
  43. Compute(...Tensor)
  44. MaxTensors() int
  45. Close()
  46. }
  47. type Tensor interface {
  48. Dim(n int) int
  49. Stride(n int) int
  50. Shape() []int
  51. DType() DType
  52. Bytes() []byte
  53. Floats() []float32
  54. Add(ctx Context, t2 Tensor) Tensor
  55. Mul(ctx Context, t2 Tensor) Tensor
  56. Mulmat(ctx Context, t2 Tensor) Tensor
  57. MulmatFullPrec(ctx Context, t2 Tensor) Tensor
  58. Softmax(ctx Context) Tensor
  59. LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
  60. RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
  61. Scale(ctx Context, s float64) Tensor
  62. Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
  63. RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
  64. Tanh(ctx Context) Tensor
  65. GELU(ctx Context) Tensor
  66. SILU(ctx Context) Tensor
  67. Reshape(ctx Context, shape ...int) Tensor
  68. View(ctx Context, offset int, shape ...int) Tensor
  69. Permute(ctx Context, shape ...int) Tensor
  70. Contiguous(ctx Context) Tensor
  71. Pad(ctx Context, shape ...int) Tensor
  72. Unpad(ctx Context, shape ...int) Tensor
  73. Stack(ctx Context, dim int, s ...Tensor) Tensor
  74. Concat(ctx Context, t2 Tensor, dim int) Tensor
  75. Rows(ctx Context, t2 Tensor) Tensor
  76. Copy(ctx Context, t2 Tensor) Tensor
  77. }
  78. type number interface {
  79. ~int | ~int8 | ~int16 | ~int32 | ~int64 |
  80. ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
  81. ~float32 | ~float64 |
  82. ~complex64 | ~complex128
  83. }
  84. func mul[T number](s ...T) T {
  85. p := T(1)
  86. for _, v := range s {
  87. p *= v
  88. }
  89. return p
  90. }
  91. type DumpOptions struct {
  92. // Items is the number of elements to print at the beginning and end of each dimension.
  93. Items int
  94. // Precision is the number of decimal places to print. Applies to float32 and float64.
  95. Precision int
  96. }
  97. func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
  98. if len(opts) < 1 {
  99. opts = append(opts, DumpOptions{
  100. Items: 3,
  101. Precision: 4,
  102. })
  103. }
  104. switch t.DType() {
  105. case DTypeF32:
  106. return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
  107. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  108. })
  109. case DTypeF16:
  110. f32 := ctx.Zeros(DTypeF32, t.Shape()...)
  111. f32 = t.Copy(ctx, f32)
  112. return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
  113. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  114. })
  115. case DTypeI32:
  116. return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
  117. return strconv.FormatInt(int64(i), 10)
  118. })
  119. default:
  120. return "<unsupported>"
  121. }
  122. }
  123. func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
  124. if t.Bytes() == nil {
  125. ctx.Forward(t)
  126. ctx.Compute(t)
  127. }
  128. s := make(S, mul(t.Shape()...))
  129. if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
  130. panic(err)
  131. }
  132. shape := t.Shape()
  133. var sb strings.Builder
  134. var f func([]int, int)
  135. f = func(dims []int, stride int) {
  136. prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
  137. fmt.Fprint(&sb, "[")
  138. defer func() { fmt.Fprint(&sb, "]") }()
  139. for i := 0; i < dims[0]; i++ {
  140. if i >= items && i < dims[0]-items {
  141. fmt.Fprint(&sb, "..., ")
  142. // skip to next printable element
  143. skip := dims[0] - 2*items
  144. if len(dims) > 1 {
  145. stride += mul(append(dims[1:], skip)...)
  146. fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
  147. }
  148. i += skip - 1
  149. } else if len(dims) > 1 {
  150. f(dims[1:], stride)
  151. stride += mul(dims[1:]...)
  152. if i < dims[0]-1 {
  153. fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
  154. }
  155. } else {
  156. fmt.Fprint(&sb, fn(s[stride+i]))
  157. if i < dims[0]-1 {
  158. fmt.Fprint(&sb, ", ")
  159. }
  160. }
  161. }
  162. }
  163. f(shape, 0)
  164. return sb.String()
  165. }
  166. type DType int
  167. const (
  168. DTypeOther DType = iota
  169. DTypeF32
  170. DTypeF16
  171. DTypeI32
  172. )