Bläddra i källkod

chat api (#991)

- update chat docs
- add messages chat endpoint
- remove deprecated context and template generate parameters from docs
- context and template are still supported for the time being and will continue to work as expected
- add partial response to chat history
Bruce MacDonald 1 år sedan
förälder
incheckning
7a0899d62d
9 ändrade filer med 665 tillägg och 254 borttagningar
  1. 13 0
      api/client.go
  2. 52 22
      api/types.go
  3. 115 119
      cmd/cmd.go
  4. 142 10
      docs/api.md
  5. 34 24
      llm/llama.go
  6. 1 1
      llm/llm.go
  7. 64 19
      server/images.go
  8. 4 6
      server/images_test.go
  9. 240 53
      server/routes.go

+ 13 - 0
api/client.go

@@ -221,6 +221,19 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
 	})
 	})
 }
 }
 
 
+type ChatResponseFunc func(ChatResponse) error
+
+func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
+	return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
+		var resp ChatResponse
+		if err := json.Unmarshal(bts, &resp); err != nil {
+			return err
+		}
+
+		return fn(resp)
+	})
+}
+
 type PullProgressFunc func(ProgressResponse) error
 type PullProgressFunc func(ProgressResponse) error
 
 
 func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
 func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {

+ 52 - 22
api/types.go

@@ -36,7 +36,7 @@ type GenerateRequest struct {
 	Prompt   string `json:"prompt"`
 	Prompt   string `json:"prompt"`
 	System   string `json:"system"`
 	System   string `json:"system"`
 	Template string `json:"template"`
 	Template string `json:"template"`
-	Context  []int  `json:"context,omitempty"`
+	Context  []int  `json:"context,omitempty"` // DEPRECATED: context is deprecated, use the /chat endpoint instead for chat history
 	Stream   *bool  `json:"stream,omitempty"`
 	Stream   *bool  `json:"stream,omitempty"`
 	Raw      bool   `json:"raw,omitempty"`
 	Raw      bool   `json:"raw,omitempty"`
 	Format   string `json:"format"`
 	Format   string `json:"format"`
@@ -44,6 +44,41 @@ type GenerateRequest struct {
 	Options map[string]interface{} `json:"options"`
 	Options map[string]interface{} `json:"options"`
 }
 }
 
 
+type ChatRequest struct {
+	Model    string    `json:"model"`
+	Messages []Message `json:"messages"`
+	Template string    `json:"template"`
+	Stream   *bool     `json:"stream,omitempty"`
+	Format   string    `json:"format"`
+
+	Options map[string]interface{} `json:"options"`
+}
+
+type Message struct {
+	Role    string `json:"role"` // one of ["system", "user", "assistant"]
+	Content string `json:"content"`
+}
+
+type ChatResponse struct {
+	Model     string    `json:"model"`
+	CreatedAt time.Time `json:"created_at"`
+	Message   *Message  `json:"message,omitempty"`
+
+	Done    bool  `json:"done"`
+	Context []int `json:"context,omitempty"`
+
+	EvalMetrics
+}
+
+type EvalMetrics struct {
+	TotalDuration      time.Duration `json:"total_duration,omitempty"`
+	LoadDuration       time.Duration `json:"load_duration,omitempty"`
+	PromptEvalCount    int           `json:"prompt_eval_count,omitempty"`
+	PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
+	EvalCount          int           `json:"eval_count,omitempty"`
+	EvalDuration       time.Duration `json:"eval_duration,omitempty"`
+}
+
 // Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
 // Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
 type Options struct {
 type Options struct {
 	Runner
 	Runner
@@ -173,39 +208,34 @@ type GenerateResponse struct {
 	Done    bool  `json:"done"`
 	Done    bool  `json:"done"`
 	Context []int `json:"context,omitempty"`
 	Context []int `json:"context,omitempty"`
 
 
-	TotalDuration      time.Duration `json:"total_duration,omitempty"`
-	LoadDuration       time.Duration `json:"load_duration,omitempty"`
-	PromptEvalCount    int           `json:"prompt_eval_count,omitempty"`
-	PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
-	EvalCount          int           `json:"eval_count,omitempty"`
-	EvalDuration       time.Duration `json:"eval_duration,omitempty"`
+	EvalMetrics
 }
 }
 
 
-func (r *GenerateResponse) Summary() {
-	if r.TotalDuration > 0 {
-		fmt.Fprintf(os.Stderr, "total duration:       %v\n", r.TotalDuration)
+func (m *EvalMetrics) Summary() {
+	if m.TotalDuration > 0 {
+		fmt.Fprintf(os.Stderr, "total duration:       %v\n", m.TotalDuration)
 	}
 	}
 
 
-	if r.LoadDuration > 0 {
-		fmt.Fprintf(os.Stderr, "load duration:        %v\n", r.LoadDuration)
+	if m.LoadDuration > 0 {
+		fmt.Fprintf(os.Stderr, "load duration:        %v\n", m.LoadDuration)
 	}
 	}
 
 
-	if r.PromptEvalCount > 0 {
-		fmt.Fprintf(os.Stderr, "prompt eval count:    %d token(s)\n", r.PromptEvalCount)
+	if m.PromptEvalCount > 0 {
+		fmt.Fprintf(os.Stderr, "prompt eval count:    %d token(s)\n", m.PromptEvalCount)
 	}
 	}
 
 
-	if r.PromptEvalDuration > 0 {
-		fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration)
-		fmt.Fprintf(os.Stderr, "prompt eval rate:     %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds())
+	if m.PromptEvalDuration > 0 {
+		fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", m.PromptEvalDuration)
+		fmt.Fprintf(os.Stderr, "prompt eval rate:     %.2f tokens/s\n", float64(m.PromptEvalCount)/m.PromptEvalDuration.Seconds())
 	}
 	}
 
 
-	if r.EvalCount > 0 {
-		fmt.Fprintf(os.Stderr, "eval count:           %d token(s)\n", r.EvalCount)
+	if m.EvalCount > 0 {
+		fmt.Fprintf(os.Stderr, "eval count:           %d token(s)\n", m.EvalCount)
 	}
 	}
 
 
-	if r.EvalDuration > 0 {
-		fmt.Fprintf(os.Stderr, "eval duration:        %s\n", r.EvalDuration)
-		fmt.Fprintf(os.Stderr, "eval rate:            %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds())
+	if m.EvalDuration > 0 {
+		fmt.Fprintf(os.Stderr, "eval duration:        %s\n", m.EvalDuration)
+		fmt.Fprintf(os.Stderr, "eval rate:            %.2f tokens/s\n", float64(m.EvalCount)/m.EvalDuration.Seconds())
 	}
 	}
 }
 }
 
 

+ 115 - 119
cmd/cmd.go

@@ -159,7 +159,54 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 		return err
 		return err
 	}
 	}
 
 
-	return RunGenerate(cmd, args)
+	interactive := true
+
+	opts := runOptions{
+		Model:    name,
+		WordWrap: os.Getenv("TERM") == "xterm-256color",
+		Options:  map[string]interface{}{},
+	}
+
+	format, err := cmd.Flags().GetString("format")
+	if err != nil {
+		return err
+	}
+	opts.Format = format
+
+	prompts := args[1:]
+
+	// prepend stdin to the prompt if provided
+	if !term.IsTerminal(int(os.Stdin.Fd())) {
+		in, err := io.ReadAll(os.Stdin)
+		if err != nil {
+			return err
+		}
+
+		prompts = append([]string{string(in)}, prompts...)
+		opts.WordWrap = false
+		interactive = false
+	}
+	msg := api.Message{
+		Role:    "user",
+		Content: strings.Join(prompts, " "),
+	}
+	opts.Messages = append(opts.Messages, msg)
+	if len(prompts) > 0 {
+		interactive = false
+	}
+
+	nowrap, err := cmd.Flags().GetBool("nowordwrap")
+	if err != nil {
+		return err
+	}
+	opts.WordWrap = !nowrap
+
+	if !interactive {
+		_, err := chat(cmd, opts)
+		return err
+	}
+
+	return chatInteractive(cmd, opts)
 }
 }
 
 
 func PushHandler(cmd *cobra.Command, args []string) error {
 func PushHandler(cmd *cobra.Command, args []string) error {
@@ -411,83 +458,26 @@ func PullHandler(cmd *cobra.Command, args []string) error {
 	return nil
 	return nil
 }
 }
 
 
-func RunGenerate(cmd *cobra.Command, args []string) error {
-	interactive := true
-
-	opts := generateOptions{
-		Model:    args[0],
-		WordWrap: os.Getenv("TERM") == "xterm-256color",
-		Options:  map[string]interface{}{},
-	}
-
-	format, err := cmd.Flags().GetString("format")
-	if err != nil {
-		return err
-	}
-	opts.Format = format
-
-	prompts := args[1:]
-
-	// prepend stdin to the prompt if provided
-	if !term.IsTerminal(int(os.Stdin.Fd())) {
-		in, err := io.ReadAll(os.Stdin)
-		if err != nil {
-			return err
-		}
-
-		prompts = append([]string{string(in)}, prompts...)
-		opts.WordWrap = false
-		interactive = false
-	}
-	opts.Prompt = strings.Join(prompts, " ")
-	if len(prompts) > 0 {
-		interactive = false
-	}
-
-	nowrap, err := cmd.Flags().GetBool("nowordwrap")
-	if err != nil {
-		return err
-	}
-	opts.WordWrap = !nowrap
-
-	if !interactive {
-		return generate(cmd, opts)
-	}
-
-	return generateInteractive(cmd, opts)
-}
-
-type generateContextKey string
-
-type generateOptions struct {
+type runOptions struct {
 	Model    string
 	Model    string
-	Prompt   string
+	Messages []api.Message
 	WordWrap bool
 	WordWrap bool
 	Format   string
 	Format   string
-	System   string
 	Template string
 	Template string
 	Options  map[string]interface{}
 	Options  map[string]interface{}
 }
 }
 
 
-func generate(cmd *cobra.Command, opts generateOptions) error {
+func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
 	client, err := api.ClientFromEnvironment()
 	client, err := api.ClientFromEnvironment()
 	if err != nil {
 	if err != nil {
-		return err
+		return nil, err
 	}
 	}
 
 
 	p := progress.NewProgress(os.Stderr)
 	p := progress.NewProgress(os.Stderr)
 	defer p.StopAndClear()
 	defer p.StopAndClear()
-
 	spinner := progress.NewSpinner("")
 	spinner := progress.NewSpinner("")
 	p.Add("", spinner)
 	p.Add("", spinner)
 
 
-	var latest api.GenerateResponse
-
-	generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
-	if !ok {
-		generateContext = []int{}
-	}
-
 	termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
 	termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
 	if err != nil {
 	if err != nil {
 		opts.WordWrap = false
 		opts.WordWrap = false
@@ -506,24 +496,24 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
 
 
 	var currentLineLength int
 	var currentLineLength int
 	var wordBuffer string
 	var wordBuffer string
+	var latest api.ChatResponse
+	var fullResponse strings.Builder
+	var role string
 
 
-	request := api.GenerateRequest{
-		Model:    opts.Model,
-		Prompt:   opts.Prompt,
-		Context:  generateContext,
-		Format:   opts.Format,
-		System:   opts.System,
-		Template: opts.Template,
-		Options:  opts.Options,
-	}
-	fn := func(response api.GenerateResponse) error {
+	fn := func(response api.ChatResponse) error {
 		p.StopAndClear()
 		p.StopAndClear()
-
 		latest = response
 		latest = response
+		if response.Message == nil {
+			// warm-up response or done
+			return nil
+		}
+		role = response.Message.Role
+		content := response.Message.Content
+		fullResponse.WriteString(content)
 
 
 		termWidth, _, _ = term.GetSize(int(os.Stdout.Fd()))
 		termWidth, _, _ = term.GetSize(int(os.Stdout.Fd()))
 		if opts.WordWrap && termWidth >= 10 {
 		if opts.WordWrap && termWidth >= 10 {
-			for _, ch := range response.Response {
+			for _, ch := range content {
 				if currentLineLength+1 > termWidth-5 {
 				if currentLineLength+1 > termWidth-5 {
 					if len(wordBuffer) > termWidth-10 {
 					if len(wordBuffer) > termWidth-10 {
 						fmt.Printf("%s%c", wordBuffer, ch)
 						fmt.Printf("%s%c", wordBuffer, ch)
@@ -551,7 +541,7 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
 				}
 				}
 			}
 			}
 		} else {
 		} else {
-			fmt.Printf("%s%s", wordBuffer, response.Response)
+			fmt.Printf("%s%s", wordBuffer, content)
 			if len(wordBuffer) > 0 {
 			if len(wordBuffer) > 0 {
 				wordBuffer = ""
 				wordBuffer = ""
 			}
 			}
@@ -560,35 +550,35 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
 		return nil
 		return nil
 	}
 	}
 
 
-	if err := client.Generate(cancelCtx, &request, fn); err != nil {
+	req := &api.ChatRequest{
+		Model:    opts.Model,
+		Messages: opts.Messages,
+		Format:   opts.Format,
+		Template: opts.Template,
+		Options:  opts.Options,
+	}
+	if err := client.Chat(cancelCtx, req, fn); err != nil {
 		if errors.Is(err, context.Canceled) {
 		if errors.Is(err, context.Canceled) {
-			return nil
+			return nil, nil
 		}
 		}
-		return err
+		return nil, err
 	}
 	}
-	if opts.Prompt != "" {
+
+	if len(opts.Messages) > 0 {
 		fmt.Println()
 		fmt.Println()
 		fmt.Println()
 		fmt.Println()
 	}
 	}
 
 
-	if !latest.Done {
-		return nil
-	}
-
 	verbose, err := cmd.Flags().GetBool("verbose")
 	verbose, err := cmd.Flags().GetBool("verbose")
 	if err != nil {
 	if err != nil {
-		return err
+		return nil, err
 	}
 	}
 
 
 	if verbose {
 	if verbose {
 		latest.Summary()
 		latest.Summary()
 	}
 	}
 
 
-	ctx := cmd.Context()
-	ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
-	cmd.SetContext(ctx)
-
-	return nil
+	return &api.Message{Role: role, Content: fullResponse.String()}, nil
 }
 }
 
 
 type MultilineState int
 type MultilineState int
@@ -600,13 +590,10 @@ const (
 	MultilineTemplate
 	MultilineTemplate
 )
 )
 
 
-func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
+func chatInteractive(cmd *cobra.Command, opts runOptions) error {
 	// load the model
 	// load the model
-	loadOpts := generateOptions{
-		Model:  opts.Model,
-		Prompt: "",
-	}
-	if err := generate(cmd, loadOpts); err != nil {
+	loadOpts := runOptions{Model: opts.Model}
+	if _, err := chat(cmd, loadOpts); err != nil {
 		return err
 		return err
 	}
 	}
 
 
@@ -677,7 +664,9 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 	defer fmt.Printf(readline.EndBracketedPaste)
 	defer fmt.Printf(readline.EndBracketedPaste)
 
 
 	var multiline MultilineState
 	var multiline MultilineState
-	var prompt string
+	var content string
+	var systemContent string
+	opts.Messages = make([]api.Message, 0)
 
 
 	for {
 	for {
 		line, err := scanner.Readline()
 		line, err := scanner.Readline()
@@ -691,7 +680,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 			}
 			}
 
 
 			scanner.Prompt.UseAlt = false
 			scanner.Prompt.UseAlt = false
-			prompt = ""
+			content = ""
 
 
 			continue
 			continue
 		case err != nil:
 		case err != nil:
@@ -699,37 +688,37 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 		}
 		}
 
 
 		switch {
 		switch {
-		case strings.HasPrefix(prompt, `"""`):
+		case strings.HasPrefix(content, `"""`):
 			// if the prompt so far starts with """ then we're in multiline mode
 			// if the prompt so far starts with """ then we're in multiline mode
 			// and we need to keep reading until we find a line that ends with """
 			// and we need to keep reading until we find a line that ends with """
 			cut, found := strings.CutSuffix(line, `"""`)
 			cut, found := strings.CutSuffix(line, `"""`)
-			prompt += cut + "\n"
+			content += cut + "\n"
 
 
 			if !found {
 			if !found {
 				continue
 				continue
 			}
 			}
 
 
-			prompt = strings.TrimPrefix(prompt, `"""`)
+			content = strings.TrimPrefix(content, `"""`)
 			scanner.Prompt.UseAlt = false
 			scanner.Prompt.UseAlt = false
 
 
 			switch multiline {
 			switch multiline {
 			case MultilineSystem:
 			case MultilineSystem:
-				opts.System = prompt
-				prompt = ""
+				systemContent = content
+				content = ""
 				fmt.Println("Set system template.\n")
 				fmt.Println("Set system template.\n")
 			case MultilineTemplate:
 			case MultilineTemplate:
-				opts.Template = prompt
-				prompt = ""
+				opts.Template = content
+				content = ""
 				fmt.Println("Set model template.\n")
 				fmt.Println("Set model template.\n")
 			}
 			}
 			multiline = MultilineNone
 			multiline = MultilineNone
-		case strings.HasPrefix(line, `"""`) && len(prompt) == 0:
+		case strings.HasPrefix(line, `"""`) && len(content) == 0:
 			scanner.Prompt.UseAlt = true
 			scanner.Prompt.UseAlt = true
 			multiline = MultilinePrompt
 			multiline = MultilinePrompt
-			prompt += line + "\n"
+			content += line + "\n"
 			continue
 			continue
 		case scanner.Pasting:
 		case scanner.Pasting:
-			prompt += line + "\n"
+			content += line + "\n"
 			continue
 			continue
 		case strings.HasPrefix(line, "/list"):
 		case strings.HasPrefix(line, "/list"):
 			args := strings.Fields(line)
 			args := strings.Fields(line)
@@ -791,17 +780,17 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 					line = strings.TrimPrefix(line, `"""`)
 					line = strings.TrimPrefix(line, `"""`)
 					if strings.HasPrefix(args[2], `"""`) {
 					if strings.HasPrefix(args[2], `"""`) {
 						cut, found := strings.CutSuffix(line, `"""`)
 						cut, found := strings.CutSuffix(line, `"""`)
-						prompt += cut + "\n"
+						content += cut + "\n"
 						if found {
 						if found {
-							opts.System = prompt
+							systemContent = content
 							if args[1] == "system" {
 							if args[1] == "system" {
 								fmt.Println("Set system template.\n")
 								fmt.Println("Set system template.\n")
 							} else {
 							} else {
 								fmt.Println("Set prompt template.\n")
 								fmt.Println("Set prompt template.\n")
 							}
 							}
-							prompt = ""
+							content = ""
 						} else {
 						} else {
-							prompt = `"""` + prompt
+							content = `"""` + content
 							if args[1] == "system" {
 							if args[1] == "system" {
 								multiline = MultilineSystem
 								multiline = MultilineSystem
 							} else {
 							} else {
@@ -810,7 +799,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 							scanner.Prompt.UseAlt = true
 							scanner.Prompt.UseAlt = true
 						}
 						}
 					} else {
 					} else {
-						opts.System = line
+						systemContent = line
 						fmt.Println("Set system template.\n")
 						fmt.Println("Set system template.\n")
 					}
 					}
 				default:
 				default:
@@ -858,8 +847,8 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 					}
 					}
 				case "system":
 				case "system":
 					switch {
 					switch {
-					case opts.System != "":
-						fmt.Println(opts.System + "\n")
+					case systemContent != "":
+						fmt.Println(systemContent + "\n")
 					case resp.System != "":
 					case resp.System != "":
 						fmt.Println(resp.System + "\n")
 						fmt.Println(resp.System + "\n")
 					default:
 					default:
@@ -899,16 +888,23 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
 			fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0])
 			fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0])
 			continue
 			continue
 		default:
 		default:
-			prompt += line
+			content += line
 		}
 		}
 
 
-		if len(prompt) > 0 && multiline == MultilineNone {
-			opts.Prompt = prompt
-			if err := generate(cmd, opts); err != nil {
+		if len(content) > 0 && multiline == MultilineNone {
+			if systemContent != "" {
+				opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: systemContent})
+			}
+			opts.Messages = append(opts.Messages, api.Message{Role: "user", Content: content})
+			assistant, err := chat(cmd, opts)
+			if err != nil {
 				return err
 				return err
 			}
 			}
+			if assistant != nil {
+				opts.Messages = append(opts.Messages, *assistant)
+			}
 
 
-			prompt = ""
+			content = ""
 		}
 		}
 	}
 	}
 }
 }

+ 142 - 10
docs/api.md

@@ -24,7 +24,7 @@ All durations are returned in nanoseconds.
 
 
 ### Streaming responses
 ### Streaming responses
 
 
-Certain endpoints stream responses as JSON objects delineated with the newline (`\n`) character.
+Certain endpoints stream responses as JSON objects.
 
 
 ## Generate a completion
 ## Generate a completion
 
 
@@ -32,10 +32,12 @@ Certain endpoints stream responses as JSON objects delineated with the newline (
 POST /api/generate
 POST /api/generate
 ```
 ```
 
 
-Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses. The final response object will include statistics and additional data from the request.
+Generate a response for a given prompt with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
 
 
 ### Parameters
 ### Parameters
 
 
+`model` is required.
+
 - `model`: (required) the [model name](#model-names)
 - `model`: (required) the [model name](#model-names)
 - `prompt`: the prompt to generate a response for
 - `prompt`: the prompt to generate a response for
 
 
@@ -43,11 +45,10 @@ Advanced parameters (optional):
 
 
 - `format`: the format to return a response in. Currently the only accepted value is `json`
 - `format`: the format to return a response in. Currently the only accepted value is `json`
 - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
 - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
-- `system`: system prompt to (overrides what is defined in the `Modelfile`)
 - `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
 - `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
-- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
+- `system`: system prompt to (overrides what is defined in the `Modelfile`)
 - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
 - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
-- `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself.
+- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API.
 
 
 ### JSON mode
 ### JSON mode
 
 
@@ -57,7 +58,7 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur
 
 
 ### Examples
 ### Examples
 
 
-#### Request
+#### Request (Prompt)
 
 
 ```shell
 ```shell
 curl http://localhost:11434/api/generate -d '{
 curl http://localhost:11434/api/generate -d '{
@@ -89,7 +90,7 @@ The final response in the stream also includes additional data about the generat
 - `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
 - `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
 - `eval_count`: number of tokens the response
 - `eval_count`: number of tokens the response
 - `eval_duration`: time in nanoseconds spent generating the response
 - `eval_duration`: time in nanoseconds spent generating the response
-- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
+- `context`: deprecated, an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
 - `response`: empty if the response was streamed, if not streamed, this will contain the full response
 - `response`: empty if the response was streamed, if not streamed, this will contain the full response
 
 
 To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`.
 To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`.
@@ -114,6 +115,8 @@ To calculate how fast the response is generated in tokens per second (token/s),
 
 
 #### Request (No streaming)
 #### Request (No streaming)
 
 
+A response can be recieved in one reply when streaming is off.
+
 ```shell
 ```shell
 curl http://localhost:11434/api/generate -d '{
 curl http://localhost:11434/api/generate -d '{
   "model": "llama2",
   "model": "llama2",
@@ -144,9 +147,9 @@ If `stream` is set to `false`, the response will be a single JSON object:
 }
 }
 ```
 ```
 
 
-#### Request (Raw mode)
+#### Request (Raw Mode)
 
 
-In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context.
+In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting.
 
 
 ```shell
 ```shell
 curl http://localhost:11434/api/generate -d '{
 curl http://localhost:11434/api/generate -d '{
@@ -164,6 +167,7 @@ curl http://localhost:11434/api/generate -d '{
   "model": "mistral",
   "model": "mistral",
   "created_at": "2023-11-03T15:36:02.583064Z",
   "created_at": "2023-11-03T15:36:02.583064Z",
   "response": " The sky appears blue because of a phenomenon called Rayleigh scattering.",
   "response": " The sky appears blue because of a phenomenon called Rayleigh scattering.",
+  "context": [1, 2, 3],
   "done": true,
   "done": true,
   "total_duration": 14648695333,
   "total_duration": 14648695333,
   "load_duration": 3302671417,
   "load_duration": 3302671417,
@@ -275,7 +279,6 @@ curl http://localhost:11434/api/generate -d '{
   "model": "llama2",
   "model": "llama2",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "response": "The sky is blue because it is the color of the sky.",
   "response": "The sky is blue because it is the color of the sky.",
-  "context": [1, 2, 3],
   "done": true,
   "done": true,
   "total_duration": 5589157167,
   "total_duration": 5589157167,
   "load_duration": 3013701500,
   "load_duration": 3013701500,
@@ -288,6 +291,135 @@ curl http://localhost:11434/api/generate -d '{
 }
 }
 ```
 ```
 
 
+## Send Chat Messages
+```shell
+POST /api/chat
+```
+
+Generate the next message in a chat with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
+
+### Parameters
+
+`model` is required.
+
+- `model`: (required) the [model name](#model-names)
+- `messages`: the messages of the chat, this can be used to keep a chat memory
+
+Advanced parameters (optional):
+
+- `format`: the format to return a response in. Currently the only accepted value is `json`
+- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
+- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
+- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
+
+### Examples
+
+#### Request
+Send a chat message with a streaming response.
+
+```shell
+curl http://localhost:11434/api/generate -d '{
+  "model": "llama2",
+  "messages": [
+    {
+      "role": "user",
+      "content": "why is the sky blue?"
+    }
+  ]
+}'
+```
+
+#### Response
+
+A stream of JSON objects is returned:
+
+```json
+{
+  "model": "llama2",
+  "created_at": "2023-08-04T08:52:19.385406455-07:00",
+  "message": {
+    "role": "assisant",
+    "content": "The"
+  },
+  "done": false
+}
+```
+
+Final response:
+
+```json
+{
+  "model": "llama2",
+  "created_at": "2023-08-04T19:22:45.499127Z",
+  "done": true,
+  "total_duration": 5589157167,
+  "load_duration": 3013701500,
+  "sample_count": 114,
+  "sample_duration": 81442000,
+  "prompt_eval_count": 46,
+  "prompt_eval_duration": 1160282000,
+  "eval_count": 113,
+  "eval_duration": 1325948000
+}
+```
+
+#### Request (With History)
+Send a chat message with a conversation history.
+
+```shell
+curl http://localhost:11434/api/generate -d '{
+  "model": "llama2",
+  "messages": [
+    {
+      "role": "user",
+      "content": "why is the sky blue?"
+    },
+    {
+      "role": "assistant",
+      "content": "due to rayleigh scattering."
+    },
+    {
+      "role": "user",
+      "content": "how is that different than mie scattering?"
+    }
+  ]
+}'
+```
+
+#### Response
+
+A stream of JSON objects is returned:
+
+```json
+{
+  "model": "llama2",
+  "created_at": "2023-08-04T08:52:19.385406455-07:00",
+  "message": {
+    "role": "assisant",
+    "content": "The"
+  },
+  "done": false
+}
+```
+
+Final response:
+
+```json
+{
+  "model": "llama2",
+  "created_at": "2023-08-04T19:22:45.499127Z",
+  "done": true,
+  "total_duration": 5589157167,
+  "load_duration": 3013701500,
+  "sample_count": 114,
+  "sample_duration": 81442000,
+  "prompt_eval_count": 46,
+  "prompt_eval_duration": 1160282000,
+  "eval_count": 113,
+  "eval_duration": 1325948000
+}
+```
+
 ## Create a Model
 ## Create a Model
 
 
 ```shell
 ```shell

+ 34 - 24
llm/llama.go

@@ -531,21 +531,31 @@ type prediction struct {
 
 
 const maxBufferSize = 512 * format.KiloByte
 const maxBufferSize = 512 * format.KiloByte
 
 
-func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error {
-	prevConvo, err := llm.Decode(ctx, prevContext)
-	if err != nil {
-		return err
-	}
-
-	// Remove leading spaces from prevConvo if present
-	prevConvo = strings.TrimPrefix(prevConvo, " ")
-
-	var nextContext strings.Builder
-	nextContext.WriteString(prevConvo)
-	nextContext.WriteString(prompt)
-
+type PredictRequest struct {
+	Model            string
+	Prompt           string
+	Format           string
+	CheckpointStart  time.Time
+	CheckpointLoaded time.Time
+}
+
+type PredictResponse struct {
+	Model              string
+	CreatedAt          time.Time
+	TotalDuration      time.Duration
+	LoadDuration       time.Duration
+	Content            string
+	Done               bool
+	PromptEvalCount    int
+	PromptEvalDuration time.Duration
+	EvalCount          int
+	EvalDuration       time.Duration
+	Context            []int
+}
+
+func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(PredictResponse)) error {
 	request := map[string]any{
 	request := map[string]any{
-		"prompt":            nextContext.String(),
+		"prompt":            predict.Prompt,
 		"stream":            true,
 		"stream":            true,
 		"n_predict":         llm.NumPredict,
 		"n_predict":         llm.NumPredict,
 		"n_keep":            llm.NumKeep,
 		"n_keep":            llm.NumKeep,
@@ -567,7 +577,7 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
 		"stop":              llm.Stop,
 		"stop":              llm.Stop,
 	}
 	}
 
 
-	if format == "json" {
+	if predict.Format == "json" {
 		request["grammar"] = jsonGrammar
 		request["grammar"] = jsonGrammar
 	}
 	}
 
 
@@ -624,25 +634,25 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
 				}
 				}
 
 
 				if p.Content != "" {
 				if p.Content != "" {
-					fn(api.GenerateResponse{Response: p.Content})
-					nextContext.WriteString(p.Content)
+					fn(PredictResponse{
+						Model:     predict.Model,
+						CreatedAt: time.Now().UTC(),
+						Content:   p.Content,
+					})
 				}
 				}
 
 
 				if p.Stop {
 				if p.Stop {
-					embd, err := llm.Encode(ctx, nextContext.String())
-					if err != nil {
-						return fmt.Errorf("encoding context: %v", err)
-					}
+					fn(PredictResponse{
+						Model:         predict.Model,
+						CreatedAt:     time.Now().UTC(),
+						TotalDuration: time.Since(predict.CheckpointStart),
 
 
-					fn(api.GenerateResponse{
 						Done:               true,
 						Done:               true,
-						Context:            embd,
 						PromptEvalCount:    p.Timings.PromptN,
 						PromptEvalCount:    p.Timings.PromptN,
 						PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
 						PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
 						EvalCount:          p.Timings.PredictedN,
 						EvalCount:          p.Timings.PredictedN,
 						EvalDuration:       parseDurationMs(p.Timings.PredictedMS),
 						EvalDuration:       parseDurationMs(p.Timings.PredictedMS),
 					})
 					})
-
 					return nil
 					return nil
 				}
 				}
 			}
 			}

+ 1 - 1
llm/llm.go

@@ -14,7 +14,7 @@ import (
 )
 )
 
 
 type LLM interface {
 type LLM interface {
-	Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error
+	Predict(context.Context, PredictRequest, func(PredictResponse)) error
 	Embedding(context.Context, string) ([]float64, error)
 	Embedding(context.Context, string) ([]float64, error)
 	Encode(context.Context, string) ([]int, error)
 	Encode(context.Context, string) ([]int, error)
 	Decode(context.Context, []int) (string, error)
 	Decode(context.Context, []int) (string, error)

+ 64 - 19
server/images.go

@@ -47,37 +47,82 @@ type Model struct {
 	Options       map[string]interface{}
 	Options       map[string]interface{}
 }
 }
 
 
-func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
-	t := m.Template
-	if request.Template != "" {
-		t = request.Template
-	}
+type PromptVars struct {
+	System   string
+	Prompt   string
+	Response string
+}
 
 
-	tmpl, err := template.New("").Parse(t)
+func (m *Model) Prompt(p PromptVars) (string, error) {
+	var prompt strings.Builder
+	tmpl, err := template.New("").Parse(m.Template)
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
 
 
-	var vars struct {
-		First  bool
-		System string
-		Prompt string
+	if p.System == "" {
+		// use the default system prompt for this model if one is not specified
+		p.System = m.System
+	}
+
+	var sb strings.Builder
+	if err := tmpl.Execute(&sb, p); err != nil {
+		return "", err
 	}
 	}
+	prompt.WriteString(sb.String())
+	prompt.WriteString(p.Response)
+	return prompt.String(), nil
+}
 
 
-	vars.First = len(request.Context) == 0
-	vars.System = m.System
-	vars.Prompt = request.Prompt
+func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
+	// build the prompt from the list of messages
+	var prompt strings.Builder
+	currentVars := PromptVars{}
 
 
-	if request.System != "" {
-		vars.System = request.System
+	writePrompt := func() error {
+		p, err := m.Prompt(currentVars)
+		if err != nil {
+			return err
+		}
+		prompt.WriteString(p)
+		currentVars = PromptVars{}
+		return nil
 	}
 	}
 
 
-	var sb strings.Builder
-	if err := tmpl.Execute(&sb, vars); err != nil {
-		return "", err
+	for _, msg := range msgs {
+		switch msg.Role {
+		case "system":
+			if currentVars.Prompt != "" || currentVars.System != "" {
+				if err := writePrompt(); err != nil {
+					return "", err
+				}
+			}
+			currentVars.System = msg.Content
+		case "user":
+			if currentVars.Prompt != "" || currentVars.System != "" {
+				if err := writePrompt(); err != nil {
+					return "", err
+				}
+			}
+			currentVars.Prompt = msg.Content
+		case "assistant":
+			currentVars.Response = msg.Content
+			if err := writePrompt(); err != nil {
+				return "", err
+			}
+		default:
+			return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
+		}
+	}
+
+	// Append the last set of vars if they are non-empty
+	if currentVars.Prompt != "" || currentVars.System != "" {
+		if err := writePrompt(); err != nil {
+			return "", err
+		}
 	}
 	}
 
 
-	return sb.String(), nil
+	return prompt.String(), nil
 }
 }
 
 
 type ManifestV2 struct {
 type ManifestV2 struct {

+ 4 - 6
server/images_test.go

@@ -2,17 +2,15 @@ package server
 
 
 import (
 import (
 	"testing"
 	"testing"
-
-	"github.com/jmorganca/ollama/api"
 )
 )
 
 
 func TestModelPrompt(t *testing.T) {
 func TestModelPrompt(t *testing.T) {
-	var m Model
-	req := api.GenerateRequest{
+	m := Model{
 		Template: "a{{ .Prompt }}b",
 		Template: "a{{ .Prompt }}b",
-		Prompt:   "<h1>",
 	}
 	}
-	s, err := m.Prompt(req)
+	s, err := m.Prompt(PromptVars{
+		Prompt: "<h1>",
+	})
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}

+ 240 - 53
server/routes.go

@@ -60,17 +60,26 @@ var loaded struct {
 var defaultSessionDuration = 5 * time.Minute
 var defaultSessionDuration = 5 * time.Minute
 
 
 // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
 // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
-func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
+func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) {
+	model, err := GetModel(modelName)
+	if err != nil {
+		return nil, err
+	}
+
+	workDir := c.GetString("workDir")
+
 	opts := api.DefaultOptions()
 	opts := api.DefaultOptions()
 	if err := opts.FromMap(model.Options); err != nil {
 	if err := opts.FromMap(model.Options); err != nil {
 		log.Printf("could not load model options: %v", err)
 		log.Printf("could not load model options: %v", err)
-		return err
+		return nil, err
 	}
 	}
 
 
 	if err := opts.FromMap(reqOpts); err != nil {
 	if err := opts.FromMap(reqOpts); err != nil {
-		return err
+		return nil, err
 	}
 	}
 
 
+	ctx := c.Request.Context()
+
 	// check if the loaded model is still running in a subprocess, in case something unexpected happened
 	// check if the loaded model is still running in a subprocess, in case something unexpected happened
 	if loaded.runner != nil {
 	if loaded.runner != nil {
 		if err := loaded.runner.Ping(ctx); err != nil {
 		if err := loaded.runner.Ping(ctx); err != nil {
@@ -106,7 +115,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
 				err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
 				err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
 			}
 			}
 
 
-			return err
+			return nil, err
 		}
 		}
 
 
 		loaded.Model = model
 		loaded.Model = model
@@ -140,7 +149,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
 	}
 	}
 
 
 	loaded.expireTimer.Reset(sessionDuration)
 	loaded.expireTimer.Reset(sessionDuration)
-	return nil
+	return model, nil
 }
 }
 
 
 func GenerateHandler(c *gin.Context) {
 func GenerateHandler(c *gin.Context) {
@@ -173,99 +182,149 @@ func GenerateHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	model, err := GetModel(req.Model)
+	sessionDuration := defaultSessionDuration
+	model, err := load(c, req.Model, req.Options, sessionDuration)
 	if err != nil {
 	if err != nil {
 		var pErr *fs.PathError
 		var pErr *fs.PathError
-		if errors.As(err, &pErr) {
+		switch {
+		case errors.As(err, &pErr):
 			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
 			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
-			return
+		case errors.Is(err, api.ErrInvalidOpts):
+			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		default:
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		}
 		}
-		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
 
 
-	workDir := c.GetString("workDir")
-
-	// TODO: set this duration from the request if specified
-	sessionDuration := defaultSessionDuration
-	if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil {
-		if errors.Is(err, api.ErrInvalidOpts) {
-			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-			return
-		}
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+	// an empty request loads the model
+	if req.Prompt == "" && req.Template == "" && req.System == "" {
+		c.JSON(http.StatusOK, api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
 		return
 		return
 	}
 	}
 
 
 	checkpointLoaded := time.Now()
 	checkpointLoaded := time.Now()
 
 
-	prompt := req.Prompt
-	if !req.Raw {
-		prompt, err = model.Prompt(req)
+	var prompt string
+	sendContext := false
+	switch {
+	case req.Raw:
+		prompt = req.Prompt
+	case req.Prompt != "":
+		if req.Template != "" {
+			// override the default model template
+			model.Template = req.Template
+		}
+
+		var rebuild strings.Builder
+		if req.Context != nil {
+			// TODO: context is deprecated, at some point the context logic within this conditional should be removed
+			prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
+			if err != nil {
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+				return
+			}
+
+			// Remove leading spaces from prevCtx if present
+			prevCtx = strings.TrimPrefix(prevCtx, " ")
+			rebuild.WriteString(prevCtx)
+		}
+		p, err := model.Prompt(PromptVars{
+			System: req.System,
+			Prompt: req.Prompt,
+		})
 		if err != nil {
 		if err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			return
 			return
 		}
 		}
+		rebuild.WriteString(p)
+		prompt = rebuild.String()
+		sendContext = true
 	}
 	}
 
 
 	ch := make(chan any)
 	ch := make(chan any)
+	var generated strings.Builder
 	go func() {
 	go func() {
 		defer close(ch)
 		defer close(ch)
-		// an empty request loads the model
-		if req.Prompt == "" && req.Template == "" && req.System == "" {
-			ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
-			return
-		}
 
 
-		fn := func(r api.GenerateResponse) {
+		fn := func(r llm.PredictResponse) {
+			// Update model expiration
 			loaded.expireAt = time.Now().Add(sessionDuration)
 			loaded.expireAt = time.Now().Add(sessionDuration)
 			loaded.expireTimer.Reset(sessionDuration)
 			loaded.expireTimer.Reset(sessionDuration)
 
 
-			r.Model = req.Model
-			r.CreatedAt = time.Now().UTC()
-			if r.Done {
-				r.TotalDuration = time.Since(checkpointStart)
-				r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
+			// Build up the full response
+			if _, err := generated.WriteString(r.Content); err != nil {
+				ch <- gin.H{"error": err.Error()}
+				return
 			}
 			}
 
 
-			if req.Raw {
-				// in raw mode the client must manage history on their own
-				r.Context = nil
+			resp := api.GenerateResponse{
+				Model:     r.Model,
+				CreatedAt: r.CreatedAt,
+				Done:      r.Done,
+				Response:  r.Content,
+				EvalMetrics: api.EvalMetrics{
+					TotalDuration:      r.TotalDuration,
+					LoadDuration:       r.LoadDuration,
+					PromptEvalCount:    r.PromptEvalCount,
+					PromptEvalDuration: r.PromptEvalDuration,
+					EvalCount:          r.EvalCount,
+					EvalDuration:       r.EvalDuration,
+				},
 			}
 			}
 
 
-			ch <- r
+			if r.Done && sendContext {
+				embd, err := loaded.runner.Encode(c.Request.Context(), req.Prompt+generated.String())
+				if err != nil {
+					ch <- gin.H{"error": err.Error()}
+					return
+				}
+				r.Context = embd
+			}
+
+			ch <- resp
 		}
 		}
 
 
-		if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil {
+		// Start prediction
+		predictReq := llm.PredictRequest{
+			Model:            model.Name,
+			Prompt:           prompt,
+			Format:           req.Format,
+			CheckpointStart:  checkpointStart,
+			CheckpointLoaded: checkpointLoaded,
+		}
+		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 			ch <- gin.H{"error": err.Error()}
 		}
 		}
 	}()
 	}()
 
 
 	if req.Stream != nil && !*req.Stream {
 	if req.Stream != nil && !*req.Stream {
-		var response api.GenerateResponse
-		generated := ""
+		// Wait for the channel to close
+		var r api.GenerateResponse
+		var sb strings.Builder
 		for resp := range ch {
 		for resp := range ch {
-			if r, ok := resp.(api.GenerateResponse); ok {
-				generated += r.Response
-				response = r
-			} else {
+			var ok bool
+			if r, ok = resp.(api.GenerateResponse); !ok {
 				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 				return
 				return
 			}
 			}
+			sb.WriteString(r.Response)
 		}
 		}
-		response.Response = generated
-		c.JSON(http.StatusOK, response)
+		r.Response = sb.String()
+		c.JSON(http.StatusOK, r)
 		return
 		return
 	}
 	}
 
 
 	streamResponse(c, ch)
 	streamResponse(c, ch)
 }
 }
 
 
-func EmbeddingHandler(c *gin.Context) {
+func ChatHandler(c *gin.Context) {
 	loaded.mu.Lock()
 	loaded.mu.Lock()
 	defer loaded.mu.Unlock()
 	defer loaded.mu.Unlock()
 
 
-	var req api.EmbeddingRequest
+	checkpointStart := time.Now()
+
+	var req api.ChatRequest
 	err := c.ShouldBindJSON(&req)
 	err := c.ShouldBindJSON(&req)
 	switch {
 	switch {
 	case errors.Is(err, io.EOF):
 	case errors.Is(err, io.EOF):
@@ -276,23 +335,150 @@ func EmbeddingHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	if req.Model == "" {
+	// validate the request
+	switch {
+	case req.Model == "":
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
 		return
 		return
+	case len(req.Format) > 0 && req.Format != "json":
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
+		return
 	}
 	}
 
 
-	model, err := GetModel(req.Model)
+	sessionDuration := defaultSessionDuration
+	model, err := load(c, req.Model, req.Options, sessionDuration)
 	if err != nil {
 	if err != nil {
-		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		var pErr *fs.PathError
+		switch {
+		case errors.As(err, &pErr):
+			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
+		case errors.Is(err, api.ErrInvalidOpts):
+			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		default:
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		}
 		return
 		return
 	}
 	}
 
 
-	workDir := c.GetString("workDir")
-	if err := load(c.Request.Context(), workDir, model, req.Options, 5*time.Minute); err != nil {
+	// an empty request loads the model
+	if len(req.Messages) == 0 {
+		c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
+		return
+	}
+
+	checkpointLoaded := time.Now()
+
+	if req.Template != "" {
+		// override the default model template
+		model.Template = req.Template
+	}
+	prompt, err := model.ChatPrompt(req.Messages)
+	if err != nil {
 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
 
 
+	ch := make(chan any)
+
+	go func() {
+		defer close(ch)
+
+		fn := func(r llm.PredictResponse) {
+			// Update model expiration
+			loaded.expireAt = time.Now().Add(sessionDuration)
+			loaded.expireTimer.Reset(sessionDuration)
+
+			resp := api.ChatResponse{
+				Model:     r.Model,
+				CreatedAt: r.CreatedAt,
+				Done:      r.Done,
+				EvalMetrics: api.EvalMetrics{
+					TotalDuration:      r.TotalDuration,
+					LoadDuration:       r.LoadDuration,
+					PromptEvalCount:    r.PromptEvalCount,
+					PromptEvalDuration: r.PromptEvalDuration,
+					EvalCount:          r.EvalCount,
+					EvalDuration:       r.EvalDuration,
+				},
+			}
+
+			if !r.Done {
+				resp.Message = &api.Message{Role: "assistant", Content: r.Content}
+			}
+
+			ch <- resp
+		}
+
+		// Start prediction
+		predictReq := llm.PredictRequest{
+			Model:            model.Name,
+			Prompt:           prompt,
+			Format:           req.Format,
+			CheckpointStart:  checkpointStart,
+			CheckpointLoaded: checkpointLoaded,
+		}
+		if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
+			ch <- gin.H{"error": err.Error()}
+		}
+	}()
+
+	if req.Stream != nil && !*req.Stream {
+		// Wait for the channel to close
+		var r api.ChatResponse
+		var sb strings.Builder
+		for resp := range ch {
+			var ok bool
+			if r, ok = resp.(api.ChatResponse); !ok {
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+				return
+			}
+			if r.Message != nil {
+				sb.WriteString(r.Message.Content)
+			}
+		}
+		r.Message = &api.Message{Role: "assistant", Content: sb.String()}
+		c.JSON(http.StatusOK, r)
+		return
+	}
+
+	streamResponse(c, ch)
+}
+
+func EmbeddingHandler(c *gin.Context) {
+	loaded.mu.Lock()
+	defer loaded.mu.Unlock()
+
+	var req api.EmbeddingRequest
+	err := c.ShouldBindJSON(&req)
+	switch {
+	case errors.Is(err, io.EOF):
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	case err != nil:
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	if req.Model == "" {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
+		return
+	}
+
+	sessionDuration := defaultSessionDuration
+	_, err = load(c, req.Model, req.Options, sessionDuration)
+	if err != nil {
+		var pErr *fs.PathError
+		switch {
+		case errors.As(err, &pErr):
+			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
+		case errors.Is(err, api.ErrInvalidOpts):
+			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		default:
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		}
+		return
+	}
+
 	if !loaded.Options.EmbeddingOnly {
 	if !loaded.Options.EmbeddingOnly {
 		c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"})
 		c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"})
 		return
 		return
@@ -767,6 +953,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
 
 
 	r.POST("/api/pull", PullModelHandler)
 	r.POST("/api/pull", PullModelHandler)
 	r.POST("/api/generate", GenerateHandler)
 	r.POST("/api/generate", GenerateHandler)
+	r.POST("/api/chat", ChatHandler)
 	r.POST("/api/embeddings", EmbeddingHandler)
 	r.POST("/api/embeddings", EmbeddingHandler)
 	r.POST("/api/create", CreateModelHandler)
 	r.POST("/api/create", CreateModelHandler)
 	r.POST("/api/push", PushModelHandler)
 	r.POST("/api/push", PushModelHandler)