Browse Source

fix: fixes a memory leak in bfloat16 package

This change vendors in the bfloat16 package from
github.com/d4l3k/go-bfloat16/ and fixes a memory leak which
was being caused by using unsafe pointers instead of the
math package.
Patrick Devine 1 month ago
parent
commit
c75b428249
6 changed files with 133 additions and 4 deletions
  1. 2 1
      convert/reader_safetensors.go
  2. 0 1
      go.mod
  3. 0 2
      go.sum
  4. 21 0
      types/bfloat16/LICENSE
  5. 57 0
      types/bfloat16/bfloat16.go
  6. 53 0
      types/bfloat16/bfloat16_test.go

+ 2 - 1
convert/reader_safetensors.go

@@ -11,9 +11,10 @@ import (
 	"slices"
 	"strings"
 
-	"github.com/d4l3k/go-bfloat16"
 	"github.com/x448/float16"
 	"golang.org/x/exp/maps"
+
+	"github.com/ollama/ollama/types/bfloat16"
 )
 
 type safetensorMetadata struct {

+ 0 - 1
go.mod

@@ -16,7 +16,6 @@ require (
 
 require (
 	github.com/agnivade/levenshtein v1.1.1
-	github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
 	github.com/dlclark/regexp2 v1.11.4
 	github.com/emirpasic/gods/v2 v2.0.0-alpha
 	github.com/google/go-cmp v0.6.0

+ 0 - 2
go.sum

@@ -35,8 +35,6 @@ github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARu
 github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U=
 github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
 github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
-github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 h1:cBzrdJPAFBsgCrDPnZxlp1dF2+k4r1kVpD7+1S1PVjY=
-github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLcxEuYUlAd/EXyjc/v55nd3+47YAgWbSXVxPrNI=
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

+ 21 - 0
types/bfloat16/LICENSE

@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 Tristan Rice
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.

+ 57 - 0
types/bfloat16/bfloat16.go

@@ -0,0 +1,57 @@
+// Vendored code from https://github.com/d4l3k/go-bfloat16
+// unsafe pointer replaced by "math"
+package bfloat16
+
+import "math"
+
+type BF16 uint16
+
+func FromBytes(buf []byte) BF16 {
+	return BF16(uint16(buf[0]) + uint16(buf[1])<<8)
+}
+
+func ToBytes(b BF16) []byte {
+	return []byte{byte(b & 0xFF), byte(b >> 8)}
+}
+
+func Decode(buf []byte) []BF16 {
+	var out []BF16
+	for i := 0; i < len(buf); i += 2 {
+		out = append(out, FromBytes(buf[i:]))
+	}
+	return out
+}
+
+func Encode(f []BF16) []byte {
+	var out []byte
+	for _, a := range f {
+		out = append(out, ToBytes(a)...)
+	}
+	return out
+}
+
+func DecodeFloat32(buf []byte) []float32 {
+	var out []float32
+	for i := 0; i < len(buf); i += 2 {
+		out = append(out, ToFloat32(FromBytes(buf[i:])))
+	}
+	return out
+}
+
+func EncodeFloat32(f []float32) []byte {
+	var out []byte
+	for _, a := range f {
+		out = append(out, ToBytes(FromFloat32(a))...)
+	}
+	return out
+}
+
+func ToFloat32(b BF16) float32 {
+	u32 := uint32(b) << 16
+	return math.Float32frombits(u32)
+}
+
+func FromFloat32(f float32) BF16 {
+	u32 := math.Float32bits(f)
+	return BF16(u32 >> 16)
+}

+ 53 - 0
types/bfloat16/bfloat16_test.go

@@ -0,0 +1,53 @@
+package bfloat16
+
+import (
+	"crypto/rand"
+	"reflect"
+	"testing"
+)
+
+func randomBytes(n int) []byte {
+	out := make([]byte, n)
+	if _, err := rand.Read(out); err != nil {
+		panic(err)
+	}
+	return out
+}
+
+func TestEncodeDecode(t *testing.T) {
+	b := randomBytes(1024)
+	bf16 := Decode(b)
+	out := Encode(bf16)
+	if !reflect.DeepEqual(b, out) {
+		t.Fatalf("%+v != %+v", b, out)
+	}
+}
+
+func TestEncodeDecodeFloat32(t *testing.T) {
+	b := randomBytes(1024)
+	bf16 := DecodeFloat32(b)
+	out := EncodeFloat32(bf16)
+	if !reflect.DeepEqual(b, out) {
+		t.Fatalf("%+v != %+v", b, out)
+	}
+}
+
+func TestBasicFloat32(t *testing.T) {
+	var in float32 = 1.0
+	out := ToFloat32(FromFloat32(in))
+	if !reflect.DeepEqual(in, out) {
+		t.Fatalf("%+v != %+v", in, out)
+	}
+}
+
+func TestComplexFloat32(t *testing.T) {
+	var in float32 = 123456789123456789.123456789
+	var want float32 = 123286039799267328.0
+	out := ToFloat32(FromFloat32(in))
+	if in == out {
+		t.Fatalf("no loss of precision")
+	}
+	if out != want {
+		t.Fatalf("%.16f != %.16f", want, out)
+	}
+}