tokenizer_spm.go 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. package convert
  2. import (
  3. "cmp"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io/fs"
  8. "os"
  9. "slices"
  10. "google.golang.org/protobuf/proto"
  11. "github.com/ollama/ollama/convert/sentencepiece"
  12. )
  13. func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
  14. bts, err := fs.ReadFile(fsys, "tokenizer.model")
  15. if err != nil {
  16. return nil, err
  17. }
  18. var spm sentencepiece.ModelProto
  19. if err := proto.Unmarshal(bts, &spm); err != nil {
  20. return nil, err
  21. }
  22. v := Vocabulary{Model: "llama"}
  23. for _, piece := range spm.GetPieces() {
  24. v.Tokens = append(v.Tokens, piece.GetPiece())
  25. v.Scores = append(v.Scores, piece.GetScore())
  26. switch t := piece.GetType(); t {
  27. case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
  28. sentencepiece.ModelProto_SentencePiece_CONTROL,
  29. sentencepiece.ModelProto_SentencePiece_UNUSED,
  30. sentencepiece.ModelProto_SentencePiece_BYTE:
  31. v.Types = append(v.Types, int32(t))
  32. default:
  33. v.Types = append(v.Types, int32(sentencepiece.ModelProto_SentencePiece_NORMAL))
  34. }
  35. }
  36. f, err := fsys.Open("added_tokens.json")
  37. if errors.Is(err, os.ErrNotExist) {
  38. return &v, nil
  39. } else if err != nil {
  40. return nil, err
  41. }
  42. defer f.Close()
  43. var atm map[string]int
  44. if err := json.NewDecoder(f).Decode(&atm); err != nil {
  45. return nil, err
  46. }
  47. type t struct {
  48. id int
  49. content string
  50. }
  51. var ts []t
  52. for content, id := range atm {
  53. ts = append(ts, t{id, content})
  54. }
  55. slices.SortFunc(ts, func(i, j t) int {
  56. return cmp.Compare(i.id, j.id)
  57. })
  58. n := len(v.Tokens)
  59. for i, t := range ts {
  60. if t.id != i+n {
  61. return nil, fmt.Errorf("invalid token id: %d", t.id)
  62. }
  63. v.Tokens = append(v.Tokens, t.content)
  64. v.Scores = append(v.Scores, -1000.0)
  65. v.Types = append(v.Types, tokenTypeUserDefined)
  66. }
  67. return &v, nil
  68. }