Browse Source

Merge pull request #3083 from ollama/mxyng/refactor-readseeker

refactor readseeker
Michael Yang 1 year ago
parent
commit
22f326464e
3 changed files with 72 additions and 70 deletions
  1. 24 14
      llm/ggla.go
  2. 10 28
      llm/ggml.go
  3. 38 28
      llm/gguf.go

+ 24 - 14
llm/ggla.go

@@ -15,8 +15,8 @@ func (c *ContainerGGLA) Name() string {
 	return "ggla"
 }
 
-func (c *ContainerGGLA) Decode(rso *readSeekOffset) (model, error) {
-	binary.Read(rso, binary.LittleEndian, &c.version)
+func (c *ContainerGGLA) Decode(rs io.ReadSeeker) (model, error) {
+	binary.Read(rs, binary.LittleEndian, &c.version)
 
 	switch c.version {
 	case 1:
@@ -25,7 +25,7 @@ func (c *ContainerGGLA) Decode(rso *readSeekOffset) (model, error) {
 	}
 
 	model := newModelGGLA(c)
-	err := model.decode(rso)
+	err := model.decode(rs)
 	return model, err
 }
 
@@ -43,39 +43,39 @@ func newModelGGLA(container *ContainerGGLA) *ModelGGLA {
 	}
 }
 
-func (m *ModelGGLA) decode(rso *readSeekOffset) error {
+func (m *ModelGGLA) decode(rs io.ReadSeeker) error {
 	var r uint32
-	if err := binary.Read(rso, binary.LittleEndian, &r); err != nil {
+	if err := binary.Read(rs, binary.LittleEndian, &r); err != nil {
 		return err
 	}
 	m.kv["r"] = r
 
 	var alpha uint32
-	if err := binary.Read(rso, binary.LittleEndian, &alpha); err != nil {
+	if err := binary.Read(rs, binary.LittleEndian, &alpha); err != nil {
 		return err
 	}
 	m.kv["alpha"] = alpha
 
 	for {
 		var dims uint32
-		if err := binary.Read(rso, binary.LittleEndian, &dims); err != nil {
+		if err := binary.Read(rs, binary.LittleEndian, &dims); err != nil {
 			return err
 		}
 
 		var namesize uint32
-		if err := binary.Read(rso, binary.LittleEndian, &namesize); err != nil {
+		if err := binary.Read(rs, binary.LittleEndian, &namesize); err != nil {
 			return err
 		}
 
 		var t Tensor
-		if err := binary.Read(rso, binary.LittleEndian, &t.Kind); err != nil {
+		if err := binary.Read(rs, binary.LittleEndian, &t.Kind); err != nil {
 			return err
 		}
 
 		t.Shape = make([]uint64, dims)
 		for i := 0; uint32(i) < dims; i++ {
 			var shape32 uint32
-			if err := binary.Read(rso, binary.LittleEndian, &shape32); err != nil {
+			if err := binary.Read(rs, binary.LittleEndian, &shape32); err != nil {
 				return err
 			}
 
@@ -87,19 +87,29 @@ func (m *ModelGGLA) decode(rso *readSeekOffset) error {
 		slices.Reverse(t.Shape)
 
 		name := make([]byte, namesize)
-		if err := binary.Read(rso, binary.LittleEndian, &name); err != nil {
+		if err := binary.Read(rs, binary.LittleEndian, &name); err != nil {
 			return err
 		}
 
 		t.Name = string(name)
 
-		if _, err := rso.Seek((rso.offset+31)&-32, io.SeekStart); err != nil {
+		offset, err := rs.Seek(0, io.SeekCurrent)
+		if err != nil {
 			return err
 		}
 
-		t.Offset = uint64(rso.offset)
+		if _, err := rs.Seek((offset+31)&-32, io.SeekStart); err != nil {
+			return err
+		}
+
+		offset, err = rs.Seek(0, io.SeekCurrent)
+		if err != nil {
+			return err
+		}
+
+		t.Offset = uint64(offset)
 
-		if _, err := rso.Seek(int64(t.Size()), io.SeekCurrent); err != nil {
+		if _, err := rs.Seek(int64(t.Size()), io.SeekCurrent); err != nil {
 			return err
 		}
 

+ 10 - 28
llm/ggml.go

@@ -103,7 +103,7 @@ type model interface {
 
 type container interface {
 	Name() string
-	Decode(*readSeekOffset) (model, error)
+	Decode(io.ReadSeeker) (model, error)
 }
 
 const (
@@ -122,11 +122,9 @@ const (
 
 var ErrUnsupportedFormat = errors.New("unsupported model format")
 
-func DecodeGGML(r io.ReadSeeker) (*GGML, error) {
-	ro := readSeekOffset{ReadSeeker: r}
-
+func DecodeGGML(rs io.ReadSeeker) (*GGML, error) {
 	var magic uint32
-	if err := binary.Read(&ro, binary.LittleEndian, &magic); err != nil {
+	if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
 		return nil, err
 	}
 
@@ -144,38 +142,22 @@ func DecodeGGML(r io.ReadSeeker) (*GGML, error) {
 		return nil, errors.New("invalid file magic")
 	}
 
-	model, err := c.Decode(&ro)
+	model, err := c.Decode(rs)
 	if errors.Is(err, io.EOF) {
 		// noop
 	} else if err != nil {
 		return nil, err
 	}
 
+	offset, err := rs.Seek(0, io.SeekCurrent)
+	if err != nil {
+		return nil, err
+	}
+
 	// final model type
 	return &GGML{
 		container: c,
 		model:     model,
-		Size:      ro.offset,
+		Size:      offset,
 	}, nil
 }
-
-type readSeekOffset struct {
-	io.ReadSeeker
-	offset int64
-}
-
-func (rso *readSeekOffset) Seek(offset int64, whence int) (int64, error) {
-	offset, err := rso.ReadSeeker.Seek(offset, whence)
-	if err != nil {
-		return 0, err
-	}
-
-	rso.offset = offset
-	return offset, nil
-}
-
-func (rso *readSeekOffset) Read(p []byte) (int, error) {
-	n, err := rso.ReadSeeker.Read(p)
-	rso.offset += int64(n)
-	return n, err
-}

+ 38 - 28
llm/gguf.go

@@ -43,18 +43,18 @@ func (c *ContainerGGUF) Name() string {
 	return "gguf"
 }
 
-func (c *ContainerGGUF) Decode(rso *readSeekOffset) (model, error) {
-	binary.Read(rso, c.ByteOrder, &c.Version)
+func (c *ContainerGGUF) Decode(rs io.ReadSeeker) (model, error) {
+	binary.Read(rs, c.ByteOrder, &c.Version)
 
 	switch c.Version {
 	case 1:
-		binary.Read(rso, c.ByteOrder, &c.V1)
+		binary.Read(rs, c.ByteOrder, &c.V1)
 	default:
-		binary.Read(rso, c.ByteOrder, &c.V2)
+		binary.Read(rs, c.ByteOrder, &c.V2)
 	}
 
 	model := NewGGUFModel(c)
-	if err := model.Decode(rso); err != nil {
+	if err := model.Decode(rs); err != nil {
 		return nil, err
 	}
 
@@ -634,49 +634,49 @@ func (llm *GGUFModel) writeString(f *os.File, s string) error {
 	return nil
 }
 
-func (llm *GGUFModel) Decode(rso *readSeekOffset) error {
+func (llm *GGUFModel) Decode(rs io.ReadSeeker) error {
 	// decode key-values
 	for i := 0; uint64(i) < llm.NumKV(); i++ {
-		k, err := llm.readString(rso)
+		k, err := llm.readString(rs)
 		if err != nil {
 			return err
 		}
 
-		vtype := llm.readU32(rso)
+		vtype := llm.readU32(rs)
 
 		var v any
 		switch vtype {
 		case GGUFTypeUint8:
-			v = llm.readU8(rso)
+			v = llm.readU8(rs)
 		case GGUFTypeInt8:
-			v = llm.readI8(rso)
+			v = llm.readI8(rs)
 		case GGUFTypeUint16:
-			v = llm.readU16(rso)
+			v = llm.readU16(rs)
 		case GGUFTypeInt16:
-			v = llm.readI16(rso)
+			v = llm.readI16(rs)
 		case GGUFTypeUint32:
-			v = llm.readU32(rso)
+			v = llm.readU32(rs)
 		case GGUFTypeInt32:
-			v = llm.readI32(rso)
+			v = llm.readI32(rs)
 		case GGUFTypeUint64:
-			v = llm.readU64(rso)
+			v = llm.readU64(rs)
 		case GGUFTypeInt64:
-			v = llm.readI64(rso)
+			v = llm.readI64(rs)
 		case GGUFTypeFloat32:
-			v = llm.readF32(rso)
+			v = llm.readF32(rs)
 		case GGUFTypeFloat64:
-			v = llm.readF64(rso)
+			v = llm.readF64(rs)
 		case GGUFTypeBool:
-			v = llm.readBool(rso)
+			v = llm.readBool(rs)
 		case GGUFTypeString:
-			s, err := llm.readString(rso)
+			s, err := llm.readString(rs)
 			if err != nil {
 				return err
 			}
 
 			v = s
 		case GGUFTypeArray:
-			a, err := llm.readArray(rso)
+			a, err := llm.readArray(rs)
 			if err != nil {
 				return err
 			}
@@ -691,23 +691,23 @@ func (llm *GGUFModel) Decode(rso *readSeekOffset) error {
 
 	// decode tensors
 	for i := 0; uint64(i) < llm.NumTensor(); i++ {
-		name, err := llm.readString(rso)
+		name, err := llm.readString(rs)
 		if err != nil {
 			return err
 		}
 
 		// dims is the number of dimensions in the tensor
-		dims := llm.readU32(rso)
+		dims := llm.readU32(rs)
 
 		shape := [4]uint64{1, 1, 1, 1}
 		for i := 0; uint32(i) < dims; i++ {
-			shape[i] = llm.readU64(rso)
+			shape[i] = llm.readU64(rs)
 		}
 
 		tensor := Tensor{
 			Name:   name,
-			Kind:   llm.readU32(rso),
-			Offset: llm.readU64(rso),
+			Kind:   llm.readU32(rs),
+			Offset: llm.readU64(rs),
 			Shape:  shape[:],
 		}
 
@@ -720,10 +720,20 @@ func (llm *GGUFModel) Decode(rso *readSeekOffset) error {
 		alignment = 32
 	}
 
-	rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent)
+	offset, err := rs.Seek(0, io.SeekCurrent)
+	if err != nil {
+		return err
+	}
+
+	if _, err := rs.Seek(int64(alignment)-offset%int64(alignment), io.SeekCurrent); err != nil {
+		return err
+	}
+
 	for _, tensor := range llm.Tensors {
 		padded := (int64(tensor.Size()) + int64(alignment) - 1) & ^(int64(alignment) - 1)
-		rso.Seek(padded, io.SeekCurrent)
+		if _, err := rs.Seek(padded, io.SeekCurrent); err != nil {
+			return err
+		}
 	}
 
 	return nil