Browse Source

wordpiece

Michael Yang 2 months ago
parent
commit
35495803f3
2 changed files with 31 additions and 5 deletions
  1. 8 4
      model/bert/model.go
  2. 23 1
      model/process_text.go

+ 8 - 4
model/bert/model.go

@@ -30,7 +30,7 @@ type Options struct {
 
 type Model struct {
 	model.Base
-	model.BytePairEncoding
+	model.WordPiece
 
 	TokenEmbedding     *nn.Embedding `gguf:"token_embd"`
 	TypeEmbedding      *nn.Embedding `gguf:"type_embd,alt:token_types"`
@@ -166,14 +166,18 @@ func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml
 func New(c ml.Config) (model.Model, error) {
 	return &Model{
 		Layers: make([]EncoderLayer, c.Uint("block_count")),
-		BytePairEncoding: model.NewBytePairEncoding(
-			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
+		WordPiece: model.NewWordPiece(
 			&model.Vocabulary{
 				Values: c.Strings("tokenizer.ggml.tokens"),
 				Types:  c.Uints("tokenizer.ggml.token_type"),
-				Merges: c.Strings("tokenizer.ggml.merges"),
+				Scores: c.Uints("tokenizer.ggml.token_scores"),
 				BOS:    c.Uint("tokenizer.ggml.bos_token_id"),
 				EOS:    c.Uint("tokenizer.ggml.eos_token_id"),
+				UNK:    c.Uint("tokenizer.ggml.unknown_token_id"),
+				SEP:    c.Uint("tokenizer.ggml.separator_token_id"),
+				PAD:    c.Uint("tokenizer.ggml.padding_token_id"),
+				CLS:    c.Uint("tokenizer.ggml.cls_token_id"),
+				MASK:   c.Uint("tokenizer.ggml.mask_token_id"),
 			},
 		),
 		Options: &Options{

+ 23 - 1
model/process_text.go

@@ -30,7 +30,7 @@ type Vocabulary struct {
 	Scores []uint32
 	Merges []string
 
-	BOS, EOS uint32
+	BOS, EOS, UNK, SEP, PAD, CLS, MASK uint32
 
 	specialOnce sync.Once
 	special     []string
@@ -311,3 +311,25 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
 	slog.Debug("decoded", "ids", ids, "text", sb.String())
 	return sb.String(), nil
 }
+
+type WordPiece struct {
+	vocab *Vocabulary
+}
+
+func NewWordPiece(vocab *Vocabulary) WordPiece {
+	return WordPiece{
+		vocab: vocab,
+	}
+}
+
+func (wp WordPiece) Is(id uint32, special Special) bool {
+	panic("not implemented")
+}
+
+func (wp WordPiece) Encode(s string) ([]int32, error) {
+	panic("not implemented")
+}
+
+func (wp WordPiece) Decode(ids []int32) (string, error) {
+	panic("not implemented")
+}