Browse Source

vocab: Use int32 for special tokens

Special tokens are currently read as uint32 from the model metadata.
However, all other parts of the system (including the tokenizer) use
int32 to represent tokens so it is impossible to represent the high
portion of the unsigned range. For consistency and to avoid casts,
we should just use int32 everywhere.
Jesse Gross 2 months ago
parent
commit
7916f55009
3 changed files with 8 additions and 8 deletions
  1. 2 2
      model/llama/model.go
  2. 2 2
      model/mllama/model.go
  3. 4 4
      model/process_text.go

+ 2 - 2
model/llama/model.go

@@ -35,8 +35,8 @@ func New(c ml.Config) (model.Model, error) {
 				Values: c.Strings("tokenizer.ggml.tokens"),
 				Values: c.Strings("tokenizer.ggml.tokens"),
 				Types:  c.Uints("tokenizer.ggml.token_type"),
 				Types:  c.Uints("tokenizer.ggml.token_type"),
 				Merges: c.Strings("tokenizer.ggml.merges"),
 				Merges: c.Strings("tokenizer.ggml.merges"),
-				BOS:    c.Uint("tokenizer.ggml.bos_token_id"),
-				EOS:    c.Uint("tokenizer.ggml.eos_token_id"),
+				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
+				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
 			},
 			},
 		),
 		),
 		Layers: make([]Layer, c.Uint("block_count")),
 		Layers: make([]Layer, c.Uint("block_count")),

+ 2 - 2
model/mllama/model.go

@@ -26,8 +26,8 @@ func New(c ml.Config) (model.Model, error) {
 				Values: c.Strings("tokenizer.ggml.tokens"),
 				Values: c.Strings("tokenizer.ggml.tokens"),
 				Types:  c.Uints("tokenizer.ggml.token_type"),
 				Types:  c.Uints("tokenizer.ggml.token_type"),
 				Merges: c.Strings("tokenizer.ggml.merges"),
 				Merges: c.Strings("tokenizer.ggml.merges"),
-				BOS:    c.Uint("tokenizer.ggml.bos_token_id"),
-				EOS:    c.Uint("tokenizer.ggml.eos_token_id"),
+				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
+				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
 			},
 			},
 		),
 		),
 		ImageProcessor: newImageProcessor(c),
 		ImageProcessor: newImageProcessor(c),

+ 4 - 4
model/process_text.go

@@ -21,7 +21,7 @@ const (
 type TextProcessor interface {
 type TextProcessor interface {
 	Encode(string) ([]int32, error)
 	Encode(string) ([]int32, error)
 	Decode([]int32) (string, error)
 	Decode([]int32) (string, error)
-	Is(uint32, Special) bool
+	Is(int32, Special) bool
 }
 }
 
 
 type Vocabulary struct {
 type Vocabulary struct {
@@ -30,7 +30,7 @@ type Vocabulary struct {
 	Scores []uint32
 	Scores []uint32
 	Merges []string
 	Merges []string
 
 
-	BOS, EOS uint32
+	BOS, EOS int32
 
 
 	specialOnce sync.Once
 	specialOnce sync.Once
 	special     []string
 	special     []string
@@ -42,7 +42,7 @@ type Vocabulary struct {
 	merge     map[string]int32
 	merge     map[string]int32
 }
 }
 
 
-func (v *Vocabulary) Is(id uint32, special Special) bool {
+func (v *Vocabulary) Is(id int32, special Special) bool {
 	switch special {
 	switch special {
 	case SpecialBOS:
 	case SpecialBOS:
 		return id == v.BOS
 		return id == v.BOS
@@ -111,7 +111,7 @@ func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
 	}
 	}
 }
 }
 
 
-func (bpe BytePairEncoding) Is(id uint32, special Special) bool {
+func (bpe BytePairEncoding) Is(id int32, special Special) bool {
 	return bpe.vocab.Is(id, special)
 	return bpe.vocab.Is(id, special)
 }
 }