backend.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. package ml
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "os"
  7. "strconv"
  8. "strings"
  9. )
  10. type Config interface {
  11. Architecture() string
  12. String(string, ...string) string
  13. Uint(string, ...uint32) uint32
  14. Float(string, ...float32) float32
  15. Strings(string, ...[]string) []string
  16. Uints(string, ...[]uint32) []uint32
  17. }
  18. type Backend interface {
  19. Config() Config
  20. Get(name string) Tensor
  21. NewContext() Context
  22. }
  23. var backends = make(map[string]func(*os.File) (Backend, error))
  24. func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
  25. if _, ok := backends[name]; ok {
  26. panic("backend: backend already registered")
  27. }
  28. backends[name] = f
  29. }
  30. func NewBackend(f *os.File) (Backend, error) {
  31. if backend, ok := backends["ggml"]; ok {
  32. return backend(f)
  33. }
  34. return nil, fmt.Errorf("unsupported backend")
  35. }
  36. // RopeType specifies the type of RoPE (Rotary Position Embedding) to use, these types are implemented in the backend
  37. type RopeType int
  38. const (
  39. RopeTypeStandard RopeType = iota
  40. _ // not yet used
  41. RopeTypeNeoX
  42. )
  43. // RopeConfig contains all configuration for the RoPE (Rotary Position Embedding) operation
  44. type RopeConfig struct {
  45. // PositionIDs contains the position indices for each token in the sequence
  46. // These indices are used to calculate the rotary embeddings
  47. PositionIDs Tensor
  48. // RopeFactors is an optional tensor containing pre-computed rotation factors
  49. RopeFactors Tensor
  50. // RopeDim specifies the dimension size for the rotary embeddings
  51. RopeDim uint32
  52. // RopeType indicates which RoPE variant to use (e.g. normal or neox)
  53. RopeType RopeType
  54. // OrigCtxLen stores the original context length the model was trained with
  55. OrigCtxLen int
  56. // RopeBase is the base value used in the frequency calculation
  57. RopeBase float32
  58. // RopeScale is a scaling factor applied to position indices
  59. RopeScale float32
  60. // YaRN parameters can be added here if they need to be configurable
  61. }
  62. type Context interface {
  63. Zeros(dtype DType, shape ...int) Tensor
  64. FromFloatSlice(s []float32, shape ...int) (Tensor, error)
  65. FromIntSlice(s []int32, shape ...int) (Tensor, error)
  66. Forward(Tensor)
  67. Compute(...Tensor)
  68. MaxTensors() int
  69. Close()
  70. }
  71. type Tensor interface {
  72. Dim(n int) int
  73. Stride(n int) int
  74. Shape() []int
  75. DType() DType
  76. Bytes() []byte
  77. Floats() []float32
  78. Add(ctx Context, t2 Tensor) Tensor
  79. Mul(ctx Context, t2 Tensor) Tensor
  80. Mulmat(ctx Context, t2 Tensor) Tensor
  81. MulmatFullPrec(ctx Context, t2 Tensor) Tensor
  82. Softmax(ctx Context) Tensor
  83. LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
  84. RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
  85. Scale(ctx Context, s float64) Tensor
  86. Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
  87. RoPE(ctx Context, rc RopeConfig) Tensor
  88. Tanh(ctx Context) Tensor
  89. GELU(ctx Context) Tensor
  90. SILU(ctx Context) Tensor
  91. Reshape(ctx Context, shape ...int) Tensor
  92. View(ctx Context, offset int, shape ...int) Tensor
  93. Permute(ctx Context, shape ...int) Tensor
  94. Contiguous(ctx Context) Tensor
  95. Pad(ctx Context, shape ...int) Tensor
  96. Unpad(ctx Context, shape ...int) Tensor
  97. Stack(ctx Context, dim int, s ...Tensor) Tensor
  98. Concat(ctx Context, t2 Tensor, dim int) Tensor
  99. Rows(ctx Context, t2 Tensor) Tensor
  100. Copy(ctx Context, t2 Tensor) Tensor
  101. }
  102. type number interface {
  103. ~int | ~int8 | ~int16 | ~int32 | ~int64 |
  104. ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
  105. ~float32 | ~float64 |
  106. ~complex64 | ~complex128
  107. }
  108. func mul[T number](s ...T) T {
  109. p := T(1)
  110. for _, v := range s {
  111. p *= v
  112. }
  113. return p
  114. }
  115. type DumpOptions struct {
  116. // Items is the number of elements to print at the beginning and end of each dimension.
  117. Items int
  118. // Precision is the number of decimal places to print. Applies to float32 and float64.
  119. Precision int
  120. }
  121. func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
  122. if len(opts) < 1 {
  123. opts = append(opts, DumpOptions{
  124. Items: 3,
  125. Precision: 4,
  126. })
  127. }
  128. switch t.DType() {
  129. case DTypeF32:
  130. return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
  131. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  132. })
  133. case DTypeF16:
  134. f32 := ctx.Zeros(DTypeF32, t.Shape()...)
  135. f32 = t.Copy(ctx, f32)
  136. return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
  137. return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
  138. })
  139. case DTypeI32:
  140. return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
  141. return strconv.FormatInt(int64(i), 10)
  142. })
  143. default:
  144. return "<unsupported>"
  145. }
  146. }
  147. func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
  148. if t.Bytes() == nil {
  149. ctx.Forward(t)
  150. ctx.Compute(t)
  151. }
  152. s := make(S, mul(t.Shape()...))
  153. if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
  154. panic(err)
  155. }
  156. shape := t.Shape()
  157. var sb strings.Builder
  158. var f func([]int, int)
  159. f = func(dims []int, stride int) {
  160. prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
  161. fmt.Fprint(&sb, "[")
  162. defer func() { fmt.Fprint(&sb, "]") }()
  163. for i := 0; i < dims[0]; i++ {
  164. if i >= items && i < dims[0]-items {
  165. fmt.Fprint(&sb, "..., ")
  166. // skip to next printable element
  167. skip := dims[0] - 2*items
  168. if len(dims) > 1 {
  169. stride += mul(append(dims[1:], skip)...)
  170. fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
  171. }
  172. i += skip - 1
  173. } else if len(dims) > 1 {
  174. f(dims[1:], stride)
  175. stride += mul(dims[1:]...)
  176. if i < dims[0]-1 {
  177. fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
  178. }
  179. } else {
  180. fmt.Fprint(&sb, fn(s[stride+i]))
  181. if i < dims[0]-1 {
  182. fmt.Fprint(&sb, ", ")
  183. }
  184. }
  185. }
  186. }
  187. f(shape, 0)
  188. return sb.String()
  189. }
  190. type DType int
  191. const (
  192. DTypeOther DType = iota
  193. DTypeF32
  194. DTypeF16
  195. DTypeI32
  196. )