123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- package convert
- import (
- "io"
- "io/fs"
- "os"
- "path/filepath"
- "strings"
- "testing"
- "github.com/google/go-cmp/cmp"
- )
- func createTokenizerFS(t *testing.T, dir string, files map[string]io.Reader) fs.FS {
- t.Helper()
- for k, v := range files {
- if err := func() error {
- f, err := os.Create(filepath.Join(dir, k))
- if err != nil {
- return err
- }
- defer f.Close()
- if _, err := io.Copy(f, v); err != nil {
- return err
- }
- return nil
- }(); err != nil {
- t.Fatalf("unexpected error: %v", err)
- }
- }
- return os.DirFS(dir)
- }
- func TestParseTokenizer(t *testing.T) {
- cases := []struct {
- name string
- fsys fs.FS
- specialTokenTypes []string
- want *Tokenizer
- }{
- {
- name: "string chat template",
- fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
- "tokenizer.json": strings.NewReader(`{}`),
- "tokenizer_config.json": strings.NewReader(`{
- "chat_template": "<default template>"
- }`),
- }),
- want: &Tokenizer{
- Vocabulary: &Vocabulary{Model: "gpt2"},
- Pre: "default",
- Template: "<default template>",
- },
- },
- {
- name: "list chat template",
- fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
- "tokenizer.json": strings.NewReader(`{}`),
- "tokenizer_config.json": strings.NewReader(`{
- "chat_template": [
- {
- "name": "default",
- "template": "<default template>"
- },
- {
- "name": "tools",
- "template": "<tools template>"
- }
- ]
- }`),
- }),
- want: &Tokenizer{
- Vocabulary: &Vocabulary{Model: "gpt2"},
- Pre: "default",
- Template: "<default template>",
- },
- },
- }
- for _, tt := range cases {
- t.Run(tt.name, func(t *testing.T) {
- tokenizer, err := parseTokenizer(tt.fsys, tt.specialTokenTypes)
- if err != nil {
- t.Fatalf("unexpected error: %v", err)
- }
- if diff := cmp.Diff(tt.want, tokenizer); diff != "" {
- t.Errorf("unexpected tokenizer (-want +got):\n%s", diff)
- }
- })
- }
- }
|