Browse Source

throw an error when encountering unsupport tensor sizes (#6538)

Patrick Devine 8 months ago
parent
commit
6c1c1ad6a9
2 changed files with 106 additions and 0 deletions
  1. 101 0
      convert/convert_test.go
  2. 5 0
      convert/reader_safetensors.go

+ 101 - 0
convert/convert_test.go

@@ -140,6 +140,107 @@ func TestConvertFull(t *testing.T) {
 	}
 }
 
+func TestConvertInvalidDatatype(t *testing.T) {
+	f, err := os.CreateTemp(t.TempDir(), "testmodel")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer f.Close()
+
+	tempDir := t.TempDir()
+	generateSafetensorTestData(t, tempDir)
+
+	err = ConvertModel(os.DirFS(tempDir), f)
+	if err == nil || err.Error() != "unsupported safetensors model" {
+		t.Errorf("expected error but didn't get one")
+	}
+}
+
+func generateSafetensorTestData(t *testing.T, tempDir string) {
+	type tensorData struct {
+		Offsets []int  `json:"data_offsets"`
+		Type    string `json:"dtype"`
+		Shape   []int  `json:"shape"`
+	}
+	offset := 4096 * 14336
+
+	td := map[string]*tensorData{}
+	td["model.layers.0.mlp.down_proj.weight"] = &tensorData{
+		Offsets: []int{0, offset},
+		Type:    "I8",
+		Shape:   []int{4096, 14336},
+	}
+	td["model.layers.0.mlp.down_proj.weight_format"] = &tensorData{
+		Offsets: []int{offset, offset},
+		Type:    "U8",
+		Shape:   []int{},
+	}
+
+	data, err := json.Marshal(td)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	var buf bytes.Buffer
+
+	l := int64(len(data))
+	err = binary.Write(&buf, binary.LittleEndian, l)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	_, err = buf.Write(data)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	fdata, err := os.Create(filepath.Join(tempDir, "model-00001-of-00001.safetensors"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer fdata.Close()
+
+	_, err = fdata.Write(buf.Bytes())
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	configData := `
+{
+  "architectures": [
+    "LlamaForCausalLM"
+  ]
+}
+`
+
+	f, err := os.Create(filepath.Join(tempDir, "config.json"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer f.Close()
+
+	_, err = f.WriteString(configData)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	tokenizerData := `
+{
+}
+`
+
+	f, err = os.Create(filepath.Join(tempDir, "tokenizer.json"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer f.Close()
+
+	_, err = f.WriteString(tokenizerData)
+	if err != nil {
+		t.Fatal(err)
+	}
+}
+
 func TestConvertAdapter(t *testing.T) {
 	type AdapterCase struct {
 		Name     string

+ 5 - 0
convert/reader_safetensors.go

@@ -4,6 +4,7 @@ import (
 	"bytes"
 	"encoding/binary"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"io/fs"
@@ -50,6 +51,10 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
 
 		for _, key := range keys {
 			if value := headers[key]; value.Type != "" {
+				// bitsandbytes quantized models are unsupported
+				if len(value.Shape) == 0 {
+					return nil, errors.New("unsupported safetensors model")
+				}
 				ts = append(ts, safetensor{
 					fs:     fsys,
 					path:   p,