123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300 |
- package ml
- import (
- "bytes"
- "encoding/binary"
- "fmt"
- "os"
- "slices"
- "strconv"
- "strings"
- )
- type Config interface {
- Architecture() string
- String(string, ...string) string
- Uint(string, ...uint32) uint32
- Float(string, ...float32) float32
- Bool(string, ...bool) bool
- Strings(string, ...[]string) []string
- Uints(string, ...[]uint32) []uint32
- Floats(string, ...[]float32) []float32
- }
- type Backend interface {
- Config() Config
- Get(name string) Tensor
- NewContext() Context
- NewContextSize(size int) Context
- }
- // BackendCacheConfig should be implemented by backends that need special output
- // from the cache to meet specific requirements. It is frequently implemented in
- // conjunction with ScaledDotProductAttention.
- type BackendCacheConfig interface {
- CacheConfig() CacheConfig
- }
- // CacheConfig controls optimizations (mostly backend-specific) that may transform
- // the output the cache to work better with specific kernels.
- type CacheConfig struct {
- // CachePadding specifies the multiple for the number of tokens of cache history
- // that will be returned from cache Get for k, v and mask. The capacity of the
- // cache itself will also be increased to a multiple of this size if needed.
- CachePadding int
- // PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
- // and return the permuted version via Get. This uses the cache copy operation
- // to avoid a Contiguous call on the permuted tensor.
- PermutedV bool
- // MaskDType specifies the data type for generating the mask. If unset it will
- // default to DTypeF32.
- MaskDType DType
- // MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
- // Any position that does not correspond to an actual token will be filled with -Inf.
- MaskBatchPadding int
- }
- // BackendParams controls how the backend loads and executes models
- type BackendParams struct {
- // NumThreads sets the number of threads to use if running on the CPU
- NumThreads int
- // MainGPU is the index of the primary GPU to use
- MainGPU int
- // NumGPULayers is the number of layers to offload to GPUs
- NumGPULayers int
- // TensorSplit is the fraction of the model to offload to each GPU
- TensorSplit []float32
- // FlashAttention indicates that we should use a fused flash attention kernel
- FlashAttention bool
- }
- var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
- func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
- if _, ok := backends[name]; ok {
- panic("backend: backend already registered")
- }
- backends[name] = f
- }
- func NewBackend(f *os.File, params BackendParams) (Backend, error) {
- if backend, ok := backends["ggml"]; ok {
- return backend(f, params)
- }
- return nil, fmt.Errorf("unsupported backend")
- }
- type Context interface {
- Empty(dtype DType, shape ...int) Tensor
- Zeros(dtype DType, shape ...int) Tensor
- FromFloatSlice(s []float32, shape ...int) (Tensor, error)
- FromIntSlice(s []int32, shape ...int) (Tensor, error)
- Forward(...Tensor) Context
- Compute(...Tensor)
- MaxGraphNodes() int
- Close()
- // Input returns a context appropriate for creating input tensors
- Input() Context
- // Output returns a context appropriate for creating output tensors
- Output() Context
- // Layer returns a context appropriate for creating intermediate tensors
- Layer(int) Context
- }
- type Tensor interface {
- Dim(n int) int
- Stride(n int) int
- Shape() []int
- DType() DType
- Bytes() []byte
- Floats() []float32
- Add(ctx Context, t2 Tensor) Tensor
- Mul(ctx Context, t2 Tensor) Tensor
- Mulmat(ctx Context, t2 Tensor) Tensor
- MulmatFullPrec(ctx Context, t2 Tensor) Tensor
- Softmax(ctx Context) Tensor
- LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
- RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
- Scale(ctx Context, s float64) Tensor
- AvgPool1D(ctx Context, k, s, p int) Tensor
- Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
- RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
- Tanh(ctx Context) Tensor
- GELU(ctx Context) Tensor
- SILU(ctx Context) Tensor
- Reshape(ctx Context, shape ...int) Tensor
- View(ctx Context, offset int, shape ...int) Tensor
- Permute(ctx Context, shape ...int) Tensor
- Contiguous(ctx Context) Tensor
- Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
- Pad(ctx Context, shape ...int) Tensor
- Unpad(ctx Context, shape ...int) Tensor
- Stack(ctx Context, dim int, s ...Tensor) Tensor
- Concat(ctx Context, t2 Tensor, dim int) Tensor
- Rows(ctx Context, t2 Tensor) Tensor
- Copy(ctx Context, t2 Tensor) Tensor
- }
- // ScaledDotProductAttention implements a fused attention
- // operation equivalent to following code on a tensor named
- // query:
- //
- // query = query.Permute(ctx, 0, 2, 1, 3)
- // key = key.Permute(ctx, 0, 2, 1, 3)
- // value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
- //
- // kq := key.MulmatFullPrec(ctx, query)
- //
- // kq = kq.Scale(ctx, scale)
- //
- // if mask != nil {
- // kq = kq.Add(ctx, mask)
- // }
- //
- // kq = kq.Softmax(ctx)
- //
- // kqv := value.Mulmat(ctx, kq)
- // return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
- type ScaledDotProductAttention interface {
- ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
- }
- type number interface {
- ~int | ~int8 | ~int16 | ~int32 | ~int64 |
- ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
- ~float32 | ~float64 |
- ~complex64 | ~complex128
- }
- func mul[T number](s ...T) T {
- p := T(1)
- for _, v := range s {
- p *= v
- }
- return p
- }
- type DumpOptions struct {
- // Items is the number of elements to print at the beginning and end of each dimension.
- Items int
- // Precision is the number of decimal places to print. Applies to float32 and float64.
- Precision int
- }
- func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
- if len(opts) < 1 {
- opts = append(opts, DumpOptions{
- Items: 3,
- Precision: 4,
- })
- }
- switch t.DType() {
- case DTypeF32:
- return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
- return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
- })
- case DTypeF16, DTypeQ80, DTypeQ40:
- f32 := ctx.Empty(DTypeF32, t.Shape()...)
- f32 = t.Copy(ctx, f32)
- return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
- return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
- })
- case DTypeI32:
- return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
- return strconv.FormatInt(int64(i), 10)
- })
- default:
- return "<unsupported>"
- }
- }
- func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
- if t.Bytes() == nil {
- ctx.Forward(t).Compute(t)
- }
- s := make(S, mul(t.Shape()...))
- if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
- panic(err)
- }
- shape := t.Shape()
- slices.Reverse(shape)
- var sb strings.Builder
- var f func([]int, int)
- f = func(dims []int, stride int) {
- prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
- sb.WriteString("[")
- defer func() { sb.WriteString("]") }()
- for i := 0; i < dims[0]; i++ {
- if i >= items && i < dims[0]-items {
- sb.WriteString("..., ")
- // skip to next printable element
- skip := dims[0] - 2*items
- if len(dims) > 1 {
- stride += mul(append(dims[1:], skip)...)
- fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
- }
- i += skip - 1
- } else if len(dims) > 1 {
- f(dims[1:], stride)
- stride += mul(dims[1:]...)
- if i < dims[0]-1 {
- fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
- }
- } else {
- text := fn(s[stride+i])
- if len(text) > 0 && text[0] != '-' {
- sb.WriteString(" ")
- }
- sb.WriteString(text)
- if i < dims[0]-1 {
- sb.WriteString(", ")
- }
- }
- }
- }
- f(shape, 0)
- return sb.String()
- }
- type DType int
- const (
- DTypeOther DType = iota
- DTypeF32
- DTypeF16
- DTypeQ80
- DTypeQ40
- DTypeI32
- )
|