Browse Source

fix unmarshaling merges

Michael Yang 4 tháng trước cách đây
mục cha
commit
4456012956
2 tập tin đã thay đổi với 79 bổ sung4 xóa
  1. 23 4
      convert/tokenizer.go
  2. 56 0
      convert/tokenizer_test.go

+ 23 - 4
convert/tokenizer.go

@@ -10,6 +10,7 @@ import (
 	"log/slog"
 	"os"
 	"slices"
+	"strings"
 
 	"golang.org/x/exp/maps"
 )
@@ -60,7 +61,25 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
 			addedTokens[t.Content] = t
 		}
 
-		t.Merges = tt.Model.Merges
+		if len(tt.Model.Merges) == 0 {
+			// noop; merges is empty
+		} else if err := json.Unmarshal(tt.Model.Merges, &t.Merges); err == nil {
+			// noop; merges is []string
+		} else if merges, err := func() ([][]string, error) {
+			var merges [][]string
+			if err := json.Unmarshal(tt.Model.Merges, &merges); err != nil {
+				return nil, err
+			}
+
+			return merges, nil
+		}(); err == nil {
+			t.Merges = make([]string, len(merges))
+			for i := range merges {
+				t.Merges[i] = strings.Join(merges[i], " ")
+			}
+		} else {
+			return nil, fmt.Errorf("could not parse tokenizer merges. expected []string or [][]string: %w", err)
+		}
 
 		sha256sum := sha256.New()
 		for _, pt := range tt.PreTokenizer.PreTokenizers {
@@ -156,9 +175,9 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
 type tokenizer struct {
 	AddedTokens []token `json:"added_tokens"`
 	Model       struct {
-		Type   string         `json:"type"`
-		Vocab  map[string]int `json:"vocab"`
-		Merges []string       `json:"merges"`
+		Type   string          `json:"type"`
+		Vocab  map[string]int  `json:"vocab"`
+		Merges json.RawMessage `json:"merges"`
 	} `json:"model"`
 
 	PreTokenizer struct {

+ 56 - 0
convert/tokenizer_test.go

@@ -191,6 +191,62 @@ func TestParseTokenizer(t *testing.T) {
 				Pre: "default",
 			},
 		},
+		{
+			name: "list string merges",
+			fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
+				"tokenizer.json": strings.NewReader(`{
+					"model": {
+						"merges": [
+							"a b",
+							"c d",
+							"e f"
+						]
+					}
+				}`),
+			}),
+			want: &Tokenizer{
+				Vocabulary: &Vocabulary{
+					Model: "gpt2",
+				},
+				Merges: []string{
+					"a b",
+					"c d",
+					"e f",
+				},
+				Pre: "default",
+			},
+		},
+		{
+			name: "list list string merges",
+			fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
+				"tokenizer.json": strings.NewReader(`{
+					"model": {
+						"merges": [
+							[
+								"a", "b"
+							],
+							[
+								"c", "d"
+							],
+							[
+								"e", "f"
+							]
+						]
+					}
+				}`),
+			}),
+			want: &Tokenizer{
+				Vocabulary: &Vocabulary{
+					Model: "gpt2",
+				},
+				Merges: []string{
+					"a b",
+					"c d",
+					"e f",
+				},
+				Pre: "default",
+			},
+		},
 	}
 
 	for _, tt := range cases {