浏览代码

bpe pretokenizer

Michael Yang 11 月之前
父节点
当前提交
547132e820
共有 4 个文件被更改,包括 79 次插入55 次删除
  1. 2 0
      convert/convert.go
  2. 14 30
      convert/llama.go
  3. 62 25
      convert/tokenizer.go
  4. 1 0
      llm/gguf.go

+ 2 - 0
convert/convert.go

@@ -37,6 +37,8 @@ type Params struct {
 	Experts     int `json:"num_local_experts"`
 	ExpertsUsed int `json:"num_experts_per_tok"`
 
+	PreTokenizer string
+
 	ByteOrder
 }
 

+ 14 - 30
convert/llama.go

@@ -2,9 +2,9 @@ package convert
 
 import (
 	"encoding/binary"
+	"errors"
 	"fmt"
 	"io"
-	"log/slog"
 	"os"
 	"path/filepath"
 	"regexp"
@@ -134,44 +134,27 @@ func (m *LlamaModel) GetTensors() error {
 }
 
 func (m *LlamaModel) LoadVocab() error {
-	v := &Vocab{
-		Tokens: []string{},
-		Types:  []int32{},
-		Merges: []string{},
-	}
+	v := &Vocab{}
 
 	tokpath := filepath.Join(m.Path, "tokenizer.json")
-	slog.Debug(fmt.Sprintf("looking for %s", tokpath))
-	if _, err := os.Stat(tokpath); !os.IsNotExist(err) {
-		t, err := newTokenizer(tokpath)
+	pre, ts, merges, err := parseTokens(tokpath)
+	if errors.Is(err, os.ErrNotExist) {
+		v, err = LoadSentencePieceTokens(m.Path, m.Params)
 		if err != nil {
 			return err
 		}
-
-		for _, tok := range t.Model.Tokens {
-			v.Tokens = append(v.Tokens, tok.Content)
-			var tokType int32
-			switch {
-			case tok.Special:
-				tokType = 3
-			case tok.UserDefined:
-				tokType = 4
-			default:
-				tokType = 1
-			}
-			v.Types = append(v.Types, tokType)
-		}
-		v.Merges = t.Model.Merges
+	} else if err != nil {
+		return err
 	} else {
-		slog.Debug("loading sentence piece vocab")
-		v, err = LoadSentencePieceTokens(m.Path, m.Params)
-		if err != nil {
-			return err
+		for _, t := range ts {
+			v.Tokens = append(v.Tokens, t.Content)
+			v.Types = append(v.Types, t.Type())
 		}
 
-		slog.Debug("vocab loaded")
-
+		m.Params.PreTokenizer = pre
+		v.Merges = merges
 	}
+
 	m.Vocab = v
 
 	return nil
@@ -194,6 +177,7 @@ func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
 		"general.file_type":                      uint32(2),
 		"tokenizer.ggml.model":                   "gpt2",
 
+		"tokenizer.ggml.pre":        m.Params.PreTokenizer,
 		"tokenizer.ggml.tokens":     m.Vocab.Tokens,
 		"tokenizer.ggml.token_type": m.Vocab.Types,
 

+ 62 - 25
convert/tokenizer.go

@@ -1,15 +1,30 @@
 package convert
 
 import (
+	"cmp"
+	"crypto/sha256"
 	"encoding/json"
-	"io/ioutil"
+	"fmt"
+	"log/slog"
 	"os"
+	"slices"
+
+	"golang.org/x/exp/maps"
 )
 
 type Tokenizer struct {
 	Version     string         `json:"version"`
 	AddedTokens []Token        `json:"added_tokens"`
 	Model       TokenizerModel `json:"model"`
+
+	PreTokenizer struct {
+		PreTokenziers []struct {
+			Type    string `json:"type"`
+			Pattern struct {
+				Regex string `json:"Regex"`
+			} `json:"pattern"`
+		} `json:"pretokenizers"`
+	} `json:"pre_tokenizer"`
 }
 
 type TokenizerModel struct {
@@ -26,47 +41,69 @@ type Token struct {
 	UserDefined bool
 }
 
-func (t *Tokenizer) getMaxID() int {
-	var maxID int
-	for _, v := range t.Model.Vocab {
-		maxID = max(maxID, v)
+func (t *Token) Type() int32 {
+	switch {
+	case t.Special:
+		return 3
+	case t.UserDefined:
+		return 4
+	default:
+		return 1
 	}
+}
 
-	for _, v := range t.AddedTokens {
-		maxID = max(maxID, v.ID)
-	}
-	return maxID
+func (t *Tokenizer) maxID() int {
+	return max(
+		slices.Max(maps.Values(t.Model.Vocab)),
+		slices.MaxFunc(t.AddedTokens, func(a, b Token) int {
+			return cmp.Compare(a.ID, b.ID)
+		}).ID,
+	)
 }
 
-func newTokenizer(dirpath string) (*Tokenizer, error) {
+func parseTokens(dirpath string) (pre string, tokens []Token, merges []string, err error) {
 	f, err := os.Open(dirpath)
 	if err != nil {
 		panic(err)
 	}
 	defer f.Close()
 
-	data, err := ioutil.ReadAll(f)
-	if err != nil {
-		return nil, err
+	var t Tokenizer
+	if err := json.NewDecoder(f).Decode(&t); err != nil {
+		return "", nil, nil, err
 	}
 
-	var tdata Tokenizer
-
-	if err := json.Unmarshal(data, &tdata); err != nil {
-		return nil, err
+	tokens = make([]Token, t.maxID()+1)
+	for k, v := range t.Model.Vocab {
+		tokens[v] = Token{ID: v, Content: k, Special: false, UserDefined: false}
 	}
 
-	maxID := tdata.getMaxID()
-	tdata.Model.Tokens = make([]Token, maxID+1)
+	for _, v := range t.AddedTokens {
+		v.UserDefined = true
+		tokens[v.ID] = v
+	}
 
-	for k, v := range tdata.Model.Vocab {
-		tdata.Model.Tokens[v] = Token{ID: v, Content: k, Special: false, UserDefined: false}
+	sha256sum := sha256.New()
+	for _, pt := range t.PreTokenizer.PreTokenziers {
+		switch pt.Type {
+		case "Split":
+			if pt.Pattern.Regex != "" {
+				sha256sum.Write([]byte(pt.Pattern.Regex))
+			}
+		}
 	}
 
-	for _, v := range tdata.AddedTokens {
-		v.UserDefined = true
-		tdata.Model.Tokens[v.ID] = v
+	switch digest := fmt.Sprintf("%x", sha256sum.Sum(nil)); digest {
+	case "d98f9631be1e9607a9848c26c1f9eac1aa9fc21ac6ba82a2fc0741af9780a48f":
+		pre = "llama-bpe"
+	case "03df5c5863ad70781dcfdef491ead25140f895fe8010964be0daefe27be32b02":
+		pre = "deepseek-llm"
+	case "21cde974d587f0d54dc8d56b183cc1e6239600172035c68fbd6d4b9f8da0576e":
+		pre = "deepseek-coder"
+	default:
+		slog.Warn("unknown pretokenizer, using default", "digest", digest)
+		pre = "default"
 	}
 
-	return &tdata, nil
+	return pre, tokens, t.Model.Merges, nil
 }

+ 1 - 0
llm/gguf.go

@@ -480,6 +480,7 @@ var ggufKVOrder = map[string][]string{
 		"gemma.attention.key_length",
 		"gemma.attention.value_length",
 		"general.file_type",
+		"tokenizer.ggml.pre",
 		"tokenizer.ggml.model",
 		"tokenizer.ggml.tokens",
 		"tokenizer.ggml.scores",