Browse Source

rename templates to template

Michael Yang 10 months ago
parent
commit
58e3fff311

+ 11 - 13
server/images.go

@@ -28,6 +28,7 @@ import (
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/parser"
+	"github.com/ollama/ollama/template"
 	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
@@ -48,12 +49,13 @@ type Model struct {
 	ParentModel    string
 	AdapterPaths   []string
 	ProjectorPaths []string
-	Template       string
 	System         string
 	License        []string
 	Digest         string
 	Options        map[string]interface{}
 	Messages       []Message
+
+	Template *template.Template
 }
 
 func (m *Model) IsEmbedding() bool {
@@ -82,10 +84,10 @@ func (m *Model) String() string {
 		})
 	}
 
-	if m.Template != "" {
+	if m.Template != nil {
 		modelfile.Commands = append(modelfile.Commands, parser.Command{
 			Name: "template",
-			Args: m.Template,
+			Args: m.Template.String(),
 		})
 	}
 
@@ -191,8 +193,7 @@ func GetModel(name string) (*Model, error) {
 		Name:      mp.GetFullTagname(),
 		ShortName: mp.GetShortTagname(),
 		Digest:    digest,
-		Template:  "{{ .Prompt }}",
-		License:   []string{},
+		Template:  template.DefaultTemplate,
 	}
 
 	filename, err := GetBlobsPath(manifest.Config.Digest)
@@ -228,27 +229,24 @@ func GetModel(name string) (*Model, error) {
 			model.AdapterPaths = append(model.AdapterPaths, filename)
 		case "application/vnd.ollama.image.projector":
 			model.ProjectorPaths = append(model.ProjectorPaths, filename)
-		case "application/vnd.ollama.image.template":
+		case "application/vnd.ollama.image.prompt",
+			"application/vnd.ollama.image.template":
 			bts, err := os.ReadFile(filename)
 			if err != nil {
 				return nil, err
 			}
 
-			model.Template = string(bts)
-		case "application/vnd.ollama.image.system":
-			bts, err := os.ReadFile(filename)
+			model.Template, err = template.Parse(string(bts))
 			if err != nil {
 				return nil, err
 			}
-
-			model.System = string(bts)
-		case "application/vnd.ollama.image.prompt":
+		case "application/vnd.ollama.image.system":
 			bts, err := os.ReadFile(filename)
 			if err != nil {
 				return nil, err
 			}
 
-			model.Template = string(bts)
+			model.System = string(bts)
 		case "application/vnd.ollama.image.params":
 			params, err := os.Open(filename)
 			if err != nil {

+ 2 - 2
server/model.go

@@ -16,7 +16,7 @@ import (
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/convert"
 	"github.com/ollama/ollama/llm"
-	"github.com/ollama/ollama/templates"
+	"github.com/ollama/ollama/template"
 	"github.com/ollama/ollama/types/model"
 )
 
@@ -258,7 +258,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
 func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
 	for _, layer := range layers {
 		if s := layer.GGML.KV().ChatTemplate(); s != "" {
-			if t, err := templates.NamedTemplate(s); err != nil {
+			if t, err := template.Named(s); err != nil {
 				slog.Debug("template detection", "error", err)
 			} else {
 				tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")

+ 7 - 11
server/prompt.go

@@ -4,10 +4,11 @@ import (
 	"fmt"
 	"log/slog"
 	"strings"
-	"text/template"
+
 	"text/template/parse"
 
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/template"
 )
 
 // isResponseNode checks if the node contains .Response
@@ -53,13 +54,8 @@ func formatTemplateForResponse(tmpl *template.Template, generate bool) {
 
 // Prompt renders a prompt from a template. If generate is set to true,
 // the response and parts of the template following it are not rendered
-func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) {
-	parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl)
-	if err != nil {
-		return "", err
-	}
-
-	formatTemplateForResponse(parsed, generate)
+func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) {
+	formatTemplateForResponse(tmpl, generate)
 
 	vars := map[string]any{
 		"System":   system,
@@ -68,14 +64,14 @@ func Prompt(tmpl, system, prompt, response string, generate bool) (string, error
 	}
 
 	var sb strings.Builder
-	if err := parsed.Execute(&sb, vars); err != nil {
+	if err := tmpl.Execute(&sb, vars); err != nil {
 		return "", err
 	}
 
 	return sb.String(), nil
 }
 
-func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
+func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
 	rendered, err := Prompt(tmpl, system, prompt, response, false)
 	if err != nil {
 		return 0, err
@@ -91,7 +87,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
 }
 
 // ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
-func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
+func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
 	type prompt struct {
 		System   string
 		Prompt   string

+ 13 - 2
server/prompt_test.go

@@ -5,6 +5,7 @@ import (
 	"testing"
 
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/template"
 )
 
 func TestPrompt(t *testing.T) {
@@ -61,7 +62,12 @@ func TestPrompt(t *testing.T) {
 
 	for _, tc := range tests {
 		t.Run(tc.name, func(t *testing.T) {
-			got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate)
+			tmpl, err := template.Parse(tc.template)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate)
 			if err != nil {
 				t.Errorf("error = %v", err)
 			}
@@ -192,7 +198,12 @@ func TestChatPrompt(t *testing.T) {
 
 	for _, tc := range tests {
 		t.Run(tc.name, func(t *testing.T) {
-			got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
+			tmpl, err := template.Parse(tc.template)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode)
 			if err != nil {
 				t.Errorf("error = %v", err)
 			}

+ 20 - 6
server/routes.go

@@ -31,6 +31,7 @@ import (
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/openai"
 	"github.com/ollama/ollama/parser"
+	"github.com/ollama/ollama/template"
 	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
@@ -161,6 +162,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		return
 	}
 
+	tmpl, err := template.Parse(req.Template)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
 	checkpointLoaded := time.Now()
 
 	var prompt string
@@ -169,7 +176,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		prompt = req.Prompt
 	case req.Prompt != "":
 		if req.Template == "" {
-			req.Template = model.Template
+			model.Template, err = template.Parse(req.Template)
+			if err != nil {
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+				return
+			}
 		}
 
 		if req.System == "" {
@@ -187,7 +198,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 
 		sb.WriteString(req.Prompt)
 
-		p, err := Prompt(req.Template, req.System, sb.String(), "", true)
+		p, err := Prompt(tmpl, req.System, sb.String(), "", true)
 		if err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			return
@@ -242,7 +253,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 				resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
 
 				if !req.Raw {
-					p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
+					p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
 					if err != nil {
 						c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 						return
@@ -680,7 +691,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 	}
 
 	if req.Template != "" {
-		m.Template = req.Template
+		m.Template, err = template.Parse(req.Template)
+		if err != nil {
+			return nil, err
+		}
 	}
 
 	msgs := make([]api.Message, 0)
@@ -701,7 +715,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 	resp := &api.ShowResponse{
 		License:    strings.Join(m.License, "\n"),
 		System:     m.System,
-		Template:   m.Template,
+		Template:   m.Template.String(),
 		Details:    modelDetails,
 		Messages:   msgs,
 		ModifiedAt: manifest.fi.ModTime(),
@@ -1246,7 +1260,7 @@ func (s *Server) ProcessHandler(c *gin.Context) {
 }
 
 // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
-func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
+func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
 	encode := func(s string) ([]int, error) {
 		return runner.llama.Tokenize(ctx, s)
 	}

+ 0 - 0
templates/alfred.gotmpl → template/alfred.gotmpl


+ 0 - 0
templates/alpaca.gotmpl → template/alpaca.gotmpl


+ 0 - 0
templates/chatml.gotmpl → template/chatml.gotmpl


+ 0 - 0
templates/chatqa.gotmpl → template/chatqa.gotmpl


+ 0 - 0
templates/codellama-70b-instruct.gotmpl → template/codellama-70b-instruct.gotmpl


+ 0 - 0
templates/falcon-instruct.gotmpl → template/falcon-instruct.gotmpl


+ 0 - 0
templates/gemma-instruct.gotmpl → template/gemma-instruct.gotmpl


+ 0 - 0
templates/granite-instruct.gotmpl → template/granite-instruct.gotmpl


+ 0 - 0
templates/index.json → template/index.json


+ 0 - 0
templates/llama2-chat.gotmpl → template/llama2-chat.gotmpl


+ 0 - 0
templates/llama3-instruct.gotmpl → template/llama3-instruct.gotmpl


+ 0 - 0
templates/magicoder.gotmpl → template/magicoder.gotmpl


+ 0 - 0
templates/mistral-instruct.gotmpl → template/mistral-instruct.gotmpl


+ 0 - 0
templates/openchat.gotmpl → template/openchat.gotmpl


+ 0 - 0
templates/phi-3.gotmpl → template/phi-3.gotmpl


+ 0 - 0
templates/solar-instruct.gotmpl → template/solar-instruct.gotmpl


+ 0 - 0
templates/starcoder2-instruct.gotmpl → template/starcoder2-instruct.gotmpl


+ 158 - 0
template/template.go

@@ -0,0 +1,158 @@
+package template
+
+import (
+	"bytes"
+	"embed"
+	"encoding/json"
+	"errors"
+	"io"
+	"math"
+	"slices"
+	"strings"
+	"sync"
+	"text/template"
+	"text/template/parse"
+
+	"github.com/agnivade/levenshtein"
+	"golang.org/x/exp/maps"
+)
+
+//go:embed index.json
+var indexBytes []byte
+
+//go:embed *.gotmpl
+var templatesFS embed.FS
+
+var templatesOnce = sync.OnceValues(func() ([]*named, error) {
+	var templates []*named
+	if err := json.Unmarshal(indexBytes, &templates); err != nil {
+		return nil, err
+	}
+
+	for _, t := range templates {
+		bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
+		if err != nil {
+			return nil, err
+		}
+
+		// normalize line endings
+		t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
+	}
+
+	return templates, nil
+})
+
+type named struct {
+	Name     string `json:"name"`
+	Template string `json:"template"`
+	Bytes    []byte
+}
+
+func (t named) Reader() io.Reader {
+	return bytes.NewReader(t.Bytes)
+}
+
+func Named(s string) (*named, error) {
+	templates, err := templatesOnce()
+	if err != nil {
+		return nil, err
+	}
+
+	var template *named
+	score := math.MaxInt
+	for _, t := range templates {
+		if s := levenshtein.ComputeDistance(s, t.Template); s < score {
+			score = s
+			template = t
+		}
+	}
+
+	if score < 100 {
+		return template, nil
+	}
+
+	return nil, errors.New("no matching template found")
+}
+
+type Template struct {
+	*template.Template
+	raw string
+}
+
+func (t *Template) String() string {
+	return t.raw
+}
+
+var DefaultTemplate, _ = Parse("{{ .Prompt }}")
+
+func Parse(s string) (*Template, error) {
+	t, err := template.New("").Option("missingkey=zero").Parse(s)
+	if err != nil {
+		return nil, err
+	}
+
+	return &Template{Template: t, raw: s}, nil
+}
+
+func (t *Template) Vars() []string {
+	var vars []string
+	for _, n := range t.Tree.Root.Nodes {
+		vars = append(vars, parseNode(n)...)
+	}
+
+	set := make(map[string]struct{})
+	for _, n := range vars {
+		set[strings.ToLower(n)] = struct{}{}
+	}
+
+	vars = maps.Keys(set)
+	slices.Sort(vars)
+	return vars
+}
+
+func parseNode(n parse.Node) []string {
+	switch n := n.(type) {
+	case *parse.ActionNode:
+		return parseNode(n.Pipe)
+	case *parse.IfNode:
+		names := parseNode(n.Pipe)
+		names = append(names, parseNode(n.List)...)
+		if n.ElseList != nil {
+			names = append(names, parseNode(n.ElseList)...)
+		}
+		return names
+	case *parse.RangeNode:
+		names := parseNode(n.Pipe)
+		names = append(names, parseNode(n.List)...)
+		if n.ElseList != nil {
+			names = append(names, parseNode(n.ElseList)...)
+		}
+		return names
+	case *parse.WithNode:
+		names := parseNode(n.Pipe)
+		names = append(names, parseNode(n.List)...)
+		if n.ElseList != nil {
+			names = append(names, parseNode(n.ElseList)...)
+		}
+		return names
+	case *parse.PipeNode:
+		var names []string
+		for _, c := range n.Cmds {
+			for _, a := range c.Args {
+				names = append(names, parseNode(a)...)
+			}
+		}
+		return names
+	case *parse.ListNode:
+		var names []string
+		for _, n := range n.Nodes {
+			names = append(names, parseNode(n)...)
+		}
+
+		return names
+	case *parse.FieldNode:
+		return n.Ident
+	}
+
+	return nil
+}

+ 89 - 0
template/template_test.go

@@ -0,0 +1,89 @@
+package template
+
+import (
+	"bufio"
+	"bytes"
+	"encoding/json"
+	"io"
+	"os"
+	"path/filepath"
+	"slices"
+	"testing"
+	"text/template"
+
+	"github.com/ollama/ollama/llm"
+)
+
+func TestNamed(t *testing.T) {
+	f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer f.Close()
+
+	scanner := bufio.NewScanner(f)
+	for scanner.Scan() {
+		var ss map[string]string
+		if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
+			t.Fatal(err)
+		}
+
+		for k, v := range ss {
+			t.Run(k, func(t *testing.T) {
+				kv := llm.KV{"tokenizer.chat_template": v}
+				s := kv.ChatTemplate()
+				r, err := Named(s)
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				if r.Name != k {
+					t.Errorf("expected %q, got %q", k, r.Name)
+				}
+
+				var b bytes.Buffer
+				if _, err := io.Copy(&b, r.Reader()); err != nil {
+					t.Fatal(err)
+				}
+
+				tmpl, err := template.New(s).Parse(b.String())
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				if tmpl.Tree.Root.String() == "" {
+					t.Errorf("empty %s template", k)
+				}
+			})
+		}
+	}
+}
+
+func TestParse(t *testing.T) {
+	cases := []struct {
+		template     string
+		capabilities []string
+	}{
+		{"{{ .Prompt }}", []string{"prompt"}},
+		{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
+		{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
+		{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}},
+		{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
+		{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
+		{"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}},
+	}
+
+	for _, tt := range cases {
+		t.Run("", func(t *testing.T) {
+			tmpl, err := Parse(tt.template)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			vars := tmpl.Vars()
+			if !slices.Equal(tt.capabilities, vars) {
+				t.Errorf("expected %v, got %v", tt.capabilities, vars)
+			}
+		})
+	}
+}

+ 0 - 0
templates/testdata/templates.jsonl → template/testdata/templates.jsonl


+ 0 - 0
templates/vicuna.gotmpl → template/vicuna.gotmpl


+ 0 - 0
templates/zephyr.gotmpl → template/zephyr.gotmpl


+ 0 - 70
templates/template.go

@@ -1,70 +0,0 @@
-package templates
-
-import (
-	"bytes"
-	"embed"
-	"encoding/json"
-	"errors"
-	"io"
-	"math"
-	"sync"
-
-	"github.com/agnivade/levenshtein"
-)
-
-//go:embed index.json
-var indexBytes []byte
-
-//go:embed *.gotmpl
-var templatesFS embed.FS
-
-var templatesOnce = sync.OnceValues(func() ([]*Template, error) {
-	var templates []*Template
-	if err := json.Unmarshal(indexBytes, &templates); err != nil {
-		return nil, err
-	}
-
-	for _, t := range templates {
-		bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
-		if err != nil {
-			return nil, err
-		}
-
-		// normalize line endings
-		t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
-	}
-
-	return templates, nil
-})
-
-type Template struct {
-	Name     string `json:"name"`
-	Template string `json:"template"`
-	Bytes []byte
-}
-
-func (t Template) Reader() io.Reader {
-	return bytes.NewReader(t.Bytes)
-}
-
-func NamedTemplate(s string) (*Template, error) {
-	templates, err := templatesOnce()
-	if err != nil {
-		return nil, err
-	}
-
-	var template *Template
-	score := math.MaxInt
-	for _, t := range templates {
-		if s := levenshtein.ComputeDistance(s, t.Template); s < score {
-			score = s
-			template = t
-		}
-	}
-
-	if score < 100 {
-		return template, nil
-	}
-
-	return nil, errors.New("no matching template found")
-}

+ 0 - 59
templates/template_test.go

@@ -1,59 +0,0 @@
-package templates
-
-import (
-	"bufio"
-	"bytes"
-	"encoding/json"
-	"io"
-	"os"
-	"path/filepath"
-	"testing"
-	"text/template"
-
-	"github.com/ollama/ollama/llm"
-)
-
-func TestKVChatTemplate(t *testing.T) {
-	f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer f.Close()
-
-	scanner := bufio.NewScanner(f)
-	for scanner.Scan() {
-		var ss map[string]string
-		if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
-			t.Fatal(err)
-		}
-
-		for k, v := range ss {
-			t.Run(k, func(t *testing.T) {
-				kv := llm.KV{"tokenizer.chat_template": v}
-				s := kv.ChatTemplate()
-				r, err := NamedTemplate(s)
-				if err != nil {
-					t.Fatal(err)
-				}
-
-				if r.Name != k {
-					t.Errorf("expected %q, got %q", k, r.Name)
-				}
-
-				var b bytes.Buffer
-				if _, err := io.Copy(&b, r.Reader()); err != nil {
-					t.Fatal(err)
-				}
-
-				tmpl, err := template.New(s).Parse(b.String())
-				if err != nil {
-					t.Fatal(err)
-				}
-
-				if tmpl.Tree.Root.String() == "" {
-					t.Errorf("empty %s template", k)
-				}
-			})
-		}
-	}
-}