process_text.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. package model
  2. import (
  3. "cmp"
  4. "iter"
  5. "log/slog"
  6. "strings"
  7. "sync"
  8. "github.com/dlclark/regexp2"
  9. heap "github.com/emirpasic/gods/v2/trees/binaryheap"
  10. )
  11. type Special int32
  12. const (
  13. SpecialBOS Special = iota
  14. SpecialEOS
  15. )
  16. const (
  17. TOKEN_TYPE_NORMAL = iota + 1
  18. TOKEN_TYPE_UNKNOWN
  19. TOKEN_TYPE_CONTROL
  20. TOKEN_TYPE_USER_DEFINED
  21. TOKEN_TYPE_UNUSED
  22. TOKEN_TYPE_BYTE
  23. )
  24. type TextProcessor interface {
  25. Encode(string) ([]int32, error)
  26. Decode([]int32) (string, error)
  27. Is(int32, Special) bool
  28. }
  29. type Vocabulary struct {
  30. Values []string
  31. Types []uint32
  32. Scores []float32
  33. Merges []string
  34. BOS, EOS int32
  35. specialOnce sync.Once
  36. special []string
  37. valuesOnce sync.Once
  38. values map[string]int32
  39. mergeOnce sync.Once
  40. merge map[string]int32
  41. }
  42. func (v *Vocabulary) Is(id int32, special Special) bool {
  43. switch special {
  44. case SpecialBOS:
  45. return id == v.BOS
  46. case SpecialEOS:
  47. return id == v.EOS
  48. default:
  49. return false
  50. }
  51. }
  52. func (v *Vocabulary) Encode(s string) int32 {
  53. v.valuesOnce.Do(func() {
  54. v.values = make(map[string]int32, len(v.Values))
  55. for i, value := range v.Values {
  56. v.values[value] = int32(i)
  57. }
  58. })
  59. if id, ok := v.values[s]; ok {
  60. return id
  61. }
  62. return -1
  63. }
  64. func (v *Vocabulary) Decode(id int32) string {
  65. return v.Values[id]
  66. }
  67. func (v *Vocabulary) SpecialVocabulary() []string {
  68. v.specialOnce.Do(func() {
  69. for i := range v.Values {
  70. if v.Types[i] == TOKEN_TYPE_CONTROL {
  71. v.special = append(v.special, v.Values[i])
  72. }
  73. }
  74. })
  75. return v.special
  76. }
  77. func (v *Vocabulary) Merge(left, right string) int {
  78. v.mergeOnce.Do(func() {
  79. v.merge = make(map[string]int32, len(v.Merges))
  80. for i, merge := range v.Merges {
  81. v.merge[merge] = int32(i)
  82. }
  83. })
  84. if id, ok := v.merge[left+" "+right]; ok {
  85. return int(id)
  86. }
  87. return -1
  88. }
  89. type BytePairEncoding struct {
  90. pre *regexp2.Regexp
  91. vocab *Vocabulary
  92. }
  93. func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
  94. return BytePairEncoding{
  95. pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
  96. vocab: vocab,
  97. }
  98. }
  99. func (bpe BytePairEncoding) Is(id int32, special Special) bool {
  100. return bpe.vocab.Is(id, special)
  101. }
  102. func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
  103. return func(yield func(string) bool) {
  104. for m, _ := bpe.pre.FindStringMatch(s); m != nil; m, _ = bpe.pre.FindNextMatch(m) {
  105. if !yield(m.String()) {
  106. break
  107. }
  108. }
  109. }
  110. }
  111. // fragment is a string fragment and their corresponding token IDs
  112. type fragment struct {
  113. value string
  114. ids []int32
  115. }
  116. // pair is a pair of runes and its rank
  117. type pair struct {
  118. a, b int
  119. rank int
  120. value string
  121. }
  122. type merge struct {
  123. p, n int
  124. runes []rune
  125. }
  126. func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
  127. fragments := []fragment{{value: s}}
  128. for _, special := range bpe.vocab.SpecialVocabulary() {
  129. // TODO: process special tokens concurrently
  130. id := bpe.vocab.Encode(special)
  131. for i := 0; i < len(fragments); i++ {
  132. frag := fragments[i]
  133. if len(frag.ids) > 0 {
  134. continue
  135. }
  136. var middle []fragment
  137. switch i := strings.Index(frag.value, special); {
  138. case i < 0:
  139. middle = append(middle, frag)
  140. case i > 0:
  141. middle = append(middle, fragment{value: frag.value[:i]})
  142. fallthrough
  143. default:
  144. middle = append(middle, fragment{value: special, ids: []int32{id}})
  145. if rest := frag.value[i+len(special):]; rest != "" {
  146. middle = append(middle, fragment{value: rest})
  147. }
  148. }
  149. fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
  150. }
  151. }
  152. var ids []int32
  153. for _, frag := range fragments {
  154. if len(frag.ids) > 0 {
  155. ids = append(ids, frag.ids...)
  156. slog.Debug("encoded", "text", frag.value, "ids", frag.ids, "special", true)
  157. continue
  158. }
  159. for split := range bpe.split(frag.value) {
  160. // TODO: process splits concurrently
  161. var sb strings.Builder
  162. for _, b := range []byte(split) {
  163. r := rune(b)
  164. switch {
  165. case r == 0x00ad:
  166. r = 0x0143
  167. case r <= 0x0020:
  168. r = r + 0x0100
  169. case r >= 0x007e && r <= 0x00a0:
  170. r = r + 0x00a2
  171. }
  172. sb.WriteRune(r)
  173. }
  174. // short circuit if the fragment is in the vocabulary
  175. if id := bpe.vocab.Encode(sb.String()); id >= 0 {
  176. ids = append(ids, id)
  177. slog.Debug("encoded", "text", sb.String(), "ids", []int32{id})
  178. continue
  179. }
  180. runes := []rune(sb.String())
  181. merges := make([]merge, len(runes))
  182. for r := range runes {
  183. merges[r] = merge{
  184. p: r - 1,
  185. n: r + 1,
  186. runes: []rune{runes[r]},
  187. }
  188. }
  189. pairwise := func(a, b int) *pair {
  190. if a < 0 || b >= len(runes) {
  191. return nil
  192. }
  193. left, right := string(merges[a].runes), string(merges[b].runes)
  194. rank := bpe.vocab.Merge(left, right)
  195. if rank < 0 {
  196. return nil
  197. }
  198. return &pair{
  199. a: a,
  200. b: b,
  201. rank: rank,
  202. value: left + right,
  203. }
  204. }
  205. pairs := heap.NewWith(func(i, j *pair) int {
  206. return cmp.Compare(i.rank, j.rank)
  207. })
  208. for i := range len(runes) - 1 {
  209. if pair := pairwise(i, i+1); pair != nil {
  210. pairs.Push(pair)
  211. }
  212. }
  213. for !pairs.Empty() {
  214. pair, _ := pairs.Pop()
  215. left, right := merges[pair.a], merges[pair.b]
  216. if len(left.runes) == 0 || len(right.runes) == 0 ||
  217. string(left.runes)+string(right.runes) != pair.value {
  218. continue
  219. }
  220. merges[pair.a].runes = append(left.runes, right.runes...)
  221. merges[pair.b].runes = nil
  222. merges[pair.a].n = right.n
  223. if right.n < len(merges) {
  224. merges[right.n].p = pair.a
  225. }
  226. if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
  227. pairs.Push(pair)
  228. }
  229. if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
  230. pairs.Push(pair)
  231. }
  232. }
  233. for _, merge := range merges {
  234. if len(merge.runes) > 0 {
  235. // TODO: handle the edge case where the rune isn't in the vocabulary
  236. if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
  237. ids = append(ids, id)
  238. slog.Debug("encoded", "text", string(merge.runes), "ids", []int32{id})
  239. }
  240. }
  241. }
  242. }
  243. }
  244. return ids, nil
  245. }
  246. func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
  247. var sb strings.Builder
  248. for _, id := range ids {
  249. for _, r := range bpe.vocab.Decode(id) {
  250. switch {
  251. case r == 0x0100:
  252. // this produces 0x00 aka NULL
  253. continue
  254. case r == 0x0143:
  255. r = 0x00ad
  256. case r > 0x0100 && r <= 0x0120:
  257. r = r - 0x0100
  258. case r > 0x0120 && r <= 0x0142:
  259. r = r - 0x00a2
  260. }
  261. // NOTE: not using WriteRune here because it writes the UTF-8
  262. // encoding of the rune which is _not_ what we want
  263. if err := sb.WriteByte(byte(r)); err != nil {
  264. return "", err
  265. }
  266. }
  267. }
  268. slog.Debug("decoded", "ids", ids, "text", sb.String())
  269. return sb.String(), nil
  270. }