Browse Source

Merge branch 'main' into royh-precision

royjhan 9 months ago
parent
commit
3971c2333f
100 changed files with 3439 additions and 1441 deletions
  1. 1 1
      .github/workflows/release.yaml
  2. 3 1
      .github/workflows/test.yaml
  3. 2 2
      Dockerfile
  4. 9 2
      README.md
  5. 42 57
      api/types.go
  6. 71 4
      api/types_test.go
  7. 4 0
      app/ollama.iss
  8. 56 48
      cmd/cmd.go
  9. 14 47
      cmd/interactive.go
  10. 1 1
      docs/api.md
  11. 1 1
      docs/development.md
  12. 16 0
      docs/faq.md
  13. 1 1
      docs/gpu.md
  14. 1 1
      docs/openai.md
  15. 14 5
      docs/troubleshooting.md
  16. 1 1
      docs/windows.md
  17. 39 12
      envconfig/config.go
  18. 17 0
      envconfig/config_test.go
  19. 2 1
      go.mod
  20. 11 12
      gpu/amd_common.go
  21. 2 3
      gpu/amd_hip_windows.go
  22. 12 14
      gpu/amd_windows.go
  23. 19 12
      gpu/assets.go
  24. 51 2
      gpu/gpu.go
  25. 2 1
      gpu/gpu_darwin.go
  26. 1 0
      gpu/gpu_info_darwin.h
  27. 24 2
      gpu/gpu_info_darwin.m
  28. 16 15
      gpu/gpu_info_nvcuda.c
  29. 5 1
      gpu/gpu_info_nvcuda.h
  30. 9 8
      gpu/gpu_linux.go
  31. 1 1
      gpu/gpu_windows.go
  32. 8 1
      gpu/types.go
  33. 12 13
      llm/ext_server/CMakeLists.txt
  34. 49 10
      llm/ext_server/server.cpp
  35. 8 8
      llm/generate/gen_darwin.sh
  36. 19 19
      llm/generate/gen_linux.sh
  37. 24 33
      llm/generate/gen_windows.ps1
  38. 11 2
      llm/ggla.go
  39. 89 17
      llm/ggml.go
  40. 1 0
      llm/ggml_test.go
  41. 92 38
      llm/gguf.go
  42. 1 1
      llm/llama.cpp
  43. 9 8
      llm/llm.go
  44. 2 2
      llm/memory.go
  45. 11 8
      llm/memory_test.go
  46. 7 7
      llm/patches/01-load-progress.diff
  47. 6 18
      llm/patches/03-load_exception.diff
  48. 3 3
      llm/patches/04-metal.diff
  49. 11 11
      llm/patches/05-default-pretokenizer.diff
  50. 3 3
      llm/patches/06-qwen2.diff
  51. 45 0
      llm/patches/07-embeddings.diff
  52. 42 0
      llm/patches/08-clip-unicode.diff
  53. 60 0
      llm/patches/09-pooling.diff
  54. 23 21
      llm/payload.go
  55. 60 34
      llm/server.go
  56. 1 0
      llm/status.go
  57. 375 13
      openai/openai.go
  58. 271 0
      openai/openai_test.go
  59. 2 2
      parser/parser.go
  60. 67 3
      parser/parser_test.go
  61. 4 1
      scripts/build_windows.ps1
  62. 1 1
      scripts/install.sh
  63. 11 0
      scripts/rh_linux_deps.sh
  64. 67 32
      server/images.go
  65. 11 9
      server/manifest.go
  66. 1 1
      server/manifest_test.go
  67. 41 24
      server/model.go
  68. 112 0
      server/model_test.go
  69. 3 18
      server/modelpath.go
  70. 54 192
      server/prompt.go
  71. 159 159
      server/prompt_test.go
  72. 188 390
      server/routes.go
  73. 2 2
      server/routes_create_test.go
  74. 56 0
      server/routes_test.go
  75. 108 29
      server/sched.go
  76. 76 42
      server/sched_test.go
  77. 0 0
      template/alfred.gotmpl
  78. 2 1
      template/alpaca.gotmpl
  79. 1 1
      template/chatml.gotmpl
  80. 2 1
      template/chatqa.gotmpl
  81. 10 0
      template/codellama-70b-instruct.gotmpl
  82. 5 0
      template/falcon-instruct.gotmpl
  83. 5 0
      template/gemma-instruct.gotmpl
  84. 3 3
      template/granite-instruct.gotmpl
  85. 0 0
      template/index.json
  86. 6 0
      template/llama2-chat.gotmpl
  87. 0 0
      template/llama3-instruct.gotmpl
  88. 2 1
      template/magicoder.gotmpl
  89. 3 0
      template/mistral-instruct.gotmpl
  90. 1 0
      template/openchat.gotmpl
  91. 1 1
      template/phi-3.gotmpl
  92. 2 1
      template/solar-instruct.gotmpl
  93. 0 1
      template/starcoder2-instruct.gotmpl
  94. 362 0
      template/template.go
  95. 361 0
      template/template_test.go
  96. 1 0
      template/testdata/alfred.gotmpl/system-user-assistant-user
  97. 1 0
      template/testdata/alfred.gotmpl/user
  98. 1 0
      template/testdata/alfred.gotmpl/user-assistant-user
  99. 12 0
      template/testdata/alpaca.gotmpl/system-user-assistant-user
  100. 4 0
      template/testdata/alpaca.gotmpl/user

+ 1 - 1
.github/workflows/release.yaml

@@ -147,7 +147,7 @@ jobs:
         run: |
           $ErrorActionPreference = "Stop"
           write-host "downloading AMD HIP Installer"
-          Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
+          Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
           write-host "Installing AMD HIP"
           Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
           write-host "Completed AMD HIP"

+ 3 - 1
.github/workflows/test.yaml

@@ -58,6 +58,7 @@ jobs:
     runs-on: ${{ matrix.os }}
     env:
       GOARCH: ${{ matrix.arch }}
+      CGO_ENABLED: '1'
     steps:
       - uses: actions/checkout@v4
       - uses: actions/setup-go@v5
@@ -79,6 +80,7 @@ jobs:
       - run: go generate -x ./...
         if: ${{ ! startsWith(matrix.os, 'windows-') }}
         name: 'Unix Go Generate'
+      - run: go build .
       - uses: actions/upload-artifact@v4
         with:
           name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
@@ -167,7 +169,7 @@ jobs:
         run: |
           $ErrorActionPreference = "Stop"
           write-host "downloading AMD HIP Installer"
-          Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
+          Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
           write-host "Installing AMD HIP"
           Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
           write-host "Completed AMD HIP"

+ 2 - 2
Dockerfile

@@ -70,12 +70,12 @@ RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
 RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
 
-FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
+FROM --platform=linux/arm64 rockylinux:8 AS cpu-builder-arm64
 ARG CMAKE_VERSION
 ARG GOLANG_VERSION
 COPY ./scripts/rh_linux_deps.sh /
 RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
-ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
+ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
 COPY --from=llm-code / /go/src/github.com/ollama/ollama/
 ARG OLLAMA_CUSTOM_CPU_DEFS
 ARG CGO_CFLAGS

+ 9 - 2
README.md

@@ -53,8 +53,8 @@ Here are some example models that can be downloaded:
 | Llama 3            | 70B        | 40GB  | `ollama run llama3:70b`        |
 | Phi 3 Mini         | 3.8B       | 2.3GB | `ollama run phi3`              |
 | Phi 3 Medium       | 14B        | 7.9GB | `ollama run phi3:medium`       |
-| Gemma              | 2B         | 1.4GB | `ollama run gemma:2b`          |
-| Gemma              | 7B         | 4.8GB | `ollama run gemma:7b`          |
+| Gemma 2            | 9B         | 5.5GB | `ollama run gemma2`            |
+| Gemma 2            | 27B        | 16GB  | `ollama run gemma2:27b`        |
 | Mistral            | 7B         | 4.1GB | `ollama run mistral`           |
 | Moondream 2        | 1.4B       | 829MB | `ollama run moondream`         |
 | Neural Chat        | 7B         | 4.1GB | `ollama run neural-chat`       |
@@ -182,6 +182,12 @@ $ ollama run llama3 "Summarize this file: $(cat README.md)"
  Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
 ```
 
+### Show model information
+
+```
+ollama show llama3
+```
+
 ### List models on your computer
 
 ```
@@ -286,6 +292,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
 - [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
 - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
 - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
+- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
 
 ### Terminal
 

+ 42 - 57
api/types.go

@@ -159,49 +159,18 @@ type Options struct {
 
 // Runner options which must be set when the model is loaded into memory
 type Runner struct {
-	UseNUMA   bool     `json:"numa,omitempty"`
-	NumCtx    int      `json:"num_ctx,omitempty"`
-	NumBatch  int      `json:"num_batch,omitempty"`
-	NumGPU    int      `json:"num_gpu,omitempty"`
-	MainGPU   int      `json:"main_gpu,omitempty"`
-	LowVRAM   bool     `json:"low_vram,omitempty"`
-	F16KV     bool     `json:"f16_kv,omitempty"`
-	LogitsAll bool     `json:"logits_all,omitempty"`
-	VocabOnly bool     `json:"vocab_only,omitempty"`
-	UseMMap   TriState `json:"use_mmap,omitempty"`
-	UseMLock  bool     `json:"use_mlock,omitempty"`
-	NumThread int      `json:"num_thread,omitempty"`
-}
-
-type TriState int
-
-const (
-	TriStateUndefined TriState = -1
-	TriStateFalse     TriState = 0
-	TriStateTrue      TriState = 1
-)
-
-func (b *TriState) UnmarshalJSON(data []byte) error {
-	var v bool
-	if err := json.Unmarshal(data, &v); err != nil {
-		return err
-	}
-	if v {
-		*b = TriStateTrue
-	}
-	*b = TriStateFalse
-	return nil
-}
-
-func (b *TriState) MarshalJSON() ([]byte, error) {
-	if *b == TriStateUndefined {
-		return nil, nil
-	}
-	var v bool
-	if *b == TriStateTrue {
-		v = true
-	}
-	return json.Marshal(v)
+	UseNUMA   bool  `json:"numa,omitempty"`
+	NumCtx    int   `json:"num_ctx,omitempty"`
+	NumBatch  int   `json:"num_batch,omitempty"`
+	NumGPU    int   `json:"num_gpu,omitempty"`
+	MainGPU   int   `json:"main_gpu,omitempty"`
+	LowVRAM   bool  `json:"low_vram,omitempty"`
+	F16KV     bool  `json:"f16_kv,omitempty"`
+	LogitsAll bool  `json:"logits_all,omitempty"`
+	VocabOnly bool  `json:"vocab_only,omitempty"`
+	UseMMap   *bool `json:"use_mmap,omitempty"`
+	UseMLock  bool  `json:"use_mlock,omitempty"`
+	NumThread int   `json:"num_thread,omitempty"`
 }
 
 // EmbeddingRequest is the request passed to [Client.Embeddings].
@@ -345,6 +314,13 @@ type ProcessModelResponse struct {
 	SizeVRAM  int64        `json:"size_vram"`
 }
 
+type RetrieveModelResponse struct {
+	Id      string `json:"id"`
+	Object  string `json:"object"`
+	Created int64  `json:"created"`
+	OwnedBy string `json:"owned_by"`
+}
+
 type TokenResponse struct {
 	Token string `json:"token"`
 }
@@ -437,19 +413,6 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
 				continue
 			}
 
-			if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) {
-				val, ok := val.(bool)
-				if !ok {
-					return fmt.Errorf("option %q must be of type boolean", key)
-				}
-				if val {
-					field.SetInt(int64(TriStateTrue))
-				} else {
-					field.SetInt(int64(TriStateFalse))
-				}
-				continue
-			}
-
 			switch field.Kind() {
 			case reflect.Int:
 				switch t := val.(type) {
@@ -496,6 +459,17 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
 					slice[i] = str
 				}
 				field.Set(reflect.ValueOf(slice))
+			case reflect.Pointer:
+				var b bool
+				if field.Type() == reflect.TypeOf(&b) {
+					val, ok := val.(bool)
+					if !ok {
+						return fmt.Errorf("option %q must be of type boolean", key)
+					}
+					field.Set(reflect.ValueOf(&val))
+				} else {
+					return fmt.Errorf("unknown type loading config params: %v %v", field.Kind(), field.Type())
+				}
 			default:
 				return fmt.Errorf("unknown type loading config params: %v", field.Kind())
 			}
@@ -538,7 +512,7 @@ func DefaultOptions() Options {
 			LowVRAM:   false,
 			F16KV:     true,
 			UseMLock:  false,
-			UseMMap:   TriStateUndefined,
+			UseMMap:   nil,
 			UseNUMA:   false,
 		},
 	}
@@ -635,6 +609,17 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
 				case reflect.Slice:
 					// TODO: only string slices are supported right now
 					out[key] = vals
+				case reflect.Pointer:
+					var b bool
+					if field.Type() == reflect.TypeOf(&b) {
+						boolVal, err := strconv.ParseBool(vals[0])
+						if err != nil {
+							return nil, fmt.Errorf("invalid bool value %s", vals)
+						}
+						out[key] = &boolVal
+					} else {
+						return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
+					}
 				default:
 					return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
 				}

+ 71 - 4
api/types_test.go

@@ -2,6 +2,7 @@ package api
 
 import (
 	"encoding/json"
+	"fmt"
 	"math"
 	"testing"
 	"time"
@@ -107,25 +108,27 @@ func TestDurationMarshalUnmarshal(t *testing.T) {
 }
 
 func TestUseMmapParsingFromJSON(t *testing.T) {
+	tr := true
+	fa := false
 	tests := []struct {
 		name string
 		req  string
-		exp  TriState
+		exp  *bool
 	}{
 		{
 			name: "Undefined",
 			req:  `{ }`,
-			exp:  TriStateUndefined,
+			exp:  nil,
 		},
 		{
 			name: "True",
 			req:  `{ "use_mmap": true }`,
-			exp:  TriStateTrue,
+			exp:  &tr,
 		},
 		{
 			name: "False",
 			req:  `{ "use_mmap": false }`,
-			exp:  TriStateFalse,
+			exp:  &fa,
 		},
 	}
 
@@ -141,3 +144,67 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
 		})
 	}
 }
+
+func TestUseMmapFormatParams(t *testing.T) {
+	tr := true
+	fa := false
+	tests := []struct {
+		name string
+		req  map[string][]string
+		exp  *bool
+		err  error
+	}{
+		{
+			name: "True",
+			req: map[string][]string{
+				"use_mmap": {"true"},
+			},
+			exp: &tr,
+			err: nil,
+		},
+		{
+			name: "False",
+			req: map[string][]string{
+				"use_mmap": {"false"},
+			},
+			exp: &fa,
+			err: nil,
+		},
+		{
+			name: "Numeric True",
+			req: map[string][]string{
+				"use_mmap": {"1"},
+			},
+			exp: &tr,
+			err: nil,
+		},
+		{
+			name: "Numeric False",
+			req: map[string][]string{
+				"use_mmap": {"0"},
+			},
+			exp: &fa,
+			err: nil,
+		},
+		{
+			name: "invalid string",
+			req: map[string][]string{
+				"use_mmap": {"foo"},
+			},
+			exp: nil,
+			err: fmt.Errorf("invalid bool value [foo]"),
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			resp, err := FormatParams(test.req)
+			require.Equal(t, test.err, err)
+			respVal, ok := resp["use_mmap"]
+			if test.exp != nil {
+				assert.True(t, ok, "resp: %v", resp)
+				assert.Equal(t, *test.exp, *respVal.(*bool))
+			}
+		})
+	}
+}

+ 4 - 0
app/ollama.iss

@@ -127,6 +127,10 @@ Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\models"
 Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history"
 ; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved
 
+[InstallDelete]
+Type: filesandordirs; Name: "{%TEMP}\ollama*"
+Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama"
+
 [Messages]
 WizardReady=Ollama Windows Preview
 ReadyLabel1=%nLet's get you up and running with your own large language models.

+ 56 - 48
cmd/cmd.go

@@ -162,9 +162,6 @@ func tempZipFiles(path string) (string, error) {
 	}
 	defer tempfile.Close()
 
-	zipfile := zip.NewWriter(tempfile)
-	defer zipfile.Close()
-
 	detectContentType := func(path string) (string, error) {
 		f, err := os.Open(path)
 		if err != nil {
@@ -233,6 +230,9 @@ func tempZipFiles(path string) (string, error) {
 		files = append(files, tks...)
 	}
 
+	zipfile := zip.NewWriter(tempfile)
+	defer zipfile.Close()
+
 	for _, file := range files {
 		f, err := os.Open(file)
 		if err != nil {
@@ -287,38 +287,12 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
 }
 
 func RunHandler(cmd *cobra.Command, args []string) error {
-	client, err := api.ClientFromEnvironment()
-	if err != nil {
-		return err
-	}
-
-	name := args[0]
-
-	// check if the model exists on the server
-	show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name})
-	var statusError api.StatusError
-	switch {
-	case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
-		if err := PullHandler(cmd, []string{name}); err != nil {
-			return err
-		}
-
-		show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
-		if err != nil {
-			return err
-		}
-	case err != nil:
-		return err
-	}
-
 	interactive := true
 
 	opts := runOptions{
-		Model:       args[0],
-		WordWrap:    os.Getenv("TERM") == "xterm-256color",
-		Options:     map[string]interface{}{},
-		MultiModal:  slices.Contains(show.Details.Families, "clip"),
-		ParentModel: show.Details.ParentModel,
+		Model:    args[0],
+		WordWrap: os.Getenv("TERM") == "xterm-256color",
+		Options:  map[string]interface{}{},
 	}
 
 	format, err := cmd.Flags().GetString("format")
@@ -362,11 +336,38 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 	}
 	opts.WordWrap = !nowrap
 
-	if !interactive {
-		return generate(cmd, opts)
+	// Fill out the rest of the options based on information about the
+	// model.
+	client, err := api.ClientFromEnvironment()
+	if err != nil {
+		return err
+	}
+
+	name := args[0]
+	info, err := func() (*api.ShowResponse, error) {
+		showReq := &api.ShowRequest{Name: name}
+		info, err := client.Show(cmd.Context(), showReq)
+		var se api.StatusError
+		if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
+			if err := PullHandler(cmd, []string{name}); err != nil {
+				return nil, err
+			}
+			return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
+		}
+		return info, err
+	}()
+	if err != nil {
+		return err
 	}
 
-	return generateInteractive(cmd, opts)
+	opts.MultiModal = slices.Contains(info.Details.Families, "clip")
+	opts.ParentModel = info.Details.ParentModel
+	opts.Messages = append(opts.Messages, info.Messages...)
+
+	if interactive {
+		return generateInteractive(cmd, opts)
+	}
+	return generate(cmd, opts)
 }
 
 func errFromUnknownKey(unknownKeyErr error) error {
@@ -623,13 +624,13 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
 		return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
 	}
 
-	if flagsSet == 1 {
-		req := api.ShowRequest{Name: args[0]}
-		resp, err := client.Show(cmd.Context(), &req)
-		if err != nil {
-			return err
-		}
+	req := api.ShowRequest{Name: args[0]}
+	resp, err := client.Show(cmd.Context(), &req)
+	if err != nil {
+		return err
+	}
 
+	if flagsSet == 1 {
 		switch showType {
 		case "license":
 			fmt.Println(resp.License)
@@ -646,12 +647,12 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
 		return nil
 	}
 
-	req := api.ShowRequest{Name: args[0]}
-	resp, err := client.Show(cmd.Context(), &req)
-	if err != nil {
-		return err
-	}
+	showInfo(resp)
 
+	return nil
+}
+
+func showInfo(resp *api.ShowResponse) {
 	arch := resp.ModelInfo["general.architecture"].(string)
 
 	modelData := [][]string{
@@ -676,6 +677,15 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
 			{"projection dimensionality", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.projection_dim"].(float64))},
 		}
 
+		if projectorType, ok := resp.ProjectorInfo["clip.projector_type"]; ok {
+			projectorData = append(projectorData, []string{"projector type", projectorType.(string)})
+		}
+
+		projectorData = append(projectorData,
+			[]string{"embedding length", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.embedding_length"].(float64))},
+			[]string{"projection dimensionality", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.projection_dim"].(float64))},
+		)
+
 		mainTableData = append(mainTableData,
 			[]string{"Projector"},
 			[]string{renderSubTable(projectorData, false)},
@@ -704,8 +714,6 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
 	}
 
 	table.Render()
-
-	return nil
 }
 
 func renderSubTable(data [][]string, file bool) string {

+ 14 - 47
cmd/interactive.go

@@ -31,65 +31,40 @@ const (
 )
 
 func loadModel(cmd *cobra.Command, opts *runOptions) error {
-	client, err := api.ClientFromEnvironment()
-	if err != nil {
-		return err
-	}
-
 	p := progress.NewProgress(os.Stderr)
 	defer p.StopAndClear()
 
 	spinner := progress.NewSpinner("")
 	p.Add("", spinner)
 
-	showReq := api.ShowRequest{Name: opts.Model}
-	showResp, err := client.Show(cmd.Context(), &showReq)
+	client, err := api.ClientFromEnvironment()
 	if err != nil {
 		return err
 	}
-	opts.MultiModal = slices.Contains(showResp.Details.Families, "clip")
-	opts.ParentModel = showResp.Details.ParentModel
-
-	if len(showResp.Messages) > 0 {
-		opts.Messages = append(opts.Messages, showResp.Messages...)
-	}
 
 	chatReq := &api.ChatRequest{
-		Model:    opts.Model,
-		Messages: []api.Message{},
+		Model:     opts.Model,
+		KeepAlive: opts.KeepAlive,
 	}
 
-	if opts.KeepAlive != nil {
-		chatReq.KeepAlive = opts.KeepAlive
-	}
-
-	err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
+	return client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
 		p.StopAndClear()
-		if len(opts.Messages) > 0 {
-			for _, msg := range opts.Messages {
-				switch msg.Role {
-				case "user":
-					fmt.Printf(">>> %s\n", msg.Content)
-				case "assistant":
-					state := &displayResponseState{}
-					displayResponse(msg.Content, opts.WordWrap, state)
-					fmt.Println()
-					fmt.Println()
-				}
+		for _, msg := range opts.Messages {
+			switch msg.Role {
+			case "user":
+				fmt.Printf(">>> %s\n", msg.Content)
+			case "assistant":
+				state := &displayResponseState{}
+				displayResponse(msg.Content, opts.WordWrap, state)
+				fmt.Println()
+				fmt.Println()
 			}
 		}
 		return nil
 	})
-	if err != nil {
-		return err
-	}
-
-	return nil
 }
 
 func generateInteractive(cmd *cobra.Command, opts runOptions) error {
-	opts.Messages = make([]api.Message, 0)
-
 	err := loadModel(cmd, &opts)
 	if err != nil {
 		return err
@@ -429,15 +404,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 
 				switch args[1] {
 				case "info":
-					fmt.Println("Model details:")
-					if len(resp.Details.Families) > 0 {
-						fmt.Printf("Family              %s\n", strings.Join(resp.Details.Families, ", "))
-					} else if resp.Details.Family != "" {
-						fmt.Printf("Family              %s\n", resp.Details.Family)
-					}
-					fmt.Printf("Parameter Size      %s\n", resp.Details.ParameterSize)
-					fmt.Printf("Quantization Level  %s\n", resp.Details.QuantizationLevel)
-					fmt.Println("")
+					showInfo(resp)
 				case "license":
 					if resp.License == "" {
 						fmt.Println("No license was specified for this model.")

+ 1 - 1
docs/api.md

@@ -26,7 +26,7 @@ All durations are returned in nanoseconds.
 
 ### Streaming responses
 
-Certain endpoints stream responses as JSON objects and can optional return non-streamed responses.
+Certain endpoints stream responses as JSON objects. Streaming can be disabled by providing `{"stream": false}` for these endpoints.
 
 ## Generate a completion
 

+ 1 - 1
docs/development.md

@@ -104,7 +104,7 @@ like to use. For example, to compile an optimized binary for an Intel i9-9880H,
 you might use:
 
 ```
-OLLAMA_CUSTOM_CPU_DEFS="-DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_F16C=on -DLLAMA_FMA=on" go generate ./...
+OLLAMA_CUSTOM_CPU_DEFS="-DGGML_AVX=on -DGGML_AVX2=on -DGGML_F16C=on -DGGML_FMA=on" go generate ./...
 go build .
 ```
 

+ 16 - 0
docs/faq.md

@@ -257,3 +257,19 @@ If you wish to override the `OLLAMA_KEEP_ALIVE` setting, use the `keep_alive` AP
 ## How do I manage the maximum number of requests the Ollama server can queue?
 
 If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded.  You can adjust how many requests may be queue by setting `OLLAMA_MAX_QUEUE`.
+
+## How does Ollama handle concurrent requests?
+
+Ollama supports two levels of concurrent processing.  If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time.  For a given model, if there is sufficient available memory when the model is loaded, it is configured to allow parallel request processing.
+
+If there is insufficient available memory to load a new model request while one or more models are already loaded, all new requests will be queued until the new model can be loaded.  As prior models become idle, one or more will be unloaded to make room for the new model.  Queued requests will be processed in order.  When using GPU inference new models must be able to completely fit in VRAM to allow concurrent model loads.
+
+Parallel request processing for a given model results in increasing the context size by the number of parallel requests.  For example, a 2K context with 4 parallel requests will result in an 8K context and additional memory allocation.
+
+The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
+
+- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory.  The default is 3 * the number of GPUs or 3 for CPU inference.
+- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time.  The default will auto-select either 4 or 1 based on available memory.
+- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
+
+Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting.  Once ROCm v6.2 is available, Windows Radeon will follow the defaults above.  You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.

+ 1 - 1
docs/gpu.md

@@ -18,7 +18,7 @@ Check your compute compatibility to see if your card is supported:
 |                    | Quadro              | `RTX 8000` `RTX 6000` `RTX 5000` `RTX 4000`                                                                 |
 | 7.0                | NVIDIA              | `TITAN V` `V100` `Quadro GV100`                                                                             |
 | 6.1                | NVIDIA TITAN        | `TITAN Xp` `TITAN X`                                                                                        |
-|                    | GeForce GTX         | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050`                                     |
+|                    | GeForce GTX         | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050 Ti` `GTX 1050`                       |
 |                    | Quadro              | `P6000` `P5200` `P4200` `P3200` `P5000` `P4000` `P3000` `P2200` `P2000` `P1000` `P620` `P600` `P500` `P520` |
 |                    | Tesla               | `P40` `P4`                                                                                                  |
 | 6.0                | NVIDIA              | `Tesla P100` `Quadro GP100`                                                                                 |

+ 1 - 1
docs/openai.md

@@ -65,6 +65,7 @@ curl http://localhost:11434/v1/chat/completions \
             }
         ]
     }'
+
 ```
 
 ## Endpoints
@@ -104,7 +105,6 @@ curl http://localhost:11434/v1/chat/completions \
 
 #### Notes
 
-- `finish_reason` will always be `stop`
 - `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
 
 ## Models

+ 14 - 5
docs/troubleshooting.md

@@ -70,14 +70,18 @@ curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION="0.1.29" sh
 
 If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
 
-## Container fails to run on NVIDIA GPU
+## NVIDIA GPU Discovery
 
-Make sure you've set up the container runtime first as described in [docker.md](./docker.md)
+When Ollama starts up, it takes inventory of the GPUs present in the system to determine compatibility and how much VRAM is available.  Sometimes this discovery can fail to find your GPUs.  In general, running the latest driver will yield the best results.
 
-Sometimes the container runtime can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
+### Linux NVIDIA Troubleshooting
 
-- Is the container runtime working?  Try `docker run --gpus all ubuntu nvidia-smi` - if this doesn't work, Ollama wont be able to see your NVIDIA GPU.
-- Is the uvm driver not loaded? `sudo nvidia-modprobe -u`
+If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker.md](./docker.md)
+
+Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
+
+- If you are using a container, is the container runtime working?  Try `docker run --gpus all ubuntu nvidia-smi` - if this doesn't work, Ollama wont be able to see your NVIDIA GPU.
+- Is the uvm driver loaded? `sudo nvidia-modprobe -u`
 - Try reloading the nvidia_uvm driver - `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm`
 - Try rebooting
 - Make sure you're running the latest nvidia drivers
@@ -85,3 +89,8 @@ Sometimes the container runtime can have difficulties initializing the GPU. When
 If none of those resolve the problem, gather additional information and file an issue:
 - Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs
 - Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia`
+
+
+## Windows Terminal Errors
+
+Older versions of Windows 10 (e.g., 21H1) are known to have a bug where the standard terminal program does not display control characters correctly.  This can result in a long string of strings like `←[?25h←[?25l` being displayed, sometimes erroring with `The parameter is incorrect`  To resolve this problem, please update to Win 10 22H1 or newer.

+ 1 - 1
docs/windows.md

@@ -19,7 +19,7 @@ Logs will often be helpful in diagnosing the problem (see
 
 ## System Requirements
 
-* Windows 10 or newer, Home or Pro
+* Windows 10 22H2 or newer, Home or Pro
 * NVIDIA 452.39 or newer Drivers if you have an NVIDIA card
 * AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card
 

+ 39 - 12
envconfig/config.go

@@ -4,12 +4,14 @@ import (
 	"errors"
 	"fmt"
 	"log/slog"
+	"math"
 	"net"
 	"os"
 	"path/filepath"
 	"runtime"
 	"strconv"
 	"strings"
+	"time"
 )
 
 type OllamaHost struct {
@@ -34,17 +36,17 @@ var (
 	// Set via OLLAMA_HOST in the environment
 	Host *OllamaHost
 	// Set via OLLAMA_KEEP_ALIVE in the environment
-	KeepAlive string
+	KeepAlive time.Duration
 	// Set via OLLAMA_LLM_LIBRARY in the environment
 	LLMLibrary string
 	// Set via OLLAMA_MAX_LOADED_MODELS in the environment
 	MaxRunners int
 	// Set via OLLAMA_MAX_QUEUE in the environment
 	MaxQueuedRequests int
-	// Set via OLLAMA_MODELS in the environment
-	ModelsDir string
 	// Set via OLLAMA_MAX_VRAM in the environment
 	MaxVRAM uint64
+	// Set via OLLAMA_MODELS in the environment
+	ModelsDir string
 	// Set via OLLAMA_NOHISTORY in the environment
 	NoHistory bool
 	// Set via OLLAMA_NOPRUNE in the environment
@@ -85,13 +87,13 @@ func AsMap() map[string]EnvVar {
 		"OLLAMA_HOST":              {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"},
 		"OLLAMA_KEEP_ALIVE":        {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
 		"OLLAMA_LLM_LIBRARY":       {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
-		"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models (default 1)"},
+		"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
 		"OLLAMA_MAX_QUEUE":         {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
 		"OLLAMA_MAX_VRAM":          {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
 		"OLLAMA_MODELS":            {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
 		"OLLAMA_NOHISTORY":         {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
 		"OLLAMA_NOPRUNE":           {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
-		"OLLAMA_NUM_PARALLEL":      {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default 1)"},
+		"OLLAMA_NUM_PARALLEL":      {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"},
 		"OLLAMA_ORIGINS":           {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
 		"OLLAMA_RUNNERS_DIR":       {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
 		"OLLAMA_SCHED_SPREAD":      {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
@@ -129,9 +131,10 @@ func clean(key string) string {
 
 func init() {
 	// default values
-	NumParallel = 1
-	MaxRunners = 1
+	NumParallel = 0 // Autoselect
+	MaxRunners = 0  // Autoselect
 	MaxQueuedRequests = 512
+	KeepAlive = 5 * time.Minute
 
 	LoadConfig()
 }
@@ -205,8 +208,8 @@ func LoadConfig() {
 
 	if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {
 		val, err := strconv.Atoi(onp)
-		if err != nil || val <= 0 {
-			slog.Error("invalid setting must be greater than zero", "OLLAMA_NUM_PARALLEL", onp, "error", err)
+		if err != nil {
+			slog.Error("invalid setting, ignoring", "OLLAMA_NUM_PARALLEL", onp, "error", err)
 		} else {
 			NumParallel = val
 		}
@@ -251,7 +254,7 @@ func LoadConfig() {
 	if maxRunners != "" {
 		m, err := strconv.Atoi(maxRunners)
 		if err != nil {
-			slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
+			slog.Error("invalid setting, ignoring", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
 		} else {
 			MaxRunners = m
 		}
@@ -260,13 +263,16 @@ func LoadConfig() {
 	if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" {
 		p, err := strconv.Atoi(onp)
 		if err != nil || p <= 0 {
-			slog.Error("invalid setting", "OLLAMA_MAX_QUEUE", onp, "error", err)
+			slog.Error("invalid setting, ignoring", "OLLAMA_MAX_QUEUE", onp, "error", err)
 		} else {
 			MaxQueuedRequests = p
 		}
 	}
 
-	KeepAlive = clean("OLLAMA_KEEP_ALIVE")
+	ka := clean("OLLAMA_KEEP_ALIVE")
+	if ka != "" {
+		loadKeepAlive(ka)
+	}
 
 	var err error
 	ModelsDir, err = getModelsDir()
@@ -344,3 +350,24 @@ func getOllamaHost() (*OllamaHost, error) {
 		Port:   port,
 	}, nil
 }
+
+func loadKeepAlive(ka string) {
+	v, err := strconv.Atoi(ka)
+	if err != nil {
+		d, err := time.ParseDuration(ka)
+		if err == nil {
+			if d < 0 {
+				KeepAlive = time.Duration(math.MaxInt64)
+			} else {
+				KeepAlive = d
+			}
+		}
+	} else {
+		d := time.Duration(v) * time.Second
+		if d < 0 {
+			KeepAlive = time.Duration(math.MaxInt64)
+		} else {
+			KeepAlive = d
+		}
+	}
+}

+ 17 - 0
envconfig/config_test.go

@@ -2,8 +2,10 @@ package envconfig
 
 import (
 	"fmt"
+	"math"
 	"net"
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
@@ -23,6 +25,21 @@ func TestConfig(t *testing.T) {
 	t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
 	LoadConfig()
 	require.True(t, FlashAttention)
+	t.Setenv("OLLAMA_KEEP_ALIVE", "")
+	LoadConfig()
+	require.Equal(t, 5*time.Minute, KeepAlive)
+	t.Setenv("OLLAMA_KEEP_ALIVE", "3")
+	LoadConfig()
+	require.Equal(t, 3*time.Second, KeepAlive)
+	t.Setenv("OLLAMA_KEEP_ALIVE", "1h")
+	LoadConfig()
+	require.Equal(t, 1*time.Hour, KeepAlive)
+	t.Setenv("OLLAMA_KEEP_ALIVE", "-1s")
+	LoadConfig()
+	require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
+	t.Setenv("OLLAMA_KEEP_ALIVE", "-1")
+	LoadConfig()
+	require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
 }
 
 func TestClientFromEnvironment(t *testing.T) {

+ 2 - 1
go.mod

@@ -18,6 +18,7 @@ require (
 require (
 	github.com/agnivade/levenshtein v1.1.1
 	github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
+	github.com/google/go-cmp v0.6.0
 	github.com/mattn/go-runewidth v0.0.14
 	github.com/nlpodyssey/gopickle v0.3.0
 	github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
@@ -71,7 +72,7 @@ require (
 	golang.org/x/net v0.25.0 // indirect
 	golang.org/x/sys v0.20.0
 	golang.org/x/term v0.20.0
-	golang.org/x/text v0.15.0 // indirect
+	golang.org/x/text v0.15.0
 	google.golang.org/protobuf v1.34.1
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 )

+ 11 - 12
gpu/amd_common.go

@@ -49,9 +49,17 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
 }
 
 func commonAMDValidateLibDir() (string, error) {
-	// We try to favor system paths first, so that we can wire up the subprocess to use
-	// the system version.  Only use our bundled version if the system version doesn't work
-	// This gives users a more recovery options if versions have subtle problems at runtime
+	// Favor our bundled version
+
+	// Installer payload location if we're running the installed binary
+	exe, err := os.Executable()
+	if err == nil {
+		rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
+		if rocmLibUsable(rocmTargetDir) {
+			slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
+			return rocmTargetDir, nil
+		}
+	}
 
 	// Prefer explicit HIP env var
 	hipPath := os.Getenv("HIP_PATH")
@@ -87,14 +95,5 @@ func commonAMDValidateLibDir() (string, error) {
 		}
 	}
 
-	// Installer payload location if we're running the installed binary
-	exe, err := os.Executable()
-	if err == nil {
-		rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
-		if rocmLibUsable(rocmTargetDir) {
-			slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
-			return rocmTargetDir, nil
-		}
-	}
 	return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
 }

+ 2 - 3
gpu/amd_hip_windows.go

@@ -84,9 +84,8 @@ func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
 	}
 
 	slog.Debug("hipDriverGetVersion", "version", version)
-	// TODO - this isn't actually right, but the docs claim hipDriverGetVersion isn't accurate anyway...
-	driverMajor = version / 1000
-	driverMinor = (version - (driverMajor * 1000)) / 10
+	driverMajor = version / 10000000
+	driverMinor = (version - (driverMajor * 10000000)) / 100000
 
 	return driverMajor, driverMinor, nil
 }

+ 12 - 14
gpu/amd_windows.go

@@ -22,8 +22,8 @@ const (
 
 var (
 	// Used to validate if the given ROCm lib is usable
-	ROCmLibGlobs          = []string{"hipblas.dll", "rocblas"}                 // TODO - probably include more coverage of files here...
-	RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
+	ROCmLibGlobs          = []string{"hipblas.dll", "rocblas"}                 // This is not sufficient to discern v5 vs v6
+	RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob?
 )
 
 func AMDGetGPUInfo() []RocmGPUInfo {
@@ -35,12 +35,11 @@ func AMDGetGPUInfo() []RocmGPUInfo {
 	}
 	defer hl.Release()
 
-	// TODO - this reports incorrect version information, so omitting for now
-	// driverMajor, driverMinor, err := hl.AMDDriverVersion()
-	// if err != nil {
-	// 	// For now this is benign, but we may eventually need to fail compatibility checks
-	// 	slog.Debug("error looking up amd driver version", "error", err)
-	// }
+	driverMajor, driverMinor, err := hl.AMDDriverVersion()
+	if err != nil {
+		// For now this is benign, but we may eventually need to fail compatibility checks
+		slog.Debug("error looking up amd driver version", "error", err)
+	}
 
 	// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
 	count := hl.HipGetDeviceCount()
@@ -115,8 +114,6 @@ func AMDGetGPUInfo() []RocmGPUInfo {
 			continue
 		}
 
-		// TODO revisit this once ROCm v6 is available on windows.
-		// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
 		slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
 		slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
 		gpuInfo := RocmGPUInfo{
@@ -126,15 +123,16 @@ func AMDGetGPUInfo() []RocmGPUInfo {
 					TotalMemory: totalMemory,
 					FreeMemory:  freeMemory,
 				},
+				// Free memory reporting on Windows is not reliable until we bump to ROCm v6.2
+				UnreliableFreeMemory: true,
+
 				ID:             strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
 				DependencyPath: libDir,
 				MinimumMemory:  rocmMinimumMemory,
 				Name:           name,
 				Compute:        gfx,
-
-				// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
-				// DriverMajor:    driverMajor,
-				// DriverMinor:    driverMinor,
+				DriverMajor:    driverMajor,
+				DriverMinor:    driverMinor,
 			},
 			index: i,
 		}

+ 19 - 12
gpu/assets.go

@@ -77,20 +77,27 @@ func cleanupTmpDirs() {
 			continue
 		}
 		raw, err := os.ReadFile(filepath.Join(d, "ollama.pid"))
-		if err == nil {
-			pid, err := strconv.Atoi(string(raw))
-			if err == nil {
-				if proc, err := os.FindProcess(pid); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
-					// Another running ollama, ignore this tmpdir
-					continue
-				}
-			}
-		} else {
-			slog.Debug("failed to open ollama.pid", "path", d, "error", err)
+		if err != nil {
+			slog.Warn("failed to read ollama.pid", "path", d, "error", err)
+			// No pid, ignore this tmpdir
+			continue
 		}
-		err = os.RemoveAll(d)
+
+		pid, err := strconv.Atoi(string(raw))
 		if err != nil {
-			slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err)
+			slog.Warn("failed to parse pid", "path", d, "error", err)
+			continue
+		}
+
+		proc, err := os.FindProcess(pid)
+		if err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
+			slog.Warn("found running ollama", "pid", pid, "path", d)
+			// Another running ollama, ignore this tmpdir
+			continue
+		}
+
+		if err := os.Remove(d); err != nil {
+			slog.Warn("unable to cleanup stale tmpdir", "path", d, "error", err)
 		}
 	}
 }

+ 51 - 2
gpu/gpu.go

@@ -202,7 +202,7 @@ func GetGPUInfo() GpuInfoList {
 	}()
 
 	if !bootstrapped {
-		slog.Debug("Detecting GPUs")
+		slog.Info("looking for compatible GPUs")
 		needRefresh = false
 		cpuCapability = GetCPUCapability()
 		var memInfo C.mem_info_t
@@ -274,6 +274,28 @@ func GetGPUInfo() GpuInfoList {
 				gpuInfo.DriverMajor = driverMajor
 				gpuInfo.DriverMinor = driverMinor
 
+				// query the management library as well so we can record any skew between the two
+				// which represents overhead on the GPU we must set aside on subsequent updates
+				if cHandles.nvml != nil {
+					C.nvml_get_free(*cHandles.nvml, C.int(gpuInfo.index), &memInfo.free, &memInfo.total, &memInfo.used)
+					if memInfo.err != nil {
+						slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
+						C.free(unsafe.Pointer(memInfo.err))
+					} else {
+						if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory {
+							gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory
+							slog.Info("detected OS VRAM overhead",
+								"id", gpuInfo.ID,
+								"library", gpuInfo.Library,
+								"compute", gpuInfo.Compute,
+								"driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor),
+								"name", gpuInfo.Name,
+								"overhead", format.HumanBytes2(gpuInfo.OSOverhead),
+							)
+						}
+					}
+				}
+
 				// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
 				cudaGPUs = append(cudaGPUs, gpuInfo)
 			}
@@ -320,6 +342,9 @@ func GetGPUInfo() GpuInfoList {
 
 		rocmGPUs = AMDGetGPUInfo()
 		bootstrapped = true
+		if len(cudaGPUs) == 0 && len(rocmGPUs) == 0 && len(oneapiGPUs) == 0 {
+			slog.Info("no compatible GPUs were discovered")
+		}
 	}
 
 	// For detected GPUs, load library if not loaded
@@ -335,14 +360,17 @@ func GetGPUInfo() GpuInfoList {
 					"before",
 					"total", format.HumanBytes2(cpus[0].TotalMemory),
 					"free", format.HumanBytes2(cpus[0].FreeMemory),
+					"free_swap", format.HumanBytes2(cpus[0].FreeSwap),
 				),
 				slog.Group(
 					"now",
 					"total", format.HumanBytes2(mem.TotalMemory),
 					"free", format.HumanBytes2(mem.FreeMemory),
+					"free_swap", format.HumanBytes2(mem.FreeSwap),
 				),
 			)
 			cpus[0].FreeMemory = mem.FreeMemory
+			cpus[0].FreeSwap = mem.FreeSwap
 		}
 
 		var memInfo C.mem_info_t
@@ -371,9 +399,14 @@ func GetGPUInfo() GpuInfoList {
 				slog.Warn("error looking up nvidia GPU memory")
 				continue
 			}
+			if cHandles.nvml != nil && gpu.OSOverhead > 0 {
+				// When using the management library update based on recorded overhead
+				memInfo.free -= C.uint64_t(gpu.OSOverhead)
+			}
 			slog.Debug("updating cuda memory data",
 				"gpu", gpu.ID,
 				"name", gpu.Name,
+				"overhead", format.HumanBytes2(gpu.OSOverhead),
 				slog.Group(
 					"before",
 					"total", format.HumanBytes2(gpu.TotalMemory),
@@ -514,7 +547,23 @@ func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
 		defer C.free(unsafe.Pointer(lib))
 		C.nvcuda_init(lib, &resp)
 		if resp.err != nil {
-			slog.Debug("Unable to load nvcuda", "library", libPath, "error", C.GoString(resp.err))
+			// Decide what log level based on the type of error message to help users understand why
+			msg := C.GoString(resp.err)
+			switch resp.cudaErr {
+			case C.CUDA_ERROR_INSUFFICIENT_DRIVER, C.CUDA_ERROR_SYSTEM_DRIVER_MISMATCH:
+				slog.Warn("version mismatch between driver and cuda driver library - reboot or upgrade may be required", "library", libPath, "error", msg)
+			case C.CUDA_ERROR_NO_DEVICE:
+				slog.Info("no nvidia devices detected", "library", libPath)
+			case C.CUDA_ERROR_UNKNOWN:
+				slog.Warn("unknown error initializing cuda driver library", "library", libPath, "error", msg)
+				slog.Warn("see https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for more information")
+			default:
+				if strings.Contains(msg, "wrong ELF class") {
+					slog.Debug("skipping 32bit library", "library", libPath)
+				} else {
+					slog.Info("unable to load cuda driver library", "library", libPath, "error", msg)
+				}
+			}
 			C.free(unsafe.Pointer(resp.err))
 		} else {
 			return int(resp.num_devices), &resp.ch, libPath

+ 2 - 1
gpu/gpu_darwin.go

@@ -56,7 +56,8 @@ func GetCPUInfo() GpuInfoList {
 func GetCPUMem() (memInfo, error) {
 	return memInfo{
 		TotalMemory: uint64(C.getPhysicalMemory()),
-		FreeMemory:  0,
+		FreeMemory:  uint64(C.getFreeMemory()),
+		// FreeSwap omitted as Darwin uses dynamic paging
 	}, nil
 }
 

+ 1 - 0
gpu/gpu_info_darwin.h

@@ -2,3 +2,4 @@
 #include <stdint.h>
 uint64_t getRecommendedMaxVRAM();
 uint64_t getPhysicalMemory();
+uint64_t getFreeMemory();

+ 24 - 2
gpu/gpu_info_darwin.m

@@ -1,4 +1,5 @@
-// go:build darwin
+#import <Foundation/Foundation.h>
+#import <mach/mach.h>
 #include "gpu_info_darwin.h"
 
 uint64_t getRecommendedMaxVRAM() {
@@ -8,6 +9,27 @@ uint64_t getRecommendedMaxVRAM() {
   return result;
 }
 
+// getPhysicalMemory returns the total physical memory in bytes
 uint64_t getPhysicalMemory() {
-  return [[NSProcessInfo processInfo] physicalMemory];
+  return [NSProcessInfo processInfo].physicalMemory;
+}
+
+// getFreeMemory returns the total free memory in bytes, including inactive
+// memory that can be reclaimed by the system.
+uint64_t getFreeMemory() {
+  mach_port_t host_port = mach_host_self();
+  mach_msg_type_number_t host_size = sizeof(vm_statistics64_data_t) / sizeof(integer_t);
+  vm_size_t pagesize;
+  vm_statistics64_data_t vm_stat;
+
+  host_page_size(host_port, &pagesize);
+  if (host_statistics64(host_port, HOST_VM_INFO64, (host_info64_t)&vm_stat, &host_size) != KERN_SUCCESS) {
+    return 0;
+  }
+
+  uint64_t free_memory = (uint64_t)vm_stat.free_count * pagesize;
+  free_memory += (uint64_t)vm_stat.speculative_count * pagesize;
+  free_memory += (uint64_t)vm_stat.inactive_count * pagesize;
+
+  return free_memory;
 }

+ 16 - 15
gpu/gpu_info_nvcuda.c

@@ -7,6 +7,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
   CUresult ret;
   resp->err = NULL;
   resp->num_devices = 0;
+  resp->cudaErr = CUDA_SUCCESS;
   const int buflen = 256;
   char buf[buflen + 1];
   int i;
@@ -38,6 +39,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
             nvcuda_lib_path, msg);
     free(msg);
     resp->err = strdup(buf);
+    resp->cudaErr = -1;
     return;
   }
 
@@ -52,6 +54,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
               msg);
       free(msg);
       resp->err = strdup(buf);
+      resp->cudaErr = -1;
       return;
     }
   }
@@ -61,12 +64,9 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
     LOG(resp->ch.verbose, "cuInit err: %d\n", ret);
     UNLOAD_LIBRARY(resp->ch.handle);
     resp->ch.handle = NULL;
-    if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
-      resp->err = strdup("your nvidia driver is too old or missing.  If you have a CUDA GPU please upgrade to run ollama");
-      return;
-    }
-    snprintf(buf, buflen, "nvcuda init failure: %d", ret);
+    snprintf(buf, buflen, "cuda driver library init failure: %d", ret);
     resp->err = strdup(buf);
+    resp->cudaErr = ret;
     return;
   }
 
@@ -91,6 +91,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
     resp->ch.handle = NULL;
     snprintf(buf, buflen, "unable to get device count: %d", ret);
     resp->err = strdup(buf);
+    resp->cudaErr = ret;
     return;
   }
 }
@@ -106,13 +107,13 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
   CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
 
   if (h.handle == NULL) {
-    resp->err = strdup("nvcuda handle isn't initialized");
+    resp->err = strdup("cuda driver library handle isn't initialized");
     return;
   }
 
   ret = (*h.cuDeviceGet)(&device, i);
   if (ret != CUDA_SUCCESS) {
-    snprintf(buf, buflen, "nvcuda device failed to initialize");
+    snprintf(buf, buflen, "cuda driver library device failed to initialize");
     resp->err = strdup(buf);
     return;
   }
@@ -168,14 +169,14 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
   // To get memory we have to set (and release) a context
   ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
   if (ret != CUDA_SUCCESS) {
-    snprintf(buf, buflen, "nvcuda failed to get device context %d", ret);
+    snprintf(buf, buflen, "cuda driver library failed to get device context %d", ret);
     resp->err = strdup(buf);
     return;
   }
 
   ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total);
   if (ret != CUDA_SUCCESS) {
-    snprintf(buf, buflen, "nvcuda device memory info lookup failure %d", ret);
+    snprintf(buf, buflen, "cuda driver library device memory info lookup failure %d", ret);
     resp->err = strdup(buf);
     // Best effort on failure...
     (*h.cuCtxDestroy)(ctx);
@@ -193,7 +194,7 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
 
   ret = (*h.cuCtxDestroy)(ctx);
   if (ret != CUDA_SUCCESS) {
-    LOG(1, "nvcuda failed to release device context %d", ret);
+    LOG(1, "cuda driver library failed to release device context %d", ret);
   }
 }
 
@@ -206,7 +207,7 @@ void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total)
 
   ret = (*h.cuDeviceGet)(&device, i);
   if (ret != CUDA_SUCCESS) {
-    LOG(1, "nvcuda device failed to initialize");
+    LOG(1, "cuda driver library device failed to initialize");
     return;
   }
 
@@ -214,13 +215,13 @@ void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total)
   // To get memory we have to set (and release) a context
   ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
   if (ret != CUDA_SUCCESS) {
-    LOG(1, "nvcuda failed to get device context %d", ret);
+    LOG(1, "cuda driver library failed to get device context %d", ret);
     return;
   }
 
   ret = (*h.cuMemGetInfo_v2)(free, total);
   if (ret != CUDA_SUCCESS) {
-    LOG(1, "nvcuda device memory info lookup failure %d", ret);
+    LOG(1, "cuda driver library device memory info lookup failure %d", ret);
     // Best effort on failure...
     (*h.cuCtxDestroy)(ctx);
     return;
@@ -228,12 +229,12 @@ void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total)
 
   ret = (*h.cuCtxDestroy)(ctx);
   if (ret != CUDA_SUCCESS) {
-    LOG(1, "nvcuda failed to release device context %d", ret);
+    LOG(1, "cuda driver library failed to release device context %d", ret);
   }
 }
 
 void nvcuda_release(nvcuda_handle_t h) {
-  LOG(h.verbose, "releasing nvcuda library\n");
+  LOG(h.verbose, "releasing cuda driver library\n");
   UNLOAD_LIBRARY(h.handle);
   // TODO and other context release logic?
   h.handle = NULL;

+ 5 - 1
gpu/gpu_info_nvcuda.h

@@ -7,9 +7,12 @@
 typedef enum cudaError_enum {
   CUDA_SUCCESS = 0,
   CUDA_ERROR_INVALID_VALUE = 1,
-  CUDA_ERROR_MEMORY_ALLOCATION = 2,
+  CUDA_ERROR_OUT_OF_MEMORY = 2,
   CUDA_ERROR_NOT_INITIALIZED = 3,
   CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
+  CUDA_ERROR_NO_DEVICE = 100,
+  CUDA_ERROR_SYSTEM_DRIVER_MISMATCH = 803,
+  CUDA_ERROR_UNKNOWN = 999,
   // Other values omitted for now...
 } CUresult;
 
@@ -64,6 +67,7 @@ typedef struct nvcuda_init_resp {
   char *err;  // If err is non-null handle is invalid
   nvcuda_handle_t ch;
   int num_devices;
+  CUresult cudaErr;
 } nvcuda_init_resp_t;
 
 void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);

+ 9 - 8
gpu/gpu_linux.go

@@ -50,7 +50,7 @@ var OneapiMgmtName = "libze_intel_gpu.so"
 
 func GetCPUMem() (memInfo, error) {
 	var mem memInfo
-	var total, available, free, buffers, cached uint64
+	var total, available, free, buffers, cached, freeSwap uint64
 	f, err := os.Open("/proc/meminfo")
 	if err != nil {
 		return mem, err
@@ -70,20 +70,21 @@ func GetCPUMem() (memInfo, error) {
 			_, err = fmt.Sscanf(line, "Buffers:%d", &buffers)
 		case strings.HasPrefix(line, "Cached:"):
 			_, err = fmt.Sscanf(line, "Cached:%d", &cached)
+		case strings.HasPrefix(line, "SwapFree:"):
+			_, err = fmt.Sscanf(line, "SwapFree:%d", &freeSwap)
 		default:
 			continue
 		}
 		if err != nil {
 			return mem, err
 		}
-
-		if total > 0 && available > 0 {
-			mem.TotalMemory = total * format.KibiByte
-			mem.FreeMemory = available * format.KibiByte
-			return mem, nil
-		}
 	}
 	mem.TotalMemory = total * format.KibiByte
-	mem.FreeMemory = (free + buffers + cached) * format.KibiByte
+	mem.FreeSwap = freeSwap * format.KibiByte
+	if available > 0 {
+		mem.FreeMemory = available * format.KibiByte
+	} else {
+		mem.FreeMemory = (free + buffers + cached) * format.KibiByte
+	}
 	return mem, nil
 }

+ 1 - 1
gpu/gpu_windows.go

@@ -51,5 +51,5 @@ func GetCPUMem() (memInfo, error) {
 	if r1 == 0 {
 		return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err)
 	}
-	return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys}, nil
+	return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys, FreeSwap: memStatus.AvailPageFile}, nil
 }

+ 8 - 1
gpu/types.go

@@ -10,6 +10,7 @@ import (
 type memInfo struct {
 	TotalMemory uint64 `json:"total_memory,omitempty"`
 	FreeMemory  uint64 `json:"free_memory,omitempty"`
+	FreeSwap    uint64 `json:"free_swap,omitempty"`
 }
 
 // Beginning of an `ollama info` command
@@ -29,6 +30,11 @@ type GpuInfo struct {
 	// Extra environment variables specific to the GPU as list of [key,value]
 	EnvWorkarounds [][2]string `json:"envs,omitempty"`
 
+	// Set to true if we can NOT reliably discover FreeMemory.  A value of true indicates
+	// the FreeMemory is best effort, and may over or under report actual memory usage
+	// False indicates FreeMemory can generally be trusted on this GPU
+	UnreliableFreeMemory bool
+
 	// GPU information
 	ID      string `json:"gpu_id"`  // string to use for selection of this specific GPU
 	Name    string `json:"name"`    // user friendly name if available
@@ -47,7 +53,8 @@ type CPUInfo struct {
 
 type CudaGPUInfo struct {
 	GpuInfo
-	index int //nolint:unused,nolintlint
+	OSOverhead uint64 // Memory overhead between the driver library and management library
+	index      int    //nolint:unused,nolintlint
 }
 type CudaGPUInfoList []CudaGPUInfo
 

+ 12 - 13
llm/ext_server/CMakeLists.txt

@@ -1,14 +1,13 @@
-
-set(TARGET ollama_llama_server)
-option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
-include_directories(${CMAKE_CURRENT_SOURCE_DIR})
-add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
-install(TARGETS ${TARGET} RUNTIME)
-target_compile_definitions(${TARGET} PRIVATE
-    SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
-)
-target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
-if (WIN32)
-    TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
-endif()
+set(TARGET ollama_llama_server)
+option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
+include_directories(${CMAKE_CURRENT_SOURCE_DIR})
+add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
+install(TARGETS ${TARGET} RUNTIME)
+target_compile_definitions(${TARGET} PRIVATE
+    SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
+)
+target_link_libraries(${TARGET} PRIVATE ggml llama common llava ${CMAKE_THREAD_LIBS_INIT})
+if (WIN32)
+    TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
+endif()
 target_compile_features(${TARGET} PRIVATE cxx_std_11)

+ 49 - 10
llm/ext_server/server.cpp

@@ -1382,12 +1382,50 @@ struct llama_server_context
         }
     }
 
+    std::string common_prefix(const std::string& str1, const std::string& str2) {
+        auto mismatch_pair = std::mismatch(str1.begin(), str1.end(), str2.begin());
+        return std::string(str1.begin(), mismatch_pair.first);
+    }
+
+    // Find the slot that has the greatest common prefix
+    server_slot *prefix_slot(const json &prompt) {
+        if (!prompt.is_string()) {
+            return nullptr;
+        }
+
+        std::string prompt_str = prompt.get<std::string>();
+        server_slot *slot = nullptr;
+        size_t longest = 0;
+
+        for (server_slot &s : slots) {
+            if (s.available() && s.prompt.is_string()) {
+                std::string s_prompt = s.prompt.get<std::string>();
+                std::string prefix = common_prefix(s_prompt, prompt_str);
+
+                if (prefix.size() > longest) {
+                    slot = &s;
+                    longest = prefix.size();
+                }
+            }
+        }
+
+        if (!slot) {
+            return get_slot(-1);
+        }
+
+        LOG_DEBUG("slot with common prefix found", {{
+            "slot_id", slot->id,
+            "characters", longest
+        }});
+        return slot;
+    }
+
     void process_single_task(task_server& task)
     {
         switch (task.type)
         {
             case TASK_TYPE_COMPLETION: {
-                server_slot *slot = get_slot(json_value(task.data, "slot_id", -1));
+                server_slot *slot = prefix_slot(task.data["prompt"]);
                 if (slot == nullptr)
                 {
                     // if no slot is available, we defer this task for processing later
@@ -1654,22 +1692,23 @@ struct llama_server_context
                     if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
                     {
                         const int n_left = slot.n_ctx - slot.params.n_keep;
-                        const int n_block_size = n_left / 2;
-                        const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
+                        const int n_shift = n_left / 2;
+                        const int n_erase = slot.n_prompt_tokens - slot.params.n_keep - n_shift;
 
                         std::vector<llama_token> new_tokens(
                             prompt_tokens.begin(),
                             prompt_tokens.begin() + slot.params.n_keep);
                         new_tokens.insert(
                             new_tokens.end(),
-                            prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
+                            prompt_tokens.begin() + slot.params.n_keep + n_erase,
                             prompt_tokens.end());
 
-                        LOG_VERBOSE("input truncated", {
-                            {"n_ctx",      slot.n_ctx},
-                            {"n_keep",     slot.params.n_keep},
-                            {"n_left",     n_left},
-                            {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
+                        LOG_INFO("input truncated", {
+                            {"n_ctx",        slot.n_ctx},
+                            {"n_keep",       slot.params.n_keep},
+                            {"n_left",       n_left},
+                            {"n_shift",      n_shift},
+                            {"n_erase",      n_erase},
                         });
                         slot.truncated = true;
                         prompt_tokens = new_tokens;
@@ -1704,7 +1743,7 @@ struct llama_server_context
                             slot.n_past -= 1;
                         }
 
-                        slot.n_prompt_tokens_processed = slot.n_prompt_tokens - slot.n_past;
+                        slot.n_prompt_tokens_processed = slot.n_prompt_tokens;
 
                         if (slot.ga_n != 1)
                         {

+ 8 - 8
llm/generate/gen_darwin.sh

@@ -18,16 +18,16 @@ sign() {
     fi
 }
 
-COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_METAL_MACOSX_VERSION_MIN=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DLLAMA_METAL_EMBED_LIBRARY=on -DLLAMA_OPENMP=off"
+COMMON_DARWIN_DEFS="-DBUILD_SHARED_LIBS=off -DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_METAL_MACOSX_VERSION_MIN=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DGGML_METAL_EMBED_LIBRARY=on -DGGML_OPENMP=off"
 
 case "${GOARCH}" in
 "amd64")
-    COMMON_CPU_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=off -DLLAMA_NATIVE=off"
+    COMMON_CPU_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DGGML_METAL=off -DGGML_NATIVE=off"
 
     # Static build for linking into the Go binary
     init_vars
     CMAKE_TARGETS="--target llama --target ggml"
-    CMAKE_DEFS="${COMMON_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DLLAMA_BLAS=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
+    CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_BLAS=off -DGGML_ACCELERATE=off -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
     BUILD_DIR="../build/darwin/${ARCH}_static"
     echo "Building static library"
     build
@@ -37,7 +37,7 @@ case "${GOARCH}" in
         # CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
         #
         init_vars
-        CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_BLAS=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
+        CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_ACCELERATE=off -DGGML_BLAS=off -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
         BUILD_DIR="../build/darwin/${ARCH}/cpu"
         echo "Building LCD CPU"
         build
@@ -49,7 +49,7 @@ case "${GOARCH}" in
         # Approximately 400% faster than LCD on same CPU
         #
         init_vars
-        CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_BLAS=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
+        CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_ACCELERATE=off -DGGML_BLAS=off -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
         BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
         echo "Building AVX CPU"
         build
@@ -61,7 +61,7 @@ case "${GOARCH}" in
         # Approximately 10% faster than AVX on same CPU
         #
         init_vars
-        CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_BLAS=off -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
+        CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_ACCELERATE=on -DGGML_BLAS=off -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on ${CMAKE_DEFS}"
         BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2"
         echo "Building AVX2 CPU"
         EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
@@ -75,14 +75,14 @@ case "${GOARCH}" in
     # Static build for linking into the Go binary
     init_vars
     CMAKE_TARGETS="--target llama --target ggml"
-    CMAKE_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_BLAS=off -DCMAKE_SYSTEM_NAME=Darwin -DBUILD_SHARED_LIBS=off -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
+    CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} ${CMAKE_DEFS}"
     BUILD_DIR="../build/darwin/${ARCH}_static"
     echo "Building static library"
     build
 
     if [ -z "$OLLAMA_SKIP_METAL_GENERATE" ]; then
         init_vars
-        CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
+        CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} ${CMAKE_DEFS}"
         BUILD_DIR="../build/darwin/${ARCH}/metal"
         EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
         build

+ 19 - 19
llm/generate/gen_linux.sh

@@ -51,7 +51,7 @@ if [ -z "${CUDACXX}" ]; then
         export CUDACXX=$(command -v nvcc)
     fi
 fi
-COMMON_CMAKE_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off -DLLAMA_OPENMP=off"
+COMMON_CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on -DGGML_NATIVE=off -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_OPENMP=off"
 source $(dirname $0)/gen_common.sh
 init_vars
 git_module_setup
@@ -64,7 +64,7 @@ if [ -z "${OLLAMA_SKIP_STATIC_GENERATE}" -o "${OLLAMA_CPU_TARGET}" = "static" ];
     # Static build for linking into the Go binary
     init_vars
     CMAKE_TARGETS="--target llama --target ggml"
-    CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off -DLLAMA_OPENMP=off ${CMAKE_DEFS}"
+    CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DGGML_NATIVE=off -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_OPENMP=off ${CMAKE_DEFS}"
     BUILD_DIR="../build/linux/${ARCH}_static"
     echo "Building static library"
     build
@@ -77,29 +77,29 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
     if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
         init_vars
         echo "OLLAMA_CUSTOM_CPU_DEFS=\"${OLLAMA_CUSTOM_CPU_DEFS}\""
-        CMAKE_DEFS="${OLLAMA_CUSTOM_CPU_DEFS} -DCMAKE_POSITION_INDEPENDENT_CODE=on ${CMAKE_DEFS}"
+        CMAKE_DEFS="${OLLAMA_CUSTOM_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on ${CMAKE_DEFS}"
         BUILD_DIR="../build/linux/${ARCH}/cpu"
         echo "Building custom CPU"
         build
         compress
     else
         # Darwin Rosetta x86 emulation does NOT support AVX, AVX2, AVX512
-        # -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
-        # -DLLAMA_F16C -- 2012 Intel Ivy Bridge & AMD 2011 Bulldozer (No significant improvement over just AVX)
-        # -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
-        # -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
+        # -DGGML_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
+        # -DGGML_F16C -- 2012 Intel Ivy Bridge & AMD 2011 Bulldozer (No significant improvement over just AVX)
+        # -DGGML_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
+        # -DGGML_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
         # Note: the following seem to yield slower results than AVX2 - ymmv
-        # -DLLAMA_AVX512 -- 2017 Intel Skylake and High End DeskTop (HEDT)
-        # -DLLAMA_AVX512_VBMI -- 2018 Intel Cannon Lake
-        # -DLLAMA_AVX512_VNNI -- 2021 Intel Alder Lake
+        # -DGGML_AVX512 -- 2017 Intel Skylake and High End DeskTop (HEDT)
+        # -DGGML_AVX512_VBMI -- 2018 Intel Cannon Lake
+        # -DGGML_AVX512_VNNI -- 2021 Intel Alder Lake
 
-        COMMON_CPU_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off -DLLAMA_OPENMP=off"
+        COMMON_CPU_DEFS="-DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on -DGGML_NATIVE=off -DGGML_OPENMP=off"
         if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu" ]; then
             #
             # CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
             #
             init_vars
-            CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
+            CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
             BUILD_DIR="../build/linux/${ARCH}/cpu"
             echo "Building LCD CPU"
             build
@@ -116,7 +116,7 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
                 # Approximately 400% faster than LCD on same CPU
                 #
                 init_vars
-                CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
+                CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
                 BUILD_DIR="../build/linux/${ARCH}/cpu_avx"
                 echo "Building AVX CPU"
                 build
@@ -129,7 +129,7 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
                 # Approximately 10% faster than AVX on same CPU
                 #
                 init_vars
-                CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
+                CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on ${CMAKE_DEFS}"
                 BUILD_DIR="../build/linux/${ARCH}/cpu_avx2"
                 echo "Building AVX2 CPU"
                 build
@@ -170,15 +170,15 @@ if [ -z "${OLLAMA_SKIP_CUDA_GENERATE}" -a -d "${CUDA_LIB_DIR}" ]; then
         #
         # CUDA compute < 6.0 lacks proper FP16 support on ARM.
         # Disabling has minimal performance effect while maintaining compatibility.
-        ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off"
+        ARM64_DEFS="-DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_CUDA_F16=off"
     fi
     # Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
     if [ -n "${OLLAMA_CUSTOM_CUDA_DEFS}" ]; then
         echo "OLLAMA_CUSTOM_CUDA_DEFS=\"${OLLAMA_CUSTOM_CUDA_DEFS}\""
-        CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
+        CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
         echo "Building custom CUDA GPU"
     else
-        CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
+        CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
     fi
     CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
     BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
@@ -216,7 +216,7 @@ if [ -z "${OLLAMA_SKIP_ONEAPI_GENERATE}" -a -d "${ONEAPI_ROOT}" ]; then
     init_vars
     source ${ONEAPI_ROOT}/setvars.sh --force # set up environment variables for oneAPI
     CC=icx
-    CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_SYCL=ON -DLLAMA_SYCL_F16=OFF"
+    CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL=ON -DGGML_SYCL_F16=OFF"
     BUILD_DIR="../build/linux/${ARCH}/oneapi"
     EXTRA_LIBS="-fsycl -Wl,-rpath,${ONEAPI_ROOT}/compiler/latest/lib,-rpath,${ONEAPI_ROOT}/mkl/latest/lib,-rpath,${ONEAPI_ROOT}/tbb/latest/lib,-rpath,${ONEAPI_ROOT}/compiler/latest/opt/oclfpga/linux64/lib -lOpenCL -lmkl_core -lmkl_sycl_blas -lmkl_intel_ilp64 -lmkl_tbb_thread -ltbb"
     DEBUG_FLAGS="" # icx compiles with -O0 if we pass -g, so we must remove it
@@ -254,7 +254,7 @@ if [ -z "${OLLAMA_SKIP_ROCM_GENERATE}" -a -d "${ROCM_PATH}" ]; then
         ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true)
     fi
     init_vars
-    CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
+    CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DLLAMA_CUDA_NO_PEER_COPY=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
     # Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
     if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
         echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""

+ 24 - 33
llm/generate/gen_windows.ps1

@@ -6,18 +6,9 @@ function amdGPUs {
     if ($env:AMDGPU_TARGETS) {
         return $env:AMDGPU_TARGETS
     }
-    # TODO - load from some common data file for linux + windows build consistency
+    # Current supported rocblas list from ROCm v6.1.2 on windows
     $GPU_LIST = @(
-        "gfx900"
         "gfx906:xnack-"
-        "gfx908:xnack-"
-        "gfx90a:xnack+"
-        "gfx90a:xnack-"
-        "gfx940"
-        "gfx941"
-        "gfx942"
-        "gfx1010"
-        "gfx1012"
         "gfx1030"
         "gfx1100"
         "gfx1101"
@@ -39,8 +30,8 @@ function init_vars {
     }
     $script:cmakeDefs = @(
         "-DBUILD_SHARED_LIBS=on",
-        "-DLLAMA_NATIVE=off",
-        "-DLLAMA_OPENMP=off"
+        "-DGGML_NATIVE=off",
+        "-DGGML_OPENMP=off"
         )
     $script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
     $script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower()
@@ -182,9 +173,9 @@ function cleanup {
 }
 
 
-# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
-# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
-# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
+# -DGGML_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
+# -DGGML_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
+# -DGGML_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
 
 
 function build_static() {
@@ -204,13 +195,13 @@ function build_static() {
             "-DCMAKE_C_COMPILER=gcc.exe",
             "-DCMAKE_CXX_COMPILER=g++.exe",
             "-DBUILD_SHARED_LIBS=off",
-            "-DLLAMA_NATIVE=off",
-            "-DLLAMA_AVX=off",
-            "-DLLAMA_AVX2=off",
-            "-DLLAMA_AVX512=off",
-            "-DLLAMA_F16C=off",
-            "-DLLAMA_FMA=off",
-            "-DLLAMA_OPENMP=off")
+            "-DGGML_NATIVE=off",
+            "-DGGML_AVX=off",
+            "-DGGML_AVX2=off",
+            "-DGGML_AVX512=off",
+            "-DGGML_F16C=off",
+            "-DGGML_FMA=off",
+            "-DGGML_OPENMP=off")
         $script:buildDir="../build/windows/${script:ARCH}_static"
         write-host "Building static library"
         build
@@ -224,7 +215,7 @@ function build_cpu($gen_arch) {
     if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
         # remaining llama.cpp builds use MSVC 
         init_vars
-        $script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
+        $script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DGGML_AVX=off", "-DGGML_AVX2=off", "-DGGML_AVX512=off", "-DGGML_FMA=off", "-DGGML_F16C=off") + $script:cmakeDefs
         $script:buildDir="../build/windows/${script:ARCH}/cpu"
         $script:distDir="$script:DIST_BASE\cpu"
         write-host "Building LCD CPU"
@@ -239,7 +230,7 @@ function build_cpu($gen_arch) {
 function build_cpu_avx() {
     if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx"))) {
         init_vars
-        $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
+        $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DGGML_AVX=on", "-DGGML_AVX2=off", "-DGGML_AVX512=off", "-DGGML_FMA=off", "-DGGML_F16C=off") + $script:cmakeDefs
         $script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
         $script:distDir="$script:DIST_BASE\cpu_avx"
         write-host "Building AVX CPU"
@@ -254,7 +245,7 @@ function build_cpu_avx() {
 function build_cpu_avx2() {
     if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx2"))) {
         init_vars
-        $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
+        $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DGGML_AVX=on", "-DGGML_AVX2=on", "-DGGML_AVX512=off", "-DGGML_FMA=on", "-DGGML_F16C=on") + $script:cmakeDefs
         $script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
         $script:distDir="$script:DIST_BASE\cpu_avx2"
         write-host "Building AVX2 CPU"
@@ -279,9 +270,9 @@ function build_cuda() {
         $script:distDir="$script:DIST_BASE\cuda$script:CUDA_VARIANT"
         $script:cmakeDefs += @(
             "-A", "x64",
-            "-DLLAMA_CUDA=ON",
-            "-DLLAMA_AVX=on",
-            "-DLLAMA_AVX2=off",
+            "-DGGML_CUDA=ON",
+            "-DGGML_AVX=on",
+            "-DGGML_AVX2=off",
             "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR",
             "-DCMAKE_CUDA_FLAGS=-t8",
             "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}"
@@ -319,7 +310,7 @@ function build_oneapi() {
     $script:distDir ="$script:DIST_BASE\oneapi$script:ONEAPI_VARIANT"
     $script:cmakeDefs += @(
       "-G", "MinGW Makefiles",
-      "-DLLAMA_SYCL=ON",
+      "-DGGML_SYCL=ON",
       "-DCMAKE_C_COMPILER=icx",
       "-DCMAKE_CXX_COMPILER=icx",
       "-DCMAKE_BUILD_TYPE=Release"
@@ -365,10 +356,11 @@ function build_rocm() {
             "-G", "Ninja", 
             "-DCMAKE_C_COMPILER=clang.exe",
             "-DCMAKE_CXX_COMPILER=clang++.exe",
-            "-DLLAMA_HIPBLAS=on",
+            "-DGGML_HIPBLAS=on",
+            "-DLLAMA_CUDA_NO_PEER_COPY=on",
             "-DHIP_PLATFORM=amd",
-            "-DLLAMA_AVX=on",
-            "-DLLAMA_AVX2=off",
+            "-DGGML_AVX=on",
+            "-DGGML_AVX2=off",
             "-DCMAKE_POSITION_INDEPENDENT_CODE=on",
             "-DAMDGPU_TARGETS=$(amdGPUs)",
             "-DGPU_TARGETS=$(amdGPUs)"
@@ -394,7 +386,6 @@ function build_rocm() {
         sign
         install
 
-        # Assumes v5.7, may need adjustments for v6
         rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
         md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
         cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"

+ 11 - 2
llm/ggla.go

@@ -53,7 +53,7 @@ func (llm *ggla) Tensors() Tensors {
 	return llm.tensors
 }
 
-func (llm *ggla) decode(rs io.ReadSeeker) error {
+func (llm *ggla) decode(rs io.ReadSeeker) (retErr error) {
 	var r uint32
 	if err := binary.Read(rs, binary.LittleEndian, &r); err != nil {
 		return err
@@ -69,9 +69,18 @@ func (llm *ggla) decode(rs io.ReadSeeker) error {
 	for {
 		var dims uint32
 		if err := binary.Read(rs, binary.LittleEndian, &dims); err != nil {
+			if errors.Is(err, io.EOF) {
+				return nil
+			}
 			return err
 		}
 
+		defer func() {
+			if errors.Is(retErr, io.EOF) {
+				retErr = io.ErrUnexpectedEOF
+			}
+		}()
+
 		var namesize uint32
 		if err := binary.Read(rs, binary.LittleEndian, &namesize); err != nil {
 			return err
@@ -108,7 +117,7 @@ func (llm *ggla) decode(rs io.ReadSeeker) error {
 			return err
 		}
 
-		if _, err := rs.Seek((offset+31)&-32, io.SeekStart); err != nil {
+		if _, err := rs.Seek((offset+31)&-32-offset, io.SeekCurrent); err != nil {
 			return err
 		}
 

+ 89 - 17
llm/ggml.go

@@ -6,6 +6,8 @@ import (
 	"fmt"
 	"io"
 	"strings"
+
+	"github.com/ollama/ollama/util/bufioutil"
 )
 
 type GGML struct {
@@ -69,6 +71,30 @@ func (kv KV) HeadCountKV() uint64 {
 	return 1
 }
 
+func (kv KV) EmbeddingHeadCount() uint64 {
+	if heads := kv.HeadCount(); heads > 0 {
+		return kv.EmbeddingLength() / kv.HeadCount()
+	}
+
+	return 0
+}
+
+func (kv KV) EmbeddingHeadCountK() uint64 {
+	if k := kv.u64(fmt.Sprintf("%s.attention.key_length", kv.Architecture())); k > 0 {
+		return k
+	}
+
+	return kv.EmbeddingHeadCount()
+}
+
+func (kv KV) EmbeddingHeadCountV() uint64 {
+	if v := kv.u64(fmt.Sprintf("%s.attention.value_length", kv.Architecture())); v > 0 {
+		return v
+	}
+
+	return kv.EmbeddingHeadCount()
+}
+
 func (kv KV) GQA() uint64 {
 	return kv.HeadCount() / kv.HeadCountKV()
 }
@@ -254,7 +280,18 @@ func DetectGGMLType(b []byte) string {
 	}
 }
 
-func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
+// DecodeGGML decodes a GGML model from the given reader.
+//
+// It collects array values for arrays with a size less than or equal to
+// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
+// the maxArraySize is negative, all arrays are collected.
+func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
+	if maxArraySize == 0 {
+		maxArraySize = 1024
+	}
+
+	rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
+
 	var magic uint32
 	if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
 		return nil, 0, err
@@ -267,17 +304,15 @@ func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
 	case FILE_MAGIC_GGLA:
 		c = &containerGGLA{}
 	case FILE_MAGIC_GGUF_LE:
-		c = &containerGGUF{ByteOrder: binary.LittleEndian}
+		c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize}
 	case FILE_MAGIC_GGUF_BE:
-		c = &containerGGUF{ByteOrder: binary.BigEndian}
+		c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
 	default:
 		return nil, 0, errors.New("invalid file magic")
 	}
 
 	model, err := c.Decode(rs)
-	if errors.Is(err, io.EOF) {
-		// noop
-	} else if err != nil {
+	if err != nil {
 		return nil, 0, err
 	}
 
@@ -297,7 +332,10 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 	embedding := llm.KV().EmbeddingLength()
 	heads := llm.KV().HeadCount()
 	headsKV := llm.KV().HeadCountKV()
-	vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
+	vocab := uint64(llm.KV()["tokenizer.ggml.tokens"].(*array).size)
+
+	embeddingHeads := llm.KV().EmbeddingHeadCount()
+	embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
 
 	layers := llm.Tensors().Layers()
 
@@ -308,7 +346,7 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 		partialOffload = 4 * batch * embedding
 		partialOffload += max(
 			// 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()),
-			4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
+			4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
 		)
 
@@ -316,21 +354,30 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 			// mixtral 8x22b
 			ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
 			partialOffload = max(
-				3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
-				4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch),
+				3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
+				4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
 			)
 		} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
 			// mixtral 8x7b
 			ffnGateWeight1 := ffnGateWeight.Shape[1]
 			fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
 			partialOffload = max(
-				4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
+				4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
 				4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
 			)
 		}
-	case "gemma":
-		fullOffload = 4 * batch * (embedding + vocab)
-		partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
+	case "gemma", "gemma2":
+		fullOffload = max(
+			4*batch*(embedding+vocab),
+			4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
+		)
+
+		partialOffload = max(
+			4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
+			4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
+				4*embeddingHeadsK*context*8+
+				embedding*embeddingHeadsK*heads*9/16,
+		)
 	case "command-r":
 		fullOffload = max(
 			4*batch*(embedding+vocab),
@@ -368,16 +415,41 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 			fullOffload,
 		)
 	case "deepseek2":
-		keys := uint64(llm.KV()["deepseek2.attention.key_length"].(uint32))
 		fullOffload = max(
 			4*batch*(3*embedding+vocab),
-			4*batch*(3*embedding+2+context*(1+headsKV)+2*keys*headsKV),
+			4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV),
 		)
 
 		partialOffload = max(
 			4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
-			4*batch*(2*embedding+1+2*keys*headsKV+context+context*headsKV)+4*keys*context*headsKV+embedding*keys*headsKV*9/16,
+			4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
 		)
+	case "chatglm":
+		fullOffload = 4 * batch * (embedding + vocab)
+		partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
+		if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
+			fullOffload = max(
+				fullOffload,
+				4*batch*(2+
+					2*embedding+
+					context+
+					context*heads+
+					embeddingHeadsK*heads+
+					qkvBias.Shape[0]),
+			)
+
+			partialOffload = max(
+				partialOffload,
+				4*batch*(1+
+					2*embedding+
+					embeddingHeadsK*heads+
+					context+
+					context*heads)+
+					4*embeddingHeadsK*context+
+					4*context*embeddingHeadsK+
+					4*qkvBias.Shape[0],
+			)
+		}
 	}
 
 	return

+ 1 - 0
llm/ggml_test.go

@@ -0,0 +1 @@
+package llm

+ 92 - 38
llm/gguf.go

@@ -3,11 +3,10 @@ package llm
 import (
 	"bytes"
 	"encoding/binary"
+	"encoding/json"
 	"fmt"
 	"io"
 	"strings"
-
-	"log/slog"
 )
 
 type containerGGUF struct {
@@ -29,6 +28,12 @@ type containerGGUF struct {
 		NumTensor uint64
 		NumKV     uint64
 	}
+
+	maxArraySize int
+}
+
+func (c *containerGGUF) canCollectArray(size int) bool {
+	return c.maxArraySize < 0 || size <= c.maxArraySize
 }
 
 func (c *containerGGUF) Name() string {
@@ -54,7 +59,6 @@ func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
 	}
 
 	model := newGGUF(c)
-	slog.Debug(fmt.Sprintf("model = %#v", model))
 	if err := model.Decode(rs); err != nil {
 		return nil, err
 	}
@@ -85,6 +89,8 @@ type gguf struct {
 	tensors []*Tensor
 
 	parameters uint64
+
+	scratch [16 << 10]byte
 }
 
 func newGGUF(container *containerGGUF) *gguf {
@@ -181,34 +187,34 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
 	}
 
 	// decode tensors
-	for i := 0; uint64(i) < llm.numTensor(); i++ {
+	for range llm.numTensor() {
 		name, err := readGGUFString(llm, rs)
 		if err != nil {
-			return err
+			return fmt.Errorf("failed to read tensor name: %w", err)
 		}
 
 		// dims is the number of dimensions in the tensor
 		dims, err := readGGUF[uint32](llm, rs)
 		if err != nil {
-			return err
+			return fmt.Errorf("failed to read tensor dimensions: %w", err)
 		}
 
 		shape := [4]uint64{1, 1, 1, 1}
 		for i := 0; uint32(i) < dims; i++ {
 			shape[i], err = readGGUF[uint64](llm, rs)
 			if err != nil {
-				return err
+				return fmt.Errorf("failed to read tensor shape: %w", err)
 			}
 		}
 
 		kind, err := readGGUF[uint32](llm, rs)
 		if err != nil {
-			return err
+			return fmt.Errorf("failed to read tensor kind: %w", err)
 		}
 
 		offset, err := readGGUF[uint64](llm, rs)
 		if err != nil {
-			return err
+			return fmt.Errorf("failed to read tensor offset: %w", err)
 		}
 
 		tensor := Tensor{
@@ -230,24 +236,19 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
 		alignment = 32
 	}
 
-	offset, err := rs.Seek(0, io.SeekCurrent)
-	if err != nil {
-		return err
-	}
-
-	padding := llm.padding(offset, int64(alignment))
-	if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
-		return err
-	}
-
 	for _, tensor := range llm.tensors {
-		if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
-			return err
+		offset, err := rs.Seek(0, io.SeekCurrent)
+		if err != nil {
+			return fmt.Errorf("failed to get current offset: %w", err)
 		}
 
-		padding := llm.padding(int64(tensor.Size()), int64(alignment))
+		padding := llm.padding(offset, int64(alignment))
 		if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
-			return err
+			return fmt.Errorf("failed to seek to init padding: %w", err)
+		}
+
+		if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
+			return fmt.Errorf("failed to seek to tensor: %w", err)
 		}
 	}
 
@@ -285,22 +286,48 @@ func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
 	return b.String(), nil
 }
 
+func discardGGUFString(llm *gguf, r io.Reader) error {
+	buf := llm.scratch[:8]
+	_, err := io.ReadFull(r, buf)
+	if err != nil {
+		return err
+	}
+
+	size := int(llm.ByteOrder.Uint64(buf))
+	for size > 0 {
+		n, err := r.Read(llm.scratch[:min(size, cap(llm.scratch))])
+		if err != nil {
+			return err
+		}
+		size -= n
+	}
+	return nil
+}
+
 func readGGUFString(llm *gguf, r io.Reader) (string, error) {
 	if llm.Version == 1 {
 		return readGGUFV1String(llm, r)
 	}
 
-	var length uint64
-	if err := binary.Read(r, llm.ByteOrder, &length); err != nil {
+	buf := llm.scratch[:8]
+	_, err := io.ReadFull(r, buf)
+	if err != nil {
 		return "", err
 	}
 
-	var b bytes.Buffer
-	if _, err := io.CopyN(&b, r, int64(length)); err != nil {
-		return "", err
+	length := int(llm.ByteOrder.Uint64(buf))
+	if length > len(llm.scratch) {
+		buf = make([]byte, length)
+	} else {
+		buf = llm.scratch[:length]
 	}
+	clear(buf)
 
-	return b.String(), nil
+	_, err = io.ReadFull(r, buf)
+	if err != nil {
+		return "", err
+	}
+	return string(buf), nil
 }
 
 func writeGGUFString(llm *gguf, w io.Writer, s string) error {
@@ -316,7 +343,16 @@ func writeGGUFString(llm *gguf, w io.Writer, s string) error {
 	return err
 }
 
-func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) {
+type array struct {
+	size   int
+	values []any
+}
+
+func (a *array) MarshalJSON() ([]byte, error) {
+	return json.Marshal(a.values)
+}
+
+func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
 	t, err := readGGUF[uint32](llm, r)
 	if err != nil {
 		return nil, err
@@ -327,7 +363,12 @@ func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) {
 		return nil, err
 	}
 
-	for i := 0; uint32(i) < n; i++ {
+	a := &array{size: int(n)}
+	if llm.canCollectArray(int(n)) {
+		a.values = make([]any, 0, int(n))
+	}
+
+	for i := range n {
 		var e any
 		switch t {
 		case ggufTypeUint8:
@@ -361,13 +402,15 @@ func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) {
 			return nil, err
 		}
 
-		a = append(a, e)
+		if a.values != nil {
+			a.values[i] = e
+		}
 	}
 
-	return
+	return a, nil
 }
 
-func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
+func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
 	if llm.Version == 1 {
 		return readGGUFV1Array(llm, r)
 	}
@@ -382,7 +425,12 @@ func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
 		return nil, err
 	}
 
-	for i := 0; uint64(i) < n; i++ {
+	a := &array{size: int(n)}
+	if llm.canCollectArray(int(n)) {
+		a.values = make([]any, int(n))
+	}
+
+	for i := range n {
 		var e any
 		switch t {
 		case ggufTypeUint8:
@@ -408,7 +456,11 @@ func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
 		case ggufTypeBool:
 			e, err = readGGUF[bool](llm, r)
 		case ggufTypeString:
-			e, err = readGGUFString(llm, r)
+			if a.values != nil {
+				e, err = readGGUFString(llm, r)
+			} else {
+				err = discardGGUFString(llm, r)
+			}
 		default:
 			return nil, fmt.Errorf("invalid array type: %d", t)
 		}
@@ -416,10 +468,12 @@ func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
 			return nil, err
 		}
 
-		a = append(a, e)
+		if a.values != nil {
+			a.values[i] = e
+		}
 	}
 
-	return
+	return a, nil
 }
 
 func writeGGUFArray[S ~[]E, E any](llm *gguf, w io.Writer, t uint32, s S) error {

+ 1 - 1
llm/llama.cpp

@@ -1 +1 @@
-Subproject commit 7c26775adb579e92b59c82e8084c07a1d0f75e9c
+Subproject commit a8db2a9ce64cd4417f6a312ab61858f17f0f8584

+ 9 - 8
llm/llm.go

@@ -1,12 +1,13 @@
 package llm
 
-// #cgo CFLAGS: -Illama.cpp
-// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
-// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
-// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
-// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/libllama.a -static -lstdc++
-// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
-// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
+// #cgo CFLAGS: -Illama.cpp -Illama.cpp/include -Illama.cpp/ggml/include
+// #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread
+// #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal
+// #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src
+// #cgo windows,amd64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src
+// #cgo windows,arm64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/arm64_static -L${SRCDIR}/build/windows/arm64_static/src -L${SRCDIR}/build/windows/arm64_static/ggml/src
+// #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
+// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
 // #include <stdlib.h>
 // #include "llama.h"
 import "C"
@@ -32,7 +33,7 @@ func Quantize(infile, outfile string, ftype fileType) error {
 	params.ftype = ftype.Value()
 
 	if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
-		return fmt.Errorf("llama_model_quantize: %d", rc)
+		return fmt.Errorf("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version")
 	}
 
 	return nil

+ 2 - 2
llm/memory.go

@@ -115,8 +115,8 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
 		slog.Warn("model missing blk.0 layer size")
 	}
 
-	// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
-	var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
+	// fp16 k,v = sizeof(float16) * n_ctx * n_layer * (n_embd_head_k + n_embd_head_v) * n_head_kv
+	var kv uint64 = 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * (ggml.KV().EmbeddingHeadCountK() + ggml.KV().EmbeddingHeadCountV()) * ggml.KV().HeadCountKV()
 
 	// KV is proportional to the number of layers
 	layerSize += kv / ggml.KV().BlockCount()

+ 11 - 8
llm/memory_test.go

@@ -22,13 +22,14 @@ func TestEstimateGPULayers(t *testing.T) {
 	defer f.Close()
 	gguf := NewGGUFV3(binary.LittleEndian)
 	inputLayerCount := 5
+
 	tensors := []Tensor{
-		{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
-		{Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
-		{Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
-		{Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
-		{Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
-		{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+		{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
+		{Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
+		{Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
+		{Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
+		{Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
+		{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
 	}
 	assert.Len(t, tensors, inputLayerCount+1)
 	err = gguf.Encode(f, KV{
@@ -45,8 +46,10 @@ func TestEstimateGPULayers(t *testing.T) {
 	}, tensors)
 	require.NoError(t, err)
 
-	ggml, err := LoadModel(f.Name())
-	require.NoError(t, err)
+	ggml, err := LoadModel(f.Name(), 0)
+	if err != nil {
+		t.Fatal(err)
+	}
 
 	// Simple CPU scenario
 	gpus := []gpu.GpuInfo{

+ 7 - 7
llm/patches/01-load-progress.diff

@@ -1,8 +1,8 @@
 diff --git a/common/common.cpp b/common/common.cpp
-index 73ff0e85..6adb1a92 100644
+index 2c05a4d4..927f0e3d 100644
 --- a/common/common.cpp
 +++ b/common/common.cpp
-@@ -2447,6 +2447,8 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
+@@ -2093,6 +2093,8 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
      mparams.use_mmap        = params.use_mmap;
      mparams.use_mlock       = params.use_mlock;
      mparams.check_tensors   = params.check_tensors;
@@ -12,10 +12,10 @@ index 73ff0e85..6adb1a92 100644
          mparams.kv_overrides = NULL;
      } else {
 diff --git a/common/common.h b/common/common.h
-index 58ed72f4..0bb2605e 100644
+index 65c0ef81..ebca2c77 100644
 --- a/common/common.h
 +++ b/common/common.h
-@@ -180,6 +180,13 @@ struct gpt_params {
+@@ -184,6 +184,13 @@ struct gpt_params {
      std::string mmproj = "";        // path to multimodal projector
      std::vector<std::string> image; // path to image file(s)
  
@@ -26,6 +26,6 @@ index 58ed72f4..0bb2605e 100644
 +    // context pointer passed to the progress callback
 +    void * progress_callback_user_data;
 +
-     // server params
-     int32_t port           = 8080;         // server listens on this network port
-     int32_t timeout_read   = 600;          // http read timeout in seconds
+     // embedding
+     bool embedding         = false; // get only sentence embedding
+     int32_t embd_normalize = 2;     // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)

+ 6 - 18
llm/patches/03-load_exception.diff

@@ -1,17 +1,8 @@
-From 544a2d2e646d39e878d87dfbb3398a356bc560ab Mon Sep 17 00:00:00 2001
-From: Michael Yang <mxyng@pm.me>
-Date: Thu, 23 May 2024 11:18:45 -0700
-Subject: [PATCH] throw exception on load errors
-
----
- llama.cpp | 25 ++++++++++++++++---------
- 1 file changed, 16 insertions(+), 9 deletions(-)
-
-diff --git a/llama.cpp b/llama.cpp
-index 15c66077..8ba90b6a 100644
---- a/llama.cpp
-+++ b/llama.cpp
-@@ -6346,7 +6346,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
+diff --git a/src/llama.cpp b/src/llama.cpp
+index 73f52435..58a00fb1 100644
+--- a/src/llama.cpp
++++ b/src/llama.cpp
+@@ -7241,7 +7241,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
          }
      } catch (const std::exception & err) {
          LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
@@ -20,7 +11,7 @@ index 15c66077..8ba90b6a 100644
      }
  
      return 0;
-@@ -15600,16 +15600,23 @@ struct llama_model * llama_load_model_from_file(
+@@ -17564,16 +17564,23 @@ struct llama_model * llama_load_model_from_file(
          }
          model->rpc_servers.push_back(servers);
      }
@@ -52,6 +43,3 @@ index 15c66077..8ba90b6a 100644
      }
  
      return model;
--- 
-2.45.1
-

+ 3 - 3
llm/patches/04-metal.diff

@@ -1,7 +1,7 @@
-diff --git a/ggml-metal.m b/ggml-metal.m
+diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
 index 0207b787..b5e9884b 100644
---- a/ggml-metal.m
-+++ b/ggml-metal.m
+--- a/ggml/src/ggml-metal.m
++++ b/ggml/src/ggml-metal.m
 @@ -1396,27 +1396,23 @@ static enum ggml_status ggml_metal_graph_compute(
                          // to the matrix-vector kernel
                          int ne11_mm_min = 1;

+ 11 - 11
llm/patches/05-default-pretokenizer.diff

@@ -1,11 +1,11 @@
-diff --git a/llama.cpp b/llama.cpp
-index 61948751..4b72a293 100644
---- a/llama.cpp
-+++ b/llama.cpp
-@@ -4824,16 +4824,7 @@ static void llm_load_vocab(
- 
-         // for now, only BPE models have pre-tokenizers
+diff --git a/src/llama.cpp b/src/llama.cpp
+index 2b9ace28..172640e2 100644
+--- a/src/llama.cpp
++++ b/src/llama.cpp
+@@ -5357,16 +5357,7 @@ static void llm_load_vocab(
          if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
+             vocab.tokenizer_add_space_prefix = false;
+             vocab.tokenizer_clean_spaces = true;
 -            if (tokenizer_pre.empty()) {
 -                LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
 -                LLAMA_LOG_WARN("%s:                                             \n", __func__);
@@ -20,13 +20,13 @@ index 61948751..4b72a293 100644
                  vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
              } else if (
                      tokenizer_pre == "llama3"   ||
-@@ -4888,7 +4879,8 @@ static void llm_load_vocab(
-                 tokenizer_pre == "poro-chat") {
-                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO;
+@@ -5439,7 +5430,8 @@ static void llm_load_vocab(
+                 tokenizer_pre == "jais") {
+                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
              } else {
 -                throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
 +                LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
 +                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
              }
-         } else {
+         } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
              vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;

+ 3 - 3
llm/patches/06-qwen2.diff

@@ -1,7 +1,7 @@
-diff --git a/llama.cpp b/llama.cpp
+diff --git a/src/llama.cpp b/src/llama.cpp
 index 40d2ec2c..f34eb79a 100644
---- a/llama.cpp
-+++ b/llama.cpp
+--- a/src/llama.cpp
++++ b/src/llama.cpp
 @@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv(
          struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
          cb(kq, "kq", il);

+ 45 - 0
llm/patches/07-embeddings.diff

@@ -0,0 +1,45 @@
+diff --git a/src/llama.cpp b/src/llama.cpp
+index 1fe2b9f7..a43312a7 100644
+--- a/src/llama.cpp
++++ b/src/llama.cpp
+@@ -13689,7 +13689,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
+     const auto n_embd  = hparams.n_embd;
+ 
+     // TODO: use a per-batch flag for logits presence instead
+-    const bool has_logits = !cparams.embeddings;
++    const bool has_logits =  cparams.causal_attn;
+     const bool has_embd   =  lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
+ 
+     const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
+@@ -13959,17 +13959,25 @@ static int llama_decode_internal(
+             // no output
+             res  = nullptr;
+             embd = nullptr;
+-        } else if (cparams.embeddings) {
+-            res = nullptr; // do not extract logits for embedding case
+-            embd = gf->nodes[gf->n_nodes - 1];
+-            if (strcmp(embd->name, "result_embd_pooled") != 0) {
+-                embd = gf->nodes[gf->n_nodes - 2];
++        }
++
++        if (cparams.embeddings) {
++            for (int i = gf->n_nodes - 1; i >= 0; --i) {
++                embd = gf->nodes[i];
++                if (strcmp(embd->name, "result_embd_pooled") == 0) {
++                    break;
++                }
+             }
+             GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
+-        } else {
++         } else {
+             embd = nullptr; // do not extract embeddings when not needed
+             GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
+         }
++
++        if (!cparams.causal_attn) {
++            res = nullptr; // do not extract logits when not needed
++        }
++
+         // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
+ 
+         ggml_backend_sched_alloc_graph(lctx.sched, gf);

+ 42 - 0
llm/patches/08-clip-unicode.diff

@@ -0,0 +1,42 @@
+diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
+index 95fbe3d0..5a02a6ec 100644
+--- a/examples/llava/clip.cpp
++++ b/examples/llava/clip.cpp
+@@ -32,6 +33,14 @@
+ #include <cinttypes>
+ #include <limits>
+ 
++#if defined(_WIN32)
++#define WIN32_LEAN_AND_MEAN
++#ifndef NOMINMAX
++    #define NOMINMAX
++#endif
++#include <windows.h>
++#endif
++
+ //#define CLIP_DEBUG_FUNCTIONS
+ 
+ // RGB uint8 image
+@@ -1055,7 +1064,22 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
+             return nullptr;
+         }
+ 
++#ifdef _WIN32
++        int wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, NULL, 0);
++        if (!wlen) {
++            return NULL;
++        }
++        wchar_t * wbuf = (wchar_t *) malloc(wlen * sizeof(wchar_t));
++        wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, wbuf, wlen);
++        if (!wlen) {
++            free(wbuf);
++            return NULL;
++        }
++        auto fin = std::ifstream(wbuf, std::ios::binary);
++        free(wbuf);
++#else
+         auto fin = std::ifstream(fname, std::ios::binary);
++#endif
+         if (!fin) {
+             LOG_TEE("cannot open model file for loading tensors\n");
+             clip_free(new_clip);

+ 60 - 0
llm/patches/09-pooling.diff

@@ -0,0 +1,60 @@
+diff --git a/src/llama.cpp b/src/llama.cpp
+index 721b8f4e..cfe7ac40 100644
+--- a/src/llama.cpp
++++ b/src/llama.cpp
+@@ -8420,14 +8420,14 @@ struct llm_build_context {
+     }
+ 
+     struct ggml_tensor * build_inp_mean() {
+-        lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
++        lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, cparams.n_seq_max);
+         cb(lctx.inp_mean, "inp_mean", -1);
+         ggml_set_input(lctx.inp_mean);
+         return lctx.inp_mean;
+     }
+ 
+     struct ggml_tensor * build_inp_cls() {
+-        lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
++        lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_seq_max);
+         cb(lctx.inp_cls, "inp_cls", -1);
+         ggml_set_input(lctx.inp_cls);
+         return lctx.inp_cls;
+@@ -13847,19 +13847,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
+         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
+ 
+         float * data = (float *) lctx.inp_mean->data;
+-        memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
++        memset(lctx.inp_mean->data, 0, n_tokens * cparams.n_seq_max * ggml_element_size(lctx.inp_mean));
+ 
+         std::vector<uint64_t> sum(n_tokens, 0);
+         for (int i = 0; i < n_tokens; ++i) {
+             const llama_seq_id seq_id = batch.seq_id[i][0];
+-
+-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
+-
+             sum[seq_id] += 1;
+         }
+ 
+-        std::vector<float> div(n_tokens, 0.0f);
+-        for (int i = 0; i < n_tokens; ++i) {
++        std::vector<float> div(cparams.n_seq_max, 0.0f);
++        for (uint32_t i = 0; i < cparams.n_seq_max; ++i) {
+             const uint64_t s = sum[i];
+             if (s > 0) {
+                 div[i] = 1.0f/float(s);
+@@ -13879,14 +13876,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
+         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
+ 
+         uint32_t * data = (uint32_t *) lctx.inp_cls->data;
+-        memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
++        memset(lctx.inp_cls->data, 0, cparams.n_seq_max * ggml_element_size(lctx.inp_cls));
+ 
+         for (int i = 0; i < n_tokens; ++i) {
+             const llama_seq_id seq_id = batch.seq_id[i][0];
+             const llama_pos    pos    = batch.pos[i];
+-
+-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
+-
+             if (pos == 0) {
+                 data[seq_id] = i;
+             }

+ 23 - 21
llm/payload.go

@@ -38,7 +38,7 @@ func Init() error {
 	}
 
 	var variants []string
-	for v := range availableServers() {
+	for v := range getAvailableServers() {
 		variants = append(variants, v)
 	}
 	slog.Info(fmt.Sprintf("Dynamic LLM libraries %v", variants))
@@ -50,7 +50,7 @@ func Init() error {
 // binary names may contain an optional variant separated by '_'
 // For example, "ollama_rocm_v6" and "ollama_rocm_v5" or "ollama_cpu" and "ollama_cpu_avx2"
 // Any library without a variant is the lowest common denominator
-func availableServers() map[string]string {
+func getAvailableServers() map[string]string {
 	payloadsDir, err := gpu.PayloadsDir()
 	if err != nil {
 		slog.Error("payload lookup error", "error", err)
@@ -80,7 +80,7 @@ func availableServers() map[string]string {
 // TODO - switch to metadata based mapping
 func serversForGpu(info gpu.GpuInfo) []string {
 	// glob workDir for files that start with ollama_
-	availableServers := availableServers()
+	availableServers := getAvailableServers()
 	requested := info.Library
 	if info.Variant != gpu.CPUCapabilityNone {
 		requested += "_" + info.Variant.String()
@@ -115,27 +115,29 @@ func serversForGpu(info gpu.GpuInfo) []string {
 		servers = append(servers, alt...)
 	}
 
-	// Load up the best CPU variant if not primary requested
-	if info.Library != "cpu" {
-		variant := gpu.GetCPUCapability()
-		// If no variant, then we fall back to default
-		// If we have a variant, try that if we find an exact match
-		// Attempting to run the wrong CPU instructions will panic the
-		// process
-		if variant != gpu.CPUCapabilityNone {
-			for cmp := range availableServers {
-				if cmp == "cpu_"+variant.String() {
-					servers = append(servers, cmp)
-					break
+	if !(runtime.GOOS == "darwin" && runtime.GOARCH == "arm64") {
+		// Load up the best CPU variant if not primary requested
+		if info.Library != "cpu" {
+			variant := gpu.GetCPUCapability()
+			// If no variant, then we fall back to default
+			// If we have a variant, try that if we find an exact match
+			// Attempting to run the wrong CPU instructions will panic the
+			// process
+			if variant != gpu.CPUCapabilityNone {
+				for cmp := range availableServers {
+					if cmp == "cpu_"+variant.String() {
+						servers = append(servers, cmp)
+						break
+					}
 				}
+			} else {
+				servers = append(servers, "cpu")
 			}
-		} else {
-			servers = append(servers, "cpu")
 		}
-	}
 
-	if len(servers) == 0 {
-		servers = []string{"cpu"}
+		if len(servers) == 0 {
+			servers = []string{"cpu"}
+		}
 	}
 
 	return servers
@@ -147,7 +149,7 @@ func serverForCpu() string {
 		return "metal"
 	}
 	variant := gpu.GetCPUCapability()
-	availableServers := availableServers()
+	availableServers := getAvailableServers()
 	if variant != gpu.CPUCapabilityNone {
 		for cmp := range availableServers {
 			if cmp == "cpu_"+variant.String() {

+ 60 - 34
llm/server.go

@@ -60,7 +60,12 @@ type llmServer struct {
 	sem *semaphore.Weighted
 }
 
-func LoadModel(model string) (*GGML, error) {
+// LoadModel will load a model from disk. The model must be in the GGML format.
+//
+// It collects array values for arrays with a size less than or equal to
+// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
+// the maxArraySize is negative, all arrays are collected.
+func LoadModel(model string, maxArraySize int) (*GGML, error) {
 	if _, err := os.Stat(model); err != nil {
 		return nil, err
 	}
@@ -71,17 +76,29 @@ func LoadModel(model string) (*GGML, error) {
 	}
 	defer f.Close()
 
-	ggml, _, err := DecodeGGML(f)
+	ggml, _, err := DecodeGGML(f, maxArraySize)
 	return ggml, err
 }
 
 // NewLlamaServer will run a server for the given GPUs
 // The gpu list must be a single family.
-func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
+func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
 	var err error
 	var cpuRunner string
 	var estimate MemoryEstimate
-	var systemMemory uint64
+	var systemTotalMemory uint64
+	var systemFreeMemory uint64
+	var systemSwapFreeMemory uint64
+
+	systemMemInfo, err := gpu.GetCPUMem()
+	if err != nil {
+		slog.Error("failed to lookup system memory", "error", err)
+	} else {
+		systemTotalMemory = systemMemInfo.TotalMemory
+		systemFreeMemory = systemMemInfo.FreeMemory
+		systemSwapFreeMemory = systemMemInfo.FreeSwap
+		slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
+	}
 
 	// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
 	if opts.NumGPU == 0 {
@@ -91,19 +108,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		cpuRunner = serverForCpu()
 		estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
 	} else {
-		if gpus[0].Library == "metal" {
-			memInfo, err := gpu.GetCPUMem()
-			if err != nil {
-				slog.Error("failed to lookup system memory", "error", err)
-			} else {
-				systemMemory = memInfo.TotalMemory
-				slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
-			}
-		}
 		estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
 
 		switch {
-		case gpus[0].Library == "metal" && estimate.VRAMSize > systemMemory:
+		case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
 			// disable partial offloading when model is greater than total system memory as this
 			// can lead to locking up the system
 			opts.NumGPU = 0
@@ -116,6 +124,16 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		}
 	}
 
+	// On linux, over-allocating CPU memory will almost always result in an error
+	if runtime.GOOS == "linux" {
+		systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
+		available := min(systemTotalMemory, systemFreeMemory+systemSwapFreeMemory)
+		if systemMemoryRequired > available {
+			slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory))
+			return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
+		}
+	}
+
 	estimate.log()
 
 	// Loop through potential servers
@@ -125,7 +143,20 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
 	}
 
-	availableServers := availableServers()
+	availableServers := getAvailableServers()
+	if len(availableServers) == 0 {
+		if runtime.GOOS != "windows" {
+			slog.Warn("llama server binary disappeared, reinitializing payloads")
+			err = Init()
+			if err != nil {
+				slog.Warn("failed to reinitialize payloads", "error", err)
+				return nil, err
+			}
+			availableServers = getAvailableServers()
+		} else {
+			return nil, finalErr
+		}
+	}
 	var servers []string
 	if cpuRunner != "" {
 		servers = []string{cpuRunner}
@@ -202,7 +233,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		if g.Library == "metal" &&
 			uint64(opts.NumGPU) > 0 &&
 			uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
-			opts.UseMMap = api.TriStateFalse
+			opts.UseMMap = new(bool)
+			*opts.UseMMap = false
 		}
 	}
 
@@ -211,7 +243,12 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 	}
 
 	// Windows CUDA should not use mmap for best performance
-	if (runtime.GOOS == "windows" && gpus[0].Library == "cuda") || opts.UseMMap == api.TriStateFalse {
+	// Linux  with a model larger than free space, mmap leads to thrashing
+	// For CPU loads we want the memory to be allocated, not FS cache
+	if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == nil) ||
+		(runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == nil) ||
+		(gpus[0].Library == "cpu" && opts.UseMMap == nil) ||
+		(opts.UseMMap != nil && !*opts.UseMMap) {
 		params = append(params, "--no-mmap")
 	}
 
@@ -223,25 +260,12 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		params = append(params, "--numa")
 	}
 
-	numParallel := envconfig.NumParallel
-
-	// TODO (jmorganca): multimodal models don't support parallel yet
-	// see https://github.com/ollama/ollama/issues/4165
-	if len(projectors) > 0 {
-		numParallel = 1
-		slog.Warn("multimodal models don't support parallel requests yet")
-	}
-
 	params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
 
 	if estimate.TensorSplit != "" {
 		params = append(params, "--tensor-split", estimate.TensorSplit)
 	}
 
-	if estimate.TensorSplit != "" {
-		params = append(params, "--tensor-split", estimate.TensorSplit)
-	}
-
 	for i := range len(servers) {
 		dir := availableServers[servers[i]]
 		if dir == "" {
@@ -408,7 +432,7 @@ func projectorMemoryRequirements(filename string) uint64 {
 	}
 	defer file.Close()
 
-	ggml, _, err := DecodeGGML(file)
+	ggml, _, err := DecodeGGML(file, 0)
 	if err != nil {
 		return 0
 	}
@@ -558,6 +582,9 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
 			if s.status != nil && s.status.LastErrMsg != "" {
 				msg = s.status.LastErrMsg
 			}
+			if strings.Contains(msg, "unknown model") {
+				return fmt.Errorf("this model is not supported by your version of Ollama. You may need to upgrade")
+			}
 			return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
 		default:
 		}
@@ -660,7 +687,7 @@ type CompletionRequest struct {
 	Prompt  string
 	Format  string
 	Images  []ImageData
-	Options api.Options
+	Options *api.Options
 }
 
 type CompletionResponse struct {
@@ -680,10 +707,9 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 	}
 	defer s.sem.Release(1)
 
-	// only allow maximum 10 "context shifts" to avoid infinite generation
+	// put an upper limit on num_predict to avoid the model running on forever
 	if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
 		req.Options.NumPredict = 10 * s.options.NumCtx
-		slog.Debug("setting token limit to 10x num_ctx", "num_ctx", s.options.NumCtx, "num_predict", req.Options.NumPredict)
 	}
 
 	request := map[string]any{

+ 1 - 0
llm/status.go

@@ -25,6 +25,7 @@ var errorPrefixes = []string{
 	"CUDA error",
 	"cudaMalloc failed",
 	"\"ERR\"",
+	"error loading model",
 }
 
 func (w *StatusWriter) Write(b []byte) (int, error) {

+ 375 - 13
openai/openai.go

@@ -12,6 +12,7 @@ import (
 
 	"github.com/gin-gonic/gin"
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/types/model"
 )
 
 type Error struct {
@@ -42,6 +43,12 @@ type ChunkChoice struct {
 	FinishReason *string `json:"finish_reason"`
 }
 
+type CompleteChunkChoice struct {
+	Text         string  `json:"text"`
+	Index        int     `json:"index"`
+	FinishReason *string `json:"finish_reason"`
+}
+
 type Usage struct {
 	PromptTokens     int `json:"prompt_tokens"`
 	CompletionTokens int `json:"completion_tokens"`
@@ -85,6 +92,51 @@ type ChatCompletionChunk struct {
 	Choices           []ChunkChoice `json:"choices"`
 }
 
+// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
+type CompletionRequest struct {
+	Model            string   `json:"model"`
+	Prompt           string   `json:"prompt"`
+	FrequencyPenalty float32  `json:"frequency_penalty"`
+	MaxTokens        *int     `json:"max_tokens"`
+	PresencePenalty  float32  `json:"presence_penalty"`
+	Seed             *int     `json:"seed"`
+	Stop             any      `json:"stop"`
+	Stream           bool     `json:"stream"`
+	Temperature      *float32 `json:"temperature"`
+	TopP             float32  `json:"top_p"`
+}
+
+type Completion struct {
+	Id                string                `json:"id"`
+	Object            string                `json:"object"`
+	Created           int64                 `json:"created"`
+	Model             string                `json:"model"`
+	SystemFingerprint string                `json:"system_fingerprint"`
+	Choices           []CompleteChunkChoice `json:"choices"`
+	Usage             Usage                 `json:"usage,omitempty"`
+}
+
+type CompletionChunk struct {
+	Id                string                `json:"id"`
+	Object            string                `json:"object"`
+	Created           int64                 `json:"created"`
+	Choices           []CompleteChunkChoice `json:"choices"`
+	Model             string                `json:"model"`
+	SystemFingerprint string                `json:"system_fingerprint"`
+}
+
+type Model struct {
+	Id      string `json:"id"`
+	Object  string `json:"object"`
+	Created int64  `json:"created"`
+	OwnedBy string `json:"owned_by"`
+}
+
+type ListCompletion struct {
+	Object string  `json:"object"`
+	Data   []Model `json:"data"`
+}
+
 func NewError(code int, message string) ErrorResponse {
 	var etype string
 	switch code {
@@ -145,7 +197,79 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
 	}
 }
 
-func fromRequest(r ChatCompletionRequest) api.ChatRequest {
+func toCompletion(id string, r api.GenerateResponse) Completion {
+	return Completion{
+		Id:                id,
+		Object:            "text_completion",
+		Created:           r.CreatedAt.Unix(),
+		Model:             r.Model,
+		SystemFingerprint: "fp_ollama",
+		Choices: []CompleteChunkChoice{{
+			Text:  r.Response,
+			Index: 0,
+			FinishReason: func(reason string) *string {
+				if len(reason) > 0 {
+					return &reason
+				}
+				return nil
+			}(r.DoneReason),
+		}},
+		Usage: Usage{
+			// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
+			PromptTokens:     r.PromptEvalCount,
+			CompletionTokens: r.EvalCount,
+			TotalTokens:      r.PromptEvalCount + r.EvalCount,
+		},
+	}
+}
+
+func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
+	return CompletionChunk{
+		Id:                id,
+		Object:            "text_completion",
+		Created:           time.Now().Unix(),
+		Model:             r.Model,
+		SystemFingerprint: "fp_ollama",
+		Choices: []CompleteChunkChoice{{
+			Text:  r.Response,
+			Index: 0,
+			FinishReason: func(reason string) *string {
+				if len(reason) > 0 {
+					return &reason
+				}
+				return nil
+			}(r.DoneReason),
+		}},
+	}
+}
+
+func toListCompletion(r api.ListResponse) ListCompletion {
+	var data []Model
+	for _, m := range r.Models {
+		data = append(data, Model{
+			Id:      m.Name,
+			Object:  "model",
+			Created: m.ModifiedAt.Unix(),
+			OwnedBy: model.ParseName(m.Name).Namespace,
+		})
+	}
+
+	return ListCompletion{
+		Object: "list",
+		Data:   data,
+	}
+}
+
+func toModel(r api.ShowResponse, m string) Model {
+	return Model{
+		Id:      m,
+		Object:  "model",
+		Created: r.ModifiedAt.Unix(),
+		OwnedBy: model.ParseName(m).Namespace,
+	}
+}
+
+func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
 	var messages []api.Message
 	for _, msg := range r.Messages {
 		messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
@@ -156,7 +280,7 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
 	switch stop := r.Stop.(type) {
 	case string:
 		options["stop"] = []string{stop}
-	case []interface{}:
+	case []any:
 		var stops []string
 		for _, s := range stop {
 			if str, ok := s.(string); ok {
@@ -208,13 +332,82 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
 	}
 }
 
-type writer struct {
+func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
+	options := make(map[string]any)
+
+	switch stop := r.Stop.(type) {
+	case string:
+		options["stop"] = []string{stop}
+	case []any:
+		var stops []string
+		for _, s := range stop {
+			if str, ok := s.(string); ok {
+				stops = append(stops, str)
+			} else {
+				return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
+			}
+		}
+		options["stop"] = stops
+	}
+
+	if r.MaxTokens != nil {
+		options["num_predict"] = *r.MaxTokens
+	}
+
+	if r.Temperature != nil {
+		options["temperature"] = *r.Temperature * 2.0
+	} else {
+		options["temperature"] = 1.0
+	}
+
+	if r.Seed != nil {
+		options["seed"] = *r.Seed
+	}
+
+	options["frequency_penalty"] = r.FrequencyPenalty * 2.0
+
+	options["presence_penalty"] = r.PresencePenalty * 2.0
+
+	if r.TopP != 0.0 {
+		options["top_p"] = r.TopP
+	} else {
+		options["top_p"] = 1.0
+	}
+
+	return api.GenerateRequest{
+		Model:   r.Model,
+		Prompt:  r.Prompt,
+		Options: options,
+		Stream:  &r.Stream,
+	}, nil
+}
+
+type BaseWriter struct {
+	gin.ResponseWriter
+}
+
+type ChatWriter struct {
 	stream bool
 	id     string
-	gin.ResponseWriter
+	BaseWriter
+}
+
+type CompleteWriter struct {
+	stream bool
+	id     string
+	BaseWriter
+}
+
+type ListWriter struct {
+	BaseWriter
+}
+
+type RetrieveWriter struct {
+	BaseWriter
+	model string
 }
 
-func (w *writer) writeError(code int, data []byte) (int, error) {
+func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
 	var serr api.StatusError
 	err := json.Unmarshal(data, &serr)
 	if err != nil {
@@ -230,7 +423,7 @@ func (w *writer) writeError(code int, data []byte) (int, error) {
 	return len(data), nil
 }
 
-func (w *writer) writeResponse(data []byte) (int, error) {
+func (w *ChatWriter) writeResponse(data []byte) (int, error) {
 	var chatResponse api.ChatResponse
 	err := json.Unmarshal(data, &chatResponse)
 	if err != nil {
@@ -270,7 +463,7 @@ func (w *writer) writeResponse(data []byte) (int, error) {
 	return len(data), nil
 }
 
-func (w *writer) Write(data []byte) (int, error) {
+func (w *ChatWriter) Write(data []byte) (int, error) {
 	code := w.ResponseWriter.Status()
 	if code != http.StatusOK {
 		return w.writeError(code, data)
@@ -279,7 +472,176 @@ func (w *writer) Write(data []byte) (int, error) {
 	return w.writeResponse(data)
 }
 
-func Middleware() gin.HandlerFunc {
+func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
+	var generateResponse api.GenerateResponse
+	err := json.Unmarshal(data, &generateResponse)
+	if err != nil {
+		return 0, err
+	}
+
+	// completion chunk
+	if w.stream {
+		d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
+		if err != nil {
+			return 0, err
+		}
+
+		w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
+		_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
+		if err != nil {
+			return 0, err
+		}
+
+		if generateResponse.Done {
+			_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
+			if err != nil {
+				return 0, err
+			}
+		}
+
+		return len(data), nil
+	}
+
+	// completion
+	w.ResponseWriter.Header().Set("Content-Type", "application/json")
+	err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
+	if err != nil {
+		return 0, err
+	}
+
+	return len(data), nil
+}
+
+func (w *CompleteWriter) Write(data []byte) (int, error) {
+	code := w.ResponseWriter.Status()
+	if code != http.StatusOK {
+		return w.writeError(code, data)
+	}
+
+	return w.writeResponse(data)
+}
+
+func (w *ListWriter) writeResponse(data []byte) (int, error) {
+	var listResponse api.ListResponse
+	err := json.Unmarshal(data, &listResponse)
+	if err != nil {
+		return 0, err
+	}
+
+	w.ResponseWriter.Header().Set("Content-Type", "application/json")
+	err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
+	if err != nil {
+		return 0, err
+	}
+
+	return len(data), nil
+}
+
+func (w *ListWriter) Write(data []byte) (int, error) {
+	code := w.ResponseWriter.Status()
+	if code != http.StatusOK {
+		return w.writeError(code, data)
+	}
+
+	return w.writeResponse(data)
+}
+
+func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
+	var showResponse api.ShowResponse
+	err := json.Unmarshal(data, &showResponse)
+	if err != nil {
+		return 0, err
+	}
+
+	// retrieve completion
+	w.ResponseWriter.Header().Set("Content-Type", "application/json")
+	err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
+	if err != nil {
+		return 0, err
+	}
+
+	return len(data), nil
+}
+
+func (w *RetrieveWriter) Write(data []byte) (int, error) {
+	code := w.ResponseWriter.Status()
+	if code != http.StatusOK {
+		return w.writeError(code, data)
+	}
+
+	return w.writeResponse(data)
+}
+
+func ListMiddleware() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		w := &ListWriter{
+			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+		}
+
+		c.Writer = w
+
+		c.Next()
+	}
+}
+
+func RetrieveMiddleware() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		var b bytes.Buffer
+		if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
+			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
+			return
+		}
+
+		c.Request.Body = io.NopCloser(&b)
+
+		// response writer
+		w := &RetrieveWriter{
+			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+			model:      c.Param("model"),
+		}
+
+		c.Writer = w
+
+		c.Next()
+	}
+}
+
+func CompletionsMiddleware() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		var req CompletionRequest
+		err := c.ShouldBindJSON(&req)
+		if err != nil {
+			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
+			return
+		}
+
+		var b bytes.Buffer
+		genReq, err := fromCompleteRequest(req)
+		if err != nil {
+			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
+			return
+		}
+
+		if err := json.NewEncoder(&b).Encode(genReq); err != nil {
+			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
+			return
+		}
+
+		c.Request.Body = io.NopCloser(&b)
+
+		w := &CompleteWriter{
+			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+			stream:     req.Stream,
+			id:         fmt.Sprintf("cmpl-%d", rand.Intn(999)),
+		}
+
+		c.Writer = w
+
+		c.Next()
+	}
+}
+
+func ChatMiddleware() gin.HandlerFunc {
 	return func(c *gin.Context) {
 		var req ChatCompletionRequest
 		err := c.ShouldBindJSON(&req)
@@ -294,17 +656,17 @@ func Middleware() gin.HandlerFunc {
 		}
 
 		var b bytes.Buffer
-		if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil {
+		if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
 			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
 			return
 		}
 
 		c.Request.Body = io.NopCloser(&b)
 
-		w := &writer{
-			ResponseWriter: c.Writer,
-			stream:         req.Stream,
-			id:             fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
+		w := &ChatWriter{
+			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+			stream:     req.Stream,
+			id:         fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
 		}
 
 		c.Writer = w

+ 271 - 0
openai/openai_test.go

@@ -0,0 +1,271 @@
+package openai
+
+import (
+	"bytes"
+	"encoding/json"
+	"io"
+	"net/http"
+	"net/http/httptest"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/gin-gonic/gin"
+	"github.com/ollama/ollama/api"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestMiddlewareRequests(t *testing.T) {
+	type testCase struct {
+		Name     string
+		Method   string
+		Path     string
+		Handler  func() gin.HandlerFunc
+		Setup    func(t *testing.T, req *http.Request)
+		Expected func(t *testing.T, req *http.Request)
+	}
+
+	var capturedRequest *http.Request
+
+	captureRequestMiddleware := func() gin.HandlerFunc {
+		return func(c *gin.Context) {
+			bodyBytes, _ := io.ReadAll(c.Request.Body)
+			c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+			capturedRequest = c.Request
+			c.Next()
+		}
+	}
+
+	testCases := []testCase{
+		{
+			Name:    "chat handler",
+			Method:  http.MethodPost,
+			Path:    "/api/chat",
+			Handler: ChatMiddleware,
+			Setup: func(t *testing.T, req *http.Request) {
+				body := ChatCompletionRequest{
+					Model:    "test-model",
+					Messages: []Message{{Role: "user", Content: "Hello"}},
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, req *http.Request) {
+				var chatReq api.ChatRequest
+				if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
+					t.Fatal(err)
+				}
+
+				if chatReq.Messages[0].Role != "user" {
+					t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
+				}
+
+				if chatReq.Messages[0].Content != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
+				}
+			},
+		},
+		{
+			Name:    "completions handler",
+			Method:  http.MethodPost,
+			Path:    "/api/generate",
+			Handler: CompletionsMiddleware,
+			Setup: func(t *testing.T, req *http.Request) {
+				temp := float32(0.8)
+				body := CompletionRequest{
+					Model:       "test-model",
+					Prompt:      "Hello",
+					Temperature: &temp,
+					Stop:        []string{"\n", "stop"},
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, req *http.Request) {
+				var genReq api.GenerateRequest
+				if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
+					t.Fatal(err)
+				}
+
+				if genReq.Prompt != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
+				}
+
+				if genReq.Options["temperature"] != 1.6 {
+					t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
+				}
+
+				stopTokens, ok := genReq.Options["stop"].([]any)
+
+				if !ok {
+					t.Fatalf("expected stop tokens to be a list")
+				}
+
+				if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
+					t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
+				}
+			},
+		},
+	}
+
+	gin.SetMode(gin.TestMode)
+	router := gin.New()
+
+	endpoint := func(c *gin.Context) {
+		c.Status(http.StatusOK)
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.Name, func(t *testing.T) {
+			router = gin.New()
+			router.Use(captureRequestMiddleware())
+			router.Use(tc.Handler())
+			router.Handle(tc.Method, tc.Path, endpoint)
+			req, _ := http.NewRequest(tc.Method, tc.Path, nil)
+
+			if tc.Setup != nil {
+				tc.Setup(t, req)
+			}
+
+			resp := httptest.NewRecorder()
+			router.ServeHTTP(resp, req)
+
+			tc.Expected(t, capturedRequest)
+		})
+	}
+}
+
+func TestMiddlewareResponses(t *testing.T) {
+	type testCase struct {
+		Name     string
+		Method   string
+		Path     string
+		TestPath string
+		Handler  func() gin.HandlerFunc
+		Endpoint func(c *gin.Context)
+		Setup    func(t *testing.T, req *http.Request)
+		Expected func(t *testing.T, resp *httptest.ResponseRecorder)
+	}
+
+	testCases := []testCase{
+		{
+			Name:     "completions handler error forwarding",
+			Method:   http.MethodPost,
+			Path:     "/api/generate",
+			TestPath: "/api/generate",
+			Handler:  CompletionsMiddleware,
+			Endpoint: func(c *gin.Context) {
+				c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
+			},
+			Setup: func(t *testing.T, req *http.Request) {
+				body := CompletionRequest{
+					Model:  "test-model",
+					Prompt: "Hello",
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				if resp.Code != http.StatusBadRequest {
+					t.Fatalf("expected 400, got %d", resp.Code)
+				}
+
+				if !strings.Contains(resp.Body.String(), `"invalid request"`) {
+					t.Fatalf("error was not forwarded")
+				}
+			},
+		},
+		{
+			Name:     "list handler",
+			Method:   http.MethodGet,
+			Path:     "/api/tags",
+			TestPath: "/api/tags",
+			Handler:  ListMiddleware,
+			Endpoint: func(c *gin.Context) {
+				c.JSON(http.StatusOK, api.ListResponse{
+					Models: []api.ListModelResponse{
+						{
+							Name: "Test Model",
+						},
+					},
+				})
+			},
+			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				assert.Equal(t, http.StatusOK, resp.Code)
+
+				var listResp ListCompletion
+				if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
+					t.Fatal(err)
+				}
+
+				if listResp.Object != "list" {
+					t.Fatalf("expected list, got %s", listResp.Object)
+				}
+
+				if len(listResp.Data) != 1 {
+					t.Fatalf("expected 1, got %d", len(listResp.Data))
+				}
+
+				if listResp.Data[0].Id != "Test Model" {
+					t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
+				}
+			},
+		},
+		{
+			Name:     "retrieve model",
+			Method:   http.MethodGet,
+			Path:     "/api/show/:model",
+			TestPath: "/api/show/test-model",
+			Handler:  RetrieveMiddleware,
+			Endpoint: func(c *gin.Context) {
+				c.JSON(http.StatusOK, api.ShowResponse{
+					ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
+				})
+			},
+			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				var retrieveResp Model
+				if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
+					t.Fatal(err)
+				}
+
+				if retrieveResp.Object != "model" {
+					t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
+				}
+
+				if retrieveResp.Id != "test-model" {
+					t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
+				}
+			},
+		},
+	}
+
+	gin.SetMode(gin.TestMode)
+	router := gin.New()
+
+	for _, tc := range testCases {
+		t.Run(tc.Name, func(t *testing.T) {
+			router = gin.New()
+			router.Use(tc.Handler())
+			router.Handle(tc.Method, tc.Path, tc.Endpoint)
+			req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
+
+			if tc.Setup != nil {
+				tc.Setup(t, req)
+			}
+
+			resp := httptest.NewRecorder()
+			router.ServeHTTP(resp, req)
+
+			tc.Expected(t, resp)
+		})
+	}
+}

+ 2 - 2
parser/parser.go

@@ -124,7 +124,7 @@ func ParseFile(r io.Reader) (*File, error) {
 			case stateComment, stateNil:
 				// pass
 			case stateValue:
-				s, ok := unquote(b.String())
+				s, ok := unquote(strings.TrimSpace(b.String()))
 				if !ok || isSpace(r) {
 					if _, err := b.WriteRune(r); err != nil {
 						return nil, err
@@ -158,7 +158,7 @@ func ParseFile(r io.Reader) (*File, error) {
 	case stateComment, stateNil:
 		// pass; nothing to flush
 	case stateValue:
-		s, ok := unquote(b.String())
+		s, ok := unquote(strings.TrimSpace(b.String()))
 		if !ok {
 			return nil, io.ErrUnexpectedEOF
 		}

+ 67 - 3
parser/parser_test.go

@@ -22,7 +22,13 @@ ADAPTER adapter1
 LICENSE MIT
 PARAMETER param1 value1
 PARAMETER param2 value2
-TEMPLATE template1
+TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
+
+{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
+
+{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
+
+{{ .Response }}<|eot_id|>"""    
 `
 
 	reader := strings.NewReader(input)
@@ -36,7 +42,40 @@ TEMPLATE template1
 		{Name: "license", Args: "MIT"},
 		{Name: "param1", Args: "value1"},
 		{Name: "param2", Args: "value2"},
-		{Name: "template", Args: "template1"},
+		{Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"},
+	}
+
+	assert.Equal(t, expectedCommands, modelfile.Commands)
+}
+
+func TestParseFileTrimSpace(t *testing.T) {
+	input := `
+FROM "     model 1"
+ADAPTER      adapter3
+LICENSE "MIT       "
+PARAMETER param1        value1
+PARAMETER param2    value2
+TEMPLATE """   {{ if .System }}<|start_header_id|>system<|end_header_id|>
+
+{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
+
+{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
+
+{{ .Response }}<|eot_id|>   """    
+`
+
+	reader := strings.NewReader(input)
+
+	modelfile, err := ParseFile(reader)
+	require.NoError(t, err)
+
+	expectedCommands := []Command{
+		{Name: "model", Args: "     model 1"},
+		{Name: "adapter", Args: "adapter3"},
+		{Name: "license", Args: "MIT       "},
+		{Name: "param1", Args: "value1"},
+		{Name: "param2", Args: "value2"},
+		{Name: "template", Args: "   {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>   "},
 	}
 
 	assert.Equal(t, expectedCommands, modelfile.Commands)
@@ -48,6 +87,26 @@ func TestParseFileFrom(t *testing.T) {
 		expected []Command
 		err      error
 	}{
+		{
+			"FROM \"FOO  BAR  \"",
+			[]Command{{Name: "model", Args: "FOO  BAR  "}},
+			nil,
+		},
+		{
+			"FROM \"FOO BAR\"\nPARAMETER param1 value1",
+			[]Command{{Name: "model", Args: "FOO BAR"}, {Name: "param1", Args: "value1"}},
+			nil,
+		},
+		{
+			"FROM     FOOO BAR    ",
+			[]Command{{Name: "model", Args: "FOOO BAR"}},
+			nil,
+		},
+		{
+			"FROM /what/is/the path ",
+			[]Command{{Name: "model", Args: "/what/is/the path"}},
+			nil,
+		},
 		{
 			"FROM foo",
 			[]Command{{Name: "model", Args: "foo"}},
@@ -86,6 +145,11 @@ func TestParseFileFrom(t *testing.T) {
 			[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
 			nil,
 		},
+		{
+			"PARAMETER what the \nFROM lemons make lemonade ",
+			[]Command{{Name: "what", Args: "the"}, {Name: "model", Args: "lemons make lemonade"}},
+			nil,
+		},
 	}
 
 	for _, c := range cases {
@@ -399,7 +463,7 @@ func TestParseFileParameters(t *testing.T) {
 		"mirostat_eta 1.0":             {"mirostat_eta", "1.0"},
 		"penalize_newline true":        {"penalize_newline", "true"},
 		"stop ### User:":               {"stop", "### User:"},
-		"stop ### User: ":              {"stop", "### User: "},
+		"stop ### User: ":              {"stop", "### User:"},
 		"stop \"### User:\"":           {"stop", "### User:"},
 		"stop \"### User: \"":          {"stop", "### User: "},
 		"stop \"\"\"### User:\"\"\"":   {"stop", "### User:"},

+ 4 - 1
scripts/build_windows.ps1

@@ -107,9 +107,12 @@ function gatherDependencies() {
 
     # TODO - this varies based on host build system and MSVC version - drive from dumpbin output
     # currently works for Win11 + MSVC 2019 + Cuda V11
-    cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\ollama_runners\"
+    cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140*.dll" "${script:DEPS_DIR}\ollama_runners\"
     cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
     cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
+    foreach ($part in $("runtime", "stdio", "filesystem", "math", "convert", "heap", "string", "time", "locale", "environment")) {
+        cp "$env:VCToolsRedistDir\..\..\..\Tools\Llvm\x64\bin\api-ms-win-crt-${part}*.dll" "${script:DEPS_DIR}\ollama_runners\"
+    }
 
 
     cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"

+ 1 - 1
scripts/install.sh

@@ -279,7 +279,7 @@ if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\
     case $OS_NAME in
         centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
         rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;;
-        fedora) [ $OS_VERSION -lt '37' ] && install_cuda_driver_yum $OS_NAME $OS_VERSION || install_cuda_driver_yum $OS_NAME '37';;
+        fedora) [ $OS_VERSION -lt '39' ] && install_cuda_driver_yum $OS_NAME $OS_VERSION || install_cuda_driver_yum $OS_NAME '39';;
         amzn) install_cuda_driver_yum 'fedora' '37' ;;
         debian) install_cuda_driver_apt $OS_NAME $OS_VERSION ;;
         ubuntu) install_cuda_driver_apt $OS_NAME $(echo $OS_VERSION | sed 's/\.//') ;;

+ 11 - 0
scripts/rh_linux_deps.sh

@@ -6,10 +6,21 @@ set -ex
 MACHINE=$(uname -m)
 
 if grep -i "centos" /etc/system-release >/dev/null; then
+    # As of 7/1/2024 mirrorlist.centos.org has been taken offline, so adjust accordingly
+    sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo
+    sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo
+    sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo
+
     # Centos 7 derivatives have too old of a git version to run our generate script
     # uninstall and ignore failures
     yum remove -y git
     yum -y install epel-release centos-release-scl
+
+    # The release packages reinstate the mirrors, undo that again
+    sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo
+    sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo
+    sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo
+
     yum -y install dnf
     if [ "${MACHINE}" = "x86_64" ]; then
         yum -y install https://repo.ius.io/ius-release-el7.rpm

+ 67 - 32
server/images.go

@@ -28,11 +28,18 @@ import (
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/parser"
+	"github.com/ollama/ollama/template"
 	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
 )
 
+var errCapabilityCompletion = errors.New("completion")
+
+type Capability string
+
+const CapabilityCompletion = Capability("completion")
+
 type registryOptions struct {
 	Insecure bool
 	Username string
@@ -48,16 +55,50 @@ type Model struct {
 	ParentModel    string
 	AdapterPaths   []string
 	ProjectorPaths []string
-	Template       string
 	System         string
 	License        []string
 	Digest         string
 	Options        map[string]interface{}
 	Messages       []Message
+
+	Template *template.Template
 }
 
-func (m *Model) IsEmbedding() bool {
-	return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
+// CheckCapabilities checks if the model has the specified capabilities returning an error describing
+// any missing or unknown capabilities
+func (m *Model) CheckCapabilities(caps ...Capability) error {
+	var errs []error
+	for _, cap := range caps {
+		switch cap {
+		case CapabilityCompletion:
+			f, err := os.Open(m.ModelPath)
+			if err != nil {
+				slog.Error("couldn't open model file", "error", err)
+				continue
+			}
+			defer f.Close()
+
+			// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
+			ggml, _, err := llm.DecodeGGML(f, 0)
+			if err != nil {
+				slog.Error("couldn't decode ggml", "error", err)
+				continue
+			}
+
+			if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
+				errs = append(errs, errCapabilityCompletion)
+			}
+		default:
+			slog.Error("unknown capability", "capability", cap)
+			return fmt.Errorf("unknown capability: %s", cap)
+		}
+	}
+
+	if err := errors.Join(errs...); err != nil {
+		return fmt.Errorf("missing capabilities: %w", errors.Join(errs...))
+	}
+
+	return nil
 }
 
 func (m *Model) String() string {
@@ -82,10 +123,10 @@ func (m *Model) String() string {
 		})
 	}
 
-	if m.Template != "" {
+	if m.Template != nil {
 		modelfile.Commands = append(modelfile.Commands, parser.Command{
 			Name: "template",
-			Args: m.Template,
+			Args: m.Template.String(),
 		})
 	}
 
@@ -135,13 +176,6 @@ type Message struct {
 	Content string `json:"content"`
 }
 
-type ManifestV2 struct {
-	SchemaVersion int      `json:"schemaVersion"`
-	MediaType     string   `json:"mediaType"`
-	Config        *Layer   `json:"config"`
-	Layers        []*Layer `json:"layers"`
-}
-
 type ConfigV2 struct {
 	ModelFormat   string   `json:"model_format"`
 	ModelFamily   string   `json:"model_family"`
@@ -160,7 +194,7 @@ type RootFS struct {
 	DiffIDs []string `json:"diff_ids"`
 }
 
-func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
+func GetManifest(mp ModelPath) (*Manifest, string, error) {
 	fp, err := mp.GetManifestPath()
 	if err != nil {
 		return nil, "", err
@@ -170,7 +204,7 @@ func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
 		return nil, "", err
 	}
 
-	var manifest *ManifestV2
+	var manifest *Manifest
 
 	bts, err := os.ReadFile(fp)
 	if err != nil {
@@ -198,8 +232,7 @@ func GetModel(name string) (*Model, error) {
 		Name:      mp.GetFullTagname(),
 		ShortName: mp.GetShortTagname(),
 		Digest:    digest,
-		Template:  "{{ .Prompt }}",
-		License:   []string{},
+		Template:  template.DefaultTemplate,
 	}
 
 	filename, err := GetBlobsPath(manifest.Config.Digest)
@@ -235,27 +268,24 @@ func GetModel(name string) (*Model, error) {
 			model.AdapterPaths = append(model.AdapterPaths, filename)
 		case "application/vnd.ollama.image.projector":
 			model.ProjectorPaths = append(model.ProjectorPaths, filename)
-		case "application/vnd.ollama.image.template":
+		case "application/vnd.ollama.image.prompt",
+			"application/vnd.ollama.image.template":
 			bts, err := os.ReadFile(filename)
 			if err != nil {
 				return nil, err
 			}
 
-			model.Template = string(bts)
-		case "application/vnd.ollama.image.system":
-			bts, err := os.ReadFile(filename)
+			model.Template, err = template.Parse(string(bts))
 			if err != nil {
 				return nil, err
 			}
-
-			model.System = string(bts)
-		case "application/vnd.ollama.image.prompt":
+		case "application/vnd.ollama.image.system":
 			bts, err := os.ReadFile(filename)
 			if err != nil {
 				return nil, err
 			}
 
-			model.Template = string(bts)
+			model.System = string(bts)
 		case "application/vnd.ollama.image.params":
 			params, err := os.Open(filename)
 			if err != nil {
@@ -414,17 +444,22 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 							return err
 						}
 
-						layers, err := parseFromFile(ctx, temp, "", fn)
+						layer, err := NewLayer(temp, baseLayer.MediaType)
 						if err != nil {
 							return err
 						}
 
-						if len(layers) != 1 {
-							return errors.New("quantization failed")
+						if _, err := temp.Seek(0, io.SeekStart); err != nil {
+							return err
+						}
+
+						ggml, _, err := llm.DecodeGGML(temp, 0)
+						if err != nil {
+							return err
 						}
 
-						baseLayer.Layer = layers[0].Layer
-						baseLayer.GGML = layers[0].GGML
+						baseLayer.Layer = layer
+						baseLayer.GGML = ggml
 					}
 				}
 
@@ -817,7 +852,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
 func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
 	mp := ParseModelPath(name)
 
-	var manifest *ManifestV2
+	var manifest *Manifest
 	var err error
 	var noprune string
 
@@ -924,7 +959,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
 	return nil
 }
 
-func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) {
+func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
 	requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
 
 	headers := make(http.Header)
@@ -935,7 +970,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
 	}
 	defer resp.Body.Close()
 
-	var m *ManifestV2
+	var m *Manifest
 	if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
 		return nil, err
 	}

+ 11 - 9
server/manifest.go

@@ -14,7 +14,10 @@ import (
 )
 
 type Manifest struct {
-	ManifestV2
+	SchemaVersion int      `json:"schemaVersion"`
+	MediaType     string   `json:"mediaType"`
+	Config        *Layer   `json:"config"`
+	Layers        []*Layer `json:"layers"`
 
 	filepath string
 	fi       os.FileInfo
@@ -66,7 +69,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
 
 	p := filepath.Join(manifests, n.Filepath())
 
-	var m ManifestV2
+	var m Manifest
 	f, err := os.Open(p)
 	if err != nil {
 		return nil, err
@@ -83,12 +86,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
 		return nil, err
 	}
 
-	return &Manifest{
-		ManifestV2: m,
-		filepath:   p,
-		fi:         fi,
-		digest:     fmt.Sprintf("%x", sha256sum.Sum(nil)),
-	}, nil
+	m.filepath = p
+	m.fi = fi
+	m.digest = fmt.Sprintf("%x", sha256sum.Sum(nil))
+
+	return &m, nil
 }
 
 func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
@@ -108,7 +110,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
 	}
 	defer f.Close()
 
-	m := ManifestV2{
+	m := Manifest{
 		SchemaVersion: 2,
 		MediaType:     "application/vnd.docker.distribution.manifest.v2+json",
 		Config:        config,

+ 1 - 1
server/manifest_test.go

@@ -25,7 +25,7 @@ func createManifest(t *testing.T, path, name string) {
 	}
 	defer f.Close()
 
-	if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
+	if err := json.NewEncoder(f).Encode(Manifest{}); err != nil {
 		t.Fatal(err)
 	}
 }

+ 41 - 24
server/model.go

@@ -15,7 +15,7 @@ import (
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/convert"
 	"github.com/ollama/ollama/llm"
-	"github.com/ollama/ollama/templates"
+	"github.com/ollama/ollama/template"
 	"github.com/ollama/ollama/types/model"
 )
 
@@ -63,7 +63,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
 			}
 			defer blob.Close()
 
-			ggml, _, err := llm.DecodeGGML(blob)
+			ggml, _, err := llm.DecodeGGML(blob, 0)
 			if err != nil {
 				return nil, err
 			}
@@ -77,62 +77,79 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
 	return layers, nil
 }
 
-func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
+func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse)) error {
 	stat, err := file.Stat()
 	if err != nil {
-		return nil, err
+		return err
 	}
 
 	r, err := zip.NewReader(file, stat.Size())
 	if err != nil {
-		return nil, err
-	}
-
-	tempdir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
-	if err != nil {
-		return nil, err
+		return err
 	}
-	defer os.RemoveAll(tempdir)
 
 	fn(api.ProgressResponse{Status: "unpacking model metadata"})
 	for _, f := range r.File {
+		if !filepath.IsLocal(f.Name) {
+			return fmt.Errorf("%w: %s", zip.ErrInsecurePath, f.Name)
+		}
+
+		n := filepath.Join(p, f.Name)
+		if err := os.MkdirAll(filepath.Dir(n), 0o750); err != nil {
+			return err
+		}
+
 		// TODO(mxyng): this should not write out all files to disk
-		outfile, err := os.Create(filepath.Join(tempdir, f.Name))
+		outfile, err := os.Create(n)
 		if err != nil {
-			return nil, err
+			return err
 		}
 		defer outfile.Close()
 
 		infile, err := f.Open()
 		if err != nil {
-			return nil, err
+			return err
 		}
 		defer infile.Close()
 
 		if _, err = io.Copy(outfile, infile); err != nil {
-			return nil, err
+			return err
 		}
 
 		if err := outfile.Close(); err != nil {
-			return nil, err
+			return err
 		}
 
 		if err := infile.Close(); err != nil {
-			return nil, err
+			return err
 		}
 	}
 
-	mf, err := convert.GetModelFormat(tempdir)
+	return nil
+}
+
+func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
+	tempDir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
+	if err != nil {
+		return nil, err
+	}
+	defer os.RemoveAll(tempDir)
+
+	if err := extractFromZipFile(tempDir, file, fn); err != nil {
+		return nil, err
+	}
+
+	mf, err := convert.GetModelFormat(tempDir)
 	if err != nil {
 		return nil, err
 	}
 
-	params, err := mf.GetParams(tempdir)
+	params, err := mf.GetParams(tempDir)
 	if err != nil {
 		return nil, err
 	}
 
-	mArch, err := mf.GetModelArch("", tempdir, params)
+	mArch, err := mf.GetModelArch("", tempDir, params)
 	if err != nil {
 		return nil, err
 	}
@@ -150,7 +167,7 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a
 
 	// TODO(mxyng): this should write directly into a layer
 	// e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
-	temp, err := os.CreateTemp(tempdir, "fp16")
+	temp, err := os.CreateTemp(tempDir, "fp16")
 	if err != nil {
 		return nil, err
 	}
@@ -176,7 +193,7 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a
 	}
 	defer bin.Close()
 
-	ggml, _, err := llm.DecodeGGML(bin)
+	ggml, _, err := llm.DecodeGGML(bin, 0)
 	if err != nil {
 		return nil, err
 	}
@@ -210,7 +227,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
 
 	var offset int64
 	for offset < stat.Size() {
-		ggml, n, err := llm.DecodeGGML(file)
+		ggml, n, err := llm.DecodeGGML(file, 0)
 		if errors.Is(err, io.EOF) {
 			break
 		} else if err != nil {
@@ -239,7 +256,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
 func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
 	for _, layer := range layers {
 		if s := layer.GGML.KV().ChatTemplate(); s != "" {
-			if t, err := templates.NamedTemplate(s); err != nil {
+			if t, err := template.Named(s); err != nil {
 				slog.Debug("template detection", "error", err)
 			} else {
 				tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")

+ 112 - 0
server/model_test.go

@@ -0,0 +1,112 @@
+package server
+
+import (
+	"archive/zip"
+	"bytes"
+	"errors"
+	"io"
+	"os"
+	"path/filepath"
+	"slices"
+	"strings"
+	"testing"
+
+	"github.com/ollama/ollama/api"
+)
+
+func createZipFile(t *testing.T, name string) *os.File {
+	t.Helper()
+
+	f, err := os.CreateTemp(t.TempDir(), "")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	zf := zip.NewWriter(f)
+	defer zf.Close()
+
+	zh, err := zf.CreateHeader(&zip.FileHeader{Name: name})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if _, err := io.Copy(zh, bytes.NewReader([]byte(""))); err != nil {
+		t.Fatal(err)
+	}
+
+	return f
+}
+
+func TestExtractFromZipFile(t *testing.T) {
+	cases := []struct {
+		name   string
+		expect []string
+		err    error
+	}{
+		{
+			name:   "good",
+			expect: []string{"good"},
+		},
+		{
+			name:   strings.Join([]string{"path", "..", "to", "good"}, string(os.PathSeparator)),
+			expect: []string{filepath.Join("to", "good")},
+		},
+		{
+			name:   strings.Join([]string{"path", "..", "to", "..", "good"}, string(os.PathSeparator)),
+			expect: []string{"good"},
+		},
+		{
+			name:   strings.Join([]string{"path", "to", "..", "..", "good"}, string(os.PathSeparator)),
+			expect: []string{"good"},
+		},
+		{
+			name: strings.Join([]string{"..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"}, string(os.PathSeparator)),
+			err:  zip.ErrInsecurePath,
+		},
+		{
+			name: strings.Join([]string{"path", "..", "..", "to", "bad"}, string(os.PathSeparator)),
+			err:  zip.ErrInsecurePath,
+		},
+	}
+
+	for _, tt := range cases {
+		t.Run(tt.name, func(t *testing.T) {
+			f := createZipFile(t, tt.name)
+			defer f.Close()
+
+			tempDir := t.TempDir()
+			if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); !errors.Is(err, tt.err) {
+				t.Fatal(err)
+			}
+
+			var matches []string
+			if err := filepath.Walk(tempDir, func(p string, fi os.FileInfo, err error) error {
+				if err != nil {
+					return err
+				}
+
+				if !fi.IsDir() {
+					matches = append(matches, p)
+				}
+
+				return nil
+			}); err != nil {
+				t.Fatal(err)
+			}
+
+			var actual []string
+			for _, match := range matches {
+				rel, err := filepath.Rel(tempDir, match)
+				if err != nil {
+					t.Error(err)
+				}
+
+				actual = append(actual, rel)
+			}
+
+			if !slices.Equal(actual, tt.expect) {
+				t.Fatalf("expected %d files, got %d", len(tt.expect), len(matches))
+			}
+		})
+	}
+}

+ 3 - 18
server/modelpath.go

@@ -103,18 +103,9 @@ func (mp ModelPath) GetShortTagname() string {
 	return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
 }
 
-// modelsDir returns the value of the OLLAMA_MODELS environment variable or the user's home directory if OLLAMA_MODELS is not set.
-// The models directory is where Ollama stores its model files and manifests.
-func modelsDir() (string, error) {
-	return envconfig.ModelsDir, nil
-}
-
 // GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
 func (mp ModelPath) GetManifestPath() (string, error) {
-	dir, err := modelsDir()
-	if err != nil {
-		return "", err
-	}
+	dir := envconfig.ModelsDir
 
 	return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
 }
@@ -127,10 +118,7 @@ func (mp ModelPath) BaseURL() *url.URL {
 }
 
 func GetManifestPath() (string, error) {
-	dir, err := modelsDir()
-	if err != nil {
-		return "", err
-	}
+	dir := envconfig.ModelsDir
 
 	path := filepath.Join(dir, "manifests")
 	if err := os.MkdirAll(path, 0o755); err != nil {
@@ -141,10 +129,7 @@ func GetManifestPath() (string, error) {
 }
 
 func GetBlobsPath(digest string) (string, error) {
-	dir, err := modelsDir()
-	if err != nil {
-		return "", err
-	}
+	dir := envconfig.ModelsDir
 
 	// only accept actual sha256 digests
 	pattern := "^sha256[:-][0-9a-fA-F]{64}$"

+ 54 - 192
server/prompt.go

@@ -1,221 +1,83 @@
 package server
 
 import (
-	"fmt"
+	"bytes"
+	"context"
 	"log/slog"
-	"strings"
-	"text/template"
-	"text/template/parse"
+	"slices"
 
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/llm"
+	"github.com/ollama/ollama/template"
 )
 
-// isResponseNode checks if the node contains .Response
-func isResponseNode(node *parse.ActionNode) bool {
-	for _, cmd := range node.Pipe.Cmds {
-		for _, arg := range cmd.Args {
-			if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 {
-				if fieldNode.Ident[0] == "Response" {
-					return true
-				}
-			}
+type tokenizeFunc func(context.Context, string) ([]int, error)
+
+// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
+// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
+// latest message and 2) system messages
+func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
+	// pull out any system messages which should always be included in the prompt
+	var system []api.Message
+	msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
+		if m.Role == "system" {
+			system = append(system, m)
+			return true
 		}
-	}
-	return false
-}
-
-// formatTemplateForResponse formats the template AST to:
-// 1. remove all nodes after the first .Response (if generate=true)
-// 2. add a .Response node to the end if it doesn't exist
-// TODO(jmorganca): this should recursively cut the template before the first .Response
-func formatTemplateForResponse(tmpl *template.Template, generate bool) {
-	var found bool
-	for i, node := range tmpl.Tree.Root.Nodes {
-		if actionNode, ok := node.(*parse.ActionNode); ok {
-			if isResponseNode(actionNode) {
-				found = true
-				if generate {
-					tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
-					break
-				}
-			}
-		}
-	}
 
-	if !found {
-		// add the response node if it doesn't exist
-		responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
-		responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
-		responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
-		tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
-	}
-}
-
-// Prompt renders a prompt from a template. If generate is set to true,
-// the response and parts of the template following it are not rendered
-func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) {
-	parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl)
-	if err != nil {
-		return "", err
-	}
-
-	formatTemplateForResponse(parsed, generate)
-
-	vars := map[string]any{
-		"System":   system,
-		"Prompt":   prompt,
-		"Response": response,
-	}
+		return false
+	})
 
-	var sb strings.Builder
-	if err := parsed.Execute(&sb, vars); err != nil {
-		return "", err
+	if len(system) == 0 && m.System != "" {
+		// add model system prompt since it wasn't provided
+		system = append(system, api.Message{Role: "system", Content: m.System})
 	}
 
-	return sb.String(), nil
-}
-
-func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
-	rendered, err := Prompt(tmpl, system, prompt, response, false)
-	if err != nil {
-		return 0, err
-	}
-
-	tokens, err := encode(rendered)
-	if err != nil {
-		slog.Error("failed to encode prompt", "err", err)
-		return 0, err
-	}
-
-	return len(tokens), err
-}
-
-// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
-func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
-	type prompt struct {
-		System   string
-		Prompt   string
-		Response string
-
-		images []int
-		tokens int
-	}
-
-	var p prompt
-
-	// iterate through messages to build up {system,user,response} prompts
-	var imgId int
-	var prompts []prompt
-	for _, msg := range messages {
-		switch strings.ToLower(msg.Role) {
-		case "system":
-			if p.System != "" || p.Prompt != "" || p.Response != "" {
-				prompts = append(prompts, p)
-				p = prompt{}
-			}
-
-			p.System = msg.Content
-		case "user":
-			if p.Prompt != "" || p.Response != "" {
-				prompts = append(prompts, p)
-				p = prompt{}
-			}
-
-			var sb strings.Builder
-			for range msg.Images {
-				fmt.Fprintf(&sb, "[img-%d] ", imgId)
-				p.images = append(p.images, imgId)
-				imgId += 1
-			}
-
-			sb.WriteString(msg.Content)
-			p.Prompt = sb.String()
-		case "assistant":
-			if p.Response != "" {
-				prompts = append(prompts, p)
-				p = prompt{}
-			}
-
-			p.Response = msg.Content
-		default:
-			return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
+	// always include the last message
+	n := len(msgs) - 1
+	// in reverse, find all messages that fit into context window
+	for i := n - 1; i >= 0; i-- {
+		var b bytes.Buffer
+		if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
+			return "", nil, err
 		}
-	}
-
-	// add final prompt
-	if p.System != "" || p.Prompt != "" || p.Response != "" {
-		prompts = append(prompts, p)
-	}
 
-	// calculate token lengths for each prompt, estimating 768 tokens per images
-	for i, p := range prompts {
-		tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
+		s, err := tokenize(ctx, b.String())
 		if err != nil {
-			return "", err
+			return "", nil, err
 		}
 
-		prompts[i].tokens = tokens + len(prompts[i].images)*768
-	}
-
-	// truncate images and prompts starting from the beginning of the list
-	// until either one prompt remains or the total tokens fits the context window
-	// TODO (jmorganca): this doesn't account for the context window room required for the response
-	for {
-		var required int
-		for _, p := range prompts {
-			required += p.tokens
+		c := len(s)
+		if m.ProjectorPaths != nil {
+			for _, m := range msgs[i:] {
+				// images are represented as 768 sized embeddings
+				// TODO: get embedding length from project metadata
+				c += 768 * len(m.Images)
+			}
 		}
 
-		required += 1 // for bos token
-
-		if required <= window {
-			slog.Debug("prompt now fits in context window", "required", required, "window", window)
+		if c > opts.NumCtx {
+			slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
 			break
+		} else {
+			n = i
 		}
+	}
 
-		prompt := &prompts[0]
-
-		if len(prompt.images) > 1 {
-			img := prompt.images[0]
-			slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
-			prompt.images = prompt.images[1:]
-			prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
-			prompt.tokens -= 768
-			continue
-		}
-
-		if len(prompts) > 1 {
-			slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
-			system := prompt.System
-			prompts = prompts[1:]
-
-			if system != "" && prompts[0].System == "" {
-				prompts[0].System = system
-
-				tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
-				if err != nil {
-					return "", err
-				}
-
-				prompts[0].tokens = tokens + len(prompts[0].images)*768
-			}
-
-			continue
-		}
-
-		// stop truncating if there's only one prompt left
-		break
+	// truncate any messages that do not fit into the context window
+	var b bytes.Buffer
+	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
+		return "", nil, err
 	}
 
-	var sb strings.Builder
-	for i, p := range prompts {
-		// last prompt should leave the response unrendered (for completion)
-		rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
-		if err != nil {
-			return "", err
+	for _, m := range msgs[n:] {
+		for _, i := range m.Images {
+			images = append(images, llm.ImageData{
+				ID:   len(images),
+				Data: i,
+			})
 		}
-		sb.WriteString(rendered)
 	}
 
-	return sb.String(), nil
+	return b.String(), images, nil
 }

+ 159 - 159
server/prompt_test.go

@@ -1,204 +1,204 @@
 package server
 
 import (
+	"bytes"
+	"context"
 	"strings"
 	"testing"
 
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/template"
 )
 
-func TestPrompt(t *testing.T) {
-	tests := []struct {
-		name     string
-		template string
-		system   string
-		prompt   string
-		response string
-		generate bool
-		want     string
-	}{
-		{
-			name:     "simple prompt",
-			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
-			system:   "You are a Wizard.",
-			prompt:   "What are the potion ingredients?",
-			want:     "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
-		},
-		{
-			name:     "implicit response",
-			template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
-			system:   "You are a Wizard.",
-			prompt:   "What are the potion ingredients?",
-			response: "I don't know.",
-			want:     "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
-		},
-		{
-			name:     "response",
-			template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
-			system:   "You are a Wizard.",
-			prompt:   "What are the potion ingredients?",
-			response: "I don't know.",
-			want:     "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
-		},
-		{
-			name:     "cut",
-			template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
-			system:   "You are a Wizard.",
-			prompt:   "What are the potion ingredients?",
-			response: "I don't know.",
-			generate: true,
-			want:     "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
-		},
-		{
-			name:     "nocut",
-			template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
-			system:   "You are a Wizard.",
-			prompt:   "What are the potion ingredients?",
-			response: "I don't know.",
-			want:     "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
-		},
+func tokenize(_ context.Context, s string) (tokens []int, err error) {
+	for range strings.Fields(s) {
+		tokens = append(tokens, len(tokens))
 	}
 
-	for _, tc := range tests {
-		t.Run(tc.name, func(t *testing.T) {
-			got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate)
-			if err != nil {
-				t.Errorf("error = %v", err)
-			}
-
-			if got != tc.want {
-				t.Errorf("got = %v, want %v", got, tc.want)
-			}
-		})
-	}
+	return
 }
 
 func TestChatPrompt(t *testing.T) {
-	tests := []struct {
-		name     string
-		template string
-		messages []api.Message
-		window   int
-		want     string
+	type expect struct {
+		prompt string
+		images [][]byte
+	}
+
+	cases := []struct {
+		name  string
+		limit int
+		msgs  []api.Message
+		expect
 	}{
 		{
-			name:     "simple prompt",
-			template: "[INST] {{ .Prompt }} [/INST]",
-			messages: []api.Message{
-				{Role: "user", Content: "Hello"},
+			name:  "messages",
+			limit: 64,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+			},
+			expect: expect{
+				prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
 			},
-			window: 1024,
-			want:   "[INST] Hello [/INST]",
-		},
-		{
-			name:     "with system message",
-			template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
-			messages: []api.Message{
-				{Role: "system", Content: "You are a Wizard."},
-				{Role: "user", Content: "Hello"},
-			},
-			window: 1024,
-			want:   "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
 		},
 		{
-			name:     "with response",
-			template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
-			messages: []api.Message{
-				{Role: "system", Content: "You are a Wizard."},
-				{Role: "user", Content: "Hello"},
-				{Role: "assistant", Content: "I am?"},
-			},
-			window: 1024,
-			want:   "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
+			name:  "truncate messages",
+			limit: 1,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+			},
+			expect: expect{
+				prompt: "A test. And a thumping good one at that, I'd wager. ",
+			},
 		},
 		{
-			name:     "with implicit response",
-			template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
-			messages: []api.Message{
-				{Role: "system", Content: "You are a Wizard."},
-				{Role: "user", Content: "Hello"},
-				{Role: "assistant", Content: "I am?"},
-			},
-			window: 1024,
-			want:   "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
+			name:  "truncate messages with image",
+			limit: 64,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
+			},
+			expect: expect{
+				prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
+				images: [][]byte{
+					[]byte("something"),
+				},
+			},
 		},
 		{
-			name:     "with conversation",
-			template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
-			messages: []api.Message{
-				{Role: "system", Content: "You are a Wizard."},
-				{Role: "user", Content: "What are the potion ingredients?"},
-				{Role: "assistant", Content: "sugar"},
-				{Role: "user", Content: "Anything else?"},
-			},
-			window: 1024,
-			want:   "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
+			name:  "truncate messages with images",
+			limit: 64,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
+			},
+			expect: expect{
+				prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
+				images: [][]byte{
+					[]byte("somethingelse"),
+				},
+			},
 		},
 		{
-			name:     "with truncation",
-			template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
-			messages: []api.Message{
-				{Role: "system", Content: "You are a Wizard."},
-				{Role: "user", Content: "Hello"},
-				{Role: "assistant", Content: "I am?"},
-				{Role: "user", Content: "Why is the sky blue?"},
-				{Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
-			},
-			window: 10,
-			want:   "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
+			name:  "messages with images",
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
+			},
+			expect: expect{
+				prompt: "[img-0] You're a test, Harry! I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
+				images: [][]byte{
+					[]byte("something"),
+					[]byte("somethingelse"),
+				},
+			},
 		},
 		{
-			name:     "images",
-			template: "{{ .System }} {{ .Prompt }}",
-			messages: []api.Message{
-				{Role: "system", Content: "You are a Wizard."},
-				{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
-			},
-			window: 1024,
-			want:   "You are a Wizard. [img-0] Hello",
+			name:  "message with image tag",
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
+			},
+			expect: expect{
+				prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
+				images: [][]byte{
+					[]byte("something"),
+					[]byte("somethingelse"),
+				},
+			},
 		},
 		{
-			name:     "images truncated",
-			template: "{{ .System }} {{ .Prompt }}",
-			messages: []api.Message{
-				{Role: "system", Content: "You are a Wizard."},
-				{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
-			},
-			window: 1024,
-			want:   "You are a Wizard. [img-0] [img-1] Hello",
+			name:  "messages with interleaved images",
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "user", Images: []api.ImageData{[]byte("something")}},
+				{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+			},
+			expect: expect{
+				prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
+				images: [][]byte{
+					[]byte("something"),
+					[]byte("somethingelse"),
+				},
+			},
 		},
 		{
-			name:     "empty list",
-			template: "{{ .System }} {{ .Prompt }}",
-			messages: []api.Message{},
-			window:   1024,
-			want:     "",
+			name:  "truncate message with interleaved images",
+			limit: 1024,
+			msgs: []api.Message{
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "user", Images: []api.ImageData{[]byte("something")}},
+				{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+			},
+			expect: expect{
+				prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
+				images: [][]byte{
+					[]byte("somethingelse"),
+				},
+			},
 		},
 		{
-			name:     "empty prompt",
-			template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
-			messages: []api.Message{
-				{Role: "user", Content: ""},
+			name:  "message with system prompt",
+			limit: 2048,
+			msgs: []api.Message{
+				{Role: "system", Content: "You are the Test Who Lived."},
+				{Role: "user", Content: "You're a test, Harry!"},
+				{Role: "assistant", Content: "I-I'm a what?"},
+				{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
+			},
+			expect: expect{
+				prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
 			},
-			window: 1024,
-			want:   "",
 		},
 	}
 
-	encode := func(s string) ([]int, error) {
-		words := strings.Fields(s)
-		return make([]int, len(words)), nil
+	tmpl, err := template.Parse(`
+{{- if .System }}{{ .System }} {{ end }}
+{{- if .Prompt }}{{ .Prompt }} {{ end }}
+{{- if .Response }}{{ .Response }} {{ end }}`)
+	if err != nil {
+		t.Fatal(err)
 	}
 
-	for _, tc := range tests {
-		t.Run(tc.name, func(t *testing.T) {
-			got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
+	for _, tt := range cases {
+		t.Run(tt.name, func(t *testing.T) {
+			model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
+			opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
+			prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
 			if err != nil {
-				t.Errorf("error = %v", err)
+				t.Fatal(err)
 			}
 
-			if got != tc.want {
-				t.Errorf("got: %q, want: %q", got, tc.want)
+			if tt.prompt != prompt {
+				t.Errorf("expected %q, got %q", tt.prompt, prompt)
+			}
+
+			if len(images) != len(tt.images) {
+				t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
+			}
+
+			for i := range images {
+				if images[i].ID != i {
+					t.Errorf("expected ID %d, got %d", i, images[i].ID)
+				}
+
+				if !bytes.Equal(images[i].Data, tt.images[i]) {
+					t.Errorf("expected %q, got %q", tt.images[i], images[i])
+				}
 			}
 		})
 	}

+ 188 - 390
server/routes.go

@@ -1,15 +1,14 @@
 package server
 
 import (
+	"bytes"
 	"cmp"
 	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
-	"io/fs"
 	"log/slog"
-	"math"
 	"net"
 	"net/http"
 	"net/netip"
@@ -17,7 +16,6 @@ import (
 	"os/signal"
 	"path/filepath"
 	"slices"
-	"strconv"
 	"strings"
 	"syscall"
 	"time"
@@ -31,6 +29,7 @@ import (
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/openai"
 	"github.com/ollama/ollama/parser"
+	"github.com/ollama/ollama/template"
 	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
@@ -55,7 +54,7 @@ func init() {
 	gin.SetMode(mode)
 }
 
-var defaultSessionDuration = 5 * time.Minute
+var errRequired = errors.New("is required")
 
 func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
 	opts := api.DefaultOptions()
@@ -70,164 +69,140 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
 	return opts, nil
 }
 
-func isSupportedImageType(image []byte) bool {
-	contentType := http.DetectContentType(image)
-	allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
-	return slices.Contains(allowedTypes, contentType)
-}
-
-func (s *Server) GenerateHandler(c *gin.Context) {
-	checkpointStart := time.Now()
-	var req api.GenerateRequest
-	err := c.ShouldBindJSON(&req)
-
-	switch {
-	case errors.Is(err, io.EOF):
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
-		return
-	case err != nil:
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-		return
+// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
+// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
+func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
+	if name == "" {
+		return nil, nil, nil, fmt.Errorf("model %w", errRequired)
 	}
 
-	// validate the request
-	switch {
-	case req.Model == "":
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
-		return
-	case len(req.Format) > 0 && req.Format != "json":
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
-		return
-	case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
-		return
+	model, err := GetModel(name)
+	if err != nil {
+		return nil, nil, nil, err
 	}
 
-	for _, img := range req.Images {
-		if !isSupportedImageType(img) {
-			c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
-			return
-		}
+	if err := model.CheckCapabilities(caps...); err != nil {
+		return nil, nil, nil, fmt.Errorf("%s %w", name, err)
 	}
 
-	model, err := GetModel(req.Model)
+	opts, err := modelOptions(model, requestOpts)
 	if err != nil {
-		var pErr *fs.PathError
-		if errors.As(err, &pErr) {
-			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
-			return
-		}
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
+		return nil, nil, nil, err
 	}
 
-	if model.IsEmbedding() {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"})
-		return
+	runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
+	var runner *runnerRef
+	select {
+	case runner = <-runnerCh:
+	case err = <-errCh:
+		return nil, nil, nil, err
 	}
 
-	opts, err := modelOptions(model, req.Options)
-	if err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+	return runner.llama, model, &opts, nil
+}
+
+func (s *Server) GenerateHandler(c *gin.Context) {
+	var req api.GenerateRequest
+	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+		return
+	} else if err != nil {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
 
-	var sessionDuration time.Duration
-	if req.KeepAlive == nil {
-		sessionDuration = getDefaultSessionDuration()
-	} else {
-		sessionDuration = req.KeepAlive.Duration
+	if req.Format != "" && req.Format != "json" {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
+		return
+	} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
+		return
 	}
 
-	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
-	var runner *runnerRef
-	select {
-	case runner = <-rCh:
-	case err = <-eCh:
-		handleErrorResponse(c, err)
+	caps := []Capability{CapabilityCompletion}
+	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
+	if errors.Is(err, errCapabilityCompletion) {
+		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
+		return
+	} else if err != nil {
+		handleScheduleError(c, req.Model, err)
 		return
 	}
 
-	// an empty request loads the model
-	// note: for a short while template was used in lieu
-	// of `raw` mode so we need to check for it too
-	if req.Prompt == "" && req.Template == "" && req.System == "" {
+	if req.Prompt == "" {
 		c.JSON(http.StatusOK, api.GenerateResponse{
-			CreatedAt:  time.Now().UTC(),
 			Model:      req.Model,
+			CreatedAt:  time.Now().UTC(),
 			Done:       true,
 			DoneReason: "load",
 		})
 		return
 	}
 
-	checkpointLoaded := time.Now()
-
-	var prompt string
-	switch {
-	case req.Raw:
-		prompt = req.Prompt
-	case req.Prompt != "":
-		if req.Template == "" {
-			req.Template = model.Template
-		}
+	images := make([]llm.ImageData, len(req.Images))
+	for i := range req.Images {
+		images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
+	}
 
-		if req.System == "" {
-			req.System = model.System
+	prompt := req.Prompt
+	if !req.Raw {
+		var msgs []api.Message
+		if req.System != "" {
+			msgs = append(msgs, api.Message{Role: "system", Content: req.System})
+		} else if m.System != "" {
+			msgs = append(msgs, api.Message{Role: "system", Content: m.System})
 		}
 
-		slog.Debug("generate handler", "prompt", req.Prompt)
-		slog.Debug("generate handler", "template", req.Template)
-		slog.Debug("generate handler", "system", req.System)
-
-		var sb strings.Builder
-		for i := range req.Images {
-			fmt.Fprintf(&sb, "[img-%d] ", i)
+		for _, i := range images {
+			msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
 		}
 
-		sb.WriteString(req.Prompt)
+		msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
 
-		p, err := Prompt(req.Template, req.System, sb.String(), "", true)
-		if err != nil {
-			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-			return
+		tmpl := m.Template
+		if req.Template != "" {
+			tmpl, err = template.Parse(req.Template)
+			if err != nil {
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+				return
+			}
 		}
 
-		sb.Reset()
+		var b bytes.Buffer
 		if req.Context != nil {
-			prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
+			s, err := r.Detokenize(c.Request.Context(), req.Context)
 			if err != nil {
 				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 				return
 			}
 
-			sb.WriteString(prev)
+			b.WriteString(s)
 		}
 
-		sb.WriteString(p)
+		if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
 
-		prompt = sb.String()
+		prompt = b.String()
 	}
 
-	slog.Debug("generate handler", "prompt", prompt)
+	slog.Debug("generate request", "prompt", prompt, "images", images)
 
 	ch := make(chan any)
-	var generated strings.Builder
 	go func() {
 		defer close(ch)
-
-		fn := func(r llm.CompletionResponse) {
-			// Build up the full response
-			if _, err := generated.WriteString(r.Content); err != nil {
-				ch <- gin.H{"error": err.Error()}
-				return
-			}
-
-			resp := api.GenerateResponse{
+		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
+			Prompt:  prompt,
+			Images:  images,
+			Format:  req.Format,
+			Options: opts,
+		}, func(r llm.CompletionResponse) {
+			ch <- api.GenerateResponse{
 				Model:      req.Model,
 				CreatedAt:  time.Now().UTC(),
-				Done:       r.Done,
 				Response:   r.Content,
+				Done:       r.Done,
 				DoneReason: r.DoneReason,
 				Metrics: api.Metrics{
 					PromptEvalCount:    r.PromptEvalCount,
@@ -236,156 +211,54 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 					EvalDuration:       r.EvalDuration,
 				},
 			}
-
-			if r.Done {
-				resp.TotalDuration = time.Since(checkpointStart)
-				resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
-
-				if !req.Raw {
-					p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
-					if err != nil {
-						c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-						return
-					}
-
-					// TODO (jmorganca): encode() should not strip special tokens
-					tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
-					if err != nil {
-						ch <- gin.H{"error": err.Error()}
-						return
-					}
-
-					resp.Context = append(req.Context, tokens...)
-				}
-			}
-
-			ch <- resp
-		}
-
-		var images []llm.ImageData
-		for i := range req.Images {
-			images = append(images, llm.ImageData{
-				ID:   i,
-				Data: req.Images[i],
-			})
-		}
-
-		// Start prediction
-		req := llm.CompletionRequest{
-			Prompt:  prompt,
-			Format:  req.Format,
-			Images:  images,
-			Options: opts,
-		}
-		if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
+		}); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()
 
 	if req.Stream != nil && !*req.Stream {
-		// Accumulate responses into the final response
-		var final api.GenerateResponse
+		var r api.GenerateResponse
 		var sb strings.Builder
-		for resp := range ch {
-			switch r := resp.(type) {
+		for rr := range ch {
+			switch t := rr.(type) {
 			case api.GenerateResponse:
-				sb.WriteString(r.Response)
-				final = r
+				sb.WriteString(t.Response)
+				r = t
 			case gin.H:
-				if errorMsg, ok := r["error"].(string); ok {
-					c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
-					return
-				} else {
-					c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
-					return
+				msg, ok := t["error"].(string)
+				if !ok {
+					msg = "unexpected error format in response"
 				}
+
+				c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
+				return
 			default:
-				c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
+				c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
 				return
 			}
 		}
 
-		final.Response = sb.String()
-		c.JSON(http.StatusOK, final)
+		r.Response = sb.String()
+		c.JSON(http.StatusOK, r)
 		return
 	}
 
 	streamResponse(c, ch)
 }
 
-func getDefaultSessionDuration() time.Duration {
-	if envconfig.KeepAlive != "" {
-		v, err := strconv.Atoi(envconfig.KeepAlive)
-		if err != nil {
-			d, err := time.ParseDuration(envconfig.KeepAlive)
-			if err != nil {
-				return defaultSessionDuration
-			}
-
-			if d < 0 {
-				return time.Duration(math.MaxInt64)
-			}
-
-			return d
-		}
-
-		d := time.Duration(v) * time.Second
-		if d < 0 {
-			return time.Duration(math.MaxInt64)
-		}
-		return d
-	}
-
-	return defaultSessionDuration
-}
-
 func (s *Server) EmbeddingsHandler(c *gin.Context) {
 	var req api.EmbeddingRequest
-	err := c.ShouldBindJSON(&req)
-	switch {
-	case errors.Is(err, io.EOF):
+	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
 		return
-	case err != nil:
+	} else if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
 
-	if req.Model == "" {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
-		return
-	}
-
-	model, err := GetModel(req.Model)
+	r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
 	if err != nil {
-		var pErr *fs.PathError
-		if errors.As(err, &pErr) {
-			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
-			return
-		}
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
-
-	opts, err := modelOptions(model, req.Options)
-	if err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
-
-	var sessionDuration time.Duration
-	if req.KeepAlive == nil {
-		sessionDuration = getDefaultSessionDuration()
-	} else {
-		sessionDuration = req.KeepAlive.Duration
-	}
-
-	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
-	var runner *runnerRef
-	select {
-	case runner = <-rCh:
-	case err = <-eCh:
-		handleErrorResponse(c, err)
+		handleScheduleError(c, req.Model, err)
 		return
 	}
 
@@ -395,17 +268,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 	}
 
-	embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
+	embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
 	if err != nil {
 		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
 		return
 	}
 
-	resp := api.EmbeddingResponse{
-		Embedding: embedding,
-	}
-	c.JSON(http.StatusOK, resp)
+	c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
 }
 
 func (s *Server) PullModelHandler(c *gin.Context) {
@@ -680,12 +550,15 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 	}
 
 	if req.Template != "" {
-		m.Template = req.Template
+		m.Template, err = template.Parse(req.Template)
+		if err != nil {
+			return nil, err
+		}
 	}
 
-	msgs := make([]api.Message, 0)
-	for _, msg := range m.Messages {
-		msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
+	msgs := make([]api.Message, len(m.Messages))
+	for i, msg := range m.Messages {
+		msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
 	}
 
 	n := model.ParseName(req.Model)
@@ -701,7 +574,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 	resp := &api.ShowResponse{
 		License:    strings.Join(m.License, "\n"),
 		System:     m.System,
-		Template:   m.Template,
+		Template:   m.Template.String(),
 		Details:    modelDetails,
 		Messages:   msgs,
 		ModifiedAt: manifest.fi.ModTime(),
@@ -754,7 +627,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 }
 
 func getKVData(digest string, verbose bool) (llm.KV, error) {
-	kvData, err := llm.LoadModel(digest)
+	maxArraySize := 0
+	if verbose {
+		maxArraySize = -1
+	}
+	kvData, err := llm.LoadModel(digest, maxArraySize)
 	if err != nil {
 		return nil, err
 	}
@@ -1035,7 +912,10 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r.GET("/api/ps", s.ProcessHandler)
 
 	// Compatibility endpoints
-	r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
+	r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
+	r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
+	r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
+	r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
 
 	for _, method := range []string{http.MethodGet, http.MethodHead} {
 		r.Handle(method, "/", func(c *gin.Context) {
@@ -1101,11 +981,20 @@ func Serve(ln net.Listener) error {
 	schedCtx, schedDone := context.WithCancel(ctx)
 	sched := InitScheduler(schedCtx)
 	s := &Server{addr: ln.Addr(), sched: sched}
-	r := s.GenerateRoutes()
+
+	http.Handle("/", s.GenerateRoutes())
 
 	slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
 	srvr := &http.Server{
-		Handler: r,
+		// Use http.DefaultServeMux so we get net/http/pprof for
+		// free.
+		//
+		// TODO(bmizerany): Decide if we want to make this
+		// configurable so it is not exposed by default, or allow
+		// users to bind it to a different port. This was a quick
+		// and easy way to get pprof, but it may not be the best
+		// way.
+		Handler: nil,
 	}
 
 	// listen for a ctrl+c and stop any loaded llm
@@ -1224,142 +1113,63 @@ func (s *Server) ProcessHandler(c *gin.Context) {
 		models = append(models, mr)
 	}
 
-	c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
-}
-
-// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
-func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
-	encode := func(s string) ([]int, error) {
-		return runner.llama.Tokenize(ctx, s)
-	}
-
-	prompt, err := ChatPrompt(template, messages, numCtx, encode)
-	if err != nil {
-		return "", err
-	}
+	slices.SortStableFunc(models, func(i, j api.ProcessModelResponse) int {
+		// longest duration remaining listed first
+		return cmp.Compare(j.ExpiresAt.Unix(), i.ExpiresAt.Unix())
+	})
 
-	return prompt, nil
+	c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
 }
 
 func (s *Server) ChatHandler(c *gin.Context) {
-	checkpointStart := time.Now()
-
 	var req api.ChatRequest
-	err := c.ShouldBindJSON(&req)
-	switch {
-	case errors.Is(err, io.EOF):
+	if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
 		return
-	case err != nil:
+	} else if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
 
-	// validate the request
-	switch {
-	case req.Model == "":
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
-		return
-	case len(req.Format) > 0 && req.Format != "json":
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
-		return
-	}
-
-	model, err := GetModel(req.Model)
-	if err != nil {
-		var pErr *fs.PathError
-		if errors.As(err, &pErr) {
-			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
-			return
-		}
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
-
-	if model.IsEmbedding() {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"})
-		return
-	}
-
-	opts, err := modelOptions(model, req.Options)
-	if err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+	caps := []Capability{CapabilityCompletion}
+	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
+	if errors.Is(err, errCapabilityCompletion) {
+		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
 		return
-	}
-
-	var sessionDuration time.Duration
-	if req.KeepAlive == nil {
-		sessionDuration = getDefaultSessionDuration()
-	} else {
-		sessionDuration = req.KeepAlive.Duration
-	}
-
-	rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
-	var runner *runnerRef
-	select {
-	case runner = <-rCh:
-	case err = <-eCh:
-		handleErrorResponse(c, err)
-		return
-	}
-
-	checkpointLoaded := time.Now()
-
-	// if the first message is not a system message, then add the model's default system message
-	if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
-		req.Messages = append([]api.Message{
-			{
-				Role:    "system",
-				Content: model.System,
-			},
-		}, req.Messages...)
-	}
-
-	prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
-	if err != nil {
-		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+	} else if err != nil {
+		handleScheduleError(c, req.Model, err)
 		return
 	}
 
-	// an empty request loads the model
-	if len(req.Messages) == 0 || prompt == "" {
-		resp := api.ChatResponse{
-			CreatedAt:  time.Now().UTC(),
+	if len(req.Messages) == 0 {
+		c.JSON(http.StatusOK, api.ChatResponse{
 			Model:      req.Model,
+			CreatedAt:  time.Now().UTC(),
+			Message:    api.Message{Role: "assistant"},
 			Done:       true,
 			DoneReason: "load",
-			Message:    api.Message{Role: "assistant"},
-		}
-		c.JSON(http.StatusOK, resp)
+		})
 		return
 	}
 
-	// only send images that are in the prompt
-	var i int
-	var images []llm.ImageData
-	for _, m := range req.Messages {
-		for _, img := range m.Images {
-			if !isSupportedImageType(img) {
-				c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
-				return
-			}
-
-			if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
-				images = append(images, llm.ImageData{Data: img, ID: i})
-			}
-			i += 1
-		}
+	prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
 	}
 
-	slog.Debug("chat handler", "prompt", prompt, "images", len(images))
+	slog.Debug("chat request", "images", len(images), "prompt", prompt)
 
 	ch := make(chan any)
-
 	go func() {
 		defer close(ch)
-
-		fn := func(r llm.CompletionResponse) {
-			resp := api.ChatResponse{
+		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
+			Prompt:  prompt,
+			Images:  images,
+			Format:  req.Format,
+			Options: opts,
+		}, func(r llm.CompletionResponse) {
+			ch <- api.ChatResponse{
 				Model:      req.Model,
 				CreatedAt:  time.Now().UTC(),
 				Message:    api.Message{Role: "assistant", Content: r.Content},
@@ -1372,64 +1182,52 @@ func (s *Server) ChatHandler(c *gin.Context) {
 					EvalDuration:       r.EvalDuration,
 				},
 			}
-
-			if r.Done {
-				resp.TotalDuration = time.Since(checkpointStart)
-				resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
-			}
-
-			ch <- resp
-		}
-
-		if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
-			Prompt:  prompt,
-			Format:  req.Format,
-			Images:  images,
-			Options: opts,
-		}, fn); err != nil {
+		}); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()
 
 	if req.Stream != nil && !*req.Stream {
-		// Accumulate responses into the final response
-		var final api.ChatResponse
+		var r api.ChatResponse
 		var sb strings.Builder
-		for resp := range ch {
-			switch r := resp.(type) {
+		for rr := range ch {
+			switch t := rr.(type) {
 			case api.ChatResponse:
-				sb.WriteString(r.Message.Content)
-				final = r
+				sb.WriteString(t.Message.Content)
+				r = t
 			case gin.H:
-				if errorMsg, ok := r["error"].(string); ok {
-					c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
-					return
-				} else {
-					c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
-					return
+				msg, ok := t["error"].(string)
+				if !ok {
+					msg = "unexpected error format in response"
 				}
+
+				c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
+				return
 			default:
-				c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
+				c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
 				return
 			}
 		}
 
-		final.Message = api.Message{Role: "assistant", Content: sb.String()}
-		c.JSON(http.StatusOK, final)
+		r.Message.Content = sb.String()
+		c.JSON(http.StatusOK, r)
 		return
 	}
 
 	streamResponse(c, ch)
 }
 
-func handleErrorResponse(c *gin.Context, err error) {
-	if errors.Is(err, context.Canceled) {
+func handleScheduleError(c *gin.Context, name string, err error) {
+	switch {
+	case errors.Is(err, errRequired):
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+	case errors.Is(err, context.Canceled):
 		c.JSON(499, gin.H{"error": "request canceled"})
-		return
-	}
-	if errors.Is(err, ErrMaxQueue) {
+	case errors.Is(err, ErrMaxQueue):
 		c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
-		return
+	case errors.Is(err, os.ErrNotExist):
+		c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
+	default:
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 	}
-	c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 }

+ 2 - 2
server/routes_create_test.go

@@ -545,9 +545,9 @@ func TestCreateDetectTemplate(t *testing.T) {
 		}
 
 		checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
-			filepath.Join(p, "blobs", "sha256-2f8e594e6f34b1b4d36a246628eeb3365ce442303d656f1fcc69e821722acea0"),
-			filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"),
 			filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
+			filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"),
+			filepath.Join(p, "blobs", "sha256-f836ee110db21567f826332e4cedd746c06d10664fd5a9ea3659e3683a944510"),
 		})
 	})
 

+ 56 - 0
server/routes_test.go

@@ -20,6 +20,7 @@ import (
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/llm"
+	"github.com/ollama/ollama/openai"
 	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
@@ -105,6 +106,24 @@ func Test_Routes(t *testing.T) {
 				assert.Empty(t, len(modelList.Models))
 			},
 		},
+		{
+			Name:   "openai empty list",
+			Method: http.MethodGet,
+			Path:   "/v1/models",
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				assert.Equal(t, "application/json", contentType)
+				body, err := io.ReadAll(resp.Body)
+				require.NoError(t, err)
+
+				var modelList openai.ListCompletion
+				err = json.Unmarshal(body, &modelList)
+				require.NoError(t, err)
+
+				assert.Equal(t, "list", modelList.Object)
+				assert.Empty(t, modelList.Data)
+			},
+		},
 		{
 			Name:   "Tags Handler (yes tags)",
 			Method: http.MethodGet,
@@ -128,6 +147,25 @@ func Test_Routes(t *testing.T) {
 				assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
 			},
 		},
+		{
+			Name:   "openai list models with tags",
+			Method: http.MethodGet,
+			Path:   "/v1/models",
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				assert.Equal(t, "application/json", contentType)
+				body, err := io.ReadAll(resp.Body)
+				require.NoError(t, err)
+
+				var modelList openai.ListCompletion
+				err = json.Unmarshal(body, &modelList)
+				require.NoError(t, err)
+
+				assert.Len(t, modelList.Data, 1)
+				assert.Equal(t, "test-model:latest", modelList.Data[0].Id)
+				assert.Equal(t, "library", modelList.Data[0].OwnedBy)
+			},
+		},
 		{
 			Name:   "Create Model Handler",
 			Method: http.MethodPost,
@@ -216,6 +254,24 @@ func Test_Routes(t *testing.T) {
 				assert.InDelta(t, 0, showResp.ModelInfo["general.parameter_count"], 1e-9, "Parameter count should be 0")
 			},
 		},
+		{
+			Name:   "openai retrieve model handler",
+			Method: http.MethodGet,
+			Path:   "/v1/models/show-model",
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				assert.Equal(t, "application/json", contentType)
+				body, err := io.ReadAll(resp.Body)
+				require.NoError(t, err)
+
+				var retrieveResp api.RetrieveModelResponse
+				err = json.Unmarshal(body, &retrieveResp)
+				require.NoError(t, err)
+
+				assert.Equal(t, "show-model", retrieveResp.Id)
+				assert.Equal(t, "library", retrieveResp.OwnedBy)
+			},
+		},
 	}
 
 	t.Setenv("OLLAMA_MODELS", t.TempDir())

+ 108 - 29
server/sched.go

@@ -23,7 +23,8 @@ type LlmRequest struct {
 	ctx             context.Context //nolint:containedctx
 	model           *Model
 	opts            api.Options
-	sessionDuration time.Duration
+	origNumCtx      int // Track the initial ctx request
+	sessionDuration *api.Duration
 	successCh       chan *runnerRef
 	errCh           chan error
 	schedAttempts   uint
@@ -38,13 +39,23 @@ type Scheduler struct {
 	loaded   map[string]*runnerRef
 	loadedMu sync.Mutex
 
-	loadFn       func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
-	newServerFn  func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error)
+	loadFn       func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int)
+	newServerFn  func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
 	getGpuFn     func() gpu.GpuInfoList
 	getCpuFn     func() gpu.GpuInfoList
 	reschedDelay time.Duration
 }
 
+// Default automatic value for number of models we allow per GPU
+// Model will still need to fit in VRAM, but loading many small models
+// on a large GPU can cause stalling
+var defaultModelsPerGPU = 3
+
+// Default automatic value for parallel setting
+// Model will still need to fit in VRAM.  If this setting wont fit
+// we'll back off down to 1 to try to get it to fit
+var defaultParallel = 4
+
 var ErrMaxQueue = fmt.Errorf("server busy, please try again.  maximum pending requests exceeded")
 
 func InitScheduler(ctx context.Context) *Scheduler {
@@ -64,14 +75,11 @@ func InitScheduler(ctx context.Context) *Scheduler {
 }
 
 // context must be canceled to decrement ref count and release the runner
-func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
-	// allocate a large enough kv cache for all parallel requests
+func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
 	if opts.NumCtx < 4 {
 		opts.NumCtx = 4
 	}
 
-	opts.NumCtx *= envconfig.NumParallel
-
 	req := &LlmRequest{
 		ctx:             c,
 		model:           model,
@@ -110,11 +118,21 @@ func (s *Scheduler) processPending(ctx context.Context) {
 		case pending := <-s.pendingReqCh:
 			// Block other requests until we get this pending request running
 			pending.schedAttempts++
+			if pending.origNumCtx == 0 {
+				pending.origNumCtx = pending.opts.NumCtx
+			}
 
 			if pending.ctx.Err() != nil {
 				slog.Debug("pending request cancelled or timed out, skipping scheduling")
 				continue
 			}
+			numParallel := envconfig.NumParallel
+			// TODO (jmorganca): multimodal models don't support parallel yet
+			// see https://github.com/ollama/ollama/issues/4165
+			if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 {
+				numParallel = 1
+				slog.Warn("multimodal models don't support parallel requests yet")
+			}
 
 			for {
 				var runnerToExpire *runnerRef
@@ -143,8 +161,28 @@ func (s *Scheduler) processPending(ctx context.Context) {
 						gpus = s.getGpuFn()
 					}
 
+					if envconfig.MaxRunners <= 0 {
+						// No user specified MaxRunners, so figure out what automatic setting to use
+						// If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs
+						// if any GPU has unreliable free memory reporting, 1x the number of GPUs
+						allReliable := true
+						for _, gpu := range gpus {
+							if gpu.UnreliableFreeMemory {
+								allReliable = false
+								break
+							}
+						}
+						if allReliable {
+							envconfig.MaxRunners = defaultModelsPerGPU * len(gpus)
+							slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners, "gpu_count", len(gpus))
+						} else {
+							slog.Info("one or more GPUs detected that are unable to accurately report free memory - disabling default concurrency")
+							envconfig.MaxRunners = len(gpus)
+						}
+					}
+
 					// Load model for fitting
-					ggml, err := llm.LoadModel(pending.model.ModelPath)
+					ggml, err := llm.LoadModel(pending.model.ModelPath, 0)
 					if err != nil {
 						pending.errCh <- err
 						break
@@ -152,26 +190,33 @@ func (s *Scheduler) processPending(ctx context.Context) {
 
 					// Evaluate if the model will fit in the available system memory, or if we should unload a model first
 					if len(gpus) == 1 && gpus[0].Library == "cpu" {
+						// simplifying assumption of defaultParallel when in CPU mode
+						if numParallel <= 0 {
+							numParallel = defaultParallel
+						}
+
+						pending.opts.NumCtx = pending.origNumCtx * numParallel
+
 						if loadedCount == 0 {
 							slog.Debug("cpu mode with first model, loading")
-							s.loadFn(pending, ggml, gpus)
+							s.loadFn(pending, ggml, gpus, numParallel)
 							break
 						}
 						runnerToExpire = s.maybeFindCPURunnerToUnload(pending, ggml, gpus)
 						if runnerToExpire == nil {
 							slog.Debug("cpu mode with available system memory or first model, loading")
-							s.loadFn(pending, ggml, gpus)
+							s.loadFn(pending, ggml, gpus, numParallel)
 							break
 						}
 						// else we need to expire a runner
 					} else if loadedCount == 0 {
 						// No models loaded. Load the model but prefer the best fit.
 						slog.Debug("loading first model", "model", pending.model.ModelPath)
-						g := pickBestFitGPUs(pending, ggml, gpus)
+						g := pickBestFitGPUs(pending, ggml, gpus, &numParallel)
 						if g != nil {
 							gpus = g
 						}
-						s.loadFn(pending, ggml, gpus)
+						s.loadFn(pending, ggml, gpus, numParallel)
 						break
 					}
 
@@ -186,10 +231,10 @@ func (s *Scheduler) processPending(ctx context.Context) {
 
 						// Update free memory from currently loaded models
 						s.updateFreeSpace(availGpus)
-						fitGpus := pickBestFitGPUs(pending, ggml, availGpus)
+						fitGpus := pickBestFitGPUs(pending, ggml, availGpus, &numParallel)
 						if fitGpus != nil {
 							slog.Debug("new model fits with existing models, loading")
-							s.loadFn(pending, ggml, fitGpus)
+							s.loadFn(pending, ggml, fitGpus, numParallel)
 							break
 						}
 
@@ -341,7 +386,9 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
 		runner.expireTimer.Stop()
 		runner.expireTimer = nil
 	}
-	runner.sessionDuration = pending.sessionDuration
+	if pending.sessionDuration != nil {
+		runner.sessionDuration = pending.sessionDuration.Duration
+	}
 	pending.successCh <- runner
 	go func() {
 		<-pending.ctx.Done()
@@ -350,8 +397,15 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
 	}()
 }
 
-func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) {
-	llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
+func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
+	if numParallel < 1 {
+		numParallel = 1
+	}
+	sessionDuration := envconfig.KeepAlive
+	if req.sessionDuration != nil {
+		sessionDuration = req.sessionDuration.Duration
+	}
+	llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
 	if err != nil {
 		// some older models are not compatible with newer versions of llama.cpp
 		// show a generalized compatibility error until there is a better way to
@@ -368,13 +422,14 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
 		modelPath:       req.model.ModelPath,
 		llama:           llama,
 		Options:         &req.opts,
-		sessionDuration: req.sessionDuration,
+		sessionDuration: sessionDuration,
 		gpus:            gpus,
 		estimatedVRAM:   llama.EstimatedVRAM(),
 		estimatedTotal:  llama.EstimatedTotal(),
 		loading:         true,
 		refCount:        1,
 	}
+	runner.numParallel = numParallel
 	runner.refMu.Lock()
 
 	s.loadedMu.Lock()
@@ -483,8 +538,9 @@ type runnerRef struct {
 	expireTimer     *time.Timer
 	expiresAt       time.Time
 
-	model     *Model
-	modelPath string
+	model       *Model
+	modelPath   string
+	numParallel int
 	*api.Options
 }
 
@@ -525,6 +581,9 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
 		optsNew.NumGPU = -1
 	}
 
+	// Normalize the NumCtx for parallelism
+	optsExisting.NumCtx = optsExisting.NumCtx / runner.numParallel
+
 	ctx, cancel := context.WithTimeout(ctx, timeout)
 	defer cancel()
 	if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
@@ -611,22 +670,38 @@ func (a ByDuration) Less(i, j int) bool {
 
 // pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits
 // If the model can not be fit fully within the available GPU(s) nil is returned
-func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu.GpuInfoList {
+// If numParallel is <= 0, this will attempt try to optimize parallism based on available VRAM, and adjust
+// opts.NumCtx accordingly
+func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel *int) gpu.GpuInfoList {
 	var estimatedVRAM uint64
+
+	var numParallelToTry []int
+	if *numParallel <= 0 {
+		// If no specific parallel setting was provided, try larger then smaller, always end with 1
+		numParallelToTry = append(numParallelToTry, defaultParallel, 1)
+	} else {
+		numParallelToTry = []int{*numParallel}
+	}
+
 	for _, gl := range gpus.ByLibrary() {
 		var ok bool
 		sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...)
 
 		// TODO - potentially sort by performance capability, existing models loaded, etc.
+		// TODO - Eliminate any GPUs that already have envconfig.MaxRunners loaded on them
 		// Note: at present, this will favor more VRAM over faster GPU speed in mixed setups
 		sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
 
 		// First attempt to fit the model into a single GPU
-		if !envconfig.SchedSpread {
-			for _, g := range sgl {
-				if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
-					slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
-					return []gpu.GpuInfo{g}
+		for _, p := range numParallelToTry {
+			req.opts.NumCtx = req.origNumCtx * p
+			if !envconfig.SchedSpread {
+				for _, g := range sgl {
+					if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
+						slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
+						*numParallel = p
+						return []gpu.GpuInfo{g}
+					}
 				}
 			}
 		}
@@ -636,9 +711,13 @@ func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu.
 		// - try subsets of GPUs instead of just falling back to 1 or all in a family
 
 		// Now try all the GPUs
-		if ok, estimatedVRAM = llm.PredictServerFit(sgl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
-			slog.Debug("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "required", format.HumanBytes2(estimatedVRAM))
-			return sgl
+		for _, p := range numParallelToTry {
+			req.opts.NumCtx = req.origNumCtx * p
+			if ok, estimatedVRAM = llm.PredictServerFit(sgl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
+				slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
+				*numParallel = p
+				return sgl
+			}
 		}
 	}
 	return nil

+ 76 - 42
server/sched_test.go

@@ -44,14 +44,14 @@ func TestLoad(t *testing.T) {
 		opts:            api.DefaultOptions(),
 		successCh:       make(chan *runnerRef, 1),
 		errCh:           make(chan error, 1),
-		sessionDuration: 2,
+		sessionDuration: &api.Duration{Duration: 2 * time.Second},
 	}
 	// Fail to load model first
-	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
+	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
 		return nil, fmt.Errorf("something failed to load model blah")
 	}
 	gpus := gpu.GpuInfoList{}
-	s.load(req, ggml, gpus)
+	s.load(req, ggml, gpus, 0)
 	require.Empty(t, req.successCh)
 	require.Len(t, req.errCh, 1)
 	s.loadedMu.Lock()
@@ -61,10 +61,10 @@ func TestLoad(t *testing.T) {
 	require.Contains(t, err.Error(), "this model may be incompatible")
 
 	server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
-	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
+	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
 		return server, nil
 	}
-	s.load(req, ggml, gpus)
+	s.load(req, ggml, gpus, 0)
 	select {
 	case err := <-req.errCh:
 		require.NoError(t, err)
@@ -78,12 +78,12 @@ func TestLoad(t *testing.T) {
 
 	req.model.ModelPath = "dummy_model_path"
 	server.waitResp = fmt.Errorf("wait failure")
-	s.load(req, ggml, gpus)
+	s.load(req, ggml, gpus, 0)
 	select {
 	case err := <-req.errCh:
 		require.Contains(t, err.Error(), "wait failure")
 	case resp := <-req.successCh:
-		t.Errorf("unexpected success %v", resp)
+		t.Fatalf("unexpected success %v", resp)
 	}
 	s.loadedMu.Lock()
 	runner := s.loaded["dummy_model_path"]
@@ -102,7 +102,7 @@ type bundle struct {
 	ggml    *llm.GGML
 }
 
-func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
+func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
 	return scenario.srv, nil
 }
 
@@ -128,21 +128,21 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
 		"tokenizer.ggml.scores":         []float32{0},
 		"tokenizer.ggml.token_type":     []int32{0},
 	}, []llm.Tensor{
-		{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
-		{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+		{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
+		{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
 	})
 	require.NoError(t, err)
 
 	fname := f.Name()
 	model := &Model{Name: modelName, ModelPath: fname}
-	scenario.ggml, err = llm.LoadModel(model.ModelPath)
+	scenario.ggml, err = llm.LoadModel(model.ModelPath, 0)
 	require.NoError(t, err)
 
 	scenario.req = &LlmRequest{
 		ctx:             scenario.ctx,
 		model:           model,
 		opts:            api.DefaultOptions(),
-		sessionDuration: 5 * time.Millisecond,
+		sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
 		successCh:       make(chan *runnerRef, 1),
 		errCh:           make(chan error, 1),
 	}
@@ -156,18 +156,18 @@ func TestRequests(t *testing.T) {
 
 	// Same model, same request
 	scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
-	scenario1a.req.sessionDuration = 5 * time.Millisecond
+	scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
 	scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
 	scenario1b.req.model = scenario1a.req.model
 	scenario1b.ggml = scenario1a.ggml
-	scenario1b.req.sessionDuration = 0
+	scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
 
 	// simple reload of same model
 	scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
 	tmpModel := *scenario1a.req.model
 	scenario2a.req.model = &tmpModel
 	scenario2a.ggml = scenario1a.ggml
-	scenario2a.req.sessionDuration = 5 * time.Millisecond
+	scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
 
 	// Multiple loaded models
 	scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
@@ -199,8 +199,10 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario1a.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario1a.req.errCh)
+	case err := <-scenario1a.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 
 	// Same runner as first request due to not needing a reload
@@ -212,8 +214,10 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario1a.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario1b.req.errCh)
+	case err := <-scenario1b.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 
 	// Trigger a reload
@@ -230,8 +234,10 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario2a.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario2a.req.errCh)
+	case err := <-scenario2a.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 
 	envconfig.MaxRunners = 1
@@ -246,8 +252,10 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario3a.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario3a.req.errCh)
+	case err := <-scenario3a.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 	s.loadedMu.Lock()
 	require.Len(t, s.loaded, 1)
@@ -262,8 +270,10 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario3b.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario3b.req.errCh)
+	case err := <-scenario3b.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 	s.loadedMu.Lock()
 	require.Len(t, s.loaded, 2)
@@ -278,8 +288,10 @@ func TestRequests(t *testing.T) {
 		require.Equal(t, resp.llama, scenario3c.srv)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario3c.req.errCh)
+	case err := <-scenario3c.req.errCh:
+		t.Fatal(err.Error())
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 	s.loadedMu.Lock()
 	require.Len(t, s.loaded, 3)
@@ -306,7 +318,7 @@ func TestRequests(t *testing.T) {
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario3d.req.errCh)
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 	s.loadedMu.Lock()
 	require.Len(t, s.loaded, 2)
@@ -318,11 +330,11 @@ func TestGetRunner(t *testing.T) {
 	defer done()
 
 	scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
-	scenario1a.req.sessionDuration = 0
+	scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
 	scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
-	scenario1b.req.sessionDuration = 0
+	scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
 	scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
-	scenario1c.req.sessionDuration = 0
+	scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
 	envconfig.MaxQueuedRequests = 1
 	s := InitScheduler(ctx)
 	s.getGpuFn = func() gpu.GpuInfoList {
@@ -349,7 +361,7 @@ func TestGetRunner(t *testing.T) {
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, errCh1a)
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 	scenario1a.ctxDone()
 	s.loadedMu.Lock()
@@ -400,9 +412,9 @@ func TestPrematureExpired(t *testing.T) {
 		slog.Info("sending premature expired event now")
 		s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
-	time.Sleep(scenario1a.req.sessionDuration)
+	time.Sleep(scenario1a.req.sessionDuration.Duration)
 	scenario1a.ctxDone()
 	time.Sleep(20 * time.Millisecond)
 	require.LessOrEqual(t, len(s.finishedReqCh), 1)
@@ -423,11 +435,11 @@ func TestUseLoadedRunner(t *testing.T) {
 		ctx:             ctx,
 		opts:            api.DefaultOptions(),
 		successCh:       make(chan *runnerRef, 1),
-		sessionDuration: 2,
+		sessionDuration: &api.Duration{Duration: 2},
 	}
 	finished := make(chan *LlmRequest)
 	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
-	r1 := &runnerRef{llama: llm1, sessionDuration: 1}
+	r1 := &runnerRef{llama: llm1, sessionDuration: 1, numParallel: 1}
 	req.useLoadedRunner(r1, finished)
 	require.Equal(t, uint(1), r1.refCount)
 	require.Equal(t, time.Duration(2), r1.sessionDuration)
@@ -435,7 +447,7 @@ func TestUseLoadedRunner(t *testing.T) {
 	case success := <-req.successCh:
 		require.Equal(t, r1, success)
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 	done()
 	fin := <-finished
@@ -461,8 +473,8 @@ func TestUpdateFreeSpace(t *testing.T) {
 	gpus[1].FreeMemory = 1900
 	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 50, "2": 50}}
 	llm2 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 125, "2": 75}}
-	r1 := &runnerRef{llama: llm1, gpus: gpus}
-	r2 := &runnerRef{llama: llm2, gpus: gpus}
+	r1 := &runnerRef{llama: llm1, gpus: gpus, numParallel: 1}
+	r2 := &runnerRef{llama: llm2, gpus: gpus, numParallel: 1}
 
 	s := InitScheduler(ctx)
 	s.loadedMu.Lock()
@@ -513,8 +525,8 @@ func TestFindRunnerToUnload(t *testing.T) {
 	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
 	defer done()
 
-	r1 := &runnerRef{refCount: 1, sessionDuration: 1}
-	r2 := &runnerRef{sessionDuration: 2}
+	r1 := &runnerRef{refCount: 1, sessionDuration: 1, numParallel: 1}
+	r2 := &runnerRef{sessionDuration: 2, numParallel: 1}
 
 	s := InitScheduler(ctx)
 	s.loadedMu.Lock()
@@ -536,9 +548,13 @@ func TestNeedsReload(t *testing.T) {
 	llm := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
 	do := api.DefaultOptions()
 	runner := &runnerRef{
-		model:   &Model{AdapterPaths: []string{"adapter1"}, ProjectorPaths: []string{"projector1"}},
-		Options: &do,
-		llama:   llm,
+		model: &Model{
+			AdapterPaths:   []string{"adapter1"},
+			ProjectorPaths: []string{"projector1"},
+		},
+		Options:     &do,
+		llama:       llm,
+		numParallel: 1,
 	}
 	req := &LlmRequest{
 		model: &Model{
@@ -581,8 +597,8 @@ func TestUnloadAllRunners(t *testing.T) {
 	s := InitScheduler(ctx)
 	s.unloadAllRunners()
 
-	r1 := &runnerRef{llama: llm1}
-	r2 := &runnerRef{llama: llm2}
+	r1 := &runnerRef{llama: llm1, numParallel: 1}
+	r2 := &runnerRef{llama: llm2, numParallel: 1}
 
 	s.loadedMu.Lock()
 	s.loaded["a"] = r1
@@ -596,14 +612,32 @@ func TestUnloadAllRunners(t *testing.T) {
 
 func TestUnload(t *testing.T) {
 	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
-	r1 := &runnerRef{llama: llm1}
-	r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}}
+	r1 := &runnerRef{llama: llm1, numParallel: 1}
+	r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}, numParallel: 1}
 	r1.unload()
 	require.True(t, llm1.closeCalled)
 	r2.unload()
 	require.Nil(t, r2.model)
 }
 
+func TestAlreadyCanceled(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
+	defer done()
+	dctx, done2 := context.WithCancel(ctx)
+	done2()
+	scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
+	scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
+	s := InitScheduler(ctx)
+	slog.Info("scenario1a")
+	s.pendingReqCh <- scenario1a.req
+	require.Len(t, s.pendingReqCh, 1)
+	s.Run(ctx)
+	time.Sleep(5 * time.Millisecond)
+	require.Empty(t, s.pendingReqCh)
+	require.Empty(t, scenario1a.req.errCh)
+	require.Empty(t, scenario1a.req.successCh)
+}
+
 type mockLlm struct {
 	pingResp           error
 	waitResp           error

+ 0 - 0
templates/alfred.gotmpl → template/alfred.gotmpl


+ 2 - 1
templates/alpaca.gotmpl → template/alpaca.gotmpl

@@ -4,4 +4,5 @@
 {{ .Prompt }}
 
 {{ end }}### Response:
-{{ .Response }}
+{{ .Response }}
+

+ 1 - 1
templates/chatml.gotmpl → template/chatml.gotmpl

@@ -3,4 +3,4 @@
 {{ end }}{{ if .Prompt }}<|im_start|>user
 {{ .Prompt }}<|im_end|>
 {{ end }}<|im_start|>assistant
-{{ .Response }}<|im_end|>
+{{ .Response }}<|im_end|>

+ 2 - 1
templates/chatqa.gotmpl → template/chatqa.gotmpl

@@ -2,4 +2,5 @@
 
 {{ end }}{{ if .Prompt }}User: {{ .Prompt }}
 
-{{ end }}Assistant: <|begin_of_text|>{{ .Response }}
+{{ end }}Assistant: {{ .Response }}
+

+ 10 - 0
template/codellama-70b-instruct.gotmpl

@@ -0,0 +1,10 @@
+{{ if .System }}Source: system
+
+ {{ .System }} <step> {{ end }}Source: user
+
+ {{ .Prompt }} <step> Source: assistant
+{{- if not .Response }}
+Destination: user
+{{- end }}
+
+ {{ .Response }} <step> 

+ 5 - 0
template/falcon-instruct.gotmpl

@@ -0,0 +1,5 @@
+{{ if .System }}System: {{ .System }}
+{{ end }}{{ if .Prompt }}User:
+{{ .Prompt }}
+{{ end }}Falcon:
+{{ .Response }}

+ 5 - 0
template/gemma-instruct.gotmpl

@@ -0,0 +1,5 @@
+<start_of_turn>user
+{{ if .System }}{{ .System }}
+{{ end }}{{ .Prompt }}<end_of_turn>
+<start_of_turn>model
+{{ .Response }}<end_of_turn>

+ 3 - 3
templates/granite-instruct.gotmpl → template/granite-instruct.gotmpl

@@ -1,9 +1,9 @@
-{{ if .System }}
-System:
+{{ if .System }}System:
 {{ .System }}
 
 {{ end }}{{ if .Prompt }}Question:
 {{ .Prompt }}
 
 {{ end }}Answer:
-{{ .Response }}
+{{ .Response }}
+

+ 0 - 0
templates/index.json → template/index.json


+ 6 - 0
template/llama2-chat.gotmpl

@@ -0,0 +1,6 @@
+[INST] <<SYS>>
+{{- if .System }}
+{{ .System }}
+{{ end }}<</SYS>>
+
+{{ .Prompt }} [/INST] {{ .Response }}</s><s>

+ 0 - 0
templates/llama3-instruct.gotmpl → template/llama3-instruct.gotmpl


+ 2 - 1
templates/magicoder.gotmpl → template/magicoder.gotmpl

@@ -4,4 +4,5 @@
 {{ .Prompt }}
 
 {{ end }}@@ Response
-{{ .Response }}
+{{ .Response }}
+

+ 3 - 0
template/mistral-instruct.gotmpl

@@ -0,0 +1,3 @@
+[INST] {{ if .System }}{{ .System }}
+
+{{ end }}{{ .Prompt }}[/INST] {{ .Response }}</s>

+ 1 - 0
template/openchat.gotmpl

@@ -0,0 +1 @@
+{{ if .System }}GPT4 Correct System: {{ .System }}<|end_of_turn|>{{ end }}GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>

+ 1 - 1
templates/phi-3.gotmpl → template/phi-3.gotmpl

@@ -3,4 +3,4 @@
 {{ end }}{{ if .Prompt }}<|user|>
 {{ .Prompt }}<|end|>
 {{ end }}<|assistant|>
-{{ .Response }}<|end|>
+{{ .Response }}<|end|>

+ 2 - 1
templates/solar-instruct.gotmpl → template/solar-instruct.gotmpl

@@ -5,4 +5,5 @@
 {{ .Prompt }}
 
 {{ end }}### Assistant:
-{{ .Response }}
+{{ .Response }}</s>
+

+ 0 - 1
templates/starcoder2-instruct.gotmpl → template/starcoder2-instruct.gotmpl

@@ -3,7 +3,6 @@
 {{ end }}{{ if .Prompt }}### Instruction
 {{ .Prompt }}
 
-
 {{ end }}### Response
 {{ .Response }}<|endoftext|>
 

+ 362 - 0
template/template.go

@@ -0,0 +1,362 @@
+package template
+
+import (
+	"bytes"
+	"embed"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"io"
+	"math"
+	"slices"
+	"strings"
+	"sync"
+	"text/template"
+	"text/template/parse"
+
+	"github.com/agnivade/levenshtein"
+	"github.com/ollama/ollama/api"
+	"golang.org/x/exp/maps"
+)
+
+//go:embed index.json
+var indexBytes []byte
+
+//go:embed *.gotmpl
+var templatesFS embed.FS
+
+var templatesOnce = sync.OnceValues(func() ([]*named, error) {
+	var templates []*named
+	if err := json.Unmarshal(indexBytes, &templates); err != nil {
+		return nil, err
+	}
+
+	for _, t := range templates {
+		bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
+		if err != nil {
+			return nil, err
+		}
+
+		// normalize line endings
+		t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
+	}
+
+	return templates, nil
+})
+
+type named struct {
+	Name     string `json:"name"`
+	Template string `json:"template"`
+	Bytes    []byte
+}
+
+func (t named) Reader() io.Reader {
+	return bytes.NewReader(t.Bytes)
+}
+
+func Named(s string) (*named, error) {
+	templates, err := templatesOnce()
+	if err != nil {
+		return nil, err
+	}
+
+	var template *named
+	score := math.MaxInt
+	for _, t := range templates {
+		if s := levenshtein.ComputeDistance(s, t.Template); s < score {
+			score = s
+			template = t
+		}
+	}
+
+	if score < 100 {
+		return template, nil
+	}
+
+	return nil, errors.New("no matching template found")
+}
+
+var DefaultTemplate, _ = Parse("{{ .Prompt }}")
+
+type Template struct {
+	*template.Template
+	raw string
+}
+
+// response is a template node that can be added to templates that don't already have one
+var response = parse.ActionNode{
+	NodeType: parse.NodeAction,
+	Pipe: &parse.PipeNode{
+		NodeType: parse.NodePipe,
+		Cmds: []*parse.CommandNode{
+			{
+				NodeType: parse.NodeCommand,
+				Args: []parse.Node{
+					&parse.FieldNode{
+						NodeType: parse.NodeField,
+						Ident:    []string{"Response"},
+					},
+				},
+			},
+		},
+	},
+}
+
+func Parse(s string) (*Template, error) {
+	tmpl := template.New("").Option("missingkey=zero")
+
+	tmpl, err := tmpl.Parse(s)
+	if err != nil {
+		return nil, err
+	}
+
+	t := Template{Template: tmpl, raw: s}
+	if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
+		// touch up the template and append {{ .Response }}
+		tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
+	}
+
+	return &t, nil
+}
+
+func (t *Template) String() string {
+	return t.raw
+}
+
+func (t *Template) Vars() []string {
+	var vars []string
+	for _, tt := range t.Templates() {
+		for _, n := range tt.Root.Nodes {
+			vars = append(vars, parseNode(n)...)
+		}
+	}
+
+	set := make(map[string]struct{})
+	for _, n := range vars {
+		set[strings.ToLower(n)] = struct{}{}
+	}
+
+	vars = maps.Keys(set)
+	slices.Sort(vars)
+	return vars
+}
+
+type Values struct {
+	Messages []api.Message
+
+	// forceLegacy is a flag used to test compatibility with legacy templates
+	forceLegacy bool
+}
+
+func (t *Template) Execute(w io.Writer, v Values) error {
+	system, collated := collate(v.Messages)
+	if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
+		return t.Template.Execute(w, map[string]any{
+			"System":   system,
+			"Messages": collated,
+		})
+	}
+
+	var b bytes.Buffer
+	var prompt, response string
+	for i, m := range collated {
+		switch m.Role {
+		case "system":
+			system = m.Content
+		case "user":
+			prompt = m.Content
+		case "assistant":
+			response = m.Content
+		}
+
+		if i != len(collated)-1 && prompt != "" && response != "" {
+			if err := t.Template.Execute(&b, map[string]any{
+				"System":   system,
+				"Prompt":   prompt,
+				"Response": response,
+			}); err != nil {
+				return err
+			}
+
+			system = ""
+			prompt = ""
+			response = ""
+		}
+	}
+
+	var cut bool
+	nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
+		switch t := n.(type) {
+		case *parse.ActionNode:
+		case *parse.FieldNode:
+			if slices.Contains(t.Ident, "Response") {
+				cut = true
+			}
+		}
+
+		return cut
+	})
+
+	tree := parse.Tree{Root: nodes.(*parse.ListNode)}
+	if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
+		"System": "",
+		"Prompt": prompt,
+	}); err != nil {
+		return err
+	}
+
+	_, err := io.Copy(w, &b)
+	return err
+}
+
+// collate messages based on role. consecutive messages of the same role are merged
+// into a single message. collate also collects and returns all system messages.
+// collate mutates message content adding image tags ([img-%d]) as needed
+func collate(msgs []api.Message) (string, []*api.Message) {
+	var n int
+
+	var system []string
+	var collated []*api.Message
+	for i := range msgs {
+		msg := msgs[i]
+		for range msg.Images {
+			imageTag := fmt.Sprintf("[img-%d]", n)
+			if !strings.Contains(msg.Content, "[img]") {
+				msg.Content = strings.TrimSpace("[img] " + msg.Content)
+			}
+
+			msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
+			n++
+		}
+
+		if msg.Role == "system" {
+			system = append(system, msg.Content)
+		}
+
+		if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
+			collated[len(collated)-1].Content += "\n\n" + msg.Content
+		} else {
+			collated = append(collated, &msg)
+		}
+	}
+
+	return strings.Join(system, "\n\n"), collated
+}
+
+func parseNode(n parse.Node) []string {
+	switch n := n.(type) {
+	case *parse.ActionNode:
+		return parseNode(n.Pipe)
+	case *parse.IfNode:
+		names := parseNode(n.Pipe)
+		names = append(names, parseNode(n.List)...)
+		if n.ElseList != nil {
+			names = append(names, parseNode(n.ElseList)...)
+		}
+		return names
+	case *parse.RangeNode:
+		names := parseNode(n.Pipe)
+		names = append(names, parseNode(n.List)...)
+		if n.ElseList != nil {
+			names = append(names, parseNode(n.ElseList)...)
+		}
+		return names
+	case *parse.WithNode:
+		names := parseNode(n.Pipe)
+		names = append(names, parseNode(n.List)...)
+		if n.ElseList != nil {
+			names = append(names, parseNode(n.ElseList)...)
+		}
+		return names
+	case *parse.PipeNode:
+		var names []string
+		for _, c := range n.Cmds {
+			for _, a := range c.Args {
+				names = append(names, parseNode(a)...)
+			}
+		}
+		return names
+	case *parse.ListNode:
+		var names []string
+		for _, n := range n.Nodes {
+			names = append(names, parseNode(n)...)
+		}
+
+		return names
+	case *parse.FieldNode:
+		return n.Ident
+	case *parse.TemplateNode:
+		return parseNode(n.Pipe)
+	}
+
+	return nil
+}
+
+// deleteNode walks the node list and deletes nodes that match the predicate
+// this is currently to remove the {{ .Response }} node from templates
+func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
+	var walk func(n parse.Node) parse.Node
+	walk = func(n parse.Node) parse.Node {
+		if fn(n) {
+			return nil
+		}
+
+		switch t := n.(type) {
+		case *parse.ListNode:
+			var nodes []parse.Node
+			for _, c := range t.Nodes {
+				if n := walk(c); n != nil {
+					nodes = append(nodes, n)
+				}
+			}
+
+			t.Nodes = nodes
+			return t
+		case *parse.IfNode:
+			t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
+		case *parse.WithNode:
+			t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
+		case *parse.RangeNode:
+			t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
+		case *parse.BranchNode:
+			t.List = walk(t.List).(*parse.ListNode)
+			if t.ElseList != nil {
+				t.ElseList = walk(t.ElseList).(*parse.ListNode)
+			}
+		case *parse.ActionNode:
+			n := walk(t.Pipe)
+			if n == nil {
+				return nil
+			}
+
+			t.Pipe = n.(*parse.PipeNode)
+		case *parse.PipeNode:
+			var commands []*parse.CommandNode
+			for _, c := range t.Cmds {
+				var args []parse.Node
+				for _, a := range c.Args {
+					if n := walk(a); n != nil {
+						args = append(args, n)
+					}
+				}
+
+				if len(args) == 0 {
+					return nil
+				}
+
+				c.Args = args
+				commands = append(commands, c)
+			}
+
+			if len(commands) == 0 {
+				return nil
+			}
+
+			t.Cmds = commands
+		}
+
+		return n
+	}
+
+	return walk(n)
+}

+ 361 - 0
template/template_test.go

@@ -0,0 +1,361 @@
+package template
+
+import (
+	"bufio"
+	"bytes"
+	"encoding/json"
+	"io"
+	"os"
+	"path/filepath"
+	"slices"
+	"strings"
+	"testing"
+
+	"github.com/google/go-cmp/cmp"
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/llm"
+)
+
+func TestNamed(t *testing.T) {
+	f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer f.Close()
+
+	scanner := bufio.NewScanner(f)
+	for scanner.Scan() {
+		var ss map[string]string
+		if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
+			t.Fatal(err)
+		}
+
+		for k, v := range ss {
+			t.Run(k, func(t *testing.T) {
+				kv := llm.KV{"tokenizer.chat_template": v}
+				s := kv.ChatTemplate()
+				r, err := Named(s)
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				if r.Name != k {
+					t.Errorf("expected %q, got %q", k, r.Name)
+				}
+
+				var b bytes.Buffer
+				if _, err := io.Copy(&b, r.Reader()); err != nil {
+					t.Fatal(err)
+				}
+
+				tmpl, err := Parse(b.String())
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				if tmpl.Tree.Root.String() == "" {
+					t.Errorf("empty %s template", k)
+				}
+			})
+		}
+	}
+}
+
+func TestTemplate(t *testing.T) {
+	cases := make(map[string][]api.Message)
+	for _, mm := range [][]api.Message{
+		{
+			{Role: "user", Content: "Hello, how are you?"},
+		},
+		{
+			{Role: "user", Content: "Hello, how are you?"},
+			{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
+			{Role: "user", Content: "I'd like to show off how chat templating works!"},
+		},
+		{
+			{Role: "system", Content: "You are a helpful assistant."},
+			{Role: "user", Content: "Hello, how are you?"},
+			{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
+			{Role: "user", Content: "I'd like to show off how chat templating works!"},
+		},
+	} {
+		var roles []string
+		for _, m := range mm {
+			roles = append(roles, m.Role)
+		}
+
+		cases[strings.Join(roles, "-")] = mm
+	}
+
+	matches, err := filepath.Glob("*.gotmpl")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	for _, match := range matches {
+		t.Run(match, func(t *testing.T) {
+			bts, err := os.ReadFile(match)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			tmpl, err := Parse(string(bts))
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			for n, tt := range cases {
+				var actual bytes.Buffer
+				t.Run(n, func(t *testing.T) {
+					if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
+						t.Fatal(err)
+					}
+
+					expect, err := os.ReadFile(filepath.Join("testdata", match, n))
+					if err != nil {
+						t.Fatal(err)
+					}
+
+					bts := actual.Bytes()
+
+					if slices.Contains([]string{"chatqa.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && bts[len(bts)-1] == ' ' {
+						t.Log("removing trailing space from output")
+						bts = bts[:len(bts)-1]
+					}
+
+					if diff := cmp.Diff(bts, expect); diff != "" {
+						t.Errorf("mismatch (-got +want):\n%s", diff)
+					}
+				})
+
+				t.Run("legacy", func(t *testing.T) {
+					t.Skip("legacy outputs are currently default outputs")
+					var legacy bytes.Buffer
+					if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil {
+						t.Fatal(err)
+					}
+
+					legacyBytes := legacy.Bytes()
+					if slices.Contains([]string{"chatqa.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && legacyBytes[len(legacyBytes)-1] == ' ' {
+						t.Log("removing trailing space from legacy output")
+						legacyBytes = legacyBytes[:len(legacyBytes)-1]
+					} else if slices.Contains([]string{"codellama-70b-instruct.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl"}, match) {
+						t.Skip("legacy outputs cannot be compared to messages outputs")
+					}
+
+					if diff := cmp.Diff(legacyBytes, actual.Bytes()); diff != "" {
+						t.Errorf("mismatch (-got +want):\n%s", diff)
+					}
+				})
+			}
+		})
+	}
+}
+
+func TestParse(t *testing.T) {
+	cases := []struct {
+		template string
+		vars     []string
+	}{
+		{"{{ .Prompt }}", []string{"prompt", "response"}},
+		{"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
+		{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
+		{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
+		{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
+		{`{{- range .Messages }}
+{{- if eq .Role "system" }}SYSTEM:
+{{- else if eq .Role "user" }}USER:
+{{- else if eq .Role "assistant" }}ASSISTANT:
+{{- end }} {{ .Content }}
+{{- end }}`, []string{"content", "messages", "role"}},
+		{`{{- if .Messages }}
+{{- range .Messages }}<|im_start|>{{ .Role }}
+{{ .Content }}<|im_end|>
+{{ end }}<|im_start|>assistant
+{{ else -}}
+{{ if .System }}<|im_start|>system
+{{ .System }}<|im_end|>
+{{ end }}{{ if .Prompt }}<|im_start|>user
+{{ .Prompt }}<|im_end|>
+{{ end }}<|im_start|>assistant
+{{ .Response }}<|im_end|>
+{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
+	}
+
+	for _, tt := range cases {
+		t.Run("", func(t *testing.T) {
+			tmpl, err := Parse(tt.template)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" {
+				t.Errorf("mismatch (-got +want):\n%s", diff)
+			}
+		})
+	}
+}
+
+func TestExecuteWithMessages(t *testing.T) {
+	type template struct {
+		name     string
+		template string
+	}
+	cases := []struct {
+		name      string
+		templates []template
+		values    Values
+		expected  string
+	}{
+		{
+			"mistral",
+			[]template{
+				{"no response", `[INST] {{ if .System }}{{ .System }}
+
+{{ end }}{{ .Prompt }}[/INST] `},
+				{"response", `[INST] {{ if .System }}{{ .System }}
+
+{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
+				{"messages", `[INST] {{ if .System }}{{ .System }}
+
+{{ end }}
+{{- range .Messages }}
+{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
+{{- end }}`},
+			},
+			Values{
+				Messages: []api.Message{
+					{Role: "user", Content: "Hello friend!"},
+					{Role: "assistant", Content: "Hello human!"},
+					{Role: "user", Content: "What is your name?"},
+				},
+			},
+			`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
+		},
+		{
+			"mistral system",
+			[]template{
+				{"no response", `[INST] {{ if .System }}{{ .System }}
+
+{{ end }}{{ .Prompt }}[/INST] `},
+				{"response", `[INST] {{ if .System }}{{ .System }}
+
+{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
+				{"messages", `[INST] {{ if .System }}{{ .System }}
+
+{{ end }}
+{{- range .Messages }}
+{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
+{{- end }}`},
+			},
+			Values{
+				Messages: []api.Message{
+					{Role: "system", Content: "You are a helpful assistant!"},
+					{Role: "user", Content: "Hello friend!"},
+					{Role: "assistant", Content: "Hello human!"},
+					{Role: "user", Content: "What is your name?"},
+				},
+			},
+			`[INST] You are a helpful assistant!
+
+Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
+		},
+		{
+			"chatml",
+			[]template{
+				// this does not have a "no response" test because it's impossible to render the same output
+				{"response", `{{ if .System }}<|im_start|>system
+{{ .System }}<|im_end|>
+{{ end }}{{ if .Prompt }}<|im_start|>user
+{{ .Prompt }}<|im_end|>
+{{ end }}<|im_start|>assistant
+{{ .Response }}<|im_end|>
+`},
+				{"messages", `
+{{- range $index, $_ := .Messages }}<|im_start|>{{ .Role }}
+{{ .Content }}<|im_end|>
+{{ end }}<|im_start|>assistant
+`},
+			},
+			Values{
+				Messages: []api.Message{
+					{Role: "system", Content: "You are a helpful assistant!"},
+					{Role: "user", Content: "Hello friend!"},
+					{Role: "assistant", Content: "Hello human!"},
+					{Role: "user", Content: "What is your name?"},
+				},
+			},
+			`<|im_start|>system
+You are a helpful assistant!<|im_end|>
+<|im_start|>user
+Hello friend!<|im_end|>
+<|im_start|>assistant
+Hello human!<|im_end|>
+<|im_start|>user
+What is your name?<|im_end|>
+<|im_start|>assistant
+`,
+		},
+		{
+			"moondream",
+			[]template{
+				// this does not have a "no response" test because it's impossible to render the same output
+				{"response", `{{ if .Prompt }}Question: {{ .Prompt }}
+
+{{ end }}Answer: {{ .Response }}
+
+`},
+				{"messages", `
+{{- range .Messages }}
+{{- if eq .Role "user" }}Question: {{ .Content }}
+
+{{ else if eq .Role "assistant" }}Answer: {{ .Content }}
+
+{{ end }}
+{{- end }}Answer: `},
+			},
+			Values{
+				Messages: []api.Message{
+					{Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}},
+					{Role: "assistant", Content: "It's a hot dog."},
+					{Role: "user", Content: "What's in _this_ image?"},
+					{Role: "user", Images: []api.ImageData{[]byte("")}},
+					{Role: "user", Content: "Is it a hot dog?"},
+				},
+			},
+			`Question: [img-0] What's in this image?
+
+Answer: It's a hot dog.
+
+Question: What's in _this_ image?
+
+[img-1]
+
+Is it a hot dog?
+
+Answer: `,
+		},
+	}
+
+	for _, tt := range cases {
+		t.Run(tt.name, func(t *testing.T) {
+			for _, ttt := range tt.templates {
+				t.Run(ttt.name, func(t *testing.T) {
+					tmpl, err := Parse(ttt.template)
+					if err != nil {
+						t.Fatal(err)
+					}
+
+					var b bytes.Buffer
+					if err := tmpl.Execute(&b, tt.values); err != nil {
+						t.Fatal(err)
+					}
+
+					if diff := cmp.Diff(b.String(), tt.expected); diff != "" {
+						t.Errorf("mismatch (-got +want):\n%s", diff)
+					}
+				})
+			}
+		})
+	}
+}

+ 1 - 0
template/testdata/alfred.gotmpl/system-user-assistant-user

@@ -0,0 +1 @@
+<start_system>You are a helpful assistant.<end_message><start_user>Hello, how are you?<end_message><start_assistant>I'm doing great. How can I help you today?<end_message><start_user>I'd like to show off how chat templating works!<end_message><start_assistant>

+ 1 - 0
template/testdata/alfred.gotmpl/user

@@ -0,0 +1 @@
+<start_user>Hello, how are you?<end_message><start_assistant>

+ 1 - 0
template/testdata/alfred.gotmpl/user-assistant-user

@@ -0,0 +1 @@
+<start_user>Hello, how are you?<end_message><start_assistant>I'm doing great. How can I help you today?<end_message><start_user>I'd like to show off how chat templating works!<end_message><start_assistant>

+ 12 - 0
template/testdata/alpaca.gotmpl/system-user-assistant-user

@@ -0,0 +1,12 @@
+You are a helpful assistant.
+
+### Instruction:
+Hello, how are you?
+
+### Response:
+I'm doing great. How can I help you today?
+
+### Instruction:
+I'd like to show off how chat templating works!
+
+### Response:

+ 4 - 0
template/testdata/alpaca.gotmpl/user

@@ -0,0 +1,4 @@
+### Instruction:
+Hello, how are you?
+
+### Response:

Some files were not shown because too many files changed in this diff