123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- package convert
- import (
- "cmp"
- "encoding/json"
- "errors"
- "fmt"
- "io/fs"
- "log/slog"
- "os"
- "reflect"
- "slices"
- "google.golang.org/protobuf/proto"
- "github.com/ollama/ollama/convert/sentencepiece"
- )
- func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
- slog.Debug("using spm vocabulary")
- ast, err := parseAdditionalSpecialTokens(fsys)
- if err != nil {
- return nil, err
- }
- bts, err := fs.ReadFile(fsys, "tokenizer.model")
- if err != nil {
- return nil, err
- }
- var spm sentencepiece.ModelProto
- if err := proto.Unmarshal(bts, &spm); err != nil {
- return nil, err
- }
- v := Vocabulary{Model: "llama"}
- for _, piece := range spm.GetPieces() {
- v.Tokens = append(v.Tokens, piece.GetPiece())
- v.Scores = append(v.Scores, piece.GetScore())
- switch t := piece.GetType(); t {
- case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
- sentencepiece.ModelProto_SentencePiece_CONTROL,
- sentencepiece.ModelProto_SentencePiece_UNUSED,
- sentencepiece.ModelProto_SentencePiece_BYTE:
- v.Types = append(v.Types, int32(t))
- default:
- tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
- for _, t := range ast {
- if t.Content == piece.GetPiece() {
- tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
- break
- }
- }
- v.Types = append(v.Types, tt)
- }
- }
- f, err := fsys.Open("added_tokens.json")
- if errors.Is(err, os.ErrNotExist) {
- return &v, nil
- } else if err != nil {
- return nil, err
- }
- defer f.Close()
- var atm map[string]int
- if err := json.NewDecoder(f).Decode(&atm); err != nil {
- return nil, err
- }
- type t struct {
- id int
- content string
- }
- var ts []t
- for content, id := range atm {
- ts = append(ts, t{id, content})
- }
- slices.SortFunc(ts, func(i, j t) int {
- return cmp.Compare(i.id, j.id)
- })
- for _, t := range ts {
- if t.id < len(v.Tokens) {
- if v.Tokens[t.id] == t.content {
- slog.Warn("tokenizer", "duplicate token", t.content, "id", t.id)
- continue
- }
- return nil, fmt.Errorf("token mismatch: %s != %s at pos [%d]", t.content, v.Tokens[t.id], t.id)
- }
- if t.id != len(v.Tokens) {
- return nil, fmt.Errorf("invalid token id: [%d] as pos [%d]", t.id, len(v.Tokens))
- }
- v.Tokens = append(v.Tokens, t.content)
- v.Scores = append(v.Scores, -1000.0)
- v.Types = append(v.Types, tokenTypeUserDefined)
- }
- return &v, nil
- }
- type specialToken struct {
- Content string `json:"content"`
- Lstrip bool `json:"lstrip"`
- Normalized bool `json:"normalized"`
- Rstrip bool `json:"rstrip"`
- SingleWord bool `json:"single_word"`
- }
- func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
- f, err := fsys.Open("special_tokens_map.json")
- if errors.Is(err, os.ErrNotExist) {
- return nil, nil
- } else if err != nil {
- return nil, err
- }
- defer f.Close()
- var m struct {
- AdditionalSpecialTokens any `json:"additional_special_tokens"`
- }
- if err := json.NewDecoder(f).Decode(&m); err != nil {
- return nil, err
- }
- var ast []specialToken
- switch st := m.AdditionalSpecialTokens.(type) {
- case []string:
- for _, s := range st {
- ast = append(ast, specialToken{Content: s})
- }
- case []any:
- for _, s := range st {
- // marshal and unmarshal the object to get the special token
- tMap := s.(map[string]any)
- data, err := json.Marshal(tMap)
- if err != nil {
- return nil, err
- }
- var token specialToken
- err = json.Unmarshal(data, &token)
- if err != nil {
- return nil, err
- }
- ast = append(ast, token)
- }
- default:
- slog.Warn("special token", "unknown token", reflect.TypeOf(st))
- }
- slog.Debug("spm tokenizer", "additional tokens", ast)
- return ast, nil
- }
|