backend.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. package ml
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "os"
  7. "strconv"
  8. "strings"
  9. )
  10. type Config interface {
  11. Architecture() string
  12. String(string, ...string) string
  13. Uint(string, ...uint32) uint32
  14. Float(string, ...float32) float32
  15. Strings(string, ...[]string) []string
  16. Uints(string, ...[]uint32) []uint32
  17. }
  18. type Backend interface {
  19. Config() Config
  20. Get(name string) Tensor
  21. NewContext() Context
  22. }
  23. var backends = make(map[string]func(*os.File) (Backend, error))
  24. func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
  25. if _, ok := backends[name]; ok {
  26. panic("backend: backend already registered")
  27. }
  28. backends[name] = f
  29. }
  30. func NewBackend(f *os.File) (Backend, error) {
  31. if backend, ok := backends["ggml"]; ok {
  32. return backend(f)
  33. }
  34. return nil, fmt.Errorf("unsupported backend")
  35. }
  36. type Context interface {
  37. Zeros(dtype DType, shape ...int) Tensor
  38. FromFloatSlice(s []float32, shape ...int) (Tensor, error)
  39. FromIntSlice(s []int32, shape ...int) (Tensor, error)
  40. Forward(Tensor)
  41. Compute(Tensor) Tensor
  42. Close() error
  43. }
  44. type Tensor interface {
  45. Dim(n int) int64
  46. Stride(n int) int64
  47. Shape() []int64
  48. DType() DType
  49. Bytes() []byte
  50. Floats() []float32
  51. Add(ctx Context, t2 Tensor) Tensor
  52. Mul(ctx Context, t2 Tensor) Tensor
  53. Mulmat(ctx Context, t2 Tensor) Tensor
  54. Softmax(ctx Context) Tensor
  55. LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
  56. RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
  57. Scale(ctx Context, s float64) Tensor
  58. Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
  59. RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
  60. Tanh(ctx Context) Tensor
  61. GELU(ctx Context) Tensor
  62. SILU(ctx Context) Tensor
  63. Reshape(ctx Context, shape ...int64) Tensor
  64. View(ctx Context, offset int, shape ...int) Tensor
  65. Permute(ctx Context, shape ...int) Tensor
  66. Contiguous(ctx Context) Tensor
  67. Pad(ctx Context, shape ...int64) Tensor
  68. Unpad(ctx Context, shape ...int64) Tensor
  69. Stack(ctx Context, dim int, s ...Tensor) Tensor
  70. Concat(ctx Context, t2 Tensor, dim int) Tensor
  71. Rows(ctx Context, t2 Tensor) Tensor
  72. Copy(ctx Context, t2 Tensor) Tensor
  73. }
  74. type number interface {
  75. ~int | ~int8 | ~int16 | ~int32 | ~int64 |
  76. ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
  77. ~float32 | ~float64 |
  78. ~complex64 | ~complex128
  79. }
  80. func mul[T number](s ...T) T {
  81. p := T(1)
  82. for _, v := range s {
  83. p *= v
  84. }
  85. return p
  86. }
  87. type DumpOptions struct {
  88. // Items is the number of elements to print at the beginning and end of each dimension.
  89. Items int64
  90. // Precision is the number of decimal places to print. Applies to float32 and float64.
  91. Precision int
  92. }
  93. func Dump(t Tensor, opts ...DumpOptions) string {
  94. if len(opts) < 1 {
  95. opts = append(opts, DumpOptions{
  96. Items: 3,
  97. Precision: 4,
  98. })
  99. }
  100. switch t.DType() {
  101. case DTypeF32:
  102. return dump[[]float32](t, opts[0].Items, func(f float32) string {
  103. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  104. })
  105. case DTypeI32:
  106. return dump[[]int32](t, opts[0].Items, func(i int32) string {
  107. return strconv.FormatInt(int64(i), 10)
  108. })
  109. default:
  110. return "<unsupported>"
  111. }
  112. }
  113. func dump[S ~[]E, E number](t Tensor, items int64, fn func(E) string) string {
  114. bts := t.Bytes()
  115. if bts == nil {
  116. return "<nil>"
  117. }
  118. s := make(S, mul(t.Shape()...))
  119. if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
  120. panic(err)
  121. }
  122. shape := t.Shape()
  123. var sb strings.Builder
  124. var f func([]int64, int64)
  125. f = func(dims []int64, stride int64) {
  126. prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
  127. fmt.Fprint(&sb, "[")
  128. defer func() { fmt.Fprint(&sb, "]") }()
  129. for i := int64(0); i < dims[0]; i++ {
  130. if i >= items && i < dims[0]-items {
  131. fmt.Fprint(&sb, "..., ")
  132. // skip to next printable element
  133. skip := dims[0] - 2*items
  134. if len(dims) > 1 {
  135. stride += mul(append(dims[1:], skip)...)
  136. fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
  137. }
  138. i += skip - 1
  139. } else if len(dims) > 1 {
  140. f(dims[1:], stride)
  141. stride += mul(dims[1:]...)
  142. if i < dims[0]-1 {
  143. fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
  144. }
  145. } else {
  146. fmt.Fprint(&sb, fn(s[stride+i]))
  147. if i < dims[0]-1 {
  148. fmt.Fprint(&sb, ", ")
  149. }
  150. }
  151. }
  152. }
  153. f(shape, 0)
  154. return sb.String()
  155. }
  156. type DType int
  157. const (
  158. DTypeF32 DType = iota
  159. DTypeI32
  160. DTypeOther
  161. )