bfloat16.go 1.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. // Vendored code from https://github.com/d4l3k/go-bfloat16
  2. // unsafe pointer replaced by "math"
  3. package bfloat16
  4. import "math"
  5. type BF16 uint16
  6. func FromBytes(buf []byte) BF16 {
  7. return BF16(uint16(buf[0]) + uint16(buf[1])<<8)
  8. }
  9. func ToBytes(b BF16) []byte {
  10. return []byte{byte(b & 0xFF), byte(b >> 8)}
  11. }
  12. func Decode(buf []byte) []BF16 {
  13. var out []BF16
  14. for i := 0; i < len(buf); i += 2 {
  15. out = append(out, FromBytes(buf[i:]))
  16. }
  17. return out
  18. }
  19. func Encode(f []BF16) []byte {
  20. var out []byte
  21. for _, a := range f {
  22. out = append(out, ToBytes(a)...)
  23. }
  24. return out
  25. }
  26. func DecodeFloat32(buf []byte) []float32 {
  27. var out []float32
  28. for i := 0; i < len(buf); i += 2 {
  29. out = append(out, ToFloat32(FromBytes(buf[i:])))
  30. }
  31. return out
  32. }
  33. func EncodeFloat32(f []float32) []byte {
  34. var out []byte
  35. for _, a := range f {
  36. out = append(out, ToBytes(FromFloat32(a))...)
  37. }
  38. return out
  39. }
  40. func ToFloat32(b BF16) float32 {
  41. u32 := uint32(b) << 16
  42. return math.Float32frombits(u32)
  43. }
  44. func FromFloat32(f float32) BF16 {
  45. u32 := math.Float32bits(f)
  46. return BF16(u32 >> 16)
  47. }