|
@@ -0,0 +1,232 @@
|
|
|
+package model
|
|
|
+
|
|
|
+import (
|
|
|
+ "fmt"
|
|
|
+ "iter"
|
|
|
+ "strings"
|
|
|
+ //"unicode/utf8"
|
|
|
+
|
|
|
+ "github.com/dlclark/regexp2"
|
|
|
+ queue "github.com/emirpasic/gods/queues/priorityqueue"
|
|
|
+)
|
|
|
+
|
|
|
+const spmWhitespaceSep = "▁"
|
|
|
+
|
|
|
+func replaceWhitespaceBySeperator(s string) string {
|
|
|
+ return strings.ReplaceAll(s, " ", spmWhitespaceSep)
|
|
|
+}
|
|
|
+
|
|
|
+type SentencePieceModel struct {
|
|
|
+ maxTokenLen int
|
|
|
+ pre *regexp2.Regexp
|
|
|
+ vocab *Vocabulary
|
|
|
+}
|
|
|
+
|
|
|
+func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
|
|
+ fmt.Printf("Tokens (%d): %5s %5s %5s ...\n", len(vocab.Values), vocab.Values[0], vocab.Values[1], vocab.Values[2])
|
|
|
+ fmt.Printf("Scores (%d): %0.3f %0.3f %0.3f ...\n", len(vocab.Scores), vocab.Scores[0], vocab.Scores[1], vocab.Scores[2])
|
|
|
+ fmt.Printf("Types (%d): %5d %5d %5d ...\n", len(vocab.Types), vocab.Types[0], vocab.Types[1], vocab.Types[2])
|
|
|
+
|
|
|
+ counter := map[int]int{}
|
|
|
+ var maxTokenLen int
|
|
|
+
|
|
|
+ for cnt, _ := range vocab.Types {
|
|
|
+ switch vocab.Types[cnt] {
|
|
|
+ case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
|
|
|
+ maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
|
|
|
+ fallthrough
|
|
|
+ default:
|
|
|
+ counter[int(vocab.Types[cnt])] += 1
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ fmt.Printf("Normal: %d\n", counter[TOKEN_TYPE_NORMAL])
|
|
|
+ fmt.Printf("Unknown: %d\n", counter[TOKEN_TYPE_UNKNOWN])
|
|
|
+ fmt.Printf("Control: %d\n", counter[TOKEN_TYPE_CONTROL])
|
|
|
+ fmt.Printf("User Defined: %d\n", counter[TOKEN_TYPE_USER_DEFINED])
|
|
|
+ fmt.Printf("Unused: %d\n", counter[TOKEN_TYPE_UNUSED])
|
|
|
+ fmt.Printf("Byte: %d\n", counter[TOKEN_TYPE_BYTE])
|
|
|
+ fmt.Printf("Max token len: %d\n", maxTokenLen)
|
|
|
+
|
|
|
+ return SentencePieceModel{
|
|
|
+ maxTokenLen: maxTokenLen,
|
|
|
+ pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
|
|
|
+ vocab: vocab,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
|
|
+ return spm.vocab.Is(id, special)
|
|
|
+}
|
|
|
+
|
|
|
+func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
|
|
|
+ return func(yield func(string) bool) {
|
|
|
+ for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
|
|
|
+ if !yield(m.String()) {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
|
|
|
+ fragments := []fragment{{value: s}}
|
|
|
+ for _, special := range spm.vocab.SpecialVocabulary() {
|
|
|
+ // TODO: process special tokens concurrently
|
|
|
+ id := spm.vocab.Encode(special)
|
|
|
+ for i := 0; i < len(fragments); i++ {
|
|
|
+ frag := fragments[i]
|
|
|
+ if len(frag.ids) > 0 {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ var middle []fragment
|
|
|
+ switch i := strings.Index(frag.value, special); {
|
|
|
+ case i < 0:
|
|
|
+ middle = append(middle, frag)
|
|
|
+ case i > 0:
|
|
|
+ middle = append(middle, fragment{value: frag.value[:i]})
|
|
|
+ fallthrough
|
|
|
+ default:
|
|
|
+ middle = append(middle, fragment{value: special, ids: []int32{id}})
|
|
|
+ if rest := frag.value[i+len(special):]; rest != "" {
|
|
|
+ middle = append(middle, fragment{value: rest})
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ fmt.Printf("frags = %#v\n", fragments)
|
|
|
+
|
|
|
+ var ids []int32
|
|
|
+ for _, frag := range fragments {
|
|
|
+ if len(frag.ids) > 0 {
|
|
|
+ ids = append(ids, frag.ids...)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ for split := range spm.split(frag.value) {
|
|
|
+ split = replaceWhitespaceBySeperator(split)
|
|
|
+
|
|
|
+ var sb strings.Builder
|
|
|
+ sb.Write([]byte(split))
|
|
|
+ if id := spm.vocab.Encode(sb.String()); id >= 0 {
|
|
|
+ ids = append(ids, id)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ runes := []rune(sb.String())
|
|
|
+ pq := queue.NewWith(func(a, b any) int {
|
|
|
+ priA := a.(*candidate)
|
|
|
+ priB := b.(*candidate)
|
|
|
+ if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
|
|
|
+ return 1
|
|
|
+ }
|
|
|
+ return -1
|
|
|
+ })
|
|
|
+
|
|
|
+ merges := make([]merge, len(runes))
|
|
|
+ for r := range runes {
|
|
|
+ merges[r] = merge{
|
|
|
+ p: r - 1,
|
|
|
+ n: r + 1,
|
|
|
+ runes: []rune{runes[r]},
|
|
|
+ }
|
|
|
+ }
|
|
|
+ fmt.Printf("remaining runes = %#v\n", runes)
|
|
|
+ fmt.Printf("merges = %#v\n", merges)
|
|
|
+
|
|
|
+ pairwise := func(a, b int) *candidate {
|
|
|
+ if a < 0 || b >= len(runes) {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ left, right := string(merges[a].runes), string(merges[b].runes)
|
|
|
+ fmt.Printf("looking up '%s'\n", left+right)
|
|
|
+ if id := spm.vocab.Encode(left + right); id >= 0 {
|
|
|
+ return &candidate{
|
|
|
+ a: a,
|
|
|
+ b: b,
|
|
|
+ length: len(left + " " + right),
|
|
|
+ score: spm.vocab.Scores[id],
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ for i := range len(runes) - 1 {
|
|
|
+ if pair := pairwise(i, i+1); pair != nil {
|
|
|
+ pq.Enqueue(pair)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ pqv := pq.Values()
|
|
|
+ for _, v := range pqv {
|
|
|
+ e := v.(*candidate)
|
|
|
+ fmt.Printf("candidate = %#v\n", e)
|
|
|
+ }
|
|
|
+
|
|
|
+ for !pq.Empty() {
|
|
|
+ v, _ := pq.Dequeue()
|
|
|
+ pair := v.(*candidate)
|
|
|
+ left, right := merges[pair.a], merges[pair.b]
|
|
|
+
|
|
|
+ if len(left.runes) == 0 || len(right.runes) == 0 {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ merges[pair.a].runes = append(left.runes, right.runes...)
|
|
|
+ merges[pair.b].runes = nil
|
|
|
+ merges[pair.a].n = right.n
|
|
|
+ if right.n < len(merges) {
|
|
|
+ merges[right.n].p = pair.a
|
|
|
+ }
|
|
|
+
|
|
|
+ if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
|
|
+ pq.Enqueue(pair)
|
|
|
+ }
|
|
|
+
|
|
|
+ if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
|
|
+ pq.Enqueue(pair)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ fmt.Printf("merges = %#v\n", merges)
|
|
|
+
|
|
|
+ for _, merge := range merges {
|
|
|
+ if len(merge.runes) > 0 {
|
|
|
+ if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
|
|
|
+ ids = append(ids, id)
|
|
|
+ } else {
|
|
|
+ fmt.Printf("!!! missing token for '%s'\n", string(merge.runes))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+ fmt.Printf("tokens = %#v\n", ids)
|
|
|
+
|
|
|
+ return ids, nil
|
|
|
+}
|
|
|
+
|
|
|
+type candidate struct {
|
|
|
+ a, b int
|
|
|
+ score float32
|
|
|
+ length int
|
|
|
+}
|
|
|
+
|
|
|
+func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
|
|
+ var sb strings.Builder
|
|
|
+ for _, id := range ids {
|
|
|
+ for _, r := range spm.vocab.Decode(id) {
|
|
|
+ // todo - do we need to introspect the chars here?
|
|
|
+ if err := sb.WriteByte(byte(r)); err != nil {
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return sb.String(), nil
|
|
|
+}
|