process_text.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. package model
  2. import (
  3. "cmp"
  4. "iter"
  5. "log/slog"
  6. "slices"
  7. "strings"
  8. "sync"
  9. "github.com/dlclark/regexp2"
  10. heap "github.com/emirpasic/gods/v2/trees/binaryheap"
  11. )
  12. type Special int32
  13. const (
  14. SpecialBOS Special = iota
  15. SpecialEOS
  16. )
  17. const (
  18. TOKEN_TYPE_NORMAL = iota + 1
  19. TOKEN_TYPE_UNKNOWN
  20. TOKEN_TYPE_CONTROL
  21. TOKEN_TYPE_USER_DEFINED
  22. TOKEN_TYPE_UNUSED
  23. TOKEN_TYPE_BYTE
  24. )
  25. type TextProcessor interface {
  26. Encode(s string, addSpecial bool) ([]int32, error)
  27. Decode([]int32) (string, error)
  28. Is(int32, Special) bool
  29. Vocab() *Vocabulary
  30. }
  31. type Vocabulary struct {
  32. Values []string
  33. Types []uint32
  34. Scores []float32
  35. Merges []string
  36. BOS, EOS, EOT int32
  37. AddBOS, AddEOS, AddEOT bool
  38. specialOnce sync.Once
  39. special []string
  40. valuesOnce sync.Once
  41. values map[string]int32
  42. mergeOnce sync.Once
  43. merge map[string]int32
  44. }
  45. func (v *Vocabulary) Is(id int32, special Special) bool {
  46. switch special {
  47. case SpecialBOS:
  48. return id == v.BOS
  49. case SpecialEOS:
  50. return id == v.EOS || id == v.EOT
  51. default:
  52. return false
  53. }
  54. }
  55. func (v *Vocabulary) Encode(s string) int32 {
  56. v.valuesOnce.Do(func() {
  57. v.values = make(map[string]int32, len(v.Values))
  58. for i, value := range v.Values {
  59. v.values[value] = int32(i)
  60. }
  61. })
  62. if id, ok := v.values[s]; ok {
  63. return id
  64. }
  65. return -1
  66. }
  67. func (v *Vocabulary) Decode(id int32) string {
  68. return v.Values[id]
  69. }
  70. func (v *Vocabulary) SpecialVocabulary() []string {
  71. v.specialOnce.Do(func() {
  72. for i := range v.Values {
  73. if slices.Contains([]int{105, 106}, i) {
  74. v.special = append(v.special, v.Values[i])
  75. } else if v.Types[i] == TOKEN_TYPE_CONTROL {
  76. v.special = append(v.special, v.Values[i])
  77. }
  78. }
  79. })
  80. return v.special
  81. }
  82. func (v *Vocabulary) Merge(left, right string) int {
  83. v.mergeOnce.Do(func() {
  84. v.merge = make(map[string]int32, len(v.Merges))
  85. for i, merge := range v.Merges {
  86. v.merge[merge] = int32(i)
  87. }
  88. })
  89. if id, ok := v.merge[left+" "+right]; ok {
  90. return int(id)
  91. }
  92. return -1
  93. }
  94. type BytePairEncoding struct {
  95. pre *regexp2.Regexp
  96. vocab *Vocabulary
  97. }
  98. func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
  99. return BytePairEncoding{
  100. pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
  101. vocab: vocab,
  102. }
  103. }
  104. func (bpe BytePairEncoding) Is(id int32, special Special) bool {
  105. return bpe.vocab.Is(id, special)
  106. }
  107. func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
  108. return func(yield func(string) bool) {
  109. for m, _ := bpe.pre.FindStringMatch(s); m != nil; m, _ = bpe.pre.FindNextMatch(m) {
  110. if !yield(m.String()) {
  111. break
  112. }
  113. }
  114. }
  115. }
  116. // fragment is a string fragment and their corresponding token IDs
  117. type fragment struct {
  118. value string
  119. ids []int32
  120. }
  121. // pair is a pair of runes and its rank
  122. type pair struct {
  123. a, b int
  124. rank int
  125. value string
  126. }
  127. type merge struct {
  128. p, n int
  129. runes []rune
  130. }
  131. func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
  132. fragments := []fragment{{value: s}}
  133. for _, special := range bpe.vocab.SpecialVocabulary() {
  134. // TODO: process special tokens concurrently
  135. id := bpe.vocab.Encode(special)
  136. for i := 0; i < len(fragments); i++ {
  137. frag := fragments[i]
  138. if len(frag.ids) > 0 {
  139. continue
  140. }
  141. var middle []fragment
  142. switch i := strings.Index(frag.value, special); {
  143. case i < 0:
  144. middle = append(middle, frag)
  145. case i > 0:
  146. middle = append(middle, fragment{value: frag.value[:i]})
  147. fallthrough
  148. default:
  149. middle = append(middle, fragment{value: special, ids: []int32{id}})
  150. if rest := frag.value[i+len(special):]; rest != "" {
  151. middle = append(middle, fragment{value: rest})
  152. }
  153. }
  154. fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
  155. }
  156. }
  157. var ids []int32
  158. for _, frag := range fragments {
  159. if len(frag.ids) > 0 {
  160. ids = append(ids, frag.ids...)
  161. continue
  162. }
  163. for split := range bpe.split(frag.value) {
  164. // TODO: process splits concurrently
  165. var sb strings.Builder
  166. for _, b := range []byte(split) {
  167. r := rune(b)
  168. switch {
  169. case r == 0x00ad:
  170. r = 0x0143
  171. case r <= 0x0020:
  172. r = r + 0x0100
  173. case r >= 0x007e && r <= 0x00a0:
  174. r = r + 0x00a2
  175. }
  176. sb.WriteRune(r)
  177. }
  178. // short circuit if the fragment is in the vocabulary
  179. if id := bpe.vocab.Encode(sb.String()); id >= 0 {
  180. ids = append(ids, id)
  181. continue
  182. }
  183. runes := []rune(sb.String())
  184. merges := make([]merge, len(runes))
  185. for r := range runes {
  186. merges[r] = merge{
  187. p: r - 1,
  188. n: r + 1,
  189. runes: []rune{runes[r]},
  190. }
  191. }
  192. pairwise := func(a, b int) *pair {
  193. if a < 0 || b >= len(runes) {
  194. return nil
  195. }
  196. left, right := string(merges[a].runes), string(merges[b].runes)
  197. rank := bpe.vocab.Merge(left, right)
  198. if rank < 0 {
  199. return nil
  200. }
  201. return &pair{
  202. a: a,
  203. b: b,
  204. rank: rank,
  205. value: left + right,
  206. }
  207. }
  208. pairs := heap.NewWith(func(i, j *pair) int {
  209. return cmp.Compare(i.rank, j.rank)
  210. })
  211. for i := range len(runes) - 1 {
  212. if pair := pairwise(i, i+1); pair != nil {
  213. pairs.Push(pair)
  214. }
  215. }
  216. for !pairs.Empty() {
  217. pair, _ := pairs.Pop()
  218. left, right := merges[pair.a], merges[pair.b]
  219. if len(left.runes) == 0 || len(right.runes) == 0 ||
  220. string(left.runes)+string(right.runes) != pair.value {
  221. continue
  222. }
  223. merges[pair.a].runes = append(left.runes, right.runes...)
  224. merges[pair.b].runes = nil
  225. merges[pair.a].n = right.n
  226. if right.n < len(merges) {
  227. merges[right.n].p = pair.a
  228. }
  229. if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
  230. pairs.Push(pair)
  231. }
  232. if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
  233. pairs.Push(pair)
  234. }
  235. }
  236. for _, merge := range merges {
  237. if len(merge.runes) > 0 {
  238. // TODO: handle the edge case where the rune isn't in the vocabulary
  239. if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
  240. ids = append(ids, id)
  241. }
  242. }
  243. }
  244. }
  245. }
  246. if addSpecial && len(ids) > 0 {
  247. if bpe.vocab.AddBOS {
  248. if ids[0] == bpe.vocab.BOS {
  249. slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
  250. }
  251. slog.Debug("adding bos token to prompt", "id", bpe.vocab.BOS)
  252. ids = append([]int32{bpe.vocab.BOS}, ids...)
  253. }
  254. if bpe.vocab.AddEOS {
  255. if ids[len(ids)-1] == bpe.vocab.EOS {
  256. slog.Warn("adding eos token to prompt which already has it", "id", bpe.vocab.EOS)
  257. }
  258. slog.Debug("adding eos token to prompt", "id", bpe.vocab.EOS)
  259. ids = append(ids, bpe.vocab.EOS)
  260. }
  261. }
  262. return ids, nil
  263. }
  264. func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
  265. var sb strings.Builder
  266. for _, id := range ids {
  267. for _, r := range bpe.vocab.Decode(id) {
  268. switch {
  269. case r == 0x0100:
  270. // this produces 0x00 aka NULL
  271. continue
  272. case r == 0x0143:
  273. r = 0x00ad
  274. case r > 0x0100 && r <= 0x0120:
  275. r = r - 0x0100
  276. case r > 0x0120 && r <= 0x0142:
  277. r = r - 0x00a2
  278. }
  279. // NOTE: not using WriteRune here because it writes the UTF-8
  280. // encoding of the rune which is _not_ what we want
  281. if err := sb.WriteByte(byte(r)); err != nil {
  282. return "", err
  283. }
  284. }
  285. }
  286. return sb.String(), nil
  287. }