Bläddra i källkod

simplify safetensors reading

Michael Yang 11 månader sedan
förälder
incheckning
171eb040fc
6 ändrade filer med 49 tillägg och 81 borttagningar
  1. 43 72
      convert/safetensors.go
  2. 0 1
      go.mod
  3. 0 2
      go.sum
  4. 1 1
      llm/ggla.go
  5. 3 3
      llm/ggml.go
  6. 2 2
      llm/gguf.go

+ 43 - 72
convert/safetensors.go

@@ -6,7 +6,6 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
-	"log/slog"
 	"os"
 	"path/filepath"
 	"regexp"
@@ -14,7 +13,6 @@ import (
 	"strings"
 
 	"github.com/d4l3k/go-bfloat16"
-	"github.com/mitchellh/mapstructure"
 	"github.com/x448/float16"
 
 	"github.com/ollama/ollama/llm"
@@ -29,38 +27,36 @@ type safetensorWriterTo struct {
 	filename string
 	dtype    string
 
-	start, end, padding uint64
-	repacker            func(string, []float32, []uint64) ([]float32, error)
+	offset, size int64
+	repacker     func(string, []float32, []uint64) ([]float32, error)
 }
 
-type tensorMetaData struct {
-	Type    string `mapstructure:"dtype"`
-	Shape   []int  `mapstructure:"shape"`
-	Offsets []int  `mapstructure:"data_offsets"`
+type safetensorMetadata struct {
+	Type    string   `json:"dtype"`
+	Shape   []uint64 `json:"shape"`
+	Offsets []int64  `json:"data_offsets"`
 }
 
 type SafetensorFormat struct{}
 
 func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
-	slog.Debug("getting tensor data")
 	var tensors []llm.Tensor
-	files, err := filepath.Glob(filepath.Join(dirpath, "/model-*.safetensors"))
+	matches, err := filepath.Glob(filepath.Join(dirpath, "*.safetensors"))
 	if err != nil {
 		return nil, err
 	}
 
 	var offset uint64
-	for _, f := range files {
+	for _, f := range matches {
 		var t []llm.Tensor
 		var err error
 		t, offset, err = m.readTensors(f, offset, params)
 		if err != nil {
-			slog.Error(err.Error())
 			return nil, err
 		}
+
 		tensors = append(tensors, t...)
 	}
-	slog.Debug(fmt.Sprintf("all tensors = %d", len(tensors)))
 	return tensors, nil
 }
 
@@ -71,76 +67,57 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
 	}
 	defer f.Close()
 
-	var jsonSize uint64
-	if err := binary.Read(f, binary.LittleEndian, &jsonSize); err != nil {
+	var n int64
+	if err := binary.Read(f, binary.LittleEndian, &n); err != nil {
 		return nil, 0, err
 	}
 
-	buf := make([]byte, jsonSize)
-	_, err = io.ReadFull(f, buf)
-	if err != nil {
+	b := bytes.NewBuffer(make([]byte, 0, n))
+	if _, err = io.CopyN(b, f, n); err != nil {
 		return nil, 0, err
 	}
 
-	d := json.NewDecoder(bytes.NewBuffer(buf))
-	d.UseNumber()
-	var parsed map[string]interface{}
-	if err = d.Decode(&parsed); err != nil {
+	var headers map[string]safetensorMetadata
+	if err := json.NewDecoder(b).Decode(&headers); err != nil {
 		return nil, 0, err
 	}
 
 	var keys []string
-	for k := range parsed {
-		keys = append(keys, k)
+	for key := range headers {
+		if !strings.HasSuffix(key, "self_attn.rotary_embd.inv_freq") {
+			keys = append(keys, key)
+		}
 	}
 
 	slices.Sort(keys)
-	slog.Info("converting layers")
 
 	var tensors []llm.Tensor
-	for _, k := range keys {
-		if strings.HasSuffix(k, "self_attn.rotary_emb.inv_freq") {
-			continue
-		}
-
-		vals := parsed[k].(map[string]interface{})
-		var data tensorMetaData
-		if err = mapstructure.Decode(vals, &data); err != nil {
-			slog.Error("couldn't decode properly")
-			return nil, 0, err
-		}
+	for _, key := range keys {
+		value := headers[key]
 
-		var size uint64
 		var kind uint32
-		switch len(data.Shape) {
+		switch len(value.Shape) {
 		case 0:
-			// metadata
+			// valuedata
 			continue
-		case 1:
-			// convert to float32
-			kind = 0
-			size = uint64(data.Shape[0] * 4)
 		case 2:
-			// convert to float16
 			kind = 1
-			size = uint64(data.Shape[0] * data.Shape[1] * 2)
 		}
 
-		ggufName, err := m.GetLayerName(k)
+		name, err := m.GetLayerName(key)
 		if err != nil {
-			slog.Error(err.Error())
 			return nil, 0, err
 		}
 
-		shape := []uint64{0, 0, 0, 0}
-		for i := range data.Shape {
-			shape[i] = uint64(data.Shape[i])
-		}
+		shape := make([]uint64, len(value.Shape))
+		copy(shape, value.Shape)
 
-		slog.Debug(fmt.Sprintf("'%45s': '%30s' %10d [%#v]", k, ggufName, size, data.Shape))
+		pad := func(s int64) int64 {
+			return 8 + n + s
+		}
 
 		t := llm.Tensor{
-			Name:   ggufName,
+			Name:   name,
 			Kind:   kind,
 			Offset: offset,
 			Shape:  shape[:],
@@ -151,19 +128,15 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
 			params:   params,
 			bo:       params.ByteOrder,
 			filename: fn,
-			dtype:    data.Type,
-			start:    uint64(data.Offsets[0]),
-			end:      uint64(data.Offsets[1]),
-			padding:  8 + jsonSize,
+			dtype:    value.Type,
+			offset:   pad(value.Offsets[0]),
+			size:     pad(value.Offsets[1]) - pad(value.Offsets[0]),
 		}
 
-		offset += size
+		offset += t.Size()
 		tensors = append(tensors, t)
 	}
 
-	slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors)))
-	slog.Debug(fmt.Sprintf("offset = %d", offset))
-
 	return tensors, offset, nil
 }
 
@@ -176,9 +149,7 @@ func (m *SafetensorFormat) GetParams(dirpath string) (*Params, error) {
 
 	var params Params
 
-	d := json.NewDecoder(f)
-	err = d.Decode(&params)
-	if err != nil {
+	if err := json.NewDecoder(f).Decode(&params); err != nil {
 		return nil, err
 	}
 
@@ -233,34 +204,34 @@ func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
 	}
 	defer f.Close()
 
-	if _, err = f.Seek(int64(r.padding+r.start), 0); err != nil {
+	if _, err = f.Seek(r.offset, io.SeekStart); err != nil {
 		return 0, err
 	}
 
 	var f32s []float32
 	switch r.dtype {
 	case "F32":
-		f32s = make([]float32, (r.end-r.start)/4)
+		f32s = make([]float32, r.size/4)
 		if err = binary.Read(f, r.bo, f32s); err != nil {
 			return 0, err
 		}
 	case "F16":
-		bts := make([]uint16, (r.end-r.start)/2)
-		if err = binary.Read(f, r.bo, bts); err != nil {
+		u16s := make([]uint16, r.size/2)
+		if err = binary.Read(f, r.bo, u16s); err != nil {
 			return 0, err
 		}
 
-		for _, b := range bts {
+		for _, b := range u16s {
 			f32s = append(f32s, float16.Frombits(b).Float32())
 		}
 
 	case "BF16":
-		bts := make([]byte, r.end-r.start)
-		if err = binary.Read(f, r.bo, bts); err != nil {
+		u8s := make([]uint8, r.size)
+		if err = binary.Read(f, r.bo, u8s); err != nil {
 			return 0, err
 		}
 
-		f32s = bfloat16.DecodeFloat32(bts)
+		f32s = bfloat16.DecodeFloat32(u8s)
 	default:
 		return 0, fmt.Errorf("unknown data type: %s", r.dtype)
 	}

+ 0 - 1
go.mod

@@ -8,7 +8,6 @@ require (
 	github.com/gin-gonic/gin v1.10.0
 	github.com/golang/protobuf v1.5.4 // indirect
 	github.com/google/uuid v1.1.2
-	github.com/mitchellh/mapstructure v1.5.0
 	github.com/olekukonko/tablewriter v0.0.5
 	github.com/spf13/cobra v1.7.0
 	github.com/stretchr/testify v1.9.0

+ 0 - 2
go.sum

@@ -135,8 +135,6 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
 github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
 github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
 github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
-github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
-github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

+ 1 - 1
llm/ggla.go

@@ -119,7 +119,7 @@ func (llm *ggla) decode(rs io.ReadSeeker) error {
 
 		t.Offset = uint64(offset)
 
-		if _, err := rs.Seek(int64(t.size()), io.SeekCurrent); err != nil {
+		if _, err := rs.Seek(int64(t.Size()), io.SeekCurrent); err != nil {
 			return err
 		}
 

+ 3 - 3
llm/ggml.go

@@ -106,7 +106,7 @@ type Layer map[string]*Tensor
 
 func (l Layer) size() (size uint64) {
 	for _, t := range l {
-		size += t.size()
+		size += t.Size()
 	}
 
 	return size
@@ -185,7 +185,7 @@ func (t Tensor) parameters() uint64 {
 	return count
 }
 
-func (t Tensor) size() uint64 {
+func (t Tensor) Size() uint64 {
 	return t.parameters() * t.typeSize() / t.blockSize()
 }
 
@@ -288,7 +288,7 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 			// mixtral 8x22b
 			ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
 			partialOffload = max(
-				3*ffnGateExpsWeight.size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
+				3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
 				4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch),
 			)
 		} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {

+ 2 - 2
llm/gguf.go

@@ -241,11 +241,11 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
 	}
 
 	for _, tensor := range llm.tensors {
-		if _, err := rs.Seek(int64(tensor.size()), io.SeekCurrent); err != nil {
+		if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
 			return err
 		}
 
-		padding := llm.padding(int64(tensor.size()), int64(alignment))
+		padding := llm.padding(int64(tensor.Size()), int64(alignment))
 		if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
 			return err
 		}