|
@@ -21,7 +21,7 @@ const (
|
|
|
type TextProcessor interface {
|
|
|
Encode(string) ([]int32, error)
|
|
|
Decode([]int32) (string, error)
|
|
|
- Is(uint32, Special) bool
|
|
|
+ Is(int32, Special) bool
|
|
|
}
|
|
|
|
|
|
type Vocabulary struct {
|
|
@@ -30,7 +30,7 @@ type Vocabulary struct {
|
|
|
Scores []uint32
|
|
|
Merges []string
|
|
|
|
|
|
- BOS, EOS uint32
|
|
|
+ BOS, EOS int32
|
|
|
|
|
|
specialOnce sync.Once
|
|
|
special []string
|
|
@@ -42,7 +42,7 @@ type Vocabulary struct {
|
|
|
merge map[string]int32
|
|
|
}
|
|
|
|
|
|
-func (v *Vocabulary) Is(id uint32, special Special) bool {
|
|
|
+func (v *Vocabulary) Is(id int32, special Special) bool {
|
|
|
switch special {
|
|
|
case SpecialBOS:
|
|
|
return id == v.BOS
|
|
@@ -111,7 +111,7 @@ func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (bpe BytePairEncoding) Is(id uint32, special Special) bool {
|
|
|
+func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
|
|
return bpe.vocab.Is(id, special)
|
|
|
}
|
|
|
|