Переглянути джерело

separate prompt into template and system

Michael Yang 1 рік тому
батько
коміт
df146c41e2
3 змінених файлів з 113 додано та 85 видалено
  1. 58 43
      parser/parser.go
  2. 53 32
      server/images.go
  3. 2 10
      server/routes.go

+ 58 - 43
parser/parser.go

@@ -2,76 +2,91 @@ package parser
 
 import (
 	"bufio"
+	"bytes"
+	"errors"
 	"fmt"
 	"io"
-	"strings"
 )
 
 type Command struct {
 	Name string
-	Arg  string
+	Args string
+}
+
+func (c *Command) Reset() {
+	c.Name = ""
+	c.Args = ""
 }
 
 func Parse(reader io.Reader) ([]Command, error) {
 	var commands []Command
-	var foundModel bool
+
+	var command, modelCommand Command
 
 	scanner := bufio.NewScanner(reader)
-	multiline := false
-	var multilineCommand *Command
+	scanner.Split(scanModelfile)
 	for scanner.Scan() {
-		line := scanner.Text()
-		if multiline {
-			// If we're in a multiline string and the line is """, end the multiline string.
-			if strings.TrimSpace(line) == `"""` {
-				multiline = false
-				commands = append(commands, *multilineCommand)
-			} else {
-				// Otherwise, append the line to the multiline string.
-				multilineCommand.Arg += "\n" + line
-			}
-			continue
-		}
-		fields := strings.Fields(line)
+		line := scanner.Bytes()
+
+		fields := bytes.SplitN(line, []byte(" "), 2)
 		if len(fields) == 0 {
 			continue
 		}
 
-		command := Command{}
-		switch strings.ToUpper(fields[0]) {
+		switch string(bytes.ToUpper(fields[0])) {
 		case "FROM":
 			command.Name = "model"
-			command.Arg = fields[1]
-			if command.Arg == "" {
-				return nil, fmt.Errorf("no model specified in FROM line")
-			}
-			foundModel = true
-		case "PROMPT", "LICENSE":
-			command.Name = strings.ToLower(fields[0])
-			if fields[1] == `"""` {
-				multiline = true
-				multilineCommand = &command
-				multilineCommand.Arg = ""
-			} else {
-				command.Arg = strings.Join(fields[1:], " ")
-			}
+			command.Args = string(fields[1])
+			// copy command for validation
+			modelCommand = command
+		case "LICENSE", "TEMPLATE", "SYSTEM":
+			command.Name = string(bytes.ToLower(fields[0]))
+			command.Args = string(fields[1])
 		case "PARAMETER":
-			command.Name = fields[1]
-			command.Arg = strings.Join(fields[2:], " ")
+			fields = bytes.SplitN(fields[1], []byte(" "), 2)
+			command.Name = string(fields[0])
+			command.Args = string(fields[1])
 		default:
 			continue
 		}
-		if !multiline {
-			commands = append(commands, command)
-		}
+
+		commands = append(commands, command)
+		command.Reset()
 	}
 
-	if !foundModel {
+	if modelCommand.Args == "" {
 		return nil, fmt.Errorf("no FROM line for the model was specified")
 	}
 
-	if multiline {
-		return nil, fmt.Errorf("unclosed multiline string")
-	}
 	return commands, scanner.Err()
 }
+
+func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) {
+	if atEOF || len(data) == 0 {
+		return 0, nil, nil
+	}
+
+	newline := bytes.IndexByte(data, '\n')
+
+	if start := bytes.Index(data, []byte(`"""`)); start >= 0 && start < newline {
+		end := bytes.Index(data[start+3:], []byte(`"""`))
+		if end < 0 {
+			return 0, nil, errors.New(`unterminated multiline string: """`)
+		}
+
+		n := start + 3 + end + 3
+		return n, bytes.Replace(data[:n], []byte(`"""`), []byte(""), 2), nil
+	}
+
+	if start := bytes.Index(data, []byte(`'''`)); start >= 0 && start < newline {
+		end := bytes.Index(data[start+3:], []byte(`'''`))
+		if end < 0 {
+			return 0, nil, errors.New("unterminated multiline string: '''")
+		}
+
+		n := start + 3 + end + 3
+		return n, bytes.Replace(data[:n], []byte("'''"), []byte(""), 2), nil
+	}
+
+	return bufio.ScanLines(data, atEOF)
+}

+ 53 - 32
server/images.go

@@ -16,6 +16,7 @@ import (
 	"reflect"
 	"strconv"
 	"strings"
+	"text/template"
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/parser"
@@ -24,10 +25,33 @@ import (
 type Model struct {
 	Name      string `json:"name"`
 	ModelPath string
-	Prompt    string
+	Template  string
+	System    string
 	Options   api.Options
 }
 
+func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
+	tmpl, err := template.New("").Parse(m.Template)
+	if err != nil {
+		return "", err
+	}
+
+	var vars struct {
+		System string
+		Prompt string
+	}
+
+	vars.System = m.System
+	vars.Prompt = request.Prompt
+
+	var sb strings.Builder
+	if err := tmpl.Execute(&sb, vars); err != nil {
+		return "", err
+	}
+
+	return sb.String(), nil
+}
+
 type ManifestV2 struct {
 	SchemaVersion int      `json:"schemaVersion"`
 	MediaType     string   `json:"mediaType"`
@@ -71,20 +95,19 @@ func GetManifest(mp ModelPath) (*ManifestV2, error) {
 	if err != nil {
 		return nil, err
 	}
+
 	if _, err = os.Stat(fp); err != nil && !errors.Is(err, os.ErrNotExist) {
 		return nil, fmt.Errorf("couldn't find model '%s'", mp.GetShortTagname())
 	}
 
 	var manifest *ManifestV2
 
-	f, err := os.Open(fp)
+	bts, err := os.ReadFile(fp)
 	if err != nil {
 		return nil, fmt.Errorf("couldn't open file '%s'", fp)
 	}
 
-	decoder := json.NewDecoder(f)
-	err = decoder.Decode(&manifest)
-	if err != nil {
+	if err := json.Unmarshal(bts, &manifest); err != nil {
 		return nil, err
 	}
 
@@ -112,12 +135,20 @@ func GetModel(name string) (*Model, error) {
 		switch layer.MediaType {
 		case "application/vnd.ollama.image.model":
 			model.ModelPath = filename
-		case "application/vnd.ollama.image.prompt":
-			data, err := os.ReadFile(filename)
+		case "application/vnd.ollama.image.template":
+			bts, err := os.ReadFile(filename)
+			if err != nil {
+				return nil, err
+			}
+
+			model.Template = string(bts)
+		case "application/vnd.ollama.image.system":
+			bts, err := os.ReadFile(filename)
 			if err != nil {
 				return nil, err
 			}
-			model.Prompt = string(data)
+
+			model.System = string(bts)
 		case "application/vnd.ollama.image.params":
 			params, err := os.Open(filename)
 			if err != nil {
@@ -156,13 +187,13 @@ func CreateModel(name string, path string, fn func(status string)) error {
 	params := make(map[string]string)
 
 	for _, c := range commands {
-		log.Printf("[%s] - %s\n", c.Name, c.Arg)
+		log.Printf("[%s] - %s\n", c.Name, c.Args)
 		switch c.Name {
 		case "model":
 			fn("looking for model")
-			mf, err := GetManifest(ParseModelPath(c.Arg))
+			mf, err := GetManifest(ParseModelPath(c.Args))
 			if err != nil {
-				fp := c.Arg
+				fp := c.Args
 
 				// If filePath starts with ~/, replace it with the user's home directory.
 				if strings.HasPrefix(fp, "~/") {
@@ -183,7 +214,7 @@ func CreateModel(name string, path string, fn func(status string)) error {
 				fn("creating model layer")
 				file, err := os.Open(fp)
 				if err != nil {
-					fn(fmt.Sprintf("couldn't find model '%s'", c.Arg))
+					fn(fmt.Sprintf("couldn't find model '%s'", c.Args))
 					return fmt.Errorf("failed to open file: %v", err)
 				}
 				defer file.Close()
@@ -206,31 +237,21 @@ func CreateModel(name string, path string, fn func(status string)) error {
 					layers = append(layers, newLayer)
 				}
 			}
-		case "prompt":
-			fn("creating prompt layer")
+		case "license", "template", "system":
+			fn(fmt.Sprintf("creating %s layer", c.Name))
 			// remove the prompt layer if one exists
-			layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.prompt")
+			mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
+			layers = removeLayerFromLayers(layers, mediaType)
 
-			prompt := strings.NewReader(c.Arg)
-			l, err := CreateLayer(prompt)
-			if err != nil {
-				fn(fmt.Sprintf("couldn't create prompt layer: %v", err))
-				return fmt.Errorf("failed to create layer: %v", err)
-			}
-			l.MediaType = "application/vnd.ollama.image.prompt"
-			layers = append(layers, l)
-		case "license":
-			fn("creating license layer")
-			license := strings.NewReader(c.Arg)
-			l, err := CreateLayer(license)
+			layer, err := CreateLayer(strings.NewReader(c.Args))
 			if err != nil {
-				fn(fmt.Sprintf("couldn't create license layer: %v", err))
-				return fmt.Errorf("failed to create layer: %v", err)
+				return err
 			}
-			l.MediaType = "application/vnd.ollama.image.license"
-			layers = append(layers, l)
+
+			layer.MediaType = mediaType
+			layers = append(layers, layer)
 		default:
-			params[c.Name] = c.Arg
+			params[c.Name] = c.Args
 		}
 	}
 

+ 2 - 10
server/routes.go

@@ -9,7 +9,6 @@ import (
 	"os"
 	"path/filepath"
 	"strings"
-	"text/template"
 	"time"
 
 	"dario.cat/mergo"
@@ -54,19 +53,12 @@ func generate(c *gin.Context) {
 		return
 	}
 
-	templ, err := template.New("").Parse(model.Prompt)
+	prompt, err := model.Prompt(req)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
 
-	var sb strings.Builder
-	if err = templ.Execute(&sb, req); err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
-	req.Prompt = sb.String()
-
 	llm, err := llama.New(model.ModelPath, opts)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -77,7 +69,7 @@ func generate(c *gin.Context) {
 	ch := make(chan any)
 	go func() {
 		defer close(ch)
-		llm.Predict(req.Context, req.Prompt, func(r api.GenerateResponse) {
+		llm.Predict(req.Context, prompt, func(r api.GenerateResponse) {
 			r.Model = req.Model
 			r.CreatedAt = time.Now().UTC()
 			if r.Done {