Prechádzať zdrojové kódy

model: add bos token if configured

Michael Yang 2 mesiacov pred
rodič
commit
53d2990d9b

+ 5 - 1
fs/ggml/ggml.go

@@ -100,6 +100,10 @@ func (kv KV) Float(key string, defaultValue ...float32) float32 {
 	return keyValue(kv, key, append(defaultValue, 0)...)
 }
 
+func (kv KV) Bool(key string, defaultValue ...bool) bool {
+	return keyValue(kv, key, append(defaultValue, false)...)
+}
+
 func (kv KV) Strings(key string, defaultValue ...[]string) []string {
 	r := keyValue(kv, key, &array{})
 	s := make([]string, r.size)
@@ -120,7 +124,7 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
 	return s
 }
 
-func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T {
+func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
 	if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
 		key = kv.Architecture() + "." + key
 	}

+ 1 - 0
ml/backend.go

@@ -14,6 +14,7 @@ type Config interface {
 	String(string, ...string) string
 	Uint(string, ...uint32) uint32
 	Float(string, ...float32) float32
+	Bool(string, ...bool) bool
 
 	Strings(string, ...[]string) []string
 	Uints(string, ...[]uint32) []uint32

+ 2 - 0
model/models/llama/model.go

@@ -37,7 +37,9 @@ func New(c ml.Config) (model.Model, error) {
 				Types:  c.Uints("tokenizer.ggml.token_type"),
 				Merges: c.Strings("tokenizer.ggml.merges"),
 				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
+				AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
 				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
+				AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
 			},
 		),
 		Layers: make([]Layer, c.Uint("block_count")),

+ 2 - 0
model/models/mllama/model.go

@@ -33,7 +33,9 @@ func New(c ml.Config) (model.Model, error) {
 				Types:  c.Uints("tokenizer.ggml.token_type"),
 				Merges: c.Strings("tokenizer.ggml.merges"),
 				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
+				AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
 				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
+				AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
 			},
 		),
 		ImageProcessor: newImageProcessor(c),

+ 22 - 1
model/process_text.go

@@ -30,7 +30,8 @@ type Vocabulary struct {
 	Scores []uint32
 	Merges []string
 
-	BOS, EOS int32
+	BOS, EOS       int32
+	AddBOS, AddEOS bool
 
 	specialOnce sync.Once
 	special     []string
@@ -281,6 +282,26 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
 		}
 	}
 
+	if len(ids) > 0 {
+		if bpe.vocab.AddBOS {
+			if ids[0] == bpe.vocab.BOS {
+				slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
+			}
+
+			slog.Debug("adding bos token to prompt", "id", bpe.vocab.BOS)
+			ids = append([]int32{bpe.vocab.BOS}, ids...)
+		}
+
+		if bpe.vocab.AddEOS {
+			if ids[len(ids)-1] == bpe.vocab.EOS {
+				slog.Warn("adding eos token to prompt which already has it", "id", bpe.vocab.EOS)
+			}
+
+			slog.Debug("adding eos token to prompt", "id", bpe.vocab.EOS)
+			ids = append(ids, bpe.vocab.EOS)
+		}
+	}
+
 	return ids, nil
 }