Parcourir la source

Merge branch 'main' into royh-batchembed

royjhan il y a 10 mois
Parent
commit
a5f23d766e
84 fichiers modifiés avec 2731 ajouts et 682 suppressions
  1. 2 2
      Dockerfile
  2. 9 2
      README.md
  3. 31 8
      api/types.go
  4. 63 0
      api/types_test.go
  5. 32 0
      app/lifecycle/logging.go
  6. 44 0
      app/lifecycle/logging_test.go
  7. 6 5
      app/lifecycle/paths.go
  8. 1 1
      app/lifecycle/server.go
  9. 6 1
      app/ollama.iss
  10. 166 52
      cmd/cmd.go
  11. 14 47
      cmd/interactive.go
  12. 33 6
      docs/api.md
  13. 16 0
      docs/faq.md
  14. 1 1
      docs/gpu.md
  15. 1 1
      docs/openai.md
  16. 1 1
      docs/troubleshooting.md
  17. 2 2
      docs/windows.md
  18. 8 8
      envconfig/config.go
  19. 3 2
      gpu/amd_windows.go
  20. 19 12
      gpu/assets.go
  21. 8 2
      gpu/gpu.go
  22. 1 1
      gpu/gpu_info_cudart.c
  23. 1 1
      gpu/gpu_info_nvcuda.c
  24. 1 1
      gpu/gpu_info_nvml.c
  25. 1 1
      gpu/gpu_info_oneapi.c
  26. 5 0
      gpu/types.go
  27. 37 26
      llm/ext_server/server.cpp
  28. 18 14
      llm/generate/gen_windows.ps1
  29. 11 2
      llm/ggla.go
  30. 71 14
      llm/ggml.go
  31. 1 0
      llm/ggml_test.go
  32. 92 38
      llm/gguf.go
  33. 2 2
      llm/memory.go
  34. 11 8
      llm/memory_test.go
  35. 305 0
      llm/patches/07-gemma.diff
  36. 2 2
      llm/payload.go
  37. 32 27
      llm/server.go
  38. 1 0
      llm/status.go
  39. 371 13
      openai/openai.go
  40. 298 0
      openai/openai_test.go
  41. 2 2
      parser/parser.go
  42. 67 3
      parser/parser_test.go
  43. 5 5
      scripts/build_windows.ps1
  44. 1 1
      scripts/install.sh
  45. 11 0
      scripts/rh_linux_deps.sh
  46. 58 32
      server/images.go
  47. 11 9
      server/manifest.go
  48. 1 1
      server/manifest_test.go
  49. 41 24
      server/model.go
  50. 112 0
      server/model_test.go
  51. 7 11
      server/prompt.go
  52. 13 2
      server/prompt_test.go
  53. 79 13
      server/routes.go
  54. 96 1
      server/routes_test.go
  55. 101 25
      server/sched.go
  56. 54 32
      server/sched_test.go
  57. 0 0
      template/alfred.gotmpl
  58. 0 0
      template/alpaca.gotmpl
  59. 0 0
      template/chatml.gotmpl
  60. 0 0
      template/chatqa.gotmpl
  61. 0 0
      template/codellama-70b-instruct.gotmpl
  62. 0 0
      template/falcon-instruct.gotmpl
  63. 0 0
      template/gemma-instruct.gotmpl
  64. 0 0
      template/granite-instruct.gotmpl
  65. 0 0
      template/index.json
  66. 0 0
      template/llama2-chat.gotmpl
  67. 0 0
      template/llama3-instruct.gotmpl
  68. 0 0
      template/magicoder.gotmpl
  69. 0 0
      template/mistral-instruct.gotmpl
  70. 0 0
      template/openchat.gotmpl
  71. 0 0
      template/phi-3.gotmpl
  72. 0 0
      template/solar-instruct.gotmpl
  73. 0 0
      template/starcoder2-instruct.gotmpl
  74. 158 0
      template/template.go
  75. 89 0
      template/template_test.go
  76. 0 0
      template/testdata/templates.jsonl
  77. 0 0
      template/vicuna.gotmpl
  78. 0 0
      template/zephyr.gotmpl
  79. 0 70
      templates/template.go
  80. 0 59
      templates/template_test.go
  81. 0 55
      types/model/name.go
  82. 0 34
      types/model/name_test.go
  83. 34 0
      util/bufioutil/buffer_seeker.go
  84. 64 0
      util/bufioutil/buffer_seeker_test.go

+ 2 - 2
Dockerfile

@@ -70,12 +70,12 @@ RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
 RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
 RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
 
 
-FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
+FROM --platform=linux/arm64 rockylinux:8 AS cpu-builder-arm64
 ARG CMAKE_VERSION
 ARG CMAKE_VERSION
 ARG GOLANG_VERSION
 ARG GOLANG_VERSION
 COPY ./scripts/rh_linux_deps.sh /
 COPY ./scripts/rh_linux_deps.sh /
 RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
 RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
-ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
+ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
 COPY --from=llm-code / /go/src/github.com/ollama/ollama/
 COPY --from=llm-code / /go/src/github.com/ollama/ollama/
 ARG OLLAMA_CUSTOM_CPU_DEFS
 ARG OLLAMA_CUSTOM_CPU_DEFS
 ARG CGO_CFLAGS
 ARG CGO_CFLAGS

+ 9 - 2
README.md

@@ -53,8 +53,8 @@ Here are some example models that can be downloaded:
 | Llama 3            | 70B        | 40GB  | `ollama run llama3:70b`        |
 | Llama 3            | 70B        | 40GB  | `ollama run llama3:70b`        |
 | Phi 3 Mini         | 3.8B       | 2.3GB | `ollama run phi3`              |
 | Phi 3 Mini         | 3.8B       | 2.3GB | `ollama run phi3`              |
 | Phi 3 Medium       | 14B        | 7.9GB | `ollama run phi3:medium`       |
 | Phi 3 Medium       | 14B        | 7.9GB | `ollama run phi3:medium`       |
-| Gemma              | 2B         | 1.4GB | `ollama run gemma:2b`          |
-| Gemma              | 7B         | 4.8GB | `ollama run gemma:7b`          |
+| Gemma 2            | 9B         | 5.5GB | `ollama run gemma2`            |
+| Gemma 2            | 27B        | 16GB  | `ollama run gemma2:27b`        |
 | Mistral            | 7B         | 4.1GB | `ollama run mistral`           |
 | Mistral            | 7B         | 4.1GB | `ollama run mistral`           |
 | Moondream 2        | 1.4B       | 829MB | `ollama run moondream`         |
 | Moondream 2        | 1.4B       | 829MB | `ollama run moondream`         |
 | Neural Chat        | 7B         | 4.1GB | `ollama run neural-chat`       |
 | Neural Chat        | 7B         | 4.1GB | `ollama run neural-chat`       |
@@ -182,6 +182,12 @@ $ ollama run llama3 "Summarize this file: $(cat README.md)"
  Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
  Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
 ```
 ```
 
 
+### Show model information
+
+```
+ollama show llama3
+```
+
 ### List models on your computer
 ### List models on your computer
 
 
 ```
 ```
@@ -286,6 +292,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
 - [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
 - [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
 - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
 - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
 - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
 - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
+- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
 
 
 ### Terminal
 ### Terminal
 
 

+ 31 - 8
api/types.go

@@ -277,6 +277,7 @@ type ShowRequest struct {
 	Model    string `json:"model"`
 	Model    string `json:"model"`
 	System   string `json:"system"`
 	System   string `json:"system"`
 	Template string `json:"template"`
 	Template string `json:"template"`
+	Verbose  bool   `json:"verbose"`
 
 
 	Options map[string]interface{} `json:"options"`
 	Options map[string]interface{} `json:"options"`
 
 
@@ -286,14 +287,16 @@ type ShowRequest struct {
 
 
 // ShowResponse is the response returned from [Client.Show].
 // ShowResponse is the response returned from [Client.Show].
 type ShowResponse struct {
 type ShowResponse struct {
-	License    string       `json:"license,omitempty"`
-	Modelfile  string       `json:"modelfile,omitempty"`
-	Parameters string       `json:"parameters,omitempty"`
-	Template   string       `json:"template,omitempty"`
-	System     string       `json:"system,omitempty"`
-	Details    ModelDetails `json:"details,omitempty"`
-	Messages   []Message    `json:"messages,omitempty"`
-	ModifiedAt time.Time    `json:"modified_at,omitempty"`
+	License       string         `json:"license,omitempty"`
+	Modelfile     string         `json:"modelfile,omitempty"`
+	Parameters    string         `json:"parameters,omitempty"`
+	Template      string         `json:"template,omitempty"`
+	System        string         `json:"system,omitempty"`
+	Details       ModelDetails   `json:"details,omitempty"`
+	Messages      []Message      `json:"messages,omitempty"`
+	ModelInfo     map[string]any `json:"model_info,omitempty"`
+	ProjectorInfo map[string]any `json:"projector_info,omitempty"`
+	ModifiedAt    time.Time      `json:"modified_at,omitempty"`
 }
 }
 
 
 // CopyRequest is the request passed to [Client.Copy].
 // CopyRequest is the request passed to [Client.Copy].
@@ -366,6 +369,13 @@ type ProcessModelResponse struct {
 	SizeVRAM  int64        `json:"size_vram"`
 	SizeVRAM  int64        `json:"size_vram"`
 }
 }
 
 
+type RetrieveModelResponse struct {
+	Id      string `json:"id"`
+	Object  string `json:"object"`
+	Created int64  `json:"created"`
+	OwnedBy string `json:"owned_by"`
+}
+
 type TokenResponse struct {
 type TokenResponse struct {
 	Token string `json:"token"`
 	Token string `json:"token"`
 }
 }
@@ -629,6 +639,19 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
 		} else {
 		} else {
 			field := valueOpts.FieldByName(opt.Name)
 			field := valueOpts.FieldByName(opt.Name)
 			if field.IsValid() && field.CanSet() {
 			if field.IsValid() && field.CanSet() {
+				if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) {
+					boolVal, err := strconv.ParseBool(vals[0])
+					if err != nil {
+						return nil, fmt.Errorf("invalid bool value %s", vals)
+					}
+					if boolVal {
+						out[key] = TriStateTrue
+					} else {
+						out[key] = TriStateFalse
+					}
+					continue
+				}
+
 				switch field.Kind() {
 				switch field.Kind() {
 				case reflect.Float32:
 				case reflect.Float32:
 					floatVal, err := strconv.ParseFloat(vals[0], 32)
 					floatVal, err := strconv.ParseFloat(vals[0], 32)

+ 63 - 0
api/types_test.go

@@ -2,6 +2,7 @@ package api
 
 
 import (
 import (
 	"encoding/json"
 	"encoding/json"
+	"fmt"
 	"math"
 	"math"
 	"testing"
 	"testing"
 	"time"
 	"time"
@@ -141,3 +142,65 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
 		})
 		})
 	}
 	}
 }
 }
+
+func TestUseMmapFormatParams(t *testing.T) {
+	tests := []struct {
+		name string
+		req  map[string][]string
+		exp  TriState
+		err  error
+	}{
+		{
+			name: "True",
+			req: map[string][]string{
+				"use_mmap": []string{"true"},
+			},
+			exp: TriStateTrue,
+			err: nil,
+		},
+		{
+			name: "False",
+			req: map[string][]string{
+				"use_mmap": []string{"false"},
+			},
+			exp: TriStateFalse,
+			err: nil,
+		},
+		{
+			name: "Numeric True",
+			req: map[string][]string{
+				"use_mmap": []string{"1"},
+			},
+			exp: TriStateTrue,
+			err: nil,
+		},
+		{
+			name: "Numeric False",
+			req: map[string][]string{
+				"use_mmap": []string{"0"},
+			},
+			exp: TriStateFalse,
+			err: nil,
+		},
+		{
+			name: "invalid string",
+			req: map[string][]string{
+				"use_mmap": []string{"foo"},
+			},
+			exp: TriStateUndefined,
+			err: fmt.Errorf("invalid bool value [foo]"),
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			resp, err := FormatParams(test.req)
+			require.Equal(t, err, test.err)
+			respVal, ok := resp["use_mmap"]
+			if test.exp != TriStateUndefined {
+				assert.True(t, ok, "resp: %v", resp)
+				assert.Equal(t, test.exp, respVal)
+			}
+		})
+	}
+}

+ 32 - 0
app/lifecycle/logging.go

@@ -5,6 +5,8 @@ import (
 	"log/slog"
 	"log/slog"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
+	"strconv"
+	"strings"
 
 
 	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/envconfig"
 )
 )
@@ -24,6 +26,7 @@ func InitLogging() {
 		logFile = os.Stderr
 		logFile = os.Stderr
 		// TODO - write one-line to the app.log file saying we're running in console mode to help avoid confusion
 		// TODO - write one-line to the app.log file saying we're running in console mode to help avoid confusion
 	} else {
 	} else {
+		rotateLogs(AppLogFile)
 		logFile, err = os.OpenFile(AppLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
 		logFile, err = os.OpenFile(AppLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
 		if err != nil {
 		if err != nil {
 			slog.Error(fmt.Sprintf("failed to create server log %v", err))
 			slog.Error(fmt.Sprintf("failed to create server log %v", err))
@@ -46,3 +49,32 @@ func InitLogging() {
 
 
 	slog.Info("ollama app started")
 	slog.Info("ollama app started")
 }
 }
+
+func rotateLogs(logFile string) {
+	if _, err := os.Stat(logFile); os.IsNotExist(err) {
+		return
+	}
+	index := strings.LastIndex(logFile, ".")
+	pre := logFile[:index]
+	post := "." + logFile[index+1:]
+	for i := LogRotationCount; i > 0; i-- {
+		older := pre + "-" + strconv.Itoa(i) + post
+		newer := pre + "-" + strconv.Itoa(i-1) + post
+		if i == 1 {
+			newer = pre + post
+		}
+		if _, err := os.Stat(newer); err == nil {
+			if _, err := os.Stat(older); err == nil {
+				err := os.Remove(older)
+				if err != nil {
+					slog.Warn("Failed to remove older log", "older", older, "error", err)
+					continue
+				}
+			}
+			err := os.Rename(newer, older)
+			if err != nil {
+				slog.Warn("Failed to rotate log", "older", older, "newer", newer, "error", err)
+			}
+		}
+	}
+}

+ 44 - 0
app/lifecycle/logging_test.go

@@ -0,0 +1,44 @@
+package lifecycle
+
+import (
+	"os"
+	"path/filepath"
+	"strconv"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestRotateLogs(t *testing.T) {
+	logDir := t.TempDir()
+	logFile := filepath.Join(logDir, "testlog.log")
+
+	// No log exists
+	rotateLogs(logFile)
+
+	require.NoError(t, os.WriteFile(logFile, []byte("1"), 0644))
+	assert.FileExists(t, logFile)
+	// First rotation
+	rotateLogs(logFile)
+	assert.FileExists(t, filepath.Join(logDir, "testlog-1.log"))
+	assert.NoFileExists(t, filepath.Join(logDir, "testlog-2.log"))
+	assert.NoFileExists(t, logFile)
+
+	// Should be a no-op without a new log
+	rotateLogs(logFile)
+	assert.FileExists(t, filepath.Join(logDir, "testlog-1.log"))
+	assert.NoFileExists(t, filepath.Join(logDir, "testlog-2.log"))
+	assert.NoFileExists(t, logFile)
+
+	for i := 2; i <= LogRotationCount+1; i++ {
+		require.NoError(t, os.WriteFile(logFile, []byte(strconv.Itoa(i)), 0644))
+		assert.FileExists(t, logFile)
+		rotateLogs(logFile)
+		assert.NoFileExists(t, logFile)
+		for j := 1; j < i; j++ {
+			assert.FileExists(t, filepath.Join(logDir, "testlog-"+strconv.Itoa(j)+".log"))
+		}
+		assert.NoFileExists(t, filepath.Join(logDir, "testlog-"+strconv.Itoa(i+1)+".log"))
+	}
+}

+ 6 - 5
app/lifecycle/paths.go

@@ -16,11 +16,12 @@ var (
 	AppDir     = "/opt/Ollama"
 	AppDir     = "/opt/Ollama"
 	AppDataDir = "/opt/Ollama"
 	AppDataDir = "/opt/Ollama"
 	// TODO - should there be a distinct log dir?
 	// TODO - should there be a distinct log dir?
-	UpdateStageDir = "/tmp"
-	AppLogFile     = "/tmp/ollama_app.log"
-	ServerLogFile  = "/tmp/ollama.log"
-	UpgradeLogFile = "/tmp/ollama_update.log"
-	Installer      = "OllamaSetup.exe"
+	UpdateStageDir   = "/tmp"
+	AppLogFile       = "/tmp/ollama_app.log"
+	ServerLogFile    = "/tmp/ollama.log"
+	UpgradeLogFile   = "/tmp/ollama_update.log"
+	Installer        = "OllamaSetup.exe"
+	LogRotationCount = 5
 )
 )
 
 
 func init() {
 func init() {

+ 1 - 1
app/lifecycle/server.go

@@ -54,7 +54,7 @@ func start(ctx context.Context, command string) (*exec.Cmd, error) {
 		return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
 		return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
 	}
 	}
 
 
-	// TODO - rotation
+	rotateLogs(ServerLogFile)
 	logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
 	logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("failed to create server log: %w", err)
 		return nil, fmt.Errorf("failed to create server log: %w", err)

+ 6 - 1
app/ollama.iss

@@ -88,10 +88,15 @@ DialogFontSize=12
 [Files]
 [Files]
 Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
 Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
 Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
 Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
-Source: "..\dist\windows-{#ARCH}\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
 Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
 Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
 Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
 Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
 Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
 Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
+#if DirExists("..\dist\windows-amd64\cuda")
+  Source: "..\dist\windows-amd64\cuda\*"; DestDir: "{app}\cuda\"; Flags: ignoreversion recursesubdirs
+#endif
+#if DirExists("..\dist\windows-amd64\oneapi")
+  Source: "..\dist\windows-amd64\oneapi\*"; DestDir: "{app}\oneapi\"; Flags: ignoreversion recursesubdirs
+#endif
 #if DirExists("..\dist\windows-amd64\rocm")
 #if DirExists("..\dist\windows-amd64\rocm")
   Source: "..\dist\windows-amd64\rocm\*"; DestDir: "{app}\rocm\"; Flags: ignoreversion recursesubdirs
   Source: "..\dist\windows-amd64\rocm\*"; DestDir: "{app}\rocm\"; Flags: ignoreversion recursesubdirs
 #endif
 #endif

+ 166 - 52
cmd/cmd.go

@@ -162,9 +162,6 @@ func tempZipFiles(path string) (string, error) {
 	}
 	}
 	defer tempfile.Close()
 	defer tempfile.Close()
 
 
-	zipfile := zip.NewWriter(tempfile)
-	defer zipfile.Close()
-
 	detectContentType := func(path string) (string, error) {
 	detectContentType := func(path string) (string, error) {
 		f, err := os.Open(path)
 		f, err := os.Open(path)
 		if err != nil {
 		if err != nil {
@@ -233,6 +230,9 @@ func tempZipFiles(path string) (string, error) {
 		files = append(files, tks...)
 		files = append(files, tks...)
 	}
 	}
 
 
+	zipfile := zip.NewWriter(tempfile)
+	defer zipfile.Close()
+
 	for _, file := range files {
 	for _, file := range files {
 		f, err := os.Open(file)
 		f, err := os.Open(file)
 		if err != nil {
 		if err != nil {
@@ -287,38 +287,12 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
 }
 }
 
 
 func RunHandler(cmd *cobra.Command, args []string) error {
 func RunHandler(cmd *cobra.Command, args []string) error {
-	client, err := api.ClientFromEnvironment()
-	if err != nil {
-		return err
-	}
-
-	name := args[0]
-
-	// check if the model exists on the server
-	show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name})
-	var statusError api.StatusError
-	switch {
-	case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
-		if err := PullHandler(cmd, []string{name}); err != nil {
-			return err
-		}
-
-		show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
-		if err != nil {
-			return err
-		}
-	case err != nil:
-		return err
-	}
-
 	interactive := true
 	interactive := true
 
 
 	opts := runOptions{
 	opts := runOptions{
-		Model:       args[0],
-		WordWrap:    os.Getenv("TERM") == "xterm-256color",
-		Options:     map[string]interface{}{},
-		MultiModal:  slices.Contains(show.Details.Families, "clip"),
-		ParentModel: show.Details.ParentModel,
+		Model:    args[0],
+		WordWrap: os.Getenv("TERM") == "xterm-256color",
+		Options:  map[string]interface{}{},
 	}
 	}
 
 
 	format, err := cmd.Flags().GetString("format")
 	format, err := cmd.Flags().GetString("format")
@@ -362,11 +336,38 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 	}
 	}
 	opts.WordWrap = !nowrap
 	opts.WordWrap = !nowrap
 
 
-	if !interactive {
-		return generate(cmd, opts)
+	// Fill out the rest of the options based on information about the
+	// model.
+	client, err := api.ClientFromEnvironment()
+	if err != nil {
+		return err
+	}
+
+	name := args[0]
+	info, err := func() (*api.ShowResponse, error) {
+		showReq := &api.ShowRequest{Name: name}
+		info, err := client.Show(cmd.Context(), showReq)
+		var se api.StatusError
+		if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
+			if err := PullHandler(cmd, []string{name}); err != nil {
+				return nil, err
+			}
+			return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
+		}
+		return info, err
+	}()
+	if err != nil {
+		return err
 	}
 	}
 
 
-	return generateInteractive(cmd, opts)
+	opts.MultiModal = slices.Contains(info.Details.Families, "clip")
+	opts.ParentModel = info.Details.ParentModel
+	opts.Messages = append(opts.Messages, info.Messages...)
+
+	if interactive {
+		return generateInteractive(cmd, opts)
+	}
+	return generate(cmd, opts)
 }
 }
 
 
 func errFromUnknownKey(unknownKeyErr error) error {
 func errFromUnknownKey(unknownKeyErr error) error {
@@ -579,10 +580,6 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
 		return err
 		return err
 	}
 	}
 
 
-	if len(args) != 1 {
-		return errors.New("missing model name")
-	}
-
 	license, errLicense := cmd.Flags().GetBool("license")
 	license, errLicense := cmd.Flags().GetBool("license")
 	modelfile, errModelfile := cmd.Flags().GetBool("modelfile")
 	modelfile, errModelfile := cmd.Flags().GetBool("modelfile")
 	parameters, errParams := cmd.Flags().GetBool("parameters")
 	parameters, errParams := cmd.Flags().GetBool("parameters")
@@ -625,8 +622,6 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
 
 
 	if flagsSet > 1 {
 	if flagsSet > 1 {
 		return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
 		return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
-	} else if flagsSet == 0 {
-		return errors.New("one of '--license', '--modelfile', '--parameters', '--system', or '--template' must be specified")
 	}
 	}
 
 
 	req := api.ShowRequest{Name: args[0]}
 	req := api.ShowRequest{Name: args[0]}
@@ -635,22 +630,141 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
 		return err
 		return err
 	}
 	}
 
 
-	switch showType {
-	case "license":
-		fmt.Println(resp.License)
-	case "modelfile":
-		fmt.Println(resp.Modelfile)
-	case "parameters":
-		fmt.Println(resp.Parameters)
-	case "system":
-		fmt.Println(resp.System)
-	case "template":
-		fmt.Println(resp.Template)
+	if flagsSet == 1 {
+		switch showType {
+		case "license":
+			fmt.Println(resp.License)
+		case "modelfile":
+			fmt.Println(resp.Modelfile)
+		case "parameters":
+			fmt.Println(resp.Parameters)
+		case "system":
+			fmt.Println(resp.System)
+		case "template":
+			fmt.Println(resp.Template)
+		}
+
+		return nil
 	}
 	}
 
 
+	showInfo(resp)
+
 	return nil
 	return nil
 }
 }
 
 
+func showInfo(resp *api.ShowResponse) {
+	arch := resp.ModelInfo["general.architecture"].(string)
+
+	modelData := [][]string{
+		{"arch", arch},
+		{"parameters", resp.Details.ParameterSize},
+		{"quantization", resp.Details.QuantizationLevel},
+		{"context length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))},
+		{"embedding length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64))},
+	}
+
+	mainTableData := [][]string{
+		{"Model"},
+		{renderSubTable(modelData, false)},
+	}
+
+	if resp.ProjectorInfo != nil {
+		projectorData := [][]string{
+			{"arch", "clip"},
+			{"parameters", format.HumanNumber(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))},
+		}
+
+		if projectorType, ok := resp.ProjectorInfo["clip.projector_type"]; ok {
+			projectorData = append(projectorData, []string{"projector type", projectorType.(string)})
+		}
+
+		projectorData = append(projectorData,
+			[]string{"embedding length", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.embedding_length"].(float64))},
+			[]string{"projection dimensionality", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.projection_dim"].(float64))},
+		)
+
+		mainTableData = append(mainTableData,
+			[]string{"Projector"},
+			[]string{renderSubTable(projectorData, false)},
+		)
+	}
+
+	if resp.Parameters != "" {
+		mainTableData = append(mainTableData, []string{"Parameters"}, []string{formatParams(resp.Parameters)})
+	}
+
+	if resp.System != "" {
+		mainTableData = append(mainTableData, []string{"System"}, []string{renderSubTable(twoLines(resp.System), true)})
+	}
+
+	if resp.License != "" {
+		mainTableData = append(mainTableData, []string{"License"}, []string{renderSubTable(twoLines(resp.License), true)})
+	}
+
+	table := tablewriter.NewWriter(os.Stdout)
+	table.SetAutoWrapText(false)
+	table.SetBorder(false)
+	table.SetAlignment(tablewriter.ALIGN_LEFT)
+
+	for _, v := range mainTableData {
+		table.Append(v)
+	}
+
+	table.Render()
+}
+
+func renderSubTable(data [][]string, file bool) string {
+	var buf bytes.Buffer
+	table := tablewriter.NewWriter(&buf)
+	table.SetAutoWrapText(!file)
+	table.SetBorder(false)
+	table.SetNoWhiteSpace(true)
+	table.SetTablePadding("\t")
+	table.SetAlignment(tablewriter.ALIGN_LEFT)
+
+	for _, v := range data {
+		table.Append(v)
+	}
+
+	table.Render()
+
+	renderedTable := buf.String()
+	lines := strings.Split(renderedTable, "\n")
+	for i, line := range lines {
+		lines[i] = "\t" + line
+	}
+
+	return strings.Join(lines, "\n")
+}
+
+func twoLines(s string) [][]string {
+	lines := strings.Split(s, "\n")
+	res := [][]string{}
+
+	count := 0
+	for _, line := range lines {
+		line = strings.TrimSpace(line)
+		if line != "" {
+			count++
+			res = append(res, []string{line})
+			if count == 2 {
+				return res
+			}
+		}
+	}
+	return res
+}
+
+func formatParams(s string) string {
+	lines := strings.Split(s, "\n")
+	table := [][]string{}
+
+	for _, line := range lines {
+		table = append(table, strings.Fields(line))
+	}
+	return renderSubTable(table, false)
+}
+
 func CopyHandler(cmd *cobra.Command, args []string) error {
 func CopyHandler(cmd *cobra.Command, args []string) error {
 	client, err := api.ClientFromEnvironment()
 	client, err := api.ClientFromEnvironment()
 	if err != nil {
 	if err != nil {

+ 14 - 47
cmd/interactive.go

@@ -31,65 +31,40 @@ const (
 )
 )
 
 
 func loadModel(cmd *cobra.Command, opts *runOptions) error {
 func loadModel(cmd *cobra.Command, opts *runOptions) error {
-	client, err := api.ClientFromEnvironment()
-	if err != nil {
-		return err
-	}
-
 	p := progress.NewProgress(os.Stderr)
 	p := progress.NewProgress(os.Stderr)
 	defer p.StopAndClear()
 	defer p.StopAndClear()
 
 
 	spinner := progress.NewSpinner("")
 	spinner := progress.NewSpinner("")
 	p.Add("", spinner)
 	p.Add("", spinner)
 
 
-	showReq := api.ShowRequest{Name: opts.Model}
-	showResp, err := client.Show(cmd.Context(), &showReq)
+	client, err := api.ClientFromEnvironment()
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	opts.MultiModal = slices.Contains(showResp.Details.Families, "clip")
-	opts.ParentModel = showResp.Details.ParentModel
-
-	if len(showResp.Messages) > 0 {
-		opts.Messages = append(opts.Messages, showResp.Messages...)
-	}
 
 
 	chatReq := &api.ChatRequest{
 	chatReq := &api.ChatRequest{
-		Model:    opts.Model,
-		Messages: []api.Message{},
+		Model:     opts.Model,
+		KeepAlive: opts.KeepAlive,
 	}
 	}
 
 
-	if opts.KeepAlive != nil {
-		chatReq.KeepAlive = opts.KeepAlive
-	}
-
-	err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
+	return client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
 		p.StopAndClear()
 		p.StopAndClear()
-		if len(opts.Messages) > 0 {
-			for _, msg := range opts.Messages {
-				switch msg.Role {
-				case "user":
-					fmt.Printf(">>> %s\n", msg.Content)
-				case "assistant":
-					state := &displayResponseState{}
-					displayResponse(msg.Content, opts.WordWrap, state)
-					fmt.Println()
-					fmt.Println()
-				}
+		for _, msg := range opts.Messages {
+			switch msg.Role {
+			case "user":
+				fmt.Printf(">>> %s\n", msg.Content)
+			case "assistant":
+				state := &displayResponseState{}
+				displayResponse(msg.Content, opts.WordWrap, state)
+				fmt.Println()
+				fmt.Println()
 			}
 			}
 		}
 		}
 		return nil
 		return nil
 	})
 	})
-	if err != nil {
-		return err
-	}
-
-	return nil
 }
 }
 
 
 func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 func generateInteractive(cmd *cobra.Command, opts runOptions) error {
-	opts.Messages = make([]api.Message, 0)
-
 	err := loadModel(cmd, &opts)
 	err := loadModel(cmd, &opts)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -429,15 +404,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 
 
 				switch args[1] {
 				switch args[1] {
 				case "info":
 				case "info":
-					fmt.Println("Model details:")
-					if len(resp.Details.Families) > 0 {
-						fmt.Printf("Family              %s\n", strings.Join(resp.Details.Families, ", "))
-					} else if resp.Details.Family != "" {
-						fmt.Printf("Family              %s\n", resp.Details.Family)
-					}
-					fmt.Printf("Parameter Size      %s\n", resp.Details.ParameterSize)
-					fmt.Printf("Quantization Level  %s\n", resp.Details.QuantizationLevel)
-					fmt.Println("")
+					showInfo(resp)
 				case "license":
 				case "license":
 					if resp.License == "" {
 					if resp.License == "" {
 						fmt.Println("No license was specified for this model.")
 						fmt.Println("No license was specified for this model.")

+ 33 - 6
docs/api.md

@@ -26,7 +26,7 @@ All durations are returned in nanoseconds.
 
 
 ### Streaming responses
 ### Streaming responses
 
 
-Certain endpoints stream responses as JSON objects and can optional return non-streamed responses.
+Certain endpoints stream responses as JSON objects. Streaming can be disabled by providing `{"stream": false}` for these endpoints.
 
 
 ## Generate a completion
 ## Generate a completion
 
 
@@ -777,11 +777,12 @@ A single JSON object will be returned.
 POST /api/show
 POST /api/show
 ```
 ```
 
 
-Show information about a model including details, modelfile, template, parameters, license, and system prompt.
+Show information about a model including details, modelfile, template, parameters, license, system prompt.
 
 
 ### Parameters
 ### Parameters
 
 
 - `name`: name of the model to show
 - `name`: name of the model to show
+- `verbose`: (optional) if set to `true`, returns full data for verbose response fields
 
 
 ### Examples
 ### Examples
 
 
@@ -798,14 +799,40 @@ curl http://localhost:11434/api/show -d '{
 ```json
 ```json
 {
 {
   "modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llava:latest\n\nFROM /Users/matt/.ollama/models/blobs/sha256:200765e1283640ffbd013184bf496e261032fa75b99498a9613be4e94d63ad52\nTEMPLATE \"\"\"{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: \"\"\"\nPARAMETER num_ctx 4096\nPARAMETER stop \"\u003c/s\u003e\"\nPARAMETER stop \"USER:\"\nPARAMETER stop \"ASSISTANT:\"",
   "modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llava:latest\n\nFROM /Users/matt/.ollama/models/blobs/sha256:200765e1283640ffbd013184bf496e261032fa75b99498a9613be4e94d63ad52\nTEMPLATE \"\"\"{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: \"\"\"\nPARAMETER num_ctx 4096\nPARAMETER stop \"\u003c/s\u003e\"\nPARAMETER stop \"USER:\"\nPARAMETER stop \"ASSISTANT:\"",
-  "parameters": "num_ctx                        4096\nstop                           \u003c/s\u003e\nstop                           USER:\nstop                           ASSISTANT:",
-  "template": "{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: ",
+  "parameters": "num_keep                       24\nstop                           \"<|start_header_id|>\"\nstop                           \"<|end_header_id|>\"\nstop                           \"<|eot_id|>\"",
+  "template": "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>",
   "details": {
   "details": {
+    "parent_model": "",
     "format": "gguf",
     "format": "gguf",
     "family": "llama",
     "family": "llama",
-    "families": ["llama", "clip"],
-    "parameter_size": "7B",
+    "families": [
+      "llama"
+    ],
+    "parameter_size": "8.0B",
     "quantization_level": "Q4_0"
     "quantization_level": "Q4_0"
+  },
+  "model_info": {
+    "general.architecture": "llama",
+    "general.file_type": 2,
+    "general.parameter_count": 8030261248,
+    "general.quantization_version": 2,
+    "llama.attention.head_count": 32,
+    "llama.attention.head_count_kv": 8,
+    "llama.attention.layer_norm_rms_epsilon": 0.00001,
+    "llama.block_count": 32,
+    "llama.context_length": 8192,
+    "llama.embedding_length": 4096,
+    "llama.feed_forward_length": 14336,
+    "llama.rope.dimension_count": 128,
+    "llama.rope.freq_base": 500000,
+    "llama.vocab_size": 128256,
+    "tokenizer.ggml.bos_token_id": 128000,
+    "tokenizer.ggml.eos_token_id": 128009,
+    "tokenizer.ggml.merges": [],            // populates if `verbose=true`
+    "tokenizer.ggml.model": "gpt2",
+    "tokenizer.ggml.pre": "llama-bpe",
+    "tokenizer.ggml.token_type": [],        // populates if `verbose=true`
+    "tokenizer.ggml.tokens": []             // populates if `verbose=true`
   }
   }
 }
 }
 ```
 ```

+ 16 - 0
docs/faq.md

@@ -257,3 +257,19 @@ If you wish to override the `OLLAMA_KEEP_ALIVE` setting, use the `keep_alive` AP
 ## How do I manage the maximum number of requests the Ollama server can queue?
 ## How do I manage the maximum number of requests the Ollama server can queue?
 
 
 If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded.  You can adjust how many requests may be queue by setting `OLLAMA_MAX_QUEUE`.
 If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded.  You can adjust how many requests may be queue by setting `OLLAMA_MAX_QUEUE`.
+
+## How does Ollama handle concurrent requests?
+
+Ollama supports two levels of concurrent processing.  If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time.  For a given model, if there is sufficient available memory when the model is loaded, it is configured to allow parallel request processing.
+
+If there is insufficient available memory to load a new model request while one or more models are already loaded, all new requests will be queued until the new model can be loaded.  As prior models become idle, one or more will be unloaded to make room for the new model.  Queued requests will be processed in order.  When using GPU inference new models must be able to completely fit in VRAM to allow concurrent model loads.
+
+Parallel request processing for a given model results in increasing the context size by the number of parallel requests.  For example, a 2K context with 4 parallel requests will result in an 8K context and additional memory allocation.
+
+The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
+
+- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory.  The default is 3 * the number of GPUs or 3 for CPU inference.
+- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time.  The default will auto-select either 4 or 1 based on available memory.
+- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
+
+Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting.  Once ROCm v6 is available, Windows Radeon will follow the defaults above.  You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.

+ 1 - 1
docs/gpu.md

@@ -18,7 +18,7 @@ Check your compute compatibility to see if your card is supported:
 |                    | Quadro              | `RTX 8000` `RTX 6000` `RTX 5000` `RTX 4000`                                                                 |
 |                    | Quadro              | `RTX 8000` `RTX 6000` `RTX 5000` `RTX 4000`                                                                 |
 | 7.0                | NVIDIA              | `TITAN V` `V100` `Quadro GV100`                                                                             |
 | 7.0                | NVIDIA              | `TITAN V` `V100` `Quadro GV100`                                                                             |
 | 6.1                | NVIDIA TITAN        | `TITAN Xp` `TITAN X`                                                                                        |
 | 6.1                | NVIDIA TITAN        | `TITAN Xp` `TITAN X`                                                                                        |
-|                    | GeForce GTX         | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050`                                     |
+|                    | GeForce GTX         | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050 Ti` `GTX 1050`                       |
 |                    | Quadro              | `P6000` `P5200` `P4200` `P3200` `P5000` `P4000` `P3000` `P2200` `P2000` `P1000` `P620` `P600` `P500` `P520` |
 |                    | Quadro              | `P6000` `P5200` `P4200` `P3200` `P5000` `P4000` `P3000` `P2200` `P2000` `P1000` `P620` `P600` `P500` `P520` |
 |                    | Tesla               | `P40` `P4`                                                                                                  |
 |                    | Tesla               | `P40` `P4`                                                                                                  |
 | 6.0                | NVIDIA              | `Tesla P100` `Quadro GP100`                                                                                 |
 | 6.0                | NVIDIA              | `Tesla P100` `Quadro GP100`                                                                                 |

+ 1 - 1
docs/openai.md

@@ -65,6 +65,7 @@ curl http://localhost:11434/v1/chat/completions \
             }
             }
         ]
         ]
     }'
     }'
+
 ```
 ```
 
 
 ## Endpoints
 ## Endpoints
@@ -104,7 +105,6 @@ curl http://localhost:11434/v1/chat/completions \
 
 
 #### Notes
 #### Notes
 
 
-- `finish_reason` will always be `stop`
 - `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
 - `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
 
 
 ## Models
 ## Models

+ 1 - 1
docs/troubleshooting.md

@@ -22,7 +22,7 @@ docker logs <container-name>
 If manually running `ollama serve` in a terminal, the logs will be on that terminal.
 If manually running `ollama serve` in a terminal, the logs will be on that terminal.
 
 
 When you run Ollama on **Windows**, there are a few different locations. You can view them in the explorer window by hitting `<cmd>+R` and type in:
 When you run Ollama on **Windows**, there are a few different locations. You can view them in the explorer window by hitting `<cmd>+R` and type in:
-- `explorer %LOCALAPPDATA%\Ollama` to view logs
+- `explorer %LOCALAPPDATA%\Ollama` to view logs.  The most recent server logs will be in `server.log` and older logs will be in `server-#.log` 
 - `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
 - `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
 - `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
 - `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
 - `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
 - `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories

+ 2 - 2
docs/windows.md

@@ -39,8 +39,8 @@ server.
 Ollama on Windows stores files in a few different locations.  You can view them in
 Ollama on Windows stores files in a few different locations.  You can view them in
 the explorer window by hitting `<cmd>+R` and type in:
 the explorer window by hitting `<cmd>+R` and type in:
 - `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
 - `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
-    - *app.log* contains logs from the GUI application
-    - *server.log* contains the server logs
+    - *app.log* contains most resent logs from the GUI application
+    - *server.log* contains the most recent server logs
     - *upgrade.log* contains log output for upgrades
     - *upgrade.log* contains log output for upgrades
 - `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
 - `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
 - `explorer %HOMEPATH%\.ollama` contains models and configuration
 - `explorer %HOMEPATH%\.ollama` contains models and configuration

+ 8 - 8
envconfig/config.go

@@ -85,13 +85,13 @@ func AsMap() map[string]EnvVar {
 		"OLLAMA_HOST":              {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"},
 		"OLLAMA_HOST":              {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"},
 		"OLLAMA_KEEP_ALIVE":        {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
 		"OLLAMA_KEEP_ALIVE":        {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
 		"OLLAMA_LLM_LIBRARY":       {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
 		"OLLAMA_LLM_LIBRARY":       {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
-		"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models (default 1)"},
+		"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
 		"OLLAMA_MAX_QUEUE":         {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
 		"OLLAMA_MAX_QUEUE":         {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
 		"OLLAMA_MAX_VRAM":          {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
 		"OLLAMA_MAX_VRAM":          {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
 		"OLLAMA_MODELS":            {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
 		"OLLAMA_MODELS":            {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
 		"OLLAMA_NOHISTORY":         {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
 		"OLLAMA_NOHISTORY":         {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
 		"OLLAMA_NOPRUNE":           {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
 		"OLLAMA_NOPRUNE":           {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
-		"OLLAMA_NUM_PARALLEL":      {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default 1)"},
+		"OLLAMA_NUM_PARALLEL":      {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"},
 		"OLLAMA_ORIGINS":           {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
 		"OLLAMA_ORIGINS":           {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
 		"OLLAMA_RUNNERS_DIR":       {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
 		"OLLAMA_RUNNERS_DIR":       {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
 		"OLLAMA_SCHED_SPREAD":      {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
 		"OLLAMA_SCHED_SPREAD":      {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
@@ -129,8 +129,8 @@ func clean(key string) string {
 
 
 func init() {
 func init() {
 	// default values
 	// default values
-	NumParallel = 1
-	MaxRunners = 1
+	NumParallel = 0 // Autoselect
+	MaxRunners = 0  // Autoselect
 	MaxQueuedRequests = 512
 	MaxQueuedRequests = 512
 
 
 	LoadConfig()
 	LoadConfig()
@@ -205,8 +205,8 @@ func LoadConfig() {
 
 
 	if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {
 	if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {
 		val, err := strconv.Atoi(onp)
 		val, err := strconv.Atoi(onp)
-		if err != nil || val <= 0 {
-			slog.Error("invalid setting must be greater than zero", "OLLAMA_NUM_PARALLEL", onp, "error", err)
+		if err != nil {
+			slog.Error("invalid setting, ignoring", "OLLAMA_NUM_PARALLEL", onp, "error", err)
 		} else {
 		} else {
 			NumParallel = val
 			NumParallel = val
 		}
 		}
@@ -251,7 +251,7 @@ func LoadConfig() {
 	if maxRunners != "" {
 	if maxRunners != "" {
 		m, err := strconv.Atoi(maxRunners)
 		m, err := strconv.Atoi(maxRunners)
 		if err != nil {
 		if err != nil {
-			slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
+			slog.Error("invalid setting, ignoring", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
 		} else {
 		} else {
 			MaxRunners = m
 			MaxRunners = m
 		}
 		}
@@ -260,7 +260,7 @@ func LoadConfig() {
 	if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" {
 	if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" {
 		p, err := strconv.Atoi(onp)
 		p, err := strconv.Atoi(onp)
 		if err != nil || p <= 0 {
 		if err != nil || p <= 0 {
-			slog.Error("invalid setting", "OLLAMA_MAX_QUEUE", onp, "error", err)
+			slog.Error("invalid setting, ignoring", "OLLAMA_MAX_QUEUE", onp, "error", err)
 		} else {
 		} else {
 			MaxQueuedRequests = p
 			MaxQueuedRequests = p
 		}
 		}

+ 3 - 2
gpu/amd_windows.go

@@ -115,8 +115,6 @@ func AMDGetGPUInfo() []RocmGPUInfo {
 			continue
 			continue
 		}
 		}
 
 
-		// TODO revisit this once ROCm v6 is available on windows.
-		// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
 		slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
 		slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
 		slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
 		slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
 		gpuInfo := RocmGPUInfo{
 		gpuInfo := RocmGPUInfo{
@@ -126,6 +124,9 @@ func AMDGetGPUInfo() []RocmGPUInfo {
 					TotalMemory: totalMemory,
 					TotalMemory: totalMemory,
 					FreeMemory:  freeMemory,
 					FreeMemory:  freeMemory,
 				},
 				},
+				// Free memory reporting on Windows is not reliable until we bump to ROCm v6.2
+				UnreliableFreeMemory: true,
+
 				ID:             strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
 				ID:             strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
 				DependencyPath: libDir,
 				DependencyPath: libDir,
 				MinimumMemory:  rocmMinimumMemory,
 				MinimumMemory:  rocmMinimumMemory,

+ 19 - 12
gpu/assets.go

@@ -77,20 +77,27 @@ func cleanupTmpDirs() {
 			continue
 			continue
 		}
 		}
 		raw, err := os.ReadFile(filepath.Join(d, "ollama.pid"))
 		raw, err := os.ReadFile(filepath.Join(d, "ollama.pid"))
-		if err == nil {
-			pid, err := strconv.Atoi(string(raw))
-			if err == nil {
-				if proc, err := os.FindProcess(pid); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
-					// Another running ollama, ignore this tmpdir
-					continue
-				}
-			}
-		} else {
-			slog.Debug("failed to open ollama.pid", "path", d, "error", err)
+		if err != nil {
+			slog.Warn("failed to read ollama.pid", "path", d, "error", err)
+			// No pid, ignore this tmpdir
+			continue
 		}
 		}
-		err = os.RemoveAll(d)
+
+		pid, err := strconv.Atoi(string(raw))
 		if err != nil {
 		if err != nil {
-			slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err)
+			slog.Warn("failed to parse pid", "path", d, "error", err)
+			continue
+		}
+
+		proc, err := os.FindProcess(pid)
+		if err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
+			slog.Warn("found running ollama", "pid", pid, "path", d)
+			// Another running ollama, ignore this tmpdir
+			continue
+		}
+
+		if err := os.Remove(d); err != nil {
+			slog.Warn("unable to cleanup stale tmpdir", "path", d, "error", err)
 		}
 		}
 	}
 	}
 }
 }

+ 8 - 2
gpu/gpu.go

@@ -231,7 +231,7 @@ func GetGPUInfo() GpuInfoList {
 		// On windows we bundle the nvidia library one level above the runner dir
 		// On windows we bundle the nvidia library one level above the runner dir
 		depPath := ""
 		depPath := ""
 		if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
 		if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
-			depPath = filepath.Dir(envconfig.RunnersDir)
+			depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir), "cuda")
 		}
 		}
 
 
 		// Load ALL libraries
 		// Load ALL libraries
@@ -282,6 +282,12 @@ func GetGPUInfo() GpuInfoList {
 		// Intel
 		// Intel
 		if envconfig.IntelGpu {
 		if envconfig.IntelGpu {
 			oHandles = initOneAPIHandles()
 			oHandles = initOneAPIHandles()
+			// On windows we bundle the oneapi library one level above the runner dir
+			depPath = ""
+			if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
+				depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir), "oneapi")
+			}
+
 			for d := range oHandles.oneapi.num_drivers {
 			for d := range oHandles.oneapi.num_drivers {
 				if oHandles.oneapi == nil {
 				if oHandles.oneapi == nil {
 					// shouldn't happen
 					// shouldn't happen
@@ -306,7 +312,7 @@ func GetGPUInfo() GpuInfoList {
 					gpuInfo.FreeMemory = uint64(memInfo.free)
 					gpuInfo.FreeMemory = uint64(memInfo.free)
 					gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
 					gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
 					gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
 					gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
-					// TODO dependency path?
+					gpuInfo.DependencyPath = depPath
 					oneapiGPUs = append(oneapiGPUs, gpuInfo)
 					oneapiGPUs = append(oneapiGPUs, gpuInfo)
 				}
 				}
 			}
 			}

+ 1 - 1
gpu/gpu_info_cudart.c

@@ -40,7 +40,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
 
 
   for (i = 0; l[i].s != NULL; i++) {
   for (i = 0; l[i].s != NULL; i++) {
     *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
     *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
-    if (!l[i].p) {
+    if (!*(l[i].p)) {
       char *msg = LOAD_ERR();
       char *msg = LOAD_ERR();
       LOG(resp->ch.verbose, "dlerr: %s\n", msg);
       LOG(resp->ch.verbose, "dlerr: %s\n", msg);
       UNLOAD_LIBRARY(resp->ch.handle);
       UNLOAD_LIBRARY(resp->ch.handle);

+ 1 - 1
gpu/gpu_info_nvcuda.c

@@ -43,7 +43,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
 
 
   for (i = 0; l[i].s != NULL; i++) {
   for (i = 0; l[i].s != NULL; i++) {
     *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
     *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
-    if (!*l[i].p) {
+    if (!*(l[i].p)) {
       char *msg = LOAD_ERR();
       char *msg = LOAD_ERR();
       LOG(resp->ch.verbose, "dlerr: %s\n", msg);
       LOG(resp->ch.verbose, "dlerr: %s\n", msg);
       UNLOAD_LIBRARY(resp->ch.handle);
       UNLOAD_LIBRARY(resp->ch.handle);

+ 1 - 1
gpu/gpu_info_nvml.c

@@ -42,7 +42,7 @@ void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
     // LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
     // LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
 
 
     *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
     *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
-    if (!l[i].p) {
+    if (!*(l[i].p)) {
       resp->ch.handle = NULL;
       resp->ch.handle = NULL;
       char *msg = LOAD_ERR();
       char *msg = LOAD_ERR();
       LOG(resp->ch.verbose, "dlerr: %s\n", msg);
       LOG(resp->ch.verbose, "dlerr: %s\n", msg);

+ 1 - 1
gpu/gpu_info_oneapi.c

@@ -50,7 +50,7 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp) {
     LOG(resp->oh.verbose, "dlsym: %s\n", l[i].s);
     LOG(resp->oh.verbose, "dlsym: %s\n", l[i].s);
 
 
     *l[i].p = LOAD_SYMBOL(resp->oh.handle, l[i].s);
     *l[i].p = LOAD_SYMBOL(resp->oh.handle, l[i].s);
-    if (!l[i].p) {
+    if (!*(l[i].p)) {
       resp->oh.handle = NULL;
       resp->oh.handle = NULL;
       char *msg = LOAD_ERR();
       char *msg = LOAD_ERR();
       LOG(resp->oh.verbose, "dlerr: %s\n", msg);
       LOG(resp->oh.verbose, "dlerr: %s\n", msg);

+ 5 - 0
gpu/types.go

@@ -29,6 +29,11 @@ type GpuInfo struct {
 	// Extra environment variables specific to the GPU as list of [key,value]
 	// Extra environment variables specific to the GPU as list of [key,value]
 	EnvWorkarounds [][2]string `json:"envs,omitempty"`
 	EnvWorkarounds [][2]string `json:"envs,omitempty"`
 
 
+	// Set to true if we can NOT reliably discover FreeMemory.  A value of true indicates
+	// the FreeMemory is best effort, and may over or under report actual memory usage
+	// False indicates FreeMemory can generally be trusted on this GPU
+	UnreliableFreeMemory bool
+
 	// GPU information
 	// GPU information
 	ID      string `json:"gpu_id"`  // string to use for selection of this specific GPU
 	ID      string `json:"gpu_id"`  // string to use for selection of this specific GPU
 	Name    string `json:"name"`    // user friendly name if available
 	Name    string `json:"name"`    // user friendly name if available

+ 37 - 26
llm/ext_server/server.cpp

@@ -56,7 +56,6 @@ struct server_params {
     std::string hostname = "127.0.0.1";
     std::string hostname = "127.0.0.1";
     std::vector<std::string> api_keys;
     std::vector<std::string> api_keys;
     std::string public_path = "examples/server/public";
     std::string public_path = "examples/server/public";
-    std::string chat_template = "";
     int32_t port = 8080;
     int32_t port = 8080;
     int32_t read_timeout = 600;
     int32_t read_timeout = 600;
     int32_t write_timeout = 600;
     int32_t write_timeout = 600;
@@ -427,16 +426,6 @@ struct llama_server_context
         return true;
         return true;
     }
     }
 
 
-    void validate_model_chat_template(server_params & sparams) {
-        llama_chat_message chat[] = {{"user", "test"}};
-        std::vector<char> buf(1);
-        int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size());
-        if (res < 0) {
-            LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
-            sparams.chat_template = "chatml";
-        }
-    }
-
     void initialize() {
     void initialize() {
         // create slots
         // create slots
         all_slots_are_idle = true;
         all_slots_are_idle = true;
@@ -1661,26 +1650,41 @@ struct llama_server_context
                     }
                     }
                     slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
                     slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
 
 
+                    char buf[256];
+                    llama_model_meta_val_str(model, "general.architecture", buf, 256);
+                    bool gemma2 = strcmp(buf, "gemma2") == 0;
+
+                    int32_t truncate_at = slot.n_ctx;
+
+                    // truncate at 2/3 of the context length for gemma2 models
+                    // as they do not support context shifts (from the sliding window implementation).
+                    // this way, prompts that almost fit the context length can still generate a full
+                    // response without a sudden stop from hitting the context limit
+                    if (gemma2) {
+                        truncate_at = 2 * slot.n_ctx / 3;
+                    }
+
                     // if input prompt is too big, truncate it, if group attention self-extend is disabled
                     // if input prompt is too big, truncate it, if group attention self-extend is disabled
-                    if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
+                    if (slot.ga_n == 1 && slot.n_prompt_tokens >= truncate_at)
                     {
                     {
                         const int n_left = slot.n_ctx - slot.params.n_keep;
                         const int n_left = slot.n_ctx - slot.params.n_keep;
-                        const int n_block_size = n_left / 2;
-                        const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
+                        const int n_shift = n_left / 2;
+                        const int n_erase = slot.n_prompt_tokens - slot.params.n_keep - n_shift;
 
 
                         std::vector<llama_token> new_tokens(
                         std::vector<llama_token> new_tokens(
                             prompt_tokens.begin(),
                             prompt_tokens.begin(),
                             prompt_tokens.begin() + slot.params.n_keep);
                             prompt_tokens.begin() + slot.params.n_keep);
                         new_tokens.insert(
                         new_tokens.insert(
                             new_tokens.end(),
                             new_tokens.end(),
-                            prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
+                            prompt_tokens.begin() + slot.params.n_keep + n_erase,
                             prompt_tokens.end());
                             prompt_tokens.end());
 
 
-                        LOG_VERBOSE("input truncated", {
-                            {"n_ctx",      slot.n_ctx},
-                            {"n_keep",     slot.params.n_keep},
-                            {"n_left",     n_left},
-                            {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
+                        LOG_INFO("input truncated", {
+                            {"n_ctx",        slot.n_ctx},
+                            {"n_keep",       slot.params.n_keep},
+                            {"n_left",       n_left},
+                            {"n_shift",      n_shift},
+                            {"n_erase",      n_erase},
                         });
                         });
                         slot.truncated = true;
                         slot.truncated = true;
                         prompt_tokens = new_tokens;
                         prompt_tokens = new_tokens;
@@ -1689,6 +1693,19 @@ struct llama_server_context
                         GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
                         GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
                     }
                     }
 
 
+                    // Models with sliding window attention do not work with context shifts, so
+                    // limit their prediction to the context length
+                    if (gemma2) {
+                        int32_t limit = slot.n_ctx - slot.n_prompt_tokens;
+                        slot.n_predict = limit;
+                        slot.params.n_predict = limit;
+                        LOG_INFO("model does not support sliding window, limiting generation", {
+                            {"n_ctx", slot.n_ctx},
+                            {"n_prompt_tokens", slot.n_prompt_tokens},
+                            {"n_predict", slot.n_predict}
+                        });
+                    }
+
                     if (!slot.params.cache_prompt)
                     if (!slot.params.cache_prompt)
                     {
                     {
                         llama_sampling_reset(slot.ctx_sampling);
                         llama_sampling_reset(slot.ctx_sampling);
@@ -2535,7 +2552,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
                 invalid_param = true;
                 invalid_param = true;
                 break;
                 break;
             }
             }
-            sparams.chat_template = argv[i];
         }
         }
         else if (arg == "--override-kv")
         else if (arg == "--override-kv")
         {
         {
@@ -3008,11 +3024,6 @@ int main(int argc, char **argv) {
     }
     }
     const auto model_meta = llama.model_meta();
     const auto model_meta = llama.model_meta();
 
 
-    if (sparams.chat_template.empty()) { // custom chat template is not supplied
-        // check if the template comes with the model is supported by us
-        llama.validate_model_chat_template(sparams);
-    }
-
     // Middleware for API key validation
     // Middleware for API key validation
     auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
     auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
         // If API key is not set, skip validation
         // If API key is not set, skip validation

+ 18 - 14
llm/generate/gen_windows.ps1

@@ -295,10 +295,12 @@ function build_cuda() {
         sign
         sign
         install
         install
 
 
-        write-host "copying CUDA dependencies to ${script:SRC_DIR}\dist\windows-${script:ARCH}\"
-        cp "${script:CUDA_LIB_DIR}\cudart64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
-        cp "${script:CUDA_LIB_DIR}\cublas64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
-        cp "${script:CUDA_LIB_DIR}\cublasLt64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
+        rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\"
+        md "${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\" -ea 0 > $null
+        write-host "copying CUDA dependencies to ${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\"
+        cp "${script:CUDA_LIB_DIR}\cudart64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\"
+        cp "${script:CUDA_LIB_DIR}\cublas64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\"
+        cp "${script:CUDA_LIB_DIR}\cublasLt64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\"
     } else {
     } else {
         write-host "Skipping CUDA generation step"
         write-host "Skipping CUDA generation step"
     }
     }
@@ -332,16 +334,18 @@ function build_oneapi() {
     sign
     sign
     install
     install
 
 
-    cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libirngmd.dll" "${script:distDir}"
-    cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libmmd.dll" "${script:distDir}"
-    cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_level_zero.dll" "${script:distDir}"
-    cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_unified_runtime.dll" "${script:distDir}"
-    cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_win_proxy_loader.dll" "${script:distDir}"
-    cp "${env:ONEAPI_ROOT}\compiler\latest\bin\svml_dispmd.dll" "${script:distDir}"
-    cp "${env:ONEAPI_ROOT}\compiler\latest\bin\sycl7.dll" "${script:distDir}"
-    cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_core.2.dll" "${script:distDir}"
-    cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_sycl_blas.4.dll" "${script:distDir}"
-    cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_tbb_thread.2.dll" "${script:distDir}"
+    rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
+    md "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\" -ea 0 > $null
+    cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libirngmd.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
+    cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libmmd.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
+    cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_level_zero.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
+    cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_unified_runtime.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
+    cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_win_proxy_loader.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
+    cp "${env:ONEAPI_ROOT}\compiler\latest\bin\svml_dispmd.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
+    cp "${env:ONEAPI_ROOT}\compiler\latest\bin\sycl7.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
+    cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_core.2.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
+    cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_sycl_blas.4.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
+    cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_tbb_thread.2.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
   } else {
   } else {
     Write-Host "Skipping oneAPI generation step"
     Write-Host "Skipping oneAPI generation step"
   }
   }

+ 11 - 2
llm/ggla.go

@@ -53,7 +53,7 @@ func (llm *ggla) Tensors() Tensors {
 	return llm.tensors
 	return llm.tensors
 }
 }
 
 
-func (llm *ggla) decode(rs io.ReadSeeker) error {
+func (llm *ggla) decode(rs io.ReadSeeker) (retErr error) {
 	var r uint32
 	var r uint32
 	if err := binary.Read(rs, binary.LittleEndian, &r); err != nil {
 	if err := binary.Read(rs, binary.LittleEndian, &r); err != nil {
 		return err
 		return err
@@ -69,9 +69,18 @@ func (llm *ggla) decode(rs io.ReadSeeker) error {
 	for {
 	for {
 		var dims uint32
 		var dims uint32
 		if err := binary.Read(rs, binary.LittleEndian, &dims); err != nil {
 		if err := binary.Read(rs, binary.LittleEndian, &dims); err != nil {
+			if errors.Is(err, io.EOF) {
+				return nil
+			}
 			return err
 			return err
 		}
 		}
 
 
+		defer func() {
+			if errors.Is(retErr, io.EOF) {
+				retErr = io.ErrUnexpectedEOF
+			}
+		}()
+
 		var namesize uint32
 		var namesize uint32
 		if err := binary.Read(rs, binary.LittleEndian, &namesize); err != nil {
 		if err := binary.Read(rs, binary.LittleEndian, &namesize); err != nil {
 			return err
 			return err
@@ -108,7 +117,7 @@ func (llm *ggla) decode(rs io.ReadSeeker) error {
 			return err
 			return err
 		}
 		}
 
 
-		if _, err := rs.Seek((offset+31)&-32, io.SeekStart); err != nil {
+		if _, err := rs.Seek((offset+31)&-32-offset, io.SeekCurrent); err != nil {
 			return err
 			return err
 		}
 		}
 
 

+ 71 - 14
llm/ggml.go

@@ -6,6 +6,8 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"strings"
 	"strings"
+
+	"github.com/ollama/ollama/util/bufioutil"
 )
 )
 
 
 type GGML struct {
 type GGML struct {
@@ -69,6 +71,30 @@ func (kv KV) HeadCountKV() uint64 {
 	return 1
 	return 1
 }
 }
 
 
+func (kv KV) EmbeddingHeadCount() uint64 {
+	if heads := kv.HeadCount(); heads > 0 {
+		return kv.EmbeddingLength() / kv.HeadCount()
+	}
+
+	return 0
+}
+
+func (kv KV) EmbeddingHeadCountK() uint64 {
+	if k := kv.u64(fmt.Sprintf("%s.attention.key_length", kv.Architecture())); k > 0 {
+		return k
+	}
+
+	return kv.EmbeddingHeadCount()
+}
+
+func (kv KV) EmbeddingHeadCountV() uint64 {
+	if v := kv.u64(fmt.Sprintf("%s.attention.value_length", kv.Architecture())); v > 0 {
+		return v
+	}
+
+	return kv.EmbeddingHeadCount()
+}
+
 func (kv KV) GQA() uint64 {
 func (kv KV) GQA() uint64 {
 	return kv.HeadCount() / kv.HeadCountKV()
 	return kv.HeadCount() / kv.HeadCountKV()
 }
 }
@@ -254,7 +280,18 @@ func DetectGGMLType(b []byte) string {
 	}
 	}
 }
 }
 
 
-func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
+// DecodeGGML decodes a GGML model from the given reader.
+//
+// It collects array values for arrays with a size less than or equal to
+// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
+// the maxArraySize is negative, all arrays are collected.
+func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
+	if maxArraySize == 0 {
+		maxArraySize = 1024
+	}
+
+	rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
+
 	var magic uint32
 	var magic uint32
 	if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
 	if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
 		return nil, 0, err
 		return nil, 0, err
@@ -267,17 +304,15 @@ func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
 	case FILE_MAGIC_GGLA:
 	case FILE_MAGIC_GGLA:
 		c = &containerGGLA{}
 		c = &containerGGLA{}
 	case FILE_MAGIC_GGUF_LE:
 	case FILE_MAGIC_GGUF_LE:
-		c = &containerGGUF{ByteOrder: binary.LittleEndian}
+		c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize}
 	case FILE_MAGIC_GGUF_BE:
 	case FILE_MAGIC_GGUF_BE:
-		c = &containerGGUF{ByteOrder: binary.BigEndian}
+		c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
 	default:
 	default:
 		return nil, 0, errors.New("invalid file magic")
 		return nil, 0, errors.New("invalid file magic")
 	}
 	}
 
 
 	model, err := c.Decode(rs)
 	model, err := c.Decode(rs)
-	if errors.Is(err, io.EOF) {
-		// noop
-	} else if err != nil {
+	if err != nil {
 		return nil, 0, err
 		return nil, 0, err
 	}
 	}
 
 
@@ -297,7 +332,10 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 	embedding := llm.KV().EmbeddingLength()
 	embedding := llm.KV().EmbeddingLength()
 	heads := llm.KV().HeadCount()
 	heads := llm.KV().HeadCount()
 	headsKV := llm.KV().HeadCountKV()
 	headsKV := llm.KV().HeadCountKV()
-	vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
+	vocab := uint64(llm.KV()["tokenizer.ggml.tokens"].(*array).size)
+
+	embeddingHeads := llm.KV().EmbeddingHeadCount()
+	embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
 
 
 	layers := llm.Tensors().Layers()
 	layers := llm.Tensors().Layers()
 
 
@@ -308,7 +346,7 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 		partialOffload = 4 * batch * embedding
 		partialOffload = 4 * batch * embedding
 		partialOffload += max(
 		partialOffload += max(
 			// 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()),
 			// 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()),
-			4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
+			4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
 		)
 		)
 
 
@@ -316,21 +354,30 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 			// mixtral 8x22b
 			// mixtral 8x22b
 			ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
 			ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
 			partialOffload = max(
 			partialOffload = max(
-				3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
-				4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch),
+				3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
+				4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
 			)
 			)
 		} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
 		} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
 			// mixtral 8x7b
 			// mixtral 8x7b
 			ffnGateWeight1 := ffnGateWeight.Shape[1]
 			ffnGateWeight1 := ffnGateWeight.Shape[1]
 			fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
 			fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
 			partialOffload = max(
 			partialOffload = max(
-				4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
+				4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
 				4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
 				4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
 			)
 			)
 		}
 		}
-	case "gemma":
-		fullOffload = 4 * batch * (embedding + vocab)
-		partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
+	case "gemma", "gemma2":
+		fullOffload = max(
+			4*batch*(embedding+vocab),
+			4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
+		)
+
+		partialOffload = max(
+			4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
+			4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
+				4*embeddingHeadsK*context*8+
+				embedding*embeddingHeadsK*heads*9/16,
+		)
 	case "command-r":
 	case "command-r":
 		fullOffload = max(
 		fullOffload = max(
 			4*batch*(embedding+vocab),
 			4*batch*(embedding+vocab),
@@ -367,6 +414,16 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 			4*batch*(vocab+2*embedding),
 			4*batch*(vocab+2*embedding),
 			fullOffload,
 			fullOffload,
 		)
 		)
+	case "deepseek2":
+		fullOffload = max(
+			4*batch*(3*embedding+vocab),
+			4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV),
+		)
+
+		partialOffload = max(
+			4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
+			4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
+		)
 	}
 	}
 
 
 	return
 	return

+ 1 - 0
llm/ggml_test.go

@@ -0,0 +1 @@
+package llm

+ 92 - 38
llm/gguf.go

@@ -3,11 +3,10 @@ package llm
 import (
 import (
 	"bytes"
 	"bytes"
 	"encoding/binary"
 	"encoding/binary"
+	"encoding/json"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"strings"
 	"strings"
-
-	"log/slog"
 )
 )
 
 
 type containerGGUF struct {
 type containerGGUF struct {
@@ -29,6 +28,12 @@ type containerGGUF struct {
 		NumTensor uint64
 		NumTensor uint64
 		NumKV     uint64
 		NumKV     uint64
 	}
 	}
+
+	maxArraySize int
+}
+
+func (c *containerGGUF) canCollectArray(size int) bool {
+	return c.maxArraySize < 0 || size <= c.maxArraySize
 }
 }
 
 
 func (c *containerGGUF) Name() string {
 func (c *containerGGUF) Name() string {
@@ -54,7 +59,6 @@ func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
 	}
 	}
 
 
 	model := newGGUF(c)
 	model := newGGUF(c)
-	slog.Debug(fmt.Sprintf("model = %#v", model))
 	if err := model.Decode(rs); err != nil {
 	if err := model.Decode(rs); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -85,6 +89,8 @@ type gguf struct {
 	tensors []*Tensor
 	tensors []*Tensor
 
 
 	parameters uint64
 	parameters uint64
+
+	scratch [16 << 10]byte
 }
 }
 
 
 func newGGUF(container *containerGGUF) *gguf {
 func newGGUF(container *containerGGUF) *gguf {
@@ -181,34 +187,34 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
 	}
 	}
 
 
 	// decode tensors
 	// decode tensors
-	for i := 0; uint64(i) < llm.numTensor(); i++ {
+	for range llm.numTensor() {
 		name, err := readGGUFString(llm, rs)
 		name, err := readGGUFString(llm, rs)
 		if err != nil {
 		if err != nil {
-			return err
+			return fmt.Errorf("failed to read tensor name: %w", err)
 		}
 		}
 
 
 		// dims is the number of dimensions in the tensor
 		// dims is the number of dimensions in the tensor
 		dims, err := readGGUF[uint32](llm, rs)
 		dims, err := readGGUF[uint32](llm, rs)
 		if err != nil {
 		if err != nil {
-			return err
+			return fmt.Errorf("failed to read tensor dimensions: %w", err)
 		}
 		}
 
 
 		shape := [4]uint64{1, 1, 1, 1}
 		shape := [4]uint64{1, 1, 1, 1}
 		for i := 0; uint32(i) < dims; i++ {
 		for i := 0; uint32(i) < dims; i++ {
 			shape[i], err = readGGUF[uint64](llm, rs)
 			shape[i], err = readGGUF[uint64](llm, rs)
 			if err != nil {
 			if err != nil {
-				return err
+				return fmt.Errorf("failed to read tensor shape: %w", err)
 			}
 			}
 		}
 		}
 
 
 		kind, err := readGGUF[uint32](llm, rs)
 		kind, err := readGGUF[uint32](llm, rs)
 		if err != nil {
 		if err != nil {
-			return err
+			return fmt.Errorf("failed to read tensor kind: %w", err)
 		}
 		}
 
 
 		offset, err := readGGUF[uint64](llm, rs)
 		offset, err := readGGUF[uint64](llm, rs)
 		if err != nil {
 		if err != nil {
-			return err
+			return fmt.Errorf("failed to read tensor offset: %w", err)
 		}
 		}
 
 
 		tensor := Tensor{
 		tensor := Tensor{
@@ -230,24 +236,19 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
 		alignment = 32
 		alignment = 32
 	}
 	}
 
 
-	offset, err := rs.Seek(0, io.SeekCurrent)
-	if err != nil {
-		return err
-	}
-
-	padding := llm.padding(offset, int64(alignment))
-	if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
-		return err
-	}
-
 	for _, tensor := range llm.tensors {
 	for _, tensor := range llm.tensors {
-		if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
-			return err
+		offset, err := rs.Seek(0, io.SeekCurrent)
+		if err != nil {
+			return fmt.Errorf("failed to get current offset: %w", err)
 		}
 		}
 
 
-		padding := llm.padding(int64(tensor.Size()), int64(alignment))
+		padding := llm.padding(offset, int64(alignment))
 		if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
 		if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
-			return err
+			return fmt.Errorf("failed to seek to init padding: %w", err)
+		}
+
+		if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
+			return fmt.Errorf("failed to seek to tensor: %w", err)
 		}
 		}
 	}
 	}
 
 
@@ -285,22 +286,48 @@ func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
 	return b.String(), nil
 	return b.String(), nil
 }
 }
 
 
+func discardGGUFString(llm *gguf, r io.Reader) error {
+	buf := llm.scratch[:8]
+	_, err := io.ReadFull(r, buf)
+	if err != nil {
+		return err
+	}
+
+	size := int(llm.ByteOrder.Uint64(buf))
+	for size > 0 {
+		n, err := r.Read(llm.scratch[:min(size, cap(llm.scratch))])
+		if err != nil {
+			return err
+		}
+		size -= n
+	}
+	return nil
+}
+
 func readGGUFString(llm *gguf, r io.Reader) (string, error) {
 func readGGUFString(llm *gguf, r io.Reader) (string, error) {
 	if llm.Version == 1 {
 	if llm.Version == 1 {
 		return readGGUFV1String(llm, r)
 		return readGGUFV1String(llm, r)
 	}
 	}
 
 
-	var length uint64
-	if err := binary.Read(r, llm.ByteOrder, &length); err != nil {
+	buf := llm.scratch[:8]
+	_, err := io.ReadFull(r, buf)
+	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
 
 
-	var b bytes.Buffer
-	if _, err := io.CopyN(&b, r, int64(length)); err != nil {
-		return "", err
+	length := int(llm.ByteOrder.Uint64(buf))
+	if length > len(llm.scratch) {
+		buf = make([]byte, length)
+	} else {
+		buf = llm.scratch[:length]
 	}
 	}
+	clear(buf)
 
 
-	return b.String(), nil
+	_, err = io.ReadFull(r, buf)
+	if err != nil {
+		return "", err
+	}
+	return string(buf), nil
 }
 }
 
 
 func writeGGUFString(llm *gguf, w io.Writer, s string) error {
 func writeGGUFString(llm *gguf, w io.Writer, s string) error {
@@ -316,7 +343,16 @@ func writeGGUFString(llm *gguf, w io.Writer, s string) error {
 	return err
 	return err
 }
 }
 
 
-func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) {
+type array struct {
+	size   int
+	values []any
+}
+
+func (a *array) MarshalJSON() ([]byte, error) {
+	return json.Marshal(a.values)
+}
+
+func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
 	t, err := readGGUF[uint32](llm, r)
 	t, err := readGGUF[uint32](llm, r)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -327,7 +363,12 @@ func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	for i := 0; uint32(i) < n; i++ {
+	a := &array{size: int(n)}
+	if llm.canCollectArray(int(n)) {
+		a.values = make([]any, 0, int(n))
+	}
+
+	for i := range n {
 		var e any
 		var e any
 		switch t {
 		switch t {
 		case ggufTypeUint8:
 		case ggufTypeUint8:
@@ -361,13 +402,15 @@ func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) {
 			return nil, err
 			return nil, err
 		}
 		}
 
 
-		a = append(a, e)
+		if a.values != nil {
+			a.values[i] = e
+		}
 	}
 	}
 
 
-	return
+	return a, nil
 }
 }
 
 
-func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
+func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
 	if llm.Version == 1 {
 	if llm.Version == 1 {
 		return readGGUFV1Array(llm, r)
 		return readGGUFV1Array(llm, r)
 	}
 	}
@@ -382,7 +425,12 @@ func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	for i := 0; uint64(i) < n; i++ {
+	a := &array{size: int(n)}
+	if llm.canCollectArray(int(n)) {
+		a.values = make([]any, int(n))
+	}
+
+	for i := range n {
 		var e any
 		var e any
 		switch t {
 		switch t {
 		case ggufTypeUint8:
 		case ggufTypeUint8:
@@ -408,7 +456,11 @@ func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
 		case ggufTypeBool:
 		case ggufTypeBool:
 			e, err = readGGUF[bool](llm, r)
 			e, err = readGGUF[bool](llm, r)
 		case ggufTypeString:
 		case ggufTypeString:
-			e, err = readGGUFString(llm, r)
+			if a.values != nil {
+				e, err = readGGUFString(llm, r)
+			} else {
+				err = discardGGUFString(llm, r)
+			}
 		default:
 		default:
 			return nil, fmt.Errorf("invalid array type: %d", t)
 			return nil, fmt.Errorf("invalid array type: %d", t)
 		}
 		}
@@ -416,10 +468,12 @@ func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
 			return nil, err
 			return nil, err
 		}
 		}
 
 
-		a = append(a, e)
+		if a.values != nil {
+			a.values[i] = e
+		}
 	}
 	}
 
 
-	return
+	return a, nil
 }
 }
 
 
 func writeGGUFArray[S ~[]E, E any](llm *gguf, w io.Writer, t uint32, s S) error {
 func writeGGUFArray[S ~[]E, E any](llm *gguf, w io.Writer, t uint32, s S) error {

+ 2 - 2
llm/memory.go

@@ -115,8 +115,8 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
 		slog.Warn("model missing blk.0 layer size")
 		slog.Warn("model missing blk.0 layer size")
 	}
 	}
 
 
-	// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
-	var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
+	// fp16 k,v = sizeof(float16) * n_ctx * n_layer * (n_embd_head_k + n_embd_head_v) * n_head_kv
+	var kv uint64 = 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * (ggml.KV().EmbeddingHeadCountK() + ggml.KV().EmbeddingHeadCountV()) * ggml.KV().HeadCountKV()
 
 
 	// KV is proportional to the number of layers
 	// KV is proportional to the number of layers
 	layerSize += kv / ggml.KV().BlockCount()
 	layerSize += kv / ggml.KV().BlockCount()

+ 11 - 8
llm/memory_test.go

@@ -22,13 +22,14 @@ func TestEstimateGPULayers(t *testing.T) {
 	defer f.Close()
 	defer f.Close()
 	gguf := NewGGUFV3(binary.LittleEndian)
 	gguf := NewGGUFV3(binary.LittleEndian)
 	inputLayerCount := 5
 	inputLayerCount := 5
+
 	tensors := []Tensor{
 	tensors := []Tensor{
-		{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
-		{Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
-		{Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
-		{Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
-		{Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
-		{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+		{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
+		{Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
+		{Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
+		{Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
+		{Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
+		{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
 	}
 	}
 	assert.Len(t, tensors, inputLayerCount+1)
 	assert.Len(t, tensors, inputLayerCount+1)
 	err = gguf.Encode(f, KV{
 	err = gguf.Encode(f, KV{
@@ -45,8 +46,10 @@ func TestEstimateGPULayers(t *testing.T) {
 	}, tensors)
 	}, tensors)
 	require.NoError(t, err)
 	require.NoError(t, err)
 
 
-	ggml, err := LoadModel(f.Name())
-	require.NoError(t, err)
+	ggml, err := LoadModel(f.Name(), 0)
+	if err != nil {
+		t.Fatal(err)
+	}
 
 
 	// Simple CPU scenario
 	// Simple CPU scenario
 	gpus := []gpu.GpuInfo{
 	gpus := []gpu.GpuInfo{

+ 305 - 0
llm/patches/07-gemma.diff

@@ -0,0 +1,305 @@
+From 5cadb45f39d001ffbad95b690d6cf0abcb4a6d96 Mon Sep 17 00:00:00 2001
+From: Ollama maintainers <hello@ollama.com>
+Date: Wed, 26 Jun 2024 16:18:09 -0700
+Subject: [PATCH] Architecture support
+
+---
+ llama.cpp | 194 +++++++++++++++++++++++++++++++++++++++++++++++++++++-
+ 1 file changed, 193 insertions(+), 1 deletion(-)
+
+diff --git a/llama.cpp b/llama.cpp
+index 61948751..3b4196f5 100644
+--- a/llama.cpp
++++ b/llama.cpp
+@@ -217,6 +217,7 @@ enum llm_arch {
+     LLM_ARCH_INTERNLM2,
+     LLM_ARCH_MINICPM,
+     LLM_ARCH_GEMMA,
++    LLM_ARCH_GEMMA2,
+     LLM_ARCH_STARCODER2,
+     LLM_ARCH_MAMBA,
+     LLM_ARCH_XVERSE,
+@@ -255,6 +256,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
+     { LLM_ARCH_INTERNLM2,       "internlm2"    },
+     { LLM_ARCH_MINICPM,         "minicpm"      },
+     { LLM_ARCH_GEMMA,           "gemma"        },
++    { LLM_ARCH_GEMMA2,          "gemma2"       },
+     { LLM_ARCH_STARCODER2,      "starcoder2"   },
+     { LLM_ARCH_MAMBA,           "mamba"        },
+     { LLM_ARCH_XVERSE,          "xverse"       },
+@@ -464,10 +466,12 @@ enum llm_tensor {
+     LLM_TENSOR_ATTN_NORM,
+     LLM_TENSOR_ATTN_NORM_2,
+     LLM_TENSOR_ATTN_OUT_NORM,
++    LLM_TENSOR_ATTN_POST_NORM,
+     LLM_TENSOR_ATTN_ROT_EMBD,
+     LLM_TENSOR_FFN_GATE_INP,
+     LLM_TENSOR_FFN_GATE_INP_SHEXP,
+     LLM_TENSOR_FFN_NORM,
++    LLM_TENSOR_FFN_POST_NORM,
+     LLM_TENSOR_FFN_GATE,
+     LLM_TENSOR_FFN_DOWN,
+     LLM_TENSOR_FFN_UP,
+@@ -960,6 +964,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
+             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+         },
+     },
++    {
++        LLM_ARCH_GEMMA2,
++        {
++            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
++            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
++            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
++            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
++            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
++            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
++            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
++            { LLM_TENSOR_ATTN_POST_NORM,  "blk.%d.post_attention_norm" },
++            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
++            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
++            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
++            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
++            { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
++        },
++    },
+     {
+         LLM_ARCH_STARCODER2,
+         {
+@@ -1941,6 +1963,8 @@ enum e_model {
+     MODEL_8x22B,
+     MODEL_16x12B,
+     MODEL_10B_128x3_66B,
++    MODEL_9B,
++    MODEL_27B,
+ };
+ 
+ static const size_t kiB = 1024;
+@@ -2114,6 +2138,7 @@ struct llama_layer {
+     struct ggml_tensor * attn_out_norm_b;
+     struct ggml_tensor * attn_q_a_norm;
+     struct ggml_tensor * attn_kv_a_norm;
++    struct ggml_tensor * attn_post_norm;
+ 
+     // attention
+     struct ggml_tensor * wq;
+@@ -2136,6 +2161,7 @@ struct llama_layer {
+     // normalization
+     struct ggml_tensor * ffn_norm;
+     struct ggml_tensor * ffn_norm_b;
++    struct ggml_tensor * ffn_post_norm;
+     struct ggml_tensor * layer_out_norm;
+     struct ggml_tensor * layer_out_norm_b;
+     struct ggml_tensor * ffn_norm_exps;
+@@ -4529,6 +4555,16 @@ static void llm_load_hparams(
+                 }
+             } break;
+         case LLM_ARCH_GEMMA:
++            {
++                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
++
++                switch (hparams.n_layer) {
++                    case 18: model.type = e_model::MODEL_9B; break;
++                    case 28: model.type = e_model::MODEL_27B; break;
++                    default: model.type = e_model::MODEL_UNKNOWN;
++               }
++            } break;
++        case LLM_ARCH_GEMMA2:
+             {
+                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ 
+@@ -6305,6 +6341,40 @@ static bool llm_load_tensors(
+                         layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                     }
+                 } break;
++            case LLM_ARCH_GEMMA2:
++                {
++                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
++
++                    // output
++                    model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
++                    model.output      = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
++
++                    const int64_t n_ff          = hparams.n_ff;
++                    const int64_t n_embd_head_k = hparams.n_embd_head_k;
++                    const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
++                    const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa();
++
++                    for (uint32_t i = 0; i < n_layer; ++i) {
++                        ggml_context * ctx_layer = ctx_for_layer(i);
++                        ggml_context * ctx_split = ctx_for_layer_split(i);
++
++                        auto & layer = model.layers[i];
++
++                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
++
++                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * hparams.n_head});
++                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
++                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
++                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head, n_embd});
++                        layer.attn_post_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd});
++
++                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
++                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
++                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
++                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
++                        layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd});
++                    }
++                } break;
+             case LLM_ARCH_STARCODER2:
+                 {
+                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+@@ -10614,6 +10684,123 @@ struct llm_build_context {
+         return gf;
+     }
+ 
++    struct ggml_cgraph * build_gemma2() {
++        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
++
++        const int64_t n_embd_head_k = hparams.n_embd_head_k;
++
++        struct ggml_tensor * cur;
++        struct ggml_tensor * inpL;
++
++        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
++
++        inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
++        cb(inpL, "inp_scaled", -1);
++
++        // inp_pos - contains the positions
++        struct ggml_tensor * inp_pos = build_inp_pos();
++
++        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
++        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
++
++        for (int il = 0; il < n_layer; ++il) {
++            // norm
++            cur = llm_build_norm(ctx0, inpL, hparams,
++                    model.layers[il].attn_norm, NULL,
++                    LLM_NORM_RMS, cb, il);
++            cb(cur, "attn_norm", il);
++
++            // self-attention
++            {
++                // compute Q and K and RoPE them
++                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
++                cb(Qcur, "Qcur", il);
++
++                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
++                cb(Kcur, "Kcur", il);
++
++                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
++                cb(Vcur, "Vcur", il);
++
++                Qcur = ggml_rope_ext(
++                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
++                        n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
++                        ext_factor, attn_factor, beta_fast, beta_slow);
++                cb(Qcur, "Qcur", il);
++
++                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
++                cb(Qcur, "Qcur_scaled", il);
++
++                Kcur = ggml_rope_ext(
++                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
++                        n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
++                        ext_factor, attn_factor, beta_fast, beta_slow);
++                cb(Kcur, "Kcur", il);
++
++                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
++                        model.layers[il].wo, NULL,
++                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
++            }
++
++            if (il == n_layer - 1) {
++                // skip computing output for unused tokens
++                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
++                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
++                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
++            }
++
++            cur = llm_build_norm(ctx0, cur, hparams,
++                    model.layers[il].attn_post_norm, NULL,
++                    LLM_NORM_RMS, cb, il);
++            cb(cur, "attn_post_norm", il);
++
++            struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
++            cb(sa_out, "sa_out", il);
++
++            cur = llm_build_norm(ctx0, sa_out, hparams,
++                    model.layers[il].ffn_norm, NULL,
++                    LLM_NORM_RMS, cb, il);
++            cb(cur, "ffn_norm", il);
++
++            // feed-forward network
++            {
++                cur = llm_build_ffn(ctx0, cur,
++                        model.layers[il].ffn_up, NULL,
++                        model.layers[il].ffn_gate, NULL,
++                        model.layers[il].ffn_down, NULL,
++                        NULL,
++                        LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
++                cb(cur, "ffn_out", il);
++            }
++
++            cur = llm_build_norm(ctx0, cur, hparams,
++                model.layers[il].ffn_post_norm, NULL,
++                LLM_NORM_RMS, cb, -1);
++            cb(cur, "ffn_post_norm", -1);
++
++            cur = ggml_add(ctx0, cur, sa_out);
++            cb(cur, "l_out", il);
++
++            // input for next layer
++            inpL = cur;
++        }
++
++        cur = inpL;
++
++        cur = llm_build_norm(ctx0, cur, hparams,
++                model.output_norm, NULL,
++                LLM_NORM_RMS, cb, -1);
++        cb(cur, "result_norm", -1);
++
++        // lm_head
++        cur = ggml_mul_mat(ctx0, model.output, cur);
++        cb(cur, "result_output", -1);
++
++        ggml_build_forward_expand(gf, cur);
++
++        return gf;
++    }
++
+     struct ggml_cgraph * build_starcoder2() {
+         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+ 
+@@ -11847,6 +12034,10 @@ static struct ggml_cgraph * llama_build_graph(
+             {
+                 result = llm.build_gemma();
+             } break;
++        case LLM_ARCH_GEMMA2:
++            {
++                result = llm.build_gemma2();
++            } break;
+         case LLM_ARCH_STARCODER2:
+             {
+                 result = llm.build_starcoder2();
+@@ -16671,6 +16862,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
+         case LLM_ARCH_PHI2:
+         case LLM_ARCH_PHI3:
+         case LLM_ARCH_GEMMA:
++        case LLM_ARCH_GEMMA2:
+         case LLM_ARCH_STARCODER2:
+         case LLM_ARCH_GPTNEOX:
+             return LLAMA_ROPE_TYPE_NEOX;
+@@ -18551,7 +18743,7 @@ static int32_t llama_chat_apply_template_internal(
+         if (add_ass) {
+             ss << "<s>assistant\n";
+         }
+-    } else if (tmpl == "gemma" || tmpl.find("<start_of_turn>") != std::string::npos) {
++    } else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl.find("<start_of_turn>") != std::string::npos) {
+         // google/gemma-7b-it
+         std::string system_prompt = "";
+         for (auto message : chat) {
+-- 
+2.45.2
+

+ 2 - 2
llm/payload.go

@@ -58,7 +58,7 @@ func availableServers() map[string]string {
 	}
 	}
 
 
 	// glob payloadsDir for files that start with ollama_
 	// glob payloadsDir for files that start with ollama_
-	pattern := filepath.Join(payloadsDir, "*")
+	pattern := filepath.Join(payloadsDir, "*", "ollama_*")
 
 
 	files, err := filepath.Glob(pattern)
 	files, err := filepath.Glob(pattern)
 	if err != nil {
 	if err != nil {
@@ -69,7 +69,7 @@ func availableServers() map[string]string {
 	servers := make(map[string]string)
 	servers := make(map[string]string)
 	for _, file := range files {
 	for _, file := range files {
 		slog.Debug("availableServers : found", "file", file)
 		slog.Debug("availableServers : found", "file", file)
-		servers[filepath.Base(file)] = file
+		servers[filepath.Base(filepath.Dir(file))] = filepath.Dir(file)
 	}
 	}
 
 
 	return servers
 	return servers

+ 32 - 27
llm/server.go

@@ -61,7 +61,12 @@ type llmServer struct {
 	sem *semaphore.Weighted
 	sem *semaphore.Weighted
 }
 }
 
 
-func LoadModel(model string) (*GGML, error) {
+// LoadModel will load a model from disk. The model must be in the GGML format.
+//
+// It collects array values for arrays with a size less than or equal to
+// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
+// the maxArraySize is negative, all arrays are collected.
+func LoadModel(model string, maxArraySize int) (*GGML, error) {
 	if _, err := os.Stat(model); err != nil {
 	if _, err := os.Stat(model); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -72,17 +77,27 @@ func LoadModel(model string) (*GGML, error) {
 	}
 	}
 	defer f.Close()
 	defer f.Close()
 
 
-	ggml, _, err := DecodeGGML(f)
+	ggml, _, err := DecodeGGML(f, maxArraySize)
 	return ggml, err
 	return ggml, err
 }
 }
 
 
 // NewLlamaServer will run a server for the given GPUs
 // NewLlamaServer will run a server for the given GPUs
 // The gpu list must be a single family.
 // The gpu list must be a single family.
-func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
+func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
 	var err error
 	var err error
 	var cpuRunner string
 	var cpuRunner string
 	var estimate MemoryEstimate
 	var estimate MemoryEstimate
-	var systemMemory uint64
+	var systemTotalMemory uint64
+	var systemFreeMemory uint64
+
+	systemMemInfo, err := gpu.GetCPUMem()
+	if err != nil {
+		slog.Error("failed to lookup system memory", "error", err)
+	} else {
+		systemTotalMemory = systemMemInfo.TotalMemory
+		systemFreeMemory = systemMemInfo.FreeMemory
+		slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", systemFreeMemory)
+	}
 
 
 	// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
 	// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
 	if opts.NumGPU == 0 {
 	if opts.NumGPU == 0 {
@@ -92,19 +107,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		cpuRunner = serverForCpu()
 		cpuRunner = serverForCpu()
 		estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
 		estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
 	} else {
 	} else {
-		if gpus[0].Library == "metal" {
-			memInfo, err := gpu.GetCPUMem()
-			if err != nil {
-				slog.Error("failed to lookup system memory", "error", err)
-			} else {
-				systemMemory = memInfo.TotalMemory
-				slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
-			}
-		}
 		estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
 		estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
 
 
 		switch {
 		switch {
-		case gpus[0].Library == "metal" && estimate.VRAMSize > systemMemory:
+		case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
 			// disable partial offloading when model is greater than total system memory as this
 			// disable partial offloading when model is greater than total system memory as this
 			// can lead to locking up the system
 			// can lead to locking up the system
 			opts.NumGPU = 0
 			opts.NumGPU = 0
@@ -212,7 +218,12 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 	}
 	}
 
 
 	// Windows CUDA should not use mmap for best performance
 	// Windows CUDA should not use mmap for best performance
-	if (runtime.GOOS == "windows" && gpus[0].Library == "cuda") || opts.UseMMap == api.TriStateFalse {
+	// Linux  with a model larger than free space, mmap leads to thrashing
+	// For CPU loads we want the memory to be allocated, not FS cache
+	if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == api.TriStateUndefined) ||
+		(runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == api.TriStateUndefined) ||
+		(gpus[0].Library == "cpu" && opts.UseMMap == api.TriStateUndefined) ||
+		opts.UseMMap == api.TriStateFalse {
 		params = append(params, "--no-mmap")
 		params = append(params, "--no-mmap")
 	}
 	}
 
 
@@ -224,15 +235,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		params = append(params, "--numa")
 		params = append(params, "--numa")
 	}
 	}
 
 
-	numParallel := envconfig.NumParallel
-
-	// TODO (jmorganca): multimodal models don't support parallel yet
-	// see https://github.com/ollama/ollama/issues/4165
-	if len(projectors) > 0 {
-		numParallel = 1
-		slog.Warn("multimodal models don't support parallel requests yet")
-	}
-
 	params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
 	params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
 
 
 	if estimate.TensorSplit != "" {
 	if estimate.TensorSplit != "" {
@@ -275,8 +277,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 		if runtime.GOOS == "windows" {
 		if runtime.GOOS == "windows" {
 			pathEnv = "PATH"
 			pathEnv = "PATH"
 		}
 		}
-		// prepend the server directory to LD_LIBRARY_PATH/PATH
-		libraryPaths := []string{dir}
+		// prepend the server directory to LD_LIBRARY_PATH/PATH and the parent dir for common dependencies
+		libraryPaths := []string{dir, filepath.Dir(dir)}
 
 
 		if libraryPath, ok := os.LookupEnv(pathEnv); ok {
 		if libraryPath, ok := os.LookupEnv(pathEnv); ok {
 			// Append our runner directory to the path
 			// Append our runner directory to the path
@@ -409,7 +411,7 @@ func projectorMemoryRequirements(filename string) uint64 {
 	}
 	}
 	defer file.Close()
 	defer file.Close()
 
 
-	ggml, _, err := DecodeGGML(file)
+	ggml, _, err := DecodeGGML(file, 0)
 	if err != nil {
 	if err != nil {
 		return 0
 		return 0
 	}
 	}
@@ -559,6 +561,9 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
 			if s.status != nil && s.status.LastErrMsg != "" {
 			if s.status != nil && s.status.LastErrMsg != "" {
 				msg = s.status.LastErrMsg
 				msg = s.status.LastErrMsg
 			}
 			}
+			if strings.Contains(msg, "unknown model") {
+				return fmt.Errorf("this model is not supported by your version of Ollama. You may need to upgrade")
+			}
 			return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
 			return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
 		default:
 		default:
 		}
 		}

+ 1 - 0
llm/status.go

@@ -25,6 +25,7 @@ var errorPrefixes = []string{
 	"CUDA error",
 	"CUDA error",
 	"cudaMalloc failed",
 	"cudaMalloc failed",
 	"\"ERR\"",
 	"\"ERR\"",
+	"architecture",
 }
 }
 
 
 func (w *StatusWriter) Write(b []byte) (int, error) {
 func (w *StatusWriter) Write(b []byte) (int, error) {

+ 371 - 13
openai/openai.go

@@ -12,6 +12,7 @@ import (
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/types/model"
 )
 )
 
 
 type Error struct {
 type Error struct {
@@ -42,6 +43,12 @@ type ChunkChoice struct {
 	FinishReason *string `json:"finish_reason"`
 	FinishReason *string `json:"finish_reason"`
 }
 }
 
 
+type CompleteChunkChoice struct {
+	Text         string  `json:"text"`
+	Index        int     `json:"index"`
+	FinishReason *string `json:"finish_reason"`
+}
+
 type Usage struct {
 type Usage struct {
 	PromptTokens     int `json:"prompt_tokens"`
 	PromptTokens     int `json:"prompt_tokens"`
 	CompletionTokens int `json:"completion_tokens"`
 	CompletionTokens int `json:"completion_tokens"`
@@ -85,6 +92,51 @@ type ChatCompletionChunk struct {
 	Choices           []ChunkChoice `json:"choices"`
 	Choices           []ChunkChoice `json:"choices"`
 }
 }
 
 
+// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
+type CompletionRequest struct {
+	Model            string   `json:"model"`
+	Prompt           string   `json:"prompt"`
+	FrequencyPenalty float32  `json:"frequency_penalty"`
+	MaxTokens        *int     `json:"max_tokens"`
+	PresencePenalty  float32  `json:"presence_penalty"`
+	Seed             *int     `json:"seed"`
+	Stop             any      `json:"stop"`
+	Stream           bool     `json:"stream"`
+	Temperature      *float32 `json:"temperature"`
+	TopP             float32  `json:"top_p"`
+}
+
+type Completion struct {
+	Id                string                `json:"id"`
+	Object            string                `json:"object"`
+	Created           int64                 `json:"created"`
+	Model             string                `json:"model"`
+	SystemFingerprint string                `json:"system_fingerprint"`
+	Choices           []CompleteChunkChoice `json:"choices"`
+	Usage             Usage                 `json:"usage,omitempty"`
+}
+
+type CompletionChunk struct {
+	Id                string                `json:"id"`
+	Object            string                `json:"object"`
+	Created           int64                 `json:"created"`
+	Choices           []CompleteChunkChoice `json:"choices"`
+	Model             string                `json:"model"`
+	SystemFingerprint string                `json:"system_fingerprint"`
+}
+
+type Model struct {
+	Id      string `json:"id"`
+	Object  string `json:"object"`
+	Created int64  `json:"created"`
+	OwnedBy string `json:"owned_by"`
+}
+
+type ListCompletion struct {
+	Object string  `json:"object"`
+	Data   []Model `json:"data"`
+}
+
 func NewError(code int, message string) ErrorResponse {
 func NewError(code int, message string) ErrorResponse {
 	var etype string
 	var etype string
 	switch code {
 	switch code {
@@ -145,7 +197,79 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
 	}
 	}
 }
 }
 
 
-func fromRequest(r ChatCompletionRequest) api.ChatRequest {
+func toCompletion(id string, r api.GenerateResponse) Completion {
+	return Completion{
+		Id:                id,
+		Object:            "text_completion",
+		Created:           r.CreatedAt.Unix(),
+		Model:             r.Model,
+		SystemFingerprint: "fp_ollama",
+		Choices: []CompleteChunkChoice{{
+			Text:  r.Response,
+			Index: 0,
+			FinishReason: func(reason string) *string {
+				if len(reason) > 0 {
+					return &reason
+				}
+				return nil
+			}(r.DoneReason),
+		}},
+		Usage: Usage{
+			// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
+			PromptTokens:     r.PromptEvalCount,
+			CompletionTokens: r.EvalCount,
+			TotalTokens:      r.PromptEvalCount + r.EvalCount,
+		},
+	}
+}
+
+func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
+	return CompletionChunk{
+		Id:                id,
+		Object:            "text_completion",
+		Created:           time.Now().Unix(),
+		Model:             r.Model,
+		SystemFingerprint: "fp_ollama",
+		Choices: []CompleteChunkChoice{{
+			Text:  r.Response,
+			Index: 0,
+			FinishReason: func(reason string) *string {
+				if len(reason) > 0 {
+					return &reason
+				}
+				return nil
+			}(r.DoneReason),
+		}},
+	}
+}
+
+func toListCompletion(r api.ListResponse) ListCompletion {
+	var data []Model
+	for _, m := range r.Models {
+		data = append(data, Model{
+			Id:      m.Name,
+			Object:  "model",
+			Created: m.ModifiedAt.Unix(),
+			OwnedBy: model.ParseName(m.Name).Namespace,
+		})
+	}
+
+	return ListCompletion{
+		Object: "list",
+		Data:   data,
+	}
+}
+
+func toModel(r api.ShowResponse, m string) Model {
+	return Model{
+		Id:      m,
+		Object:  "model",
+		Created: r.ModifiedAt.Unix(),
+		OwnedBy: model.ParseName(m).Namespace,
+	}
+}
+
+func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
 	var messages []api.Message
 	var messages []api.Message
 	for _, msg := range r.Messages {
 	for _, msg := range r.Messages {
 		messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
 		messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
@@ -156,7 +280,7 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
 	switch stop := r.Stop.(type) {
 	switch stop := r.Stop.(type) {
 	case string:
 	case string:
 		options["stop"] = []string{stop}
 		options["stop"] = []string{stop}
-	case []interface{}:
+	case []any:
 		var stops []string
 		var stops []string
 		for _, s := range stop {
 		for _, s := range stop {
 			if str, ok := s.(string); ok {
 			if str, ok := s.(string); ok {
@@ -208,13 +332,78 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
 	}
 	}
 }
 }
 
 
-type writer struct {
+func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
+	options := make(map[string]any)
+
+	switch stop := r.Stop.(type) {
+	case string:
+		options["stop"] = []string{stop}
+	case []string:
+		options["stop"] = stop
+	default:
+		if r.Stop != nil {
+			return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop)
+		}
+	}
+
+	if r.MaxTokens != nil {
+		options["num_predict"] = *r.MaxTokens
+	}
+
+	if r.Temperature != nil {
+		options["temperature"] = *r.Temperature * 2.0
+	} else {
+		options["temperature"] = 1.0
+	}
+
+	if r.Seed != nil {
+		options["seed"] = *r.Seed
+	}
+
+	options["frequency_penalty"] = r.FrequencyPenalty * 2.0
+
+	options["presence_penalty"] = r.PresencePenalty * 2.0
+
+	if r.TopP != 0.0 {
+		options["top_p"] = r.TopP
+	} else {
+		options["top_p"] = 1.0
+	}
+
+	return api.GenerateRequest{
+		Model:   r.Model,
+		Prompt:  r.Prompt,
+		Options: options,
+		Stream:  &r.Stream,
+	}, nil
+}
+
+type BaseWriter struct {
+	gin.ResponseWriter
+}
+
+type ChatWriter struct {
 	stream bool
 	stream bool
 	id     string
 	id     string
-	gin.ResponseWriter
+	BaseWriter
 }
 }
 
 
-func (w *writer) writeError(code int, data []byte) (int, error) {
+type CompleteWriter struct {
+	stream bool
+	id     string
+	BaseWriter
+}
+
+type ListWriter struct {
+	BaseWriter
+}
+
+type RetrieveWriter struct {
+	BaseWriter
+	model string
+}
+
+func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
 	var serr api.StatusError
 	var serr api.StatusError
 	err := json.Unmarshal(data, &serr)
 	err := json.Unmarshal(data, &serr)
 	if err != nil {
 	if err != nil {
@@ -230,7 +419,7 @@ func (w *writer) writeError(code int, data []byte) (int, error) {
 	return len(data), nil
 	return len(data), nil
 }
 }
 
 
-func (w *writer) writeResponse(data []byte) (int, error) {
+func (w *ChatWriter) writeResponse(data []byte) (int, error) {
 	var chatResponse api.ChatResponse
 	var chatResponse api.ChatResponse
 	err := json.Unmarshal(data, &chatResponse)
 	err := json.Unmarshal(data, &chatResponse)
 	if err != nil {
 	if err != nil {
@@ -270,7 +459,107 @@ func (w *writer) writeResponse(data []byte) (int, error) {
 	return len(data), nil
 	return len(data), nil
 }
 }
 
 
-func (w *writer) Write(data []byte) (int, error) {
+func (w *ChatWriter) Write(data []byte) (int, error) {
+	code := w.ResponseWriter.Status()
+	if code != http.StatusOK {
+		return w.writeError(code, data)
+	}
+
+	return w.writeResponse(data)
+}
+
+func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
+	var generateResponse api.GenerateResponse
+	err := json.Unmarshal(data, &generateResponse)
+	if err != nil {
+		return 0, err
+	}
+
+	// completion chunk
+	if w.stream {
+		d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
+		if err != nil {
+			return 0, err
+		}
+
+		w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
+		_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
+		if err != nil {
+			return 0, err
+		}
+
+		if generateResponse.Done {
+			_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
+			if err != nil {
+				return 0, err
+			}
+		}
+
+		return len(data), nil
+	}
+
+	// completion
+	w.ResponseWriter.Header().Set("Content-Type", "application/json")
+	err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
+	if err != nil {
+		return 0, err
+	}
+
+	return len(data), nil
+}
+
+func (w *CompleteWriter) Write(data []byte) (int, error) {
+	code := w.ResponseWriter.Status()
+	if code != http.StatusOK {
+		return w.writeError(code, data)
+	}
+
+	return w.writeResponse(data)
+}
+
+func (w *ListWriter) writeResponse(data []byte) (int, error) {
+	var listResponse api.ListResponse
+	err := json.Unmarshal(data, &listResponse)
+	if err != nil {
+		return 0, err
+	}
+
+	w.ResponseWriter.Header().Set("Content-Type", "application/json")
+	err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
+	if err != nil {
+		return 0, err
+	}
+
+	return len(data), nil
+}
+
+func (w *ListWriter) Write(data []byte) (int, error) {
+	code := w.ResponseWriter.Status()
+	if code != http.StatusOK {
+		return w.writeError(code, data)
+	}
+
+	return w.writeResponse(data)
+}
+
+func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
+	var showResponse api.ShowResponse
+	err := json.Unmarshal(data, &showResponse)
+	if err != nil {
+		return 0, err
+	}
+
+	// retrieve completion
+	w.ResponseWriter.Header().Set("Content-Type", "application/json")
+	err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
+	if err != nil {
+		return 0, err
+	}
+
+	return len(data), nil
+}
+
+func (w *RetrieveWriter) Write(data []byte) (int, error) {
 	code := w.ResponseWriter.Status()
 	code := w.ResponseWriter.Status()
 	if code != http.StatusOK {
 	if code != http.StatusOK {
 		return w.writeError(code, data)
 		return w.writeError(code, data)
@@ -279,7 +568,76 @@ func (w *writer) Write(data []byte) (int, error) {
 	return w.writeResponse(data)
 	return w.writeResponse(data)
 }
 }
 
 
-func Middleware() gin.HandlerFunc {
+func ListMiddleware() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		w := &ListWriter{
+			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+		}
+
+		c.Writer = w
+
+		c.Next()
+	}
+}
+
+func RetrieveMiddleware() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		var b bytes.Buffer
+		if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
+			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
+			return
+		}
+
+		c.Request.Body = io.NopCloser(&b)
+
+		// response writer
+		w := &RetrieveWriter{
+			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+			model:      c.Param("model"),
+		}
+
+		c.Writer = w
+
+		c.Next()
+	}
+}
+
+func CompletionsMiddleware() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		var req CompletionRequest
+		err := c.ShouldBindJSON(&req)
+		if err != nil {
+			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
+			return
+		}
+
+		var b bytes.Buffer
+		genReq, err := fromCompleteRequest(req)
+		if err != nil {
+			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
+			return
+		}
+
+		if err := json.NewEncoder(&b).Encode(genReq); err != nil {
+			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
+			return
+		}
+
+		c.Request.Body = io.NopCloser(&b)
+
+		w := &CompleteWriter{
+			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+			stream:     req.Stream,
+			id:         fmt.Sprintf("cmpl-%d", rand.Intn(999)),
+		}
+
+		c.Writer = w
+
+		c.Next()
+	}
+}
+
+func ChatMiddleware() gin.HandlerFunc {
 	return func(c *gin.Context) {
 	return func(c *gin.Context) {
 		var req ChatCompletionRequest
 		var req ChatCompletionRequest
 		err := c.ShouldBindJSON(&req)
 		err := c.ShouldBindJSON(&req)
@@ -294,17 +652,17 @@ func Middleware() gin.HandlerFunc {
 		}
 		}
 
 
 		var b bytes.Buffer
 		var b bytes.Buffer
-		if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil {
+		if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
 			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
 			c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
 			return
 			return
 		}
 		}
 
 
 		c.Request.Body = io.NopCloser(&b)
 		c.Request.Body = io.NopCloser(&b)
 
 
-		w := &writer{
-			ResponseWriter: c.Writer,
-			stream:         req.Stream,
-			id:             fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
+		w := &ChatWriter{
+			BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+			stream:     req.Stream,
+			id:         fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
 		}
 		}
 
 
 		c.Writer = w
 		c.Writer = w

+ 298 - 0
openai/openai_test.go

@@ -0,0 +1,298 @@
+package openai
+
+import (
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"net/http/httptest"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/gin-gonic/gin"
+	"github.com/ollama/ollama/api"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestMiddleware(t *testing.T) {
+	type testCase struct {
+		Name     string
+		Method   string
+		Path     string
+		TestPath string
+		Handler  func() gin.HandlerFunc
+		Endpoint func(c *gin.Context)
+		Setup    func(t *testing.T, req *http.Request)
+		Expected func(t *testing.T, resp *httptest.ResponseRecorder)
+	}
+
+	testCases := []testCase{
+		{
+			Name:     "chat handler",
+			Method:   http.MethodPost,
+			Path:     "/api/chat",
+			TestPath: "/api/chat",
+			Handler:  ChatMiddleware,
+			Endpoint: func(c *gin.Context) {
+				var chatReq api.ChatRequest
+				if err := c.ShouldBindJSON(&chatReq); err != nil {
+					c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
+					return
+				}
+
+				userMessage := chatReq.Messages[0].Content
+				var assistantMessage string
+
+				switch userMessage {
+				case "Hello":
+					assistantMessage = "Hello!"
+				default:
+					assistantMessage = "I'm not sure how to respond to that."
+				}
+
+				c.JSON(http.StatusOK, api.ChatResponse{
+					Message: api.Message{
+						Role:    "assistant",
+						Content: assistantMessage,
+					},
+				})
+			},
+			Setup: func(t *testing.T, req *http.Request) {
+				body := ChatCompletionRequest{
+					Model:    "test-model",
+					Messages: []Message{{Role: "user", Content: "Hello"}},
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				assert.Equal(t, http.StatusOK, resp.Code)
+
+				var chatResp ChatCompletion
+				if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
+					t.Fatal(err)
+				}
+
+				if chatResp.Object != "chat.completion" {
+					t.Fatalf("expected chat.completion, got %s", chatResp.Object)
+				}
+
+				if chatResp.Choices[0].Message.Content != "Hello!" {
+					t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
+				}
+			},
+		},
+		{
+			Name:     "completions handler",
+			Method:   http.MethodPost,
+			Path:     "/api/generate",
+			TestPath: "/api/generate",
+			Handler:  CompletionsMiddleware,
+			Endpoint: func(c *gin.Context) {
+				c.JSON(http.StatusOK, api.GenerateResponse{
+					Response: "Hello!",
+				})
+			},
+			Setup: func(t *testing.T, req *http.Request) {
+				body := CompletionRequest{
+					Model:  "test-model",
+					Prompt: "Hello",
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				assert.Equal(t, http.StatusOK, resp.Code)
+				var completionResp Completion
+				if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
+					t.Fatal(err)
+				}
+
+				if completionResp.Object != "text_completion" {
+					t.Fatalf("expected text_completion, got %s", completionResp.Object)
+				}
+
+				if completionResp.Choices[0].Text != "Hello!" {
+					t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
+				}
+			},
+		},
+		{
+			Name:     "completions handler with params",
+			Method:   http.MethodPost,
+			Path:     "/api/generate",
+			TestPath: "/api/generate",
+			Handler:  CompletionsMiddleware,
+			Endpoint: func(c *gin.Context) {
+				var generateReq api.GenerateRequest
+				if err := c.ShouldBindJSON(&generateReq); err != nil {
+					c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
+					return
+				}
+
+				temperature := generateReq.Options["temperature"].(float64)
+				var assistantMessage string
+
+				switch temperature {
+				case 1.6:
+					assistantMessage = "Received temperature of 1.6"
+				default:
+					assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
+				}
+
+				c.JSON(http.StatusOK, api.GenerateResponse{
+					Response: assistantMessage,
+				})
+			},
+			Setup: func(t *testing.T, req *http.Request) {
+				temp := float32(0.8)
+				body := CompletionRequest{
+					Model:       "test-model",
+					Prompt:      "Hello",
+					Temperature: &temp,
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				assert.Equal(t, http.StatusOK, resp.Code)
+				var completionResp Completion
+				if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
+					t.Fatal(err)
+				}
+
+				if completionResp.Object != "text_completion" {
+					t.Fatalf("expected text_completion, got %s", completionResp.Object)
+				}
+
+				if completionResp.Choices[0].Text != "Received temperature of 1.6" {
+					t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
+				}
+			},
+		},
+		{
+			Name:     "completions handler with error",
+			Method:   http.MethodPost,
+			Path:     "/api/generate",
+			TestPath: "/api/generate",
+			Handler:  CompletionsMiddleware,
+			Endpoint: func(c *gin.Context) {
+				c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
+			},
+			Setup: func(t *testing.T, req *http.Request) {
+				body := CompletionRequest{
+					Model:  "test-model",
+					Prompt: "Hello",
+				}
+
+				bodyBytes, _ := json.Marshal(body)
+
+				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+				req.Header.Set("Content-Type", "application/json")
+			},
+			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				if resp.Code != http.StatusBadRequest {
+					t.Fatalf("expected 400, got %d", resp.Code)
+				}
+
+				if !strings.Contains(resp.Body.String(), `"invalid request"`) {
+					t.Fatalf("error was not forwarded")
+				}
+			},
+		},
+		{
+			Name:     "list handler",
+			Method:   http.MethodGet,
+			Path:     "/api/tags",
+			TestPath: "/api/tags",
+			Handler:  ListMiddleware,
+			Endpoint: func(c *gin.Context) {
+				c.JSON(http.StatusOK, api.ListResponse{
+					Models: []api.ListModelResponse{
+						{
+							Name: "Test Model",
+						},
+					},
+				})
+			},
+			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				assert.Equal(t, http.StatusOK, resp.Code)
+
+				var listResp ListCompletion
+				if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
+					t.Fatal(err)
+				}
+
+				if listResp.Object != "list" {
+					t.Fatalf("expected list, got %s", listResp.Object)
+				}
+
+				if len(listResp.Data) != 1 {
+					t.Fatalf("expected 1, got %d", len(listResp.Data))
+				}
+
+				if listResp.Data[0].Id != "Test Model" {
+					t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
+				}
+			},
+		},
+		{
+			Name:     "retrieve model",
+			Method:   http.MethodGet,
+			Path:     "/api/show/:model",
+			TestPath: "/api/show/test-model",
+			Handler:  RetrieveMiddleware,
+			Endpoint: func(c *gin.Context) {
+				c.JSON(http.StatusOK, api.ShowResponse{
+					ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
+				})
+			},
+			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
+				var retrieveResp Model
+				if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
+					t.Fatal(err)
+				}
+
+				if retrieveResp.Object != "model" {
+					t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
+				}
+
+				if retrieveResp.Id != "test-model" {
+					t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
+				}
+			},
+		},
+	}
+
+	gin.SetMode(gin.TestMode)
+	router := gin.New()
+
+	for _, tc := range testCases {
+		t.Run(tc.Name, func(t *testing.T) {
+			router = gin.New()
+			router.Use(tc.Handler())
+			router.Handle(tc.Method, tc.Path, tc.Endpoint)
+			req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
+
+			if tc.Setup != nil {
+				tc.Setup(t, req)
+			}
+
+			resp := httptest.NewRecorder()
+			router.ServeHTTP(resp, req)
+
+			tc.Expected(t, resp)
+		})
+	}
+}

+ 2 - 2
parser/parser.go

@@ -124,7 +124,7 @@ func ParseFile(r io.Reader) (*File, error) {
 			case stateComment, stateNil:
 			case stateComment, stateNil:
 				// pass
 				// pass
 			case stateValue:
 			case stateValue:
-				s, ok := unquote(b.String())
+				s, ok := unquote(strings.TrimSpace(b.String()))
 				if !ok || isSpace(r) {
 				if !ok || isSpace(r) {
 					if _, err := b.WriteRune(r); err != nil {
 					if _, err := b.WriteRune(r); err != nil {
 						return nil, err
 						return nil, err
@@ -158,7 +158,7 @@ func ParseFile(r io.Reader) (*File, error) {
 	case stateComment, stateNil:
 	case stateComment, stateNil:
 		// pass; nothing to flush
 		// pass; nothing to flush
 	case stateValue:
 	case stateValue:
-		s, ok := unquote(b.String())
+		s, ok := unquote(strings.TrimSpace(b.String()))
 		if !ok {
 		if !ok {
 			return nil, io.ErrUnexpectedEOF
 			return nil, io.ErrUnexpectedEOF
 		}
 		}

+ 67 - 3
parser/parser_test.go

@@ -22,7 +22,13 @@ ADAPTER adapter1
 LICENSE MIT
 LICENSE MIT
 PARAMETER param1 value1
 PARAMETER param1 value1
 PARAMETER param2 value2
 PARAMETER param2 value2
-TEMPLATE template1
+TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
+
+{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
+
+{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
+
+{{ .Response }}<|eot_id|>"""    
 `
 `
 
 
 	reader := strings.NewReader(input)
 	reader := strings.NewReader(input)
@@ -36,7 +42,40 @@ TEMPLATE template1
 		{Name: "license", Args: "MIT"},
 		{Name: "license", Args: "MIT"},
 		{Name: "param1", Args: "value1"},
 		{Name: "param1", Args: "value1"},
 		{Name: "param2", Args: "value2"},
 		{Name: "param2", Args: "value2"},
-		{Name: "template", Args: "template1"},
+		{Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"},
+	}
+
+	assert.Equal(t, expectedCommands, modelfile.Commands)
+}
+
+func TestParseFileTrimSpace(t *testing.T) {
+	input := `
+FROM "     model 1"
+ADAPTER      adapter3
+LICENSE "MIT       "
+PARAMETER param1        value1
+PARAMETER param2    value2
+TEMPLATE """   {{ if .System }}<|start_header_id|>system<|end_header_id|>
+
+{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
+
+{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
+
+{{ .Response }}<|eot_id|>   """    
+`
+
+	reader := strings.NewReader(input)
+
+	modelfile, err := ParseFile(reader)
+	require.NoError(t, err)
+
+	expectedCommands := []Command{
+		{Name: "model", Args: "     model 1"},
+		{Name: "adapter", Args: "adapter3"},
+		{Name: "license", Args: "MIT       "},
+		{Name: "param1", Args: "value1"},
+		{Name: "param2", Args: "value2"},
+		{Name: "template", Args: "   {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>   "},
 	}
 	}
 
 
 	assert.Equal(t, expectedCommands, modelfile.Commands)
 	assert.Equal(t, expectedCommands, modelfile.Commands)
@@ -48,6 +87,26 @@ func TestParseFileFrom(t *testing.T) {
 		expected []Command
 		expected []Command
 		err      error
 		err      error
 	}{
 	}{
+		{
+			"FROM \"FOO  BAR  \"",
+			[]Command{{Name: "model", Args: "FOO  BAR  "}},
+			nil,
+		},
+		{
+			"FROM \"FOO BAR\"\nPARAMETER param1 value1",
+			[]Command{{Name: "model", Args: "FOO BAR"}, {Name: "param1", Args: "value1"}},
+			nil,
+		},
+		{
+			"FROM     FOOO BAR    ",
+			[]Command{{Name: "model", Args: "FOOO BAR"}},
+			nil,
+		},
+		{
+			"FROM /what/is/the path ",
+			[]Command{{Name: "model", Args: "/what/is/the path"}},
+			nil,
+		},
 		{
 		{
 			"FROM foo",
 			"FROM foo",
 			[]Command{{Name: "model", Args: "foo"}},
 			[]Command{{Name: "model", Args: "foo"}},
@@ -86,6 +145,11 @@ func TestParseFileFrom(t *testing.T) {
 			[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
 			[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
 			nil,
 			nil,
 		},
 		},
+		{
+			"PARAMETER what the \nFROM lemons make lemonade ",
+			[]Command{{Name: "what", Args: "the"}, {Name: "model", Args: "lemons make lemonade"}},
+			nil,
+		},
 	}
 	}
 
 
 	for _, c := range cases {
 	for _, c := range cases {
@@ -399,7 +463,7 @@ func TestParseFileParameters(t *testing.T) {
 		"mirostat_eta 1.0":             {"mirostat_eta", "1.0"},
 		"mirostat_eta 1.0":             {"mirostat_eta", "1.0"},
 		"penalize_newline true":        {"penalize_newline", "true"},
 		"penalize_newline true":        {"penalize_newline", "true"},
 		"stop ### User:":               {"stop", "### User:"},
 		"stop ### User:":               {"stop", "### User:"},
-		"stop ### User: ":              {"stop", "### User: "},
+		"stop ### User: ":              {"stop", "### User:"},
 		"stop \"### User:\"":           {"stop", "### User:"},
 		"stop \"### User:\"":           {"stop", "### User:"},
 		"stop \"### User: \"":          {"stop", "### User: "},
 		"stop \"### User: \"":          {"stop", "### User: "},
 		"stop \"\"\"### User:\"\"\"":   {"stop", "### User:"},
 		"stop \"\"\"### User:\"\"\"":   {"stop", "### User:"},

+ 5 - 5
scripts/build_windows.ps1

@@ -103,19 +103,19 @@ function buildApp() {
 function gatherDependencies() {
 function gatherDependencies() {
     write-host "Gathering runtime dependencies"
     write-host "Gathering runtime dependencies"
     cd "${script:SRC_DIR}"
     cd "${script:SRC_DIR}"
-    md "${script:DEPS_DIR}" -ea 0 > $null
+    md "${script:DEPS_DIR}\ollama_runners" -ea 0 > $null
 
 
     # TODO - this varies based on host build system and MSVC version - drive from dumpbin output
     # TODO - this varies based on host build system and MSVC version - drive from dumpbin output
     # currently works for Win11 + MSVC 2019 + Cuda V11
     # currently works for Win11 + MSVC 2019 + Cuda V11
-    cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\"
-    cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\"
-    cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\"
+    cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\ollama_runners\"
+    cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
+    cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
 
 
 
 
     cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"
     cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"
     if ("${env:KEY_CONTAINER}") {
     if ("${env:KEY_CONTAINER}") {
         write-host "about to sign"
         write-host "about to sign"
-        foreach ($file in (get-childitem "${script:DEPS_DIR}/cu*.dll") + @("${script:SRC_DIR}\dist\ollama_welcome.ps1")){
+        foreach ($file in (get-childitem "${script:DEPS_DIR}\cuda\cu*.dll") + @("${script:SRC_DIR}\dist\ollama_welcome.ps1")){
             write-host "signing $file"
             write-host "signing $file"
             & "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
             & "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
                 /csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} $file
                 /csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} $file

+ 1 - 1
scripts/install.sh

@@ -279,7 +279,7 @@ if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\
     case $OS_NAME in
     case $OS_NAME in
         centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
         centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
         rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;;
         rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;;
-        fedora) [ $OS_VERSION -lt '37' ] && install_cuda_driver_yum $OS_NAME $OS_VERSION || install_cuda_driver_yum $OS_NAME '37';;
+        fedora) [ $OS_VERSION -lt '39' ] && install_cuda_driver_yum $OS_NAME $OS_VERSION || install_cuda_driver_yum $OS_NAME '39';;
         amzn) install_cuda_driver_yum 'fedora' '37' ;;
         amzn) install_cuda_driver_yum 'fedora' '37' ;;
         debian) install_cuda_driver_apt $OS_NAME $OS_VERSION ;;
         debian) install_cuda_driver_apt $OS_NAME $OS_VERSION ;;
         ubuntu) install_cuda_driver_apt $OS_NAME $(echo $OS_VERSION | sed 's/\.//') ;;
         ubuntu) install_cuda_driver_apt $OS_NAME $(echo $OS_VERSION | sed 's/\.//') ;;

+ 11 - 0
scripts/rh_linux_deps.sh

@@ -6,10 +6,21 @@ set -ex
 MACHINE=$(uname -m)
 MACHINE=$(uname -m)
 
 
 if grep -i "centos" /etc/system-release >/dev/null; then
 if grep -i "centos" /etc/system-release >/dev/null; then
+    # As of 7/1/2024 mirrorlist.centos.org has been taken offline, so adjust accordingly
+    sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo
+    sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo
+    sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo
+
     # Centos 7 derivatives have too old of a git version to run our generate script
     # Centos 7 derivatives have too old of a git version to run our generate script
     # uninstall and ignore failures
     # uninstall and ignore failures
     yum remove -y git
     yum remove -y git
     yum -y install epel-release centos-release-scl
     yum -y install epel-release centos-release-scl
+
+    # The release packages reinstate the mirrors, undo that again
+    sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo
+    sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo
+    sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo
+
     yum -y install dnf
     yum -y install dnf
     if [ "${MACHINE}" = "x86_64" ]; then
     if [ "${MACHINE}" = "x86_64" ]; then
         yum -y install https://repo.ius.io/ius-release-el7.rpm
         yum -y install https://repo.ius.io/ius-release-el7.rpm

+ 58 - 32
server/images.go

@@ -28,11 +28,16 @@ import (
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/parser"
+	"github.com/ollama/ollama/template"
 	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
 	"github.com/ollama/ollama/version"
 )
 )
 
 
+type Capability string
+
+const CapabilityCompletion = Capability("completion")
+
 type registryOptions struct {
 type registryOptions struct {
 	Insecure bool
 	Insecure bool
 	Username string
 	Username string
@@ -48,16 +53,43 @@ type Model struct {
 	ParentModel    string
 	ParentModel    string
 	AdapterPaths   []string
 	AdapterPaths   []string
 	ProjectorPaths []string
 	ProjectorPaths []string
-	Template       string
 	System         string
 	System         string
 	License        []string
 	License        []string
 	Digest         string
 	Digest         string
 	Options        map[string]interface{}
 	Options        map[string]interface{}
 	Messages       []Message
 	Messages       []Message
+
+	Template *template.Template
 }
 }
 
 
-func (m *Model) IsEmbedding() bool {
-	return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
+func (m *Model) Has(caps ...Capability) bool {
+	for _, cap := range caps {
+		switch cap {
+		case CapabilityCompletion:
+			f, err := os.Open(m.ModelPath)
+			if err != nil {
+				slog.Error("couldn't open model file", "error", err)
+				continue
+			}
+			defer f.Close()
+
+			// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
+			ggml, _, err := llm.DecodeGGML(f, 0)
+			if err != nil {
+				slog.Error("couldn't decode ggml", "error", err)
+				continue
+			}
+
+			if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
+				return false
+			}
+		default:
+			slog.Error("unknown capability", "capability", cap)
+			return false
+		}
+	}
+
+	return true
 }
 }
 
 
 func (m *Model) String() string {
 func (m *Model) String() string {
@@ -82,10 +114,10 @@ func (m *Model) String() string {
 		})
 		})
 	}
 	}
 
 
-	if m.Template != "" {
+	if m.Template != nil {
 		modelfile.Commands = append(modelfile.Commands, parser.Command{
 		modelfile.Commands = append(modelfile.Commands, parser.Command{
 			Name: "template",
 			Name: "template",
-			Args: m.Template,
+			Args: m.Template.String(),
 		})
 		})
 	}
 	}
 
 
@@ -135,13 +167,6 @@ type Message struct {
 	Content string `json:"content"`
 	Content string `json:"content"`
 }
 }
 
 
-type ManifestV2 struct {
-	SchemaVersion int      `json:"schemaVersion"`
-	MediaType     string   `json:"mediaType"`
-	Config        *Layer   `json:"config"`
-	Layers        []*Layer `json:"layers"`
-}
-
 type ConfigV2 struct {
 type ConfigV2 struct {
 	ModelFormat   string   `json:"model_format"`
 	ModelFormat   string   `json:"model_format"`
 	ModelFamily   string   `json:"model_family"`
 	ModelFamily   string   `json:"model_family"`
@@ -160,7 +185,7 @@ type RootFS struct {
 	DiffIDs []string `json:"diff_ids"`
 	DiffIDs []string `json:"diff_ids"`
 }
 }
 
 
-func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
+func GetManifest(mp ModelPath) (*Manifest, string, error) {
 	fp, err := mp.GetManifestPath()
 	fp, err := mp.GetManifestPath()
 	if err != nil {
 	if err != nil {
 		return nil, "", err
 		return nil, "", err
@@ -170,7 +195,7 @@ func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
 		return nil, "", err
 		return nil, "", err
 	}
 	}
 
 
-	var manifest *ManifestV2
+	var manifest *Manifest
 
 
 	bts, err := os.ReadFile(fp)
 	bts, err := os.ReadFile(fp)
 	if err != nil {
 	if err != nil {
@@ -198,8 +223,7 @@ func GetModel(name string) (*Model, error) {
 		Name:      mp.GetFullTagname(),
 		Name:      mp.GetFullTagname(),
 		ShortName: mp.GetShortTagname(),
 		ShortName: mp.GetShortTagname(),
 		Digest:    digest,
 		Digest:    digest,
-		Template:  "{{ .Prompt }}",
-		License:   []string{},
+		Template:  template.DefaultTemplate,
 	}
 	}
 
 
 	filename, err := GetBlobsPath(manifest.Config.Digest)
 	filename, err := GetBlobsPath(manifest.Config.Digest)
@@ -235,27 +259,24 @@ func GetModel(name string) (*Model, error) {
 			model.AdapterPaths = append(model.AdapterPaths, filename)
 			model.AdapterPaths = append(model.AdapterPaths, filename)
 		case "application/vnd.ollama.image.projector":
 		case "application/vnd.ollama.image.projector":
 			model.ProjectorPaths = append(model.ProjectorPaths, filename)
 			model.ProjectorPaths = append(model.ProjectorPaths, filename)
-		case "application/vnd.ollama.image.template":
+		case "application/vnd.ollama.image.prompt",
+			"application/vnd.ollama.image.template":
 			bts, err := os.ReadFile(filename)
 			bts, err := os.ReadFile(filename)
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
 
 
-			model.Template = string(bts)
-		case "application/vnd.ollama.image.system":
-			bts, err := os.ReadFile(filename)
+			model.Template, err = template.Parse(string(bts))
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
-
-			model.System = string(bts)
-		case "application/vnd.ollama.image.prompt":
+		case "application/vnd.ollama.image.system":
 			bts, err := os.ReadFile(filename)
 			bts, err := os.ReadFile(filename)
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
 
 
-			model.Template = string(bts)
+			model.System = string(bts)
 		case "application/vnd.ollama.image.params":
 		case "application/vnd.ollama.image.params":
 			params, err := os.Open(filename)
 			params, err := os.Open(filename)
 			if err != nil {
 			if err != nil {
@@ -414,17 +435,22 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 							return err
 							return err
 						}
 						}
 
 
-						layers, err := parseFromFile(ctx, temp, "", fn)
+						layer, err := NewLayer(temp, baseLayer.MediaType)
 						if err != nil {
 						if err != nil {
 							return err
 							return err
 						}
 						}
 
 
-						if len(layers) != 1 {
-							return errors.New("quantization failed")
+						if _, err := temp.Seek(0, io.SeekStart); err != nil {
+							return err
+						}
+
+						ggml, _, err := llm.DecodeGGML(temp, 0)
+						if err != nil {
+							return err
 						}
 						}
 
 
-						baseLayer.Layer = layers[0].Layer
-						baseLayer.GGML = layers[0].GGML
+						baseLayer.Layer = layer
+						baseLayer.GGML = ggml
 					}
 					}
 				}
 				}
 
 
@@ -817,7 +843,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
 func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
 func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
 	mp := ParseModelPath(name)
 	mp := ParseModelPath(name)
 
 
-	var manifest *ManifestV2
+	var manifest *Manifest
 	var err error
 	var err error
 	var noprune string
 	var noprune string
 
 
@@ -924,7 +950,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
 	return nil
 	return nil
 }
 }
 
 
-func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) {
+func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
 	requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
 	requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
 
 
 	headers := make(http.Header)
 	headers := make(http.Header)
@@ -935,7 +961,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
 	}
 	}
 	defer resp.Body.Close()
 	defer resp.Body.Close()
 
 
-	var m *ManifestV2
+	var m *Manifest
 	if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
 	if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}

+ 11 - 9
server/manifest.go

@@ -14,7 +14,10 @@ import (
 )
 )
 
 
 type Manifest struct {
 type Manifest struct {
-	ManifestV2
+	SchemaVersion int      `json:"schemaVersion"`
+	MediaType     string   `json:"mediaType"`
+	Config        *Layer   `json:"config"`
+	Layers        []*Layer `json:"layers"`
 
 
 	filepath string
 	filepath string
 	fi       os.FileInfo
 	fi       os.FileInfo
@@ -66,7 +69,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
 
 
 	p := filepath.Join(manifests, n.Filepath())
 	p := filepath.Join(manifests, n.Filepath())
 
 
-	var m ManifestV2
+	var m Manifest
 	f, err := os.Open(p)
 	f, err := os.Open(p)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -83,12 +86,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	return &Manifest{
-		ManifestV2: m,
-		filepath:   p,
-		fi:         fi,
-		digest:     fmt.Sprintf("%x", sha256sum.Sum(nil)),
-	}, nil
+	m.filepath = p
+	m.fi = fi
+	m.digest = fmt.Sprintf("%x", sha256sum.Sum(nil))
+
+	return &m, nil
 }
 }
 
 
 func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
 func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
@@ -108,7 +110,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
 	}
 	}
 	defer f.Close()
 	defer f.Close()
 
 
-	m := ManifestV2{
+	m := Manifest{
 		SchemaVersion: 2,
 		SchemaVersion: 2,
 		MediaType:     "application/vnd.docker.distribution.manifest.v2+json",
 		MediaType:     "application/vnd.docker.distribution.manifest.v2+json",
 		Config:        config,
 		Config:        config,

+ 1 - 1
server/manifest_test.go

@@ -25,7 +25,7 @@ func createManifest(t *testing.T, path, name string) {
 	}
 	}
 	defer f.Close()
 	defer f.Close()
 
 
-	if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
+	if err := json.NewEncoder(f).Encode(Manifest{}); err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 }
 }

+ 41 - 24
server/model.go

@@ -15,7 +15,7 @@ import (
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/convert"
 	"github.com/ollama/ollama/convert"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/llm"
-	"github.com/ollama/ollama/templates"
+	"github.com/ollama/ollama/template"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/types/model"
 )
 )
 
 
@@ -63,7 +63,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
 			}
 			}
 			defer blob.Close()
 			defer blob.Close()
 
 
-			ggml, _, err := llm.DecodeGGML(blob)
+			ggml, _, err := llm.DecodeGGML(blob, 0)
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
@@ -77,62 +77,79 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
 	return layers, nil
 	return layers, nil
 }
 }
 
 
-func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
+func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse)) error {
 	stat, err := file.Stat()
 	stat, err := file.Stat()
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return err
 	}
 	}
 
 
 	r, err := zip.NewReader(file, stat.Size())
 	r, err := zip.NewReader(file, stat.Size())
 	if err != nil {
 	if err != nil {
-		return nil, err
-	}
-
-	tempdir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
-	if err != nil {
-		return nil, err
+		return err
 	}
 	}
-	defer os.RemoveAll(tempdir)
 
 
 	fn(api.ProgressResponse{Status: "unpacking model metadata"})
 	fn(api.ProgressResponse{Status: "unpacking model metadata"})
 	for _, f := range r.File {
 	for _, f := range r.File {
+		if !filepath.IsLocal(f.Name) {
+			return fmt.Errorf("%w: %s", zip.ErrInsecurePath, f.Name)
+		}
+
+		n := filepath.Join(p, f.Name)
+		if err := os.MkdirAll(filepath.Dir(n), 0o750); err != nil {
+			return err
+		}
+
 		// TODO(mxyng): this should not write out all files to disk
 		// TODO(mxyng): this should not write out all files to disk
-		outfile, err := os.Create(filepath.Join(tempdir, f.Name))
+		outfile, err := os.Create(n)
 		if err != nil {
 		if err != nil {
-			return nil, err
+			return err
 		}
 		}
 		defer outfile.Close()
 		defer outfile.Close()
 
 
 		infile, err := f.Open()
 		infile, err := f.Open()
 		if err != nil {
 		if err != nil {
-			return nil, err
+			return err
 		}
 		}
 		defer infile.Close()
 		defer infile.Close()
 
 
 		if _, err = io.Copy(outfile, infile); err != nil {
 		if _, err = io.Copy(outfile, infile); err != nil {
-			return nil, err
+			return err
 		}
 		}
 
 
 		if err := outfile.Close(); err != nil {
 		if err := outfile.Close(); err != nil {
-			return nil, err
+			return err
 		}
 		}
 
 
 		if err := infile.Close(); err != nil {
 		if err := infile.Close(); err != nil {
-			return nil, err
+			return err
 		}
 		}
 	}
 	}
 
 
-	mf, err := convert.GetModelFormat(tempdir)
+	return nil
+}
+
+func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
+	tempDir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
+	if err != nil {
+		return nil, err
+	}
+	defer os.RemoveAll(tempDir)
+
+	if err := extractFromZipFile(tempDir, file, fn); err != nil {
+		return nil, err
+	}
+
+	mf, err := convert.GetModelFormat(tempDir)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	params, err := mf.GetParams(tempdir)
+	params, err := mf.GetParams(tempDir)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	mArch, err := mf.GetModelArch("", tempdir, params)
+	mArch, err := mf.GetModelArch("", tempDir, params)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -150,7 +167,7 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a
 
 
 	// TODO(mxyng): this should write directly into a layer
 	// TODO(mxyng): this should write directly into a layer
 	// e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
 	// e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
-	temp, err := os.CreateTemp(tempdir, "fp16")
+	temp, err := os.CreateTemp(tempDir, "fp16")
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -176,7 +193,7 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a
 	}
 	}
 	defer bin.Close()
 	defer bin.Close()
 
 
-	ggml, _, err := llm.DecodeGGML(bin)
+	ggml, _, err := llm.DecodeGGML(bin, 0)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -210,7 +227,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
 
 
 	var offset int64
 	var offset int64
 	for offset < stat.Size() {
 	for offset < stat.Size() {
-		ggml, n, err := llm.DecodeGGML(file)
+		ggml, n, err := llm.DecodeGGML(file, 0)
 		if errors.Is(err, io.EOF) {
 		if errors.Is(err, io.EOF) {
 			break
 			break
 		} else if err != nil {
 		} else if err != nil {
@@ -239,7 +256,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
 func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
 func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
 	for _, layer := range layers {
 	for _, layer := range layers {
 		if s := layer.GGML.KV().ChatTemplate(); s != "" {
 		if s := layer.GGML.KV().ChatTemplate(); s != "" {
-			if t, err := templates.NamedTemplate(s); err != nil {
+			if t, err := template.Named(s); err != nil {
 				slog.Debug("template detection", "error", err)
 				slog.Debug("template detection", "error", err)
 			} else {
 			} else {
 				tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
 				tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")

+ 112 - 0
server/model_test.go

@@ -0,0 +1,112 @@
+package server
+
+import (
+	"archive/zip"
+	"bytes"
+	"errors"
+	"io"
+	"os"
+	"path/filepath"
+	"slices"
+	"strings"
+	"testing"
+
+	"github.com/ollama/ollama/api"
+)
+
+func createZipFile(t *testing.T, name string) *os.File {
+	t.Helper()
+
+	f, err := os.CreateTemp(t.TempDir(), "")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	zf := zip.NewWriter(f)
+	defer zf.Close()
+
+	zh, err := zf.CreateHeader(&zip.FileHeader{Name: name})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if _, err := io.Copy(zh, bytes.NewReader([]byte(""))); err != nil {
+		t.Fatal(err)
+	}
+
+	return f
+}
+
+func TestExtractFromZipFile(t *testing.T) {
+	cases := []struct {
+		name   string
+		expect []string
+		err    error
+	}{
+		{
+			name:   "good",
+			expect: []string{"good"},
+		},
+		{
+			name:   strings.Join([]string{"path", "..", "to", "good"}, string(os.PathSeparator)),
+			expect: []string{filepath.Join("to", "good")},
+		},
+		{
+			name:   strings.Join([]string{"path", "..", "to", "..", "good"}, string(os.PathSeparator)),
+			expect: []string{"good"},
+		},
+		{
+			name:   strings.Join([]string{"path", "to", "..", "..", "good"}, string(os.PathSeparator)),
+			expect: []string{"good"},
+		},
+		{
+			name: strings.Join([]string{"..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"}, string(os.PathSeparator)),
+			err:  zip.ErrInsecurePath,
+		},
+		{
+			name: strings.Join([]string{"path", "..", "..", "to", "bad"}, string(os.PathSeparator)),
+			err:  zip.ErrInsecurePath,
+		},
+	}
+
+	for _, tt := range cases {
+		t.Run(tt.name, func(t *testing.T) {
+			f := createZipFile(t, tt.name)
+			defer f.Close()
+
+			tempDir := t.TempDir()
+			if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); !errors.Is(err, tt.err) {
+				t.Fatal(err)
+			}
+
+			var matches []string
+			if err := filepath.Walk(tempDir, func(p string, fi os.FileInfo, err error) error {
+				if err != nil {
+					return err
+				}
+
+				if !fi.IsDir() {
+					matches = append(matches, p)
+				}
+
+				return nil
+			}); err != nil {
+				t.Fatal(err)
+			}
+
+			var actual []string
+			for _, match := range matches {
+				rel, err := filepath.Rel(tempDir, match)
+				if err != nil {
+					t.Error(err)
+				}
+
+				actual = append(actual, rel)
+			}
+
+			if !slices.Equal(actual, tt.expect) {
+				t.Fatalf("expected %d files, got %d", len(tt.expect), len(matches))
+			}
+		})
+	}
+}

+ 7 - 11
server/prompt.go

@@ -4,10 +4,11 @@ import (
 	"fmt"
 	"fmt"
 	"log/slog"
 	"log/slog"
 	"strings"
 	"strings"
-	"text/template"
+
 	"text/template/parse"
 	"text/template/parse"
 
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/template"
 )
 )
 
 
 // isResponseNode checks if the node contains .Response
 // isResponseNode checks if the node contains .Response
@@ -53,13 +54,8 @@ func formatTemplateForResponse(tmpl *template.Template, generate bool) {
 
 
 // Prompt renders a prompt from a template. If generate is set to true,
 // Prompt renders a prompt from a template. If generate is set to true,
 // the response and parts of the template following it are not rendered
 // the response and parts of the template following it are not rendered
-func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) {
-	parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl)
-	if err != nil {
-		return "", err
-	}
-
-	formatTemplateForResponse(parsed, generate)
+func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) {
+	formatTemplateForResponse(tmpl, generate)
 
 
 	vars := map[string]any{
 	vars := map[string]any{
 		"System":   system,
 		"System":   system,
@@ -68,14 +64,14 @@ func Prompt(tmpl, system, prompt, response string, generate bool) (string, error
 	}
 	}
 
 
 	var sb strings.Builder
 	var sb strings.Builder
-	if err := parsed.Execute(&sb, vars); err != nil {
+	if err := tmpl.Execute(&sb, vars); err != nil {
 		return "", err
 		return "", err
 	}
 	}
 
 
 	return sb.String(), nil
 	return sb.String(), nil
 }
 }
 
 
-func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
+func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
 	rendered, err := Prompt(tmpl, system, prompt, response, false)
 	rendered, err := Prompt(tmpl, system, prompt, response, false)
 	if err != nil {
 	if err != nil {
 		return 0, err
 		return 0, err
@@ -91,7 +87,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
 }
 }
 
 
 // ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
 // ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
-func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
+func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
 	type prompt struct {
 	type prompt struct {
 		System   string
 		System   string
 		Prompt   string
 		Prompt   string

+ 13 - 2
server/prompt_test.go

@@ -5,6 +5,7 @@ import (
 	"testing"
 	"testing"
 
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/template"
 )
 )
 
 
 func TestPrompt(t *testing.T) {
 func TestPrompt(t *testing.T) {
@@ -61,7 +62,12 @@ func TestPrompt(t *testing.T) {
 
 
 	for _, tc := range tests {
 	for _, tc := range tests {
 		t.Run(tc.name, func(t *testing.T) {
 		t.Run(tc.name, func(t *testing.T) {
-			got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate)
+			tmpl, err := template.Parse(tc.template)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate)
 			if err != nil {
 			if err != nil {
 				t.Errorf("error = %v", err)
 				t.Errorf("error = %v", err)
 			}
 			}
@@ -192,7 +198,12 @@ func TestChatPrompt(t *testing.T) {
 
 
 	for _, tc := range tests {
 	for _, tc := range tests {
 		t.Run(tc.name, func(t *testing.T) {
 		t.Run(tc.name, func(t *testing.T) {
-			got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
+			tmpl, err := template.Parse(tc.template)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode)
 			if err != nil {
 			if err != nil {
 				t.Errorf("error = %v", err)
 				t.Errorf("error = %v", err)
 			}
 			}

+ 79 - 13
server/routes.go

@@ -31,6 +31,7 @@ import (
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/openai"
 	"github.com/ollama/ollama/openai"
 	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/parser"
+	"github.com/ollama/ollama/template"
 	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
 	"github.com/ollama/ollama/version"
@@ -121,8 +122,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	if model.IsEmbedding() {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"})
+	if !model.Has(CapabilityCompletion) {
+		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
 		return
 		return
 	}
 	}
 
 
@@ -161,6 +162,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
+	tmpl, err := template.Parse(req.Template)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+
 	checkpointLoaded := time.Now()
 	checkpointLoaded := time.Now()
 
 
 	var prompt string
 	var prompt string
@@ -169,7 +176,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		prompt = req.Prompt
 		prompt = req.Prompt
 	case req.Prompt != "":
 	case req.Prompt != "":
 		if req.Template == "" {
 		if req.Template == "" {
-			req.Template = model.Template
+			tmpl = model.Template
 		}
 		}
 
 
 		if req.System == "" {
 		if req.System == "" {
@@ -187,7 +194,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 
 
 		sb.WriteString(req.Prompt)
 		sb.WriteString(req.Prompt)
 
 
-		p, err := Prompt(req.Template, req.System, sb.String(), "", true)
+		p, err := Prompt(tmpl, req.System, sb.String(), "", true)
 		if err != nil {
 		if err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			return
 			return
@@ -242,7 +249,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 				resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
 				resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
 
 
 				if !req.Raw {
 				if !req.Raw {
-					p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
+					p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
 					if err != nil {
 					if err != nil {
 						c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 						c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 						return
 						return
@@ -832,7 +839,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 	}
 	}
 
 
 	if req.Template != "" {
 	if req.Template != "" {
-		m.Template = req.Template
+		m.Template, err = template.Parse(req.Template)
+		if err != nil {
+			return nil, err
+		}
 	}
 	}
 
 
 	msgs := make([]api.Message, 0)
 	msgs := make([]api.Message, 0)
@@ -853,7 +863,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 	resp := &api.ShowResponse{
 	resp := &api.ShowResponse{
 		License:    strings.Join(m.License, "\n"),
 		License:    strings.Join(m.License, "\n"),
 		System:     m.System,
 		System:     m.System,
-		Template:   m.Template,
+		Template:   m.Template.String(),
 		Details:    modelDetails,
 		Details:    modelDetails,
 		Messages:   msgs,
 		Messages:   msgs,
 		ModifiedAt: manifest.fi.ModTime(),
 		ModifiedAt: manifest.fi.ModTime(),
@@ -886,9 +896,48 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 	fmt.Fprint(&sb, m.String())
 	fmt.Fprint(&sb, m.String())
 	resp.Modelfile = sb.String()
 	resp.Modelfile = sb.String()
 
 
+	kvData, err := getKVData(m.ModelPath, req.Verbose)
+	if err != nil {
+		return nil, err
+	}
+	delete(kvData, "general.name")
+	delete(kvData, "tokenizer.chat_template")
+	resp.ModelInfo = kvData
+
+	if len(m.ProjectorPaths) > 0 {
+		projectorData, err := getKVData(m.ProjectorPaths[0], req.Verbose)
+		if err != nil {
+			return nil, err
+		}
+		resp.ProjectorInfo = projectorData
+	}
+
 	return resp, nil
 	return resp, nil
 }
 }
 
 
+func getKVData(digest string, verbose bool) (llm.KV, error) {
+	maxArraySize := 0
+	if verbose {
+		maxArraySize = -1
+	}
+	kvData, err := llm.LoadModel(digest, maxArraySize)
+	if err != nil {
+		return nil, err
+	}
+
+	kv := kvData.KV()
+
+	if !verbose {
+		for k := range kv {
+			if t, ok := kv[k].([]any); len(t) > 5 && ok {
+				kv[k] = []any{}
+			}
+		}
+	}
+
+	return kv, nil
+}
+
 func (s *Server) ListModelsHandler(c *gin.Context) {
 func (s *Server) ListModelsHandler(c *gin.Context) {
 	ms, err := Manifests()
 	ms, err := Manifests()
 	if err != nil {
 	if err != nil {
@@ -1153,7 +1202,10 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r.GET("/api/ps", s.ProcessHandler)
 	r.GET("/api/ps", s.ProcessHandler)
 
 
 	// Compatibility endpoints
 	// Compatibility endpoints
-	r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
+	r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
+	r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
+	r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
+	r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
 
 
 	for _, method := range []string{http.MethodGet, http.MethodHead} {
 	for _, method := range []string{http.MethodGet, http.MethodHead} {
 		r.Handle(method, "/", func(c *gin.Context) {
 		r.Handle(method, "/", func(c *gin.Context) {
@@ -1219,11 +1271,20 @@ func Serve(ln net.Listener) error {
 	schedCtx, schedDone := context.WithCancel(ctx)
 	schedCtx, schedDone := context.WithCancel(ctx)
 	sched := InitScheduler(schedCtx)
 	sched := InitScheduler(schedCtx)
 	s := &Server{addr: ln.Addr(), sched: sched}
 	s := &Server{addr: ln.Addr(), sched: sched}
-	r := s.GenerateRoutes()
+
+	http.Handle("/", s.GenerateRoutes())
 
 
 	slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
 	slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
 	srvr := &http.Server{
 	srvr := &http.Server{
-		Handler: r,
+		// Use http.DefaultServeMux so we get net/http/pprof for
+		// free.
+		//
+		// TODO(bmizerany): Decide if we want to make this
+		// configurable so it is not exposed by default, or allow
+		// users to bind it to a different port. This was a quick
+		// and easy way to get pprof, but it may not be the best
+		// way.
+		Handler: nil,
 	}
 	}
 
 
 	// listen for a ctrl+c and stop any loaded llm
 	// listen for a ctrl+c and stop any loaded llm
@@ -1342,11 +1403,16 @@ func (s *Server) ProcessHandler(c *gin.Context) {
 		models = append(models, mr)
 		models = append(models, mr)
 	}
 	}
 
 
+	slices.SortStableFunc(models, func(i, j api.ProcessModelResponse) int {
+		// longest duration remaining listed first
+		return cmp.Compare(j.ExpiresAt.Unix(), i.ExpiresAt.Unix())
+	})
+
 	c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
 	c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
 }
 }
 
 
 // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
 // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
-func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
+func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
 	encode := func(s string) ([]int, error) {
 	encode := func(s string) ([]int, error) {
 		return runner.llama.Tokenize(ctx, s)
 		return runner.llama.Tokenize(ctx, s)
 	}
 	}
@@ -1394,8 +1460,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	if model.IsEmbedding() {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"})
+	if !model.Has(CapabilityCompletion) {
+		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
 		return
 		return
 	}
 	}
 
 

+ 96 - 1
server/routes_test.go

@@ -20,6 +20,8 @@ import (
 
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/envconfig"
+	"github.com/ollama/ollama/llm"
+	"github.com/ollama/ollama/openai"
 	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
 	"github.com/ollama/ollama/version"
@@ -105,6 +107,24 @@ func Test_Routes(t *testing.T) {
 				assert.Empty(t, len(modelList.Models))
 				assert.Empty(t, len(modelList.Models))
 			},
 			},
 		},
 		},
+		{
+			Name:   "openai empty list",
+			Method: http.MethodGet,
+			Path:   "/v1/models",
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				assert.Equal(t, "application/json", contentType)
+				body, err := io.ReadAll(resp.Body)
+				require.NoError(t, err)
+
+				var modelList openai.ListCompletion
+				err = json.Unmarshal(body, &modelList)
+				require.NoError(t, err)
+
+				assert.Equal(t, "list", modelList.Object)
+				assert.Empty(t, modelList.Data)
+			},
+		},
 		{
 		{
 			Name:   "Tags Handler (yes tags)",
 			Name:   "Tags Handler (yes tags)",
 			Method: http.MethodGet,
 			Method: http.MethodGet,
@@ -128,6 +148,25 @@ func Test_Routes(t *testing.T) {
 				assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
 				assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
 			},
 			},
 		},
 		},
+		{
+			Name:   "openai list models with tags",
+			Method: http.MethodGet,
+			Path:   "/v1/models",
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				assert.Equal(t, "application/json", contentType)
+				body, err := io.ReadAll(resp.Body)
+				require.NoError(t, err)
+
+				var modelList openai.ListCompletion
+				err = json.Unmarshal(body, &modelList)
+				require.NoError(t, err)
+
+				assert.Len(t, modelList.Data, 1)
+				assert.Equal(t, "test-model:latest", modelList.Data[0].Id)
+				assert.Equal(t, "library", modelList.Data[0].OwnedBy)
+			},
+		},
 		{
 		{
 			Name:   "Create Model Handler",
 			Name:   "Create Model Handler",
 			Method: http.MethodPost,
 			Method: http.MethodPost,
@@ -213,6 +252,25 @@ func Test_Routes(t *testing.T) {
 					"top_p 0.9",
 					"top_p 0.9",
 				}
 				}
 				assert.Equal(t, expectedParams, params)
 				assert.Equal(t, expectedParams, params)
+				assert.InDelta(t, 0, showResp.ModelInfo["general.parameter_count"], 1e-9, "Parameter count should be 0")
+			},
+		},
+		{
+			Name:   "openai retrieve model handler",
+			Method: http.MethodGet,
+			Path:   "/v1/models/show-model",
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				assert.Equal(t, "application/json", contentType)
+				body, err := io.ReadAll(resp.Body)
+				require.NoError(t, err)
+
+				var retrieveResp api.RetrieveModelResponse
+				err = json.Unmarshal(body, &retrieveResp)
+				require.NoError(t, err)
+
+				assert.Equal(t, "show-model", retrieveResp.Id)
+				assert.Equal(t, "library", retrieveResp.OwnedBy)
 			},
 			},
 		},
 		},
 	}
 	}
@@ -327,6 +385,43 @@ func TestCase(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestShow(t *testing.T) {
+	t.Setenv("OLLAMA_MODELS", t.TempDir())
+	envconfig.LoadConfig()
+
+	var s Server
+
+	createRequest(t, s.CreateModelHandler, api.CreateRequest{
+		Name: "show-model",
+		Modelfile: fmt.Sprintf(
+			"FROM %s\nFROM %s",
+			createBinFile(t, llm.KV{"general.architecture": "test"}, nil),
+			createBinFile(t, llm.KV{"general.architecture": "clip"}, nil),
+		),
+	})
+
+	w := createRequest(t, s.ShowModelHandler, api.ShowRequest{
+		Name: "show-model",
+	})
+
+	if w.Code != http.StatusOK {
+		t.Fatalf("expected status code 200, actual %d", w.Code)
+	}
+
+	var resp api.ShowResponse
+	if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
+		t.Fatal(err)
+	}
+
+	if resp.ModelInfo["general.architecture"] != "test" {
+		t.Fatal("Expected model architecture to be 'test', but got", resp.ModelInfo["general.architecture"])
+	}
+
+	if resp.ProjectorInfo["general.architecture"] != "clip" {
+		t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
+	}
+}
+
 func TestNormalize(t *testing.T) {
 func TestNormalize(t *testing.T) {
 	type testCase struct {
 	type testCase struct {
 		input []float32
 		input []float32
@@ -359,5 +454,5 @@ func TestNormalize(t *testing.T) {
 				t.Errorf("Vector %v is not normalized", tc.input)
 				t.Errorf("Vector %v is not normalized", tc.input)
 			}
 			}
 		})
 		})
-	}
+  }
 }
 }

+ 101 - 25
server/sched.go

@@ -23,6 +23,7 @@ type LlmRequest struct {
 	ctx             context.Context //nolint:containedctx
 	ctx             context.Context //nolint:containedctx
 	model           *Model
 	model           *Model
 	opts            api.Options
 	opts            api.Options
+	origNumCtx      int // Track the initial ctx request
 	sessionDuration time.Duration
 	sessionDuration time.Duration
 	successCh       chan *runnerRef
 	successCh       chan *runnerRef
 	errCh           chan error
 	errCh           chan error
@@ -38,13 +39,23 @@ type Scheduler struct {
 	loaded   map[string]*runnerRef
 	loaded   map[string]*runnerRef
 	loadedMu sync.Mutex
 	loadedMu sync.Mutex
 
 
-	loadFn       func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
-	newServerFn  func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error)
+	loadFn       func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int)
+	newServerFn  func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
 	getGpuFn     func() gpu.GpuInfoList
 	getGpuFn     func() gpu.GpuInfoList
 	getCpuFn     func() gpu.GpuInfoList
 	getCpuFn     func() gpu.GpuInfoList
 	reschedDelay time.Duration
 	reschedDelay time.Duration
 }
 }
 
 
+// Default automatic value for number of models we allow per GPU
+// Model will still need to fit in VRAM, but loading many small models
+// on a large GPU can cause stalling
+var defaultModelsPerGPU = 3
+
+// Default automatic value for parallel setting
+// Model will still need to fit in VRAM.  If this setting wont fit
+// we'll back off down to 1 to try to get it to fit
+var defaultParallel = 4
+
 var ErrMaxQueue = fmt.Errorf("server busy, please try again.  maximum pending requests exceeded")
 var ErrMaxQueue = fmt.Errorf("server busy, please try again.  maximum pending requests exceeded")
 
 
 func InitScheduler(ctx context.Context) *Scheduler {
 func InitScheduler(ctx context.Context) *Scheduler {
@@ -65,13 +76,10 @@ func InitScheduler(ctx context.Context) *Scheduler {
 
 
 // context must be canceled to decrement ref count and release the runner
 // context must be canceled to decrement ref count and release the runner
 func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
 func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
-	// allocate a large enough kv cache for all parallel requests
 	if opts.NumCtx < 4 {
 	if opts.NumCtx < 4 {
 		opts.NumCtx = 4
 		opts.NumCtx = 4
 	}
 	}
 
 
-	opts.NumCtx *= envconfig.NumParallel
-
 	req := &LlmRequest{
 	req := &LlmRequest{
 		ctx:             c,
 		ctx:             c,
 		model:           model,
 		model:           model,
@@ -110,11 +118,25 @@ func (s *Scheduler) processPending(ctx context.Context) {
 		case pending := <-s.pendingReqCh:
 		case pending := <-s.pendingReqCh:
 			// Block other requests until we get this pending request running
 			// Block other requests until we get this pending request running
 			pending.schedAttempts++
 			pending.schedAttempts++
+			if pending.origNumCtx == 0 {
+				pending.origNumCtx = pending.opts.NumCtx
+			}
 
 
 			if pending.ctx.Err() != nil {
 			if pending.ctx.Err() != nil {
 				slog.Debug("pending request cancelled or timed out, skipping scheduling")
 				slog.Debug("pending request cancelled or timed out, skipping scheduling")
 				continue
 				continue
 			}
 			}
+			numParallel := envconfig.NumParallel
+			// TODO (jmorganca): multimodal models don't support parallel yet
+			// see https://github.com/ollama/ollama/issues/4165
+			if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 {
+				numParallel = 1
+				slog.Warn("multimodal models don't support parallel requests yet")
+			}
+			// Keep NumCtx and numParallel in sync
+			if numParallel > 1 {
+				pending.opts.NumCtx = pending.origNumCtx * numParallel
+			}
 
 
 			for {
 			for {
 				var runnerToExpire *runnerRef
 				var runnerToExpire *runnerRef
@@ -143,8 +165,28 @@ func (s *Scheduler) processPending(ctx context.Context) {
 						gpus = s.getGpuFn()
 						gpus = s.getGpuFn()
 					}
 					}
 
 
+					if envconfig.MaxRunners <= 0 {
+						// No user specified MaxRunners, so figure out what automatic setting to use
+						// If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs
+						// if any GPU has unreliable free memory reporting, 1x the number of GPUs
+						allReliable := true
+						for _, gpu := range gpus {
+							if gpu.UnreliableFreeMemory {
+								allReliable = false
+								break
+							}
+						}
+						if allReliable {
+							envconfig.MaxRunners = defaultModelsPerGPU * len(gpus)
+							slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners, "gpu_count", len(gpus))
+						} else {
+							slog.Info("one or more GPUs detected that are unable to accurately report free memory - disabling default concurrency")
+							envconfig.MaxRunners = len(gpus)
+						}
+					}
+
 					// Load model for fitting
 					// Load model for fitting
-					ggml, err := llm.LoadModel(pending.model.ModelPath)
+					ggml, err := llm.LoadModel(pending.model.ModelPath, 0)
 					if err != nil {
 					if err != nil {
 						pending.errCh <- err
 						pending.errCh <- err
 						break
 						break
@@ -152,26 +194,32 @@ func (s *Scheduler) processPending(ctx context.Context) {
 
 
 					// Evaluate if the model will fit in the available system memory, or if we should unload a model first
 					// Evaluate if the model will fit in the available system memory, or if we should unload a model first
 					if len(gpus) == 1 && gpus[0].Library == "cpu" {
 					if len(gpus) == 1 && gpus[0].Library == "cpu" {
+						// simplifying assumption of defaultParallel when in CPU mode
+						if numParallel <= 0 {
+							numParallel = defaultParallel
+							pending.opts.NumCtx = pending.origNumCtx * numParallel
+						}
+
 						if loadedCount == 0 {
 						if loadedCount == 0 {
 							slog.Debug("cpu mode with first model, loading")
 							slog.Debug("cpu mode with first model, loading")
-							s.loadFn(pending, ggml, gpus)
+							s.loadFn(pending, ggml, gpus, numParallel)
 							break
 							break
 						}
 						}
 						runnerToExpire = s.maybeFindCPURunnerToUnload(pending, ggml, gpus)
 						runnerToExpire = s.maybeFindCPURunnerToUnload(pending, ggml, gpus)
 						if runnerToExpire == nil {
 						if runnerToExpire == nil {
 							slog.Debug("cpu mode with available system memory or first model, loading")
 							slog.Debug("cpu mode with available system memory or first model, loading")
-							s.loadFn(pending, ggml, gpus)
+							s.loadFn(pending, ggml, gpus, numParallel)
 							break
 							break
 						}
 						}
 						// else we need to expire a runner
 						// else we need to expire a runner
 					} else if loadedCount == 0 {
 					} else if loadedCount == 0 {
 						// No models loaded. Load the model but prefer the best fit.
 						// No models loaded. Load the model but prefer the best fit.
 						slog.Debug("loading first model", "model", pending.model.ModelPath)
 						slog.Debug("loading first model", "model", pending.model.ModelPath)
-						g := pickBestFitGPUs(pending, ggml, gpus)
+						g := pickBestFitGPUs(pending, ggml, gpus, &numParallel)
 						if g != nil {
 						if g != nil {
 							gpus = g
 							gpus = g
 						}
 						}
-						s.loadFn(pending, ggml, gpus)
+						s.loadFn(pending, ggml, gpus, numParallel)
 						break
 						break
 					}
 					}
 
 
@@ -186,10 +234,10 @@ func (s *Scheduler) processPending(ctx context.Context) {
 
 
 						// Update free memory from currently loaded models
 						// Update free memory from currently loaded models
 						s.updateFreeSpace(availGpus)
 						s.updateFreeSpace(availGpus)
-						fitGpus := pickBestFitGPUs(pending, ggml, availGpus)
+						fitGpus := pickBestFitGPUs(pending, ggml, availGpus, &numParallel)
 						if fitGpus != nil {
 						if fitGpus != nil {
 							slog.Debug("new model fits with existing models, loading")
 							slog.Debug("new model fits with existing models, loading")
-							s.loadFn(pending, ggml, fitGpus)
+							s.loadFn(pending, ggml, fitGpus, numParallel)
 							break
 							break
 						}
 						}
 
 
@@ -350,8 +398,11 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
 	}()
 	}()
 }
 }
 
 
-func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) {
-	llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
+func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
+	if numParallel < 1 {
+		numParallel = 1
+	}
+	llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
 	if err != nil {
 	if err != nil {
 		// some older models are not compatible with newer versions of llama.cpp
 		// some older models are not compatible with newer versions of llama.cpp
 		// show a generalized compatibility error until there is a better way to
 		// show a generalized compatibility error until there is a better way to
@@ -375,6 +426,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
 		loading:         true,
 		loading:         true,
 		refCount:        1,
 		refCount:        1,
 	}
 	}
+	runner.numParallel = numParallel
 	runner.refMu.Lock()
 	runner.refMu.Lock()
 
 
 	s.loadedMu.Lock()
 	s.loadedMu.Lock()
@@ -483,8 +535,9 @@ type runnerRef struct {
 	expireTimer     *time.Timer
 	expireTimer     *time.Timer
 	expiresAt       time.Time
 	expiresAt       time.Time
 
 
-	model     *Model
-	modelPath string
+	model       *Model
+	modelPath   string
+	numParallel int
 	*api.Options
 	*api.Options
 }
 }
 
 
@@ -525,6 +578,9 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
 		optsNew.NumGPU = -1
 		optsNew.NumGPU = -1
 	}
 	}
 
 
+	// Normalize the NumCtx for parallelism
+	optsExisting.NumCtx = optsExisting.NumCtx / runner.numParallel
+
 	ctx, cancel := context.WithTimeout(ctx, timeout)
 	ctx, cancel := context.WithTimeout(ctx, timeout)
 	defer cancel()
 	defer cancel()
 	if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
 	if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
@@ -611,22 +667,38 @@ func (a ByDuration) Less(i, j int) bool {
 
 
 // pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits
 // pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits
 // If the model can not be fit fully within the available GPU(s) nil is returned
 // If the model can not be fit fully within the available GPU(s) nil is returned
-func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu.GpuInfoList {
+// If numParallel is <= 0, this will attempt try to optimize parallism based on available VRAM, and adjust
+// opts.NumCtx accordingly
+func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel *int) gpu.GpuInfoList {
 	var estimatedVRAM uint64
 	var estimatedVRAM uint64
+
+	var numParallelToTry []int
+	if *numParallel <= 0 {
+		// If no specific parallel setting was provided, try larger then smaller, always end with 1
+		numParallelToTry = append(numParallelToTry, defaultParallel, 1)
+	} else {
+		numParallelToTry = []int{*numParallel}
+	}
+
 	for _, gl := range gpus.ByLibrary() {
 	for _, gl := range gpus.ByLibrary() {
 		var ok bool
 		var ok bool
 		sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...)
 		sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...)
 
 
 		// TODO - potentially sort by performance capability, existing models loaded, etc.
 		// TODO - potentially sort by performance capability, existing models loaded, etc.
+		// TODO - Eliminate any GPUs that already have envconfig.MaxRunners loaded on them
 		// Note: at present, this will favor more VRAM over faster GPU speed in mixed setups
 		// Note: at present, this will favor more VRAM over faster GPU speed in mixed setups
 		sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
 		sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
 
 
 		// First attempt to fit the model into a single GPU
 		// First attempt to fit the model into a single GPU
-		if !envconfig.SchedSpread {
-			for _, g := range sgl {
-				if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
-					slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
-					return []gpu.GpuInfo{g}
+		for _, p := range numParallelToTry {
+			req.opts.NumCtx = req.origNumCtx * p
+			if !envconfig.SchedSpread {
+				for _, g := range sgl {
+					if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
+						slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
+						*numParallel = p
+						return []gpu.GpuInfo{g}
+					}
 				}
 				}
 			}
 			}
 		}
 		}
@@ -636,9 +708,13 @@ func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu.
 		// - try subsets of GPUs instead of just falling back to 1 or all in a family
 		// - try subsets of GPUs instead of just falling back to 1 or all in a family
 
 
 		// Now try all the GPUs
 		// Now try all the GPUs
-		if ok, estimatedVRAM = llm.PredictServerFit(sgl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
-			slog.Debug("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "required", format.HumanBytes2(estimatedVRAM))
-			return sgl
+		for _, p := range numParallelToTry {
+			req.opts.NumCtx = req.origNumCtx * p
+			if ok, estimatedVRAM = llm.PredictServerFit(sgl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
+				slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
+				*numParallel = p
+				return sgl
+			}
 		}
 		}
 	}
 	}
 	return nil
 	return nil

+ 54 - 32
server/sched_test.go

@@ -47,11 +47,11 @@ func TestLoad(t *testing.T) {
 		sessionDuration: 2,
 		sessionDuration: 2,
 	}
 	}
 	// Fail to load model first
 	// Fail to load model first
-	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
+	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
 		return nil, fmt.Errorf("something failed to load model blah")
 		return nil, fmt.Errorf("something failed to load model blah")
 	}
 	}
 	gpus := gpu.GpuInfoList{}
 	gpus := gpu.GpuInfoList{}
-	s.load(req, ggml, gpus)
+	s.load(req, ggml, gpus, 0)
 	require.Empty(t, req.successCh)
 	require.Empty(t, req.successCh)
 	require.Len(t, req.errCh, 1)
 	require.Len(t, req.errCh, 1)
 	s.loadedMu.Lock()
 	s.loadedMu.Lock()
@@ -61,10 +61,10 @@ func TestLoad(t *testing.T) {
 	require.Contains(t, err.Error(), "this model may be incompatible")
 	require.Contains(t, err.Error(), "this model may be incompatible")
 
 
 	server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
 	server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
-	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
+	s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
 		return server, nil
 		return server, nil
 	}
 	}
-	s.load(req, ggml, gpus)
+	s.load(req, ggml, gpus, 0)
 	select {
 	select {
 	case err := <-req.errCh:
 	case err := <-req.errCh:
 		require.NoError(t, err)
 		require.NoError(t, err)
@@ -78,12 +78,12 @@ func TestLoad(t *testing.T) {
 
 
 	req.model.ModelPath = "dummy_model_path"
 	req.model.ModelPath = "dummy_model_path"
 	server.waitResp = fmt.Errorf("wait failure")
 	server.waitResp = fmt.Errorf("wait failure")
-	s.load(req, ggml, gpus)
+	s.load(req, ggml, gpus, 0)
 	select {
 	select {
 	case err := <-req.errCh:
 	case err := <-req.errCh:
 		require.Contains(t, err.Error(), "wait failure")
 		require.Contains(t, err.Error(), "wait failure")
 	case resp := <-req.successCh:
 	case resp := <-req.successCh:
-		t.Errorf("unexpected success %v", resp)
+		t.Fatalf("unexpected success %v", resp)
 	}
 	}
 	s.loadedMu.Lock()
 	s.loadedMu.Lock()
 	runner := s.loaded["dummy_model_path"]
 	runner := s.loaded["dummy_model_path"]
@@ -102,7 +102,7 @@ type bundle struct {
 	ggml    *llm.GGML
 	ggml    *llm.GGML
 }
 }
 
 
-func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
+func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
 	return scenario.srv, nil
 	return scenario.srv, nil
 }
 }
 
 
@@ -128,14 +128,14 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
 		"tokenizer.ggml.scores":         []float32{0},
 		"tokenizer.ggml.scores":         []float32{0},
 		"tokenizer.ggml.token_type":     []int32{0},
 		"tokenizer.ggml.token_type":     []int32{0},
 	}, []llm.Tensor{
 	}, []llm.Tensor{
-		{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
-		{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
+		{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
+		{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
 	})
 	})
 	require.NoError(t, err)
 	require.NoError(t, err)
 
 
 	fname := f.Name()
 	fname := f.Name()
 	model := &Model{Name: modelName, ModelPath: fname}
 	model := &Model{Name: modelName, ModelPath: fname}
-	scenario.ggml, err = llm.LoadModel(model.ModelPath)
+	scenario.ggml, err = llm.LoadModel(model.ModelPath, 0)
 	require.NoError(t, err)
 	require.NoError(t, err)
 
 
 	scenario.req = &LlmRequest{
 	scenario.req = &LlmRequest{
@@ -200,7 +200,7 @@ func TestRequests(t *testing.T) {
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario1a.req.errCh)
 		require.Empty(t, scenario1a.req.errCh)
 	case <-ctx.Done():
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 	}
 
 
 	// Same runner as first request due to not needing a reload
 	// Same runner as first request due to not needing a reload
@@ -213,7 +213,7 @@ func TestRequests(t *testing.T) {
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario1b.req.errCh)
 		require.Empty(t, scenario1b.req.errCh)
 	case <-ctx.Done():
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 	}
 
 
 	// Trigger a reload
 	// Trigger a reload
@@ -231,7 +231,7 @@ func TestRequests(t *testing.T) {
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario2a.req.errCh)
 		require.Empty(t, scenario2a.req.errCh)
 	case <-ctx.Done():
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 	}
 
 
 	envconfig.MaxRunners = 1
 	envconfig.MaxRunners = 1
@@ -247,7 +247,7 @@ func TestRequests(t *testing.T) {
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario3a.req.errCh)
 		require.Empty(t, scenario3a.req.errCh)
 	case <-ctx.Done():
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 	}
 	s.loadedMu.Lock()
 	s.loadedMu.Lock()
 	require.Len(t, s.loaded, 1)
 	require.Len(t, s.loaded, 1)
@@ -263,7 +263,7 @@ func TestRequests(t *testing.T) {
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario3b.req.errCh)
 		require.Empty(t, scenario3b.req.errCh)
 	case <-ctx.Done():
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 	}
 	s.loadedMu.Lock()
 	s.loadedMu.Lock()
 	require.Len(t, s.loaded, 2)
 	require.Len(t, s.loaded, 2)
@@ -279,7 +279,7 @@ func TestRequests(t *testing.T) {
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario3c.req.errCh)
 		require.Empty(t, scenario3c.req.errCh)
 	case <-ctx.Done():
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 	}
 	s.loadedMu.Lock()
 	s.loadedMu.Lock()
 	require.Len(t, s.loaded, 3)
 	require.Len(t, s.loaded, 3)
@@ -306,7 +306,7 @@ func TestRequests(t *testing.T) {
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, scenario3d.req.errCh)
 		require.Empty(t, scenario3d.req.errCh)
 	case <-ctx.Done():
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 	}
 	s.loadedMu.Lock()
 	s.loadedMu.Lock()
 	require.Len(t, s.loaded, 2)
 	require.Len(t, s.loaded, 2)
@@ -349,7 +349,7 @@ func TestGetRunner(t *testing.T) {
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, s.pendingReqCh)
 		require.Empty(t, errCh1a)
 		require.Empty(t, errCh1a)
 	case <-ctx.Done():
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 	}
 	scenario1a.ctxDone()
 	scenario1a.ctxDone()
 	s.loadedMu.Lock()
 	s.loadedMu.Lock()
@@ -400,7 +400,7 @@ func TestPrematureExpired(t *testing.T) {
 		slog.Info("sending premature expired event now")
 		slog.Info("sending premature expired event now")
 		s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
 		s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
 	case <-ctx.Done():
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 	}
 	time.Sleep(scenario1a.req.sessionDuration)
 	time.Sleep(scenario1a.req.sessionDuration)
 	scenario1a.ctxDone()
 	scenario1a.ctxDone()
@@ -427,7 +427,7 @@ func TestUseLoadedRunner(t *testing.T) {
 	}
 	}
 	finished := make(chan *LlmRequest)
 	finished := make(chan *LlmRequest)
 	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
 	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
-	r1 := &runnerRef{llama: llm1, sessionDuration: 1}
+	r1 := &runnerRef{llama: llm1, sessionDuration: 1, numParallel: 1}
 	req.useLoadedRunner(r1, finished)
 	req.useLoadedRunner(r1, finished)
 	require.Equal(t, uint(1), r1.refCount)
 	require.Equal(t, uint(1), r1.refCount)
 	require.Equal(t, time.Duration(2), r1.sessionDuration)
 	require.Equal(t, time.Duration(2), r1.sessionDuration)
@@ -435,7 +435,7 @@ func TestUseLoadedRunner(t *testing.T) {
 	case success := <-req.successCh:
 	case success := <-req.successCh:
 		require.Equal(t, r1, success)
 		require.Equal(t, r1, success)
 	case <-ctx.Done():
 	case <-ctx.Done():
-		t.Errorf("timeout")
+		t.Fatal("timeout")
 	}
 	}
 	done()
 	done()
 	fin := <-finished
 	fin := <-finished
@@ -461,8 +461,8 @@ func TestUpdateFreeSpace(t *testing.T) {
 	gpus[1].FreeMemory = 1900
 	gpus[1].FreeMemory = 1900
 	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 50, "2": 50}}
 	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 50, "2": 50}}
 	llm2 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 125, "2": 75}}
 	llm2 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 125, "2": 75}}
-	r1 := &runnerRef{llama: llm1, gpus: gpus}
-	r2 := &runnerRef{llama: llm2, gpus: gpus}
+	r1 := &runnerRef{llama: llm1, gpus: gpus, numParallel: 1}
+	r2 := &runnerRef{llama: llm2, gpus: gpus, numParallel: 1}
 
 
 	s := InitScheduler(ctx)
 	s := InitScheduler(ctx)
 	s.loadedMu.Lock()
 	s.loadedMu.Lock()
@@ -513,8 +513,8 @@ func TestFindRunnerToUnload(t *testing.T) {
 	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
 	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
 	defer done()
 	defer done()
 
 
-	r1 := &runnerRef{refCount: 1, sessionDuration: 1}
-	r2 := &runnerRef{sessionDuration: 2}
+	r1 := &runnerRef{refCount: 1, sessionDuration: 1, numParallel: 1}
+	r2 := &runnerRef{sessionDuration: 2, numParallel: 1}
 
 
 	s := InitScheduler(ctx)
 	s := InitScheduler(ctx)
 	s.loadedMu.Lock()
 	s.loadedMu.Lock()
@@ -536,9 +536,13 @@ func TestNeedsReload(t *testing.T) {
 	llm := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
 	llm := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
 	do := api.DefaultOptions()
 	do := api.DefaultOptions()
 	runner := &runnerRef{
 	runner := &runnerRef{
-		model:   &Model{AdapterPaths: []string{"adapter1"}, ProjectorPaths: []string{"projector1"}},
-		Options: &do,
-		llama:   llm,
+		model: &Model{
+			AdapterPaths:   []string{"adapter1"},
+			ProjectorPaths: []string{"projector1"},
+		},
+		Options:     &do,
+		llama:       llm,
+		numParallel: 1,
 	}
 	}
 	req := &LlmRequest{
 	req := &LlmRequest{
 		model: &Model{
 		model: &Model{
@@ -581,8 +585,8 @@ func TestUnloadAllRunners(t *testing.T) {
 	s := InitScheduler(ctx)
 	s := InitScheduler(ctx)
 	s.unloadAllRunners()
 	s.unloadAllRunners()
 
 
-	r1 := &runnerRef{llama: llm1}
-	r2 := &runnerRef{llama: llm2}
+	r1 := &runnerRef{llama: llm1, numParallel: 1}
+	r2 := &runnerRef{llama: llm2, numParallel: 1}
 
 
 	s.loadedMu.Lock()
 	s.loadedMu.Lock()
 	s.loaded["a"] = r1
 	s.loaded["a"] = r1
@@ -596,14 +600,32 @@ func TestUnloadAllRunners(t *testing.T) {
 
 
 func TestUnload(t *testing.T) {
 func TestUnload(t *testing.T) {
 	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
 	llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
-	r1 := &runnerRef{llama: llm1}
-	r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}}
+	r1 := &runnerRef{llama: llm1, numParallel: 1}
+	r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}, numParallel: 1}
 	r1.unload()
 	r1.unload()
 	require.True(t, llm1.closeCalled)
 	require.True(t, llm1.closeCalled)
 	r2.unload()
 	r2.unload()
 	require.Nil(t, r2.model)
 	require.Nil(t, r2.model)
 }
 }
 
 
+func TestAlreadyCanceled(t *testing.T) {
+	ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
+	defer done()
+	dctx, done2 := context.WithCancel(ctx)
+	done2()
+	scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
+	scenario1a.req.sessionDuration = 0
+	s := InitScheduler(ctx)
+	slog.Info("scenario1a")
+	s.pendingReqCh <- scenario1a.req
+	require.Len(t, s.pendingReqCh, 1)
+	s.Run(ctx)
+	time.Sleep(5 * time.Millisecond)
+	require.Empty(t, s.pendingReqCh)
+	require.Empty(t, scenario1a.req.errCh)
+	require.Empty(t, scenario1a.req.successCh)
+}
+
 type mockLlm struct {
 type mockLlm struct {
 	pingResp           error
 	pingResp           error
 	waitResp           error
 	waitResp           error

+ 0 - 0
templates/alfred.gotmpl → template/alfred.gotmpl


+ 0 - 0
templates/alpaca.gotmpl → template/alpaca.gotmpl


+ 0 - 0
templates/chatml.gotmpl → template/chatml.gotmpl


+ 0 - 0
templates/chatqa.gotmpl → template/chatqa.gotmpl


+ 0 - 0
templates/codellama-70b-instruct.gotmpl → template/codellama-70b-instruct.gotmpl


+ 0 - 0
templates/falcon-instruct.gotmpl → template/falcon-instruct.gotmpl


+ 0 - 0
templates/gemma-instruct.gotmpl → template/gemma-instruct.gotmpl


+ 0 - 0
templates/granite-instruct.gotmpl → template/granite-instruct.gotmpl


+ 0 - 0
templates/index.json → template/index.json


+ 0 - 0
templates/llama2-chat.gotmpl → template/llama2-chat.gotmpl


+ 0 - 0
templates/llama3-instruct.gotmpl → template/llama3-instruct.gotmpl


+ 0 - 0
templates/magicoder.gotmpl → template/magicoder.gotmpl


+ 0 - 0
templates/mistral-instruct.gotmpl → template/mistral-instruct.gotmpl


+ 0 - 0
templates/openchat.gotmpl → template/openchat.gotmpl


+ 0 - 0
templates/phi-3.gotmpl → template/phi-3.gotmpl


+ 0 - 0
templates/solar-instruct.gotmpl → template/solar-instruct.gotmpl


+ 0 - 0
templates/starcoder2-instruct.gotmpl → template/starcoder2-instruct.gotmpl


+ 158 - 0
template/template.go

@@ -0,0 +1,158 @@
+package template
+
+import (
+	"bytes"
+	"embed"
+	"encoding/json"
+	"errors"
+	"io"
+	"math"
+	"slices"
+	"strings"
+	"sync"
+	"text/template"
+	"text/template/parse"
+
+	"github.com/agnivade/levenshtein"
+	"golang.org/x/exp/maps"
+)
+
+//go:embed index.json
+var indexBytes []byte
+
+//go:embed *.gotmpl
+var templatesFS embed.FS
+
+var templatesOnce = sync.OnceValues(func() ([]*named, error) {
+	var templates []*named
+	if err := json.Unmarshal(indexBytes, &templates); err != nil {
+		return nil, err
+	}
+
+	for _, t := range templates {
+		bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
+		if err != nil {
+			return nil, err
+		}
+
+		// normalize line endings
+		t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
+	}
+
+	return templates, nil
+})
+
+type named struct {
+	Name     string `json:"name"`
+	Template string `json:"template"`
+	Bytes    []byte
+}
+
+func (t named) Reader() io.Reader {
+	return bytes.NewReader(t.Bytes)
+}
+
+func Named(s string) (*named, error) {
+	templates, err := templatesOnce()
+	if err != nil {
+		return nil, err
+	}
+
+	var template *named
+	score := math.MaxInt
+	for _, t := range templates {
+		if s := levenshtein.ComputeDistance(s, t.Template); s < score {
+			score = s
+			template = t
+		}
+	}
+
+	if score < 100 {
+		return template, nil
+	}
+
+	return nil, errors.New("no matching template found")
+}
+
+type Template struct {
+	*template.Template
+	raw string
+}
+
+func (t *Template) String() string {
+	return t.raw
+}
+
+var DefaultTemplate, _ = Parse("{{ .Prompt }}")
+
+func Parse(s string) (*Template, error) {
+	t, err := template.New("").Option("missingkey=zero").Parse(s)
+	if err != nil {
+		return nil, err
+	}
+
+	return &Template{Template: t, raw: s}, nil
+}
+
+func (t *Template) Vars() []string {
+	var vars []string
+	for _, n := range t.Tree.Root.Nodes {
+		vars = append(vars, parseNode(n)...)
+	}
+
+	set := make(map[string]struct{})
+	for _, n := range vars {
+		set[strings.ToLower(n)] = struct{}{}
+	}
+
+	vars = maps.Keys(set)
+	slices.Sort(vars)
+	return vars
+}
+
+func parseNode(n parse.Node) []string {
+	switch n := n.(type) {
+	case *parse.ActionNode:
+		return parseNode(n.Pipe)
+	case *parse.IfNode:
+		names := parseNode(n.Pipe)
+		names = append(names, parseNode(n.List)...)
+		if n.ElseList != nil {
+			names = append(names, parseNode(n.ElseList)...)
+		}
+		return names
+	case *parse.RangeNode:
+		names := parseNode(n.Pipe)
+		names = append(names, parseNode(n.List)...)
+		if n.ElseList != nil {
+			names = append(names, parseNode(n.ElseList)...)
+		}
+		return names
+	case *parse.WithNode:
+		names := parseNode(n.Pipe)
+		names = append(names, parseNode(n.List)...)
+		if n.ElseList != nil {
+			names = append(names, parseNode(n.ElseList)...)
+		}
+		return names
+	case *parse.PipeNode:
+		var names []string
+		for _, c := range n.Cmds {
+			for _, a := range c.Args {
+				names = append(names, parseNode(a)...)
+			}
+		}
+		return names
+	case *parse.ListNode:
+		var names []string
+		for _, n := range n.Nodes {
+			names = append(names, parseNode(n)...)
+		}
+
+		return names
+	case *parse.FieldNode:
+		return n.Ident
+	}
+
+	return nil
+}

+ 89 - 0
template/template_test.go

@@ -0,0 +1,89 @@
+package template
+
+import (
+	"bufio"
+	"bytes"
+	"encoding/json"
+	"io"
+	"os"
+	"path/filepath"
+	"slices"
+	"testing"
+	"text/template"
+
+	"github.com/ollama/ollama/llm"
+)
+
+func TestNamed(t *testing.T) {
+	f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer f.Close()
+
+	scanner := bufio.NewScanner(f)
+	for scanner.Scan() {
+		var ss map[string]string
+		if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
+			t.Fatal(err)
+		}
+
+		for k, v := range ss {
+			t.Run(k, func(t *testing.T) {
+				kv := llm.KV{"tokenizer.chat_template": v}
+				s := kv.ChatTemplate()
+				r, err := Named(s)
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				if r.Name != k {
+					t.Errorf("expected %q, got %q", k, r.Name)
+				}
+
+				var b bytes.Buffer
+				if _, err := io.Copy(&b, r.Reader()); err != nil {
+					t.Fatal(err)
+				}
+
+				tmpl, err := template.New(s).Parse(b.String())
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				if tmpl.Tree.Root.String() == "" {
+					t.Errorf("empty %s template", k)
+				}
+			})
+		}
+	}
+}
+
+func TestParse(t *testing.T) {
+	cases := []struct {
+		template string
+		vars     []string
+	}{
+		{"{{ .Prompt }}", []string{"prompt"}},
+		{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
+		{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
+		{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}},
+		{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
+		{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
+		{"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}},
+	}
+
+	for _, tt := range cases {
+		t.Run("", func(t *testing.T) {
+			tmpl, err := Parse(tt.template)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			vars := tmpl.Vars()
+			if !slices.Equal(tt.vars, vars) {
+				t.Errorf("expected %v, got %v", tt.vars, vars)
+			}
+		})
+	}
+}

+ 0 - 0
templates/testdata/templates.jsonl → template/testdata/templates.jsonl


+ 0 - 0
templates/vicuna.gotmpl → template/vicuna.gotmpl


+ 0 - 0
templates/zephyr.gotmpl → template/zephyr.gotmpl


+ 0 - 70
templates/template.go

@@ -1,70 +0,0 @@
-package templates
-
-import (
-	"bytes"
-	"embed"
-	"encoding/json"
-	"errors"
-	"io"
-	"math"
-	"sync"
-
-	"github.com/agnivade/levenshtein"
-)
-
-//go:embed index.json
-var indexBytes []byte
-
-//go:embed *.gotmpl
-var templatesFS embed.FS
-
-var templatesOnce = sync.OnceValues(func() ([]*Template, error) {
-	var templates []*Template
-	if err := json.Unmarshal(indexBytes, &templates); err != nil {
-		return nil, err
-	}
-
-	for _, t := range templates {
-		bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
-		if err != nil {
-			return nil, err
-		}
-
-		// normalize line endings
-		t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
-	}
-
-	return templates, nil
-})
-
-type Template struct {
-	Name     string `json:"name"`
-	Template string `json:"template"`
-	Bytes []byte
-}
-
-func (t Template) Reader() io.Reader {
-	return bytes.NewReader(t.Bytes)
-}
-
-func NamedTemplate(s string) (*Template, error) {
-	templates, err := templatesOnce()
-	if err != nil {
-		return nil, err
-	}
-
-	var template *Template
-	score := math.MaxInt
-	for _, t := range templates {
-		if s := levenshtein.ComputeDistance(s, t.Template); s < score {
-			score = s
-			template = t
-		}
-	}
-
-	if score < 100 {
-		return template, nil
-	}
-
-	return nil, errors.New("no matching template found")
-}

+ 0 - 59
templates/template_test.go

@@ -1,59 +0,0 @@
-package templates
-
-import (
-	"bufio"
-	"bytes"
-	"encoding/json"
-	"io"
-	"os"
-	"path/filepath"
-	"testing"
-	"text/template"
-
-	"github.com/ollama/ollama/llm"
-)
-
-func TestKVChatTemplate(t *testing.T) {
-	f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer f.Close()
-
-	scanner := bufio.NewScanner(f)
-	for scanner.Scan() {
-		var ss map[string]string
-		if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
-			t.Fatal(err)
-		}
-
-		for k, v := range ss {
-			t.Run(k, func(t *testing.T) {
-				kv := llm.KV{"tokenizer.chat_template": v}
-				s := kv.ChatTemplate()
-				r, err := NamedTemplate(s)
-				if err != nil {
-					t.Fatal(err)
-				}
-
-				if r.Name != k {
-					t.Errorf("expected %q, got %q", k, r.Name)
-				}
-
-				var b bytes.Buffer
-				if _, err := io.Copy(&b, r.Reader()); err != nil {
-					t.Fatal(err)
-				}
-
-				tmpl, err := template.New(s).Parse(b.String())
-				if err != nil {
-					t.Fatal(err)
-				}
-
-				if tmpl.Tree.Root.String() == "" {
-					t.Errorf("empty %s template", k)
-				}
-			})
-		}
-	}
-}

+ 0 - 55
types/model/name.go

@@ -4,7 +4,6 @@ package model
 
 
 import (
 import (
 	"cmp"
 	"cmp"
-	"encoding/hex"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"log/slog"
 	"log/slog"
@@ -371,57 +370,3 @@ func cutPromised(s, sep string) (before, after string, ok bool) {
 	}
 	}
 	return cmp.Or(before, MissingPart), cmp.Or(after, MissingPart), true
 	return cmp.Or(before, MissingPart), cmp.Or(after, MissingPart), true
 }
 }
-
-type DigestType byte
-
-const (
-	DigestTypeInvalid DigestType = iota
-	DigestTypeSHA256
-)
-
-func (t DigestType) String() string {
-	switch t {
-	case DigestTypeSHA256:
-		return "sha256"
-	default:
-		return "invalid"
-	}
-}
-
-type Digest struct {
-	Type DigestType
-	Sum  [32]byte
-}
-
-func ParseDigest(s string) (Digest, error) {
-	i := strings.IndexAny(s, "-:")
-	if i < 0 {
-		return Digest{}, fmt.Errorf("invalid digest %q", s)
-	}
-	typ, encSum := s[:i], s[i+1:]
-	if typ != "sha256" {
-		return Digest{}, fmt.Errorf("unsupported digest type %q", typ)
-	}
-	d := Digest{
-		Type: DigestTypeSHA256,
-	}
-	n, err := hex.Decode(d.Sum[:], []byte(encSum))
-	if err != nil {
-		return Digest{}, err
-	}
-	if n != 32 {
-		return Digest{}, fmt.Errorf("digest %q decoded to %d bytes; want 32", encSum, n)
-	}
-	return d, nil
-}
-
-func (d Digest) String() string {
-	if d.Type == DigestTypeInvalid {
-		return ""
-	}
-	return fmt.Sprintf("sha256-%x", d.Sum)
-}
-
-func (d Digest) IsValid() bool {
-	return d.Type != DigestTypeInvalid
-}

+ 0 - 34
types/model/name_test.go

@@ -284,40 +284,6 @@ func TestFilepathAllocs(t *testing.T) {
 	}
 	}
 }
 }
 
 
-const (
-	validSha256    = "sha256-1000000000000000000000000000000000000000000000000000000000000000"
-	validSha256Old = "sha256:1000000000000000000000000000000000000000000000000000000000000000"
-)
-
-func TestParseDigest(t *testing.T) {
-	cases := []struct {
-		in   string
-		want string
-	}{
-		{"", ""},           // empty
-		{"sha123-12", ""},  // invalid type
-		{"sha256-", ""},    // invalid sum
-		{"sha256-123", ""}, // invalid odd length sum
-
-		{validSha256, validSha256},
-		{validSha256Old, validSha256},
-	}
-	for _, tt := range cases {
-		t.Run(tt.in, func(t *testing.T) {
-			got, err := ParseDigest(tt.in)
-			if err != nil {
-				if tt.want != "" {
-					t.Errorf("parseDigest(%q) = %v; want %v", tt.in, err, tt.want)
-				}
-				return
-			}
-			if got.String() != tt.want {
-				t.Errorf("parseDigest(%q).String() = %q; want %q", tt.in, got, tt.want)
-			}
-		})
-	}
-}
-
 func TestParseNameFromFilepath(t *testing.T) {
 func TestParseNameFromFilepath(t *testing.T) {
 	cases := map[string]Name{
 	cases := map[string]Name{
 		filepath.Join("host", "namespace", "model", "tag"):      {Host: "host", Namespace: "namespace", Model: "model", Tag: "tag"},
 		filepath.Join("host", "namespace", "model", "tag"):      {Host: "host", Namespace: "namespace", Model: "model", Tag: "tag"},

+ 34 - 0
util/bufioutil/buffer_seeker.go

@@ -0,0 +1,34 @@
+package bufioutil
+
+import (
+	"bufio"
+	"io"
+)
+
+type BufferedSeeker struct {
+	rs io.ReadSeeker
+	br *bufio.Reader
+}
+
+func NewBufferedSeeker(rs io.ReadSeeker, size int) *BufferedSeeker {
+	return &BufferedSeeker{
+		rs: rs,
+		br: bufio.NewReaderSize(rs, size),
+	}
+}
+
+func (b *BufferedSeeker) Read(p []byte) (int, error) {
+	return b.br.Read(p)
+}
+
+func (b *BufferedSeeker) Seek(offset int64, whence int) (int64, error) {
+	if whence == io.SeekCurrent {
+		offset -= int64(b.br.Buffered())
+	}
+	n, err := b.rs.Seek(offset, whence)
+	if err != nil {
+		return 0, err
+	}
+	b.br.Reset(b.rs)
+	return n, nil
+}

+ 64 - 0
util/bufioutil/buffer_seeker_test.go

@@ -0,0 +1,64 @@
+package bufioutil
+
+import (
+	"bytes"
+	"io"
+	"strings"
+	"testing"
+)
+
+func TestBufferedSeeker(t *testing.T) {
+	const alphabet = "abcdefghijklmnopqrstuvwxyz"
+
+	bs := NewBufferedSeeker(strings.NewReader(alphabet), 0) // minReadBufferSize = 16
+
+	checkRead := func(buf []byte, expected string) {
+		t.Helper()
+		_, err := bs.Read(buf)
+		if err != nil {
+			t.Fatal(err)
+		}
+		if !bytes.Equal(buf, []byte(expected)) {
+			t.Fatalf("expected %s, got %s", expected, buf)
+		}
+	}
+
+	// Read the first 5 bytes
+	buf := make([]byte, 5)
+
+	checkRead(buf, "abcde")
+
+	// Seek back to the beginning
+	_, err := bs.Seek(0, io.SeekStart)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// read 'a'
+	checkRead(buf[:1], "a")
+
+	if bs.br.Buffered() == 0 {
+		t.Fatalf("totally unexpected sanity check failed")
+	}
+
+	// Seek past 'b'
+	_, err = bs.Seek(1, io.SeekCurrent)
+	if err != nil {
+		t.Fatal(err)
+	}
+	checkRead(buf, "cdefg")
+
+	// Seek back to the beginning
+	_, err = bs.Seek(0, io.SeekStart)
+	if err != nil {
+		t.Fatal(err)
+	}
+	checkRead(buf, "abcde")
+
+	// Seek to the end
+	_, err = bs.Seek(-5, io.SeekEnd)
+	if err != nil {
+		t.Fatal(err)
+	}
+	checkRead(buf, "vwxyz")
+}