Browse Source

add sentence piece tokenizer

Patrick Devine 2 months ago
parent
commit
8cf1ea4fd8
7 changed files with 263 additions and 9 deletions
  1. 9 0
      fs/ggml/ggml.go
  2. 1 0
      go.mod
  3. 2 0
      go.sum
  4. 1 0
      ml/backend.go
  5. 5 7
      model/models/gemma2/model.go
  6. 13 2
      model/process_text.go
  7. 232 0
      model/process_text_spm.go

+ 9 - 0
fs/ggml/ggml.go

@@ -120,6 +120,15 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
 	return s
 }
 
+func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
+	r := keyValue(kv, key, &array{})
+	s := make([]float32, r.size)
+	for i := range r.size {
+		s[i] = float32(r.values[i].(float32))
+	}
+	return s
+}
+
 func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T {
 	if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
 		key = kv.Architecture() + "." + key

+ 1 - 0
go.mod

@@ -18,6 +18,7 @@ require (
 	github.com/agnivade/levenshtein v1.1.1
 	github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
 	github.com/dlclark/regexp2 v1.11.4
+	github.com/emirpasic/gods v1.18.1
 	github.com/emirpasic/gods/v2 v2.0.0-alpha
 	github.com/google/go-cmp v0.6.0
 	github.com/mattn/go-runewidth v0.0.14

+ 2 - 0
go.sum

@@ -44,6 +44,8 @@ github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+
 github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
 github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
 github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
+github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
+github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
 github.com/emirpasic/gods/v2 v2.0.0-alpha h1:dwFlh8pBg1VMOXWGipNMRt8v96dKAIvBehtCt6OtunU=
 github.com/emirpasic/gods/v2 v2.0.0-alpha/go.mod h1:W0y4M2dtBB9U5z3YlghmpuUhiaZT2h6yoeE+C1sCp6A=
 github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=

+ 1 - 0
ml/backend.go

@@ -17,6 +17,7 @@ type Config interface {
 
 	Strings(string, ...[]string) []string
 	Uints(string, ...[]uint32) []uint32
+	Floats(string, ...[]float32) []float32
 }
 
 type Backend interface {

+ 5 - 7
model/models/gemma2/model.go

@@ -1,7 +1,6 @@
 package gemma2
 
 import (
-	"fmt"
 	"math"
 
 	"github.com/ollama/ollama/kvcache"
@@ -20,7 +19,7 @@ type Options struct {
 
 type Model struct {
 	model.Base
-	model.BytePairEncoding
+	model.SentencePieceModel
 
 	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
 	Layers         []Layer       `gguf:"blk"`
@@ -32,10 +31,11 @@ type Model struct {
 
 func New(c ml.Config) (model.Model, error) {
 	m := Model{
-		BytePairEncoding: model.NewBytePairEncoding(
+		SentencePieceModel: model.NewSentencePieceModel(
 			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
 			&model.Vocabulary{
 				Values: c.Strings("tokenizer.ggml.tokens"),
+				Scores: c.Floats("tokenizer.ggml.scores"),
 				Types:  c.Uints("tokenizer.ggml.token_type"),
 				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
 				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
@@ -55,7 +55,7 @@ func New(c ml.Config) (model.Model, error) {
 	}
 
 	slidingWindowLen := int32(c.Uint("attention.sliding_window"))
-	m.Cache = kvcache.NewWrapperCache(kvcache.NewCausalCache(m.Shift), kvcache.NewSWACache(slidingWindowLen, m.Shift))
+	m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
 
 	return &m, nil
 }
@@ -76,7 +76,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
 	q = q.RoPE(ctx, positionIDs, opts.RopeFactors, uint32(headDim), opts.ropeBase, opts.ropeScale)
 
 	// todo: this should be 1.0/math.Sqrt(float64(headDim)) for 27B models
-	q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
+	//q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
 
 	k := sa.Key.Forward(ctx, hiddenState)
 	k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
@@ -140,8 +140,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cach
 }
 
 func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
-	fmt.Printf("HELLO THERE!!\n")
-
 	inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
 	if err != nil {
 		return nil, err

+ 13 - 2
model/process_text.go

@@ -2,6 +2,7 @@ package model
 
 import (
 	"cmp"
+	"fmt"
 	"iter"
 	"log/slog"
 	"strings"
@@ -18,6 +19,15 @@ const (
 	SpecialEOS
 )
 
+const (
+	TOKEN_TYPE_NORMAL = iota + 1
+	TOKEN_TYPE_UNKNOWN
+	TOKEN_TYPE_CONTROL
+	TOKEN_TYPE_USER_DEFINED
+	TOKEN_TYPE_UNUSED
+	TOKEN_TYPE_BYTE
+)
+
 type TextProcessor interface {
 	Encode(string) ([]int32, error)
 	Decode([]int32) (string, error)
@@ -27,7 +37,7 @@ type TextProcessor interface {
 type Vocabulary struct {
 	Values []string
 	Types  []uint32
-	Scores []uint32
+	Scores []float32
 	Merges []string
 
 	BOS, EOS int32
@@ -75,7 +85,7 @@ func (v *Vocabulary) Decode(id int32) string {
 func (v *Vocabulary) SpecialVocabulary() []string {
 	v.specialOnce.Do(func() {
 		for i := range v.Values {
-			if v.Types[i] == 3 {
+			if v.Types[i] == TOKEN_TYPE_CONTROL {
 				v.special = append(v.special, v.Values[i])
 			}
 		}
@@ -171,6 +181,7 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
 			fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
 		}
 	}
+	fmt.Printf("frags = %#v\n", fragments)
 
 	var ids []int32
 	for _, frag := range fragments {

+ 232 - 0
model/process_text_spm.go

@@ -0,0 +1,232 @@
+package model
+
+import (
+	"fmt"
+	"iter"
+	"strings"
+	//"unicode/utf8"
+
+	"github.com/dlclark/regexp2"
+	queue "github.com/emirpasic/gods/queues/priorityqueue"
+)
+
+const spmWhitespaceSep = "▁"
+
+func replaceWhitespaceBySeperator(s string) string {
+	return strings.ReplaceAll(s, " ", spmWhitespaceSep)
+}
+
+type SentencePieceModel struct {
+	maxTokenLen int
+	pre         *regexp2.Regexp
+	vocab       *Vocabulary
+}
+
+func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
+	fmt.Printf("Tokens (%d): %5s %5s %5s ...\n", len(vocab.Values), vocab.Values[0], vocab.Values[1], vocab.Values[2])
+	fmt.Printf("Scores (%d): %0.3f %0.3f %0.3f ...\n", len(vocab.Scores), vocab.Scores[0], vocab.Scores[1], vocab.Scores[2])
+	fmt.Printf("Types  (%d): %5d %5d %5d ...\n", len(vocab.Types), vocab.Types[0], vocab.Types[1], vocab.Types[2])
+
+	counter := map[int]int{}
+	var maxTokenLen int
+
+	for cnt, _ := range vocab.Types {
+		switch vocab.Types[cnt] {
+		case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
+			maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
+			fallthrough
+		default:
+			counter[int(vocab.Types[cnt])] += 1
+		}
+	}
+
+	fmt.Printf("Normal: %d\n", counter[TOKEN_TYPE_NORMAL])
+	fmt.Printf("Unknown: %d\n", counter[TOKEN_TYPE_UNKNOWN])
+	fmt.Printf("Control: %d\n", counter[TOKEN_TYPE_CONTROL])
+	fmt.Printf("User Defined: %d\n", counter[TOKEN_TYPE_USER_DEFINED])
+	fmt.Printf("Unused: %d\n", counter[TOKEN_TYPE_UNUSED])
+	fmt.Printf("Byte: %d\n", counter[TOKEN_TYPE_BYTE])
+	fmt.Printf("Max token len: %d\n", maxTokenLen)
+
+	return SentencePieceModel{
+		maxTokenLen: maxTokenLen,
+		pre:         regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
+		vocab:       vocab,
+	}
+}
+
+func (spm SentencePieceModel) Is(id int32, special Special) bool {
+	return spm.vocab.Is(id, special)
+}
+
+func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
+	return func(yield func(string) bool) {
+		for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
+			if !yield(m.String()) {
+				break
+			}
+		}
+	}
+}
+
+func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
+	fragments := []fragment{{value: s}}
+	for _, special := range spm.vocab.SpecialVocabulary() {
+		// TODO: process special tokens concurrently
+		id := spm.vocab.Encode(special)
+		for i := 0; i < len(fragments); i++ {
+			frag := fragments[i]
+			if len(frag.ids) > 0 {
+				continue
+			}
+
+			var middle []fragment
+			switch i := strings.Index(frag.value, special); {
+			case i < 0:
+				middle = append(middle, frag)
+			case i > 0:
+				middle = append(middle, fragment{value: frag.value[:i]})
+				fallthrough
+			default:
+				middle = append(middle, fragment{value: special, ids: []int32{id}})
+				if rest := frag.value[i+len(special):]; rest != "" {
+					middle = append(middle, fragment{value: rest})
+				}
+			}
+
+			fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
+		}
+	}
+	fmt.Printf("frags = %#v\n", fragments)
+
+	var ids []int32
+	for _, frag := range fragments {
+		if len(frag.ids) > 0 {
+			ids = append(ids, frag.ids...)
+			continue
+		}
+
+		for split := range spm.split(frag.value) {
+			split = replaceWhitespaceBySeperator(split)
+
+			var sb strings.Builder
+			sb.Write([]byte(split))
+			if id := spm.vocab.Encode(sb.String()); id >= 0 {
+				ids = append(ids, id)
+				continue
+			}
+
+			runes := []rune(sb.String())
+			pq := queue.NewWith(func(a, b any) int {
+				priA := a.(*candidate)
+				priB := b.(*candidate)
+				if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
+					return 1
+				}
+				return -1
+			})
+
+			merges := make([]merge, len(runes))
+			for r := range runes {
+				merges[r] = merge{
+					p:     r - 1,
+					n:     r + 1,
+					runes: []rune{runes[r]},
+				}
+			}
+			fmt.Printf("remaining runes = %#v\n", runes)
+			fmt.Printf("merges = %#v\n", merges)
+
+			pairwise := func(a, b int) *candidate {
+				if a < 0 || b >= len(runes) {
+					return nil
+				}
+
+				left, right := string(merges[a].runes), string(merges[b].runes)
+				fmt.Printf("looking up '%s'\n", left+right)
+				if id := spm.vocab.Encode(left + right); id >= 0 {
+					return &candidate{
+						a:      a,
+						b:      b,
+						length: len(left + " " + right),
+						score:  spm.vocab.Scores[id],
+					}
+				}
+				return nil
+			}
+
+			for i := range len(runes) - 1 {
+				if pair := pairwise(i, i+1); pair != nil {
+					pq.Enqueue(pair)
+				}
+			}
+
+			pqv := pq.Values()
+			for _, v := range pqv {
+				e := v.(*candidate)
+				fmt.Printf("candidate = %#v\n", e)
+			}
+
+			for !pq.Empty() {
+				v, _ := pq.Dequeue()
+				pair := v.(*candidate)
+				left, right := merges[pair.a], merges[pair.b]
+
+				if len(left.runes) == 0 || len(right.runes) == 0 {
+					continue
+				}
+
+				merges[pair.a].runes = append(left.runes, right.runes...)
+				merges[pair.b].runes = nil
+				merges[pair.a].n = right.n
+				if right.n < len(merges) {
+					merges[right.n].p = pair.a
+				}
+
+				if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
+					pq.Enqueue(pair)
+				}
+
+				if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
+					pq.Enqueue(pair)
+				}
+			}
+
+			fmt.Printf("merges = %#v\n", merges)
+
+			for _, merge := range merges {
+				if len(merge.runes) > 0 {
+					if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
+						ids = append(ids, id)
+					} else {
+						fmt.Printf("!!! missing token for '%s'\n", string(merge.runes))
+					}
+				}
+			}
+		}
+
+	}
+	fmt.Printf("tokens = %#v\n", ids)
+
+	return ids, nil
+}
+
+type candidate struct {
+	a, b   int
+	score  float32
+	length int
+}
+
+func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
+	var sb strings.Builder
+	for _, id := range ids {
+		for _, r := range spm.vocab.Decode(id) {
+			// todo - do we need to introspect the chars here?
+			if err := sb.WriteByte(byte(r)); err != nil {
+				return "", err
+			}
+		}
+	}
+
+	return sb.String(), nil
+}