backend.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. package ml
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "os"
  7. "strings"
  8. )
  9. type Config interface {
  10. Architecture() string
  11. String(string, ...string) string
  12. Uint(string, ...uint32) uint32
  13. Float(string, ...float32) float32
  14. Strings(string, ...[]string) []string
  15. Uints(string, ...[]uint32) []uint32
  16. }
  17. type Backend interface {
  18. Config() Config
  19. Get(name string) Tensor
  20. NewContext() Context
  21. }
  22. type GraphLayer struct {
  23. Name string `json:"name"`
  24. Shape []int64 `json:"shape"`
  25. }
  26. type Graph struct {
  27. Graph []GraphLayer `json:"graph"`
  28. }
  29. var backends = make(map[string]func(*os.File) (Backend, error))
  30. func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
  31. if _, ok := backends[name]; ok {
  32. panic("backend: backend already registered")
  33. }
  34. backends[name] = f
  35. }
  36. func NewBackend(f *os.File) (Backend, error) {
  37. if backend, ok := backends["ggml"]; ok {
  38. return backend(f)
  39. }
  40. return nil, fmt.Errorf("unsupported backend")
  41. }
  42. type Context interface {
  43. Zeros(dtype DType, shape ...int64) Tensor
  44. FromFloatSlice(s []float32, shape ...int) (Tensor, error)
  45. FromIntSlice(s []int32, shape ...int) (Tensor, error)
  46. Forward(Tensor)
  47. Compute(Tensor) Tensor
  48. Close() error
  49. SetDebug(bool)
  50. Trace(string, Tensor)
  51. GetTrace() Graph
  52. }
  53. type Tensor interface {
  54. Dim(n int) int64
  55. Stride(n int) int64
  56. Shape() []int64
  57. DType() DType
  58. Bytes() []byte
  59. Floats() []float32
  60. Add(ctx Context, t2 Tensor) Tensor
  61. Mul(ctx Context, t2 Tensor) Tensor
  62. Mulmat(ctx Context, t2 Tensor) Tensor
  63. Softmax(ctx Context) Tensor
  64. LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
  65. RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
  66. Scale(ctx Context, s float64) Tensor
  67. Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
  68. RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
  69. Tanh(ctx Context) Tensor
  70. GELU(ctx Context) Tensor
  71. SILU(ctx Context) Tensor
  72. Reshape(ctx Context, shape ...int64) Tensor
  73. View(ctx Context, offset int, shape ...int) Tensor
  74. Permute(ctx Context, shape ...int) Tensor
  75. Contiguous(ctx Context) Tensor
  76. Pad(ctx Context, shape ...int64) Tensor
  77. Unpad(ctx Context, shape ...int64) Tensor
  78. Stack(ctx Context, dim int, s ...Tensor) Tensor
  79. Concat(ctx Context, t2 Tensor, dim int) Tensor
  80. Rows(ctx Context, t2 Tensor) Tensor
  81. Copy(ctx Context, t2 Tensor) Tensor
  82. }
  83. type number interface {
  84. ~int | ~int8 | ~int16 | ~int32 | ~int64 |
  85. ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
  86. ~float32 | ~float64 |
  87. ~complex64 | ~complex128
  88. }
  89. func mul[T number](s ...T) T {
  90. p := T(1)
  91. for _, v := range s {
  92. p *= v
  93. }
  94. return p
  95. }
  96. type DumpOptions struct {
  97. // Items is the number of elements to print at the beginning and end of each dimension.
  98. Items int64
  99. // Precision is the number of decimal places to print. Applies to float32 and float64.
  100. Precision int
  101. }
  102. func Dump(t Tensor, opts ...DumpOptions) string {
  103. if len(opts) < 1 {
  104. opts = append(opts, DumpOptions{
  105. Items: 3,
  106. Precision: 4,
  107. })
  108. }
  109. switch t.DType() {
  110. case DTypeF32:
  111. return dump[[]float32](t, opts[0])
  112. case DTypeI32:
  113. return dump[[]int32](t, opts[0])
  114. default:
  115. return "<unsupported>"
  116. }
  117. }
  118. func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
  119. bts := t.Bytes()
  120. if bts == nil {
  121. return "<nil>"
  122. }
  123. s := make(S, mul(t.Shape()...))
  124. if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
  125. panic(err)
  126. }
  127. shape := t.Shape()
  128. var sb strings.Builder
  129. var f func([]int64, int64)
  130. f = func(dims []int64, stride int64) {
  131. prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
  132. fmt.Fprint(&sb, "[")
  133. defer func() { fmt.Fprint(&sb, "]") }()
  134. for i := int64(0); i < dims[0]; i++ {
  135. if i >= opts.Items && i < dims[0]-opts.Items {
  136. fmt.Fprint(&sb, "..., ")
  137. // skip to next printable element
  138. skip := dims[0] - 2*opts.Items
  139. if len(dims) > 1 {
  140. stride += mul(append(dims[1:], skip)...)
  141. fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
  142. }
  143. i += skip - 1
  144. } else if len(dims) > 1 {
  145. f(dims[1:], stride)
  146. stride += mul(dims[1:]...)
  147. if i < dims[0]-1 {
  148. fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
  149. }
  150. } else {
  151. fmt.Fprint(&sb, s[stride+i])
  152. if i < dims[0]-1 {
  153. fmt.Fprint(&sb, ", ")
  154. }
  155. }
  156. }
  157. }
  158. f(shape, 0)
  159. return sb.String()
  160. }
  161. type DType int
  162. const (
  163. DTypeF32 DType = iota
  164. DTypeI32
  165. DTypeOther
  166. )