Browse Source

Merge branch 'ollama:main' into main

mraiser 1 year ago
parent
commit
4c4c730a0a

+ 77 - 21
.github/workflows/test.yaml

@@ -23,29 +23,72 @@ jobs:
         with:
         with:
           go-version: '1.21'
           go-version: '1.21'
           cache: true
           cache: true
-      - if: ${{ startsWith(matrix.os, 'windows-') }}
-        shell: pwsh
-        run: |
-          $path = vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath
-          if ($path) {
-              $path = join-path $path 'Common7\Tools\vsdevcmd.bat'
-              if (test-path $path) {
-                  cmd /s /c """$path"" $args && set" | where { $_ -match '(\w+)=(.*)' } | foreach {
-                      echo "$($Matches[1])=$($Matches[2])" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
-                  }
-              }
-          }
-
-          echo "C:\Program Files\Git\usr\bin" | Out-File -FilePath $Env:GITHUB_PATH -Encoding utf8 -Append
       - run: go get ./...
       - run: go get ./...
       - run: go generate -x ./...
       - run: go generate -x ./...
       - uses: actions/upload-artifact@v4
       - uses: actions/upload-artifact@v4
         with:
         with:
           name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
           name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
-          path: |
-            llm/llama.cpp/build/**/lib/*
+          path: llm/llama.cpp/build/**/lib/*
+  generate-cuda:
+    strategy:
+      matrix:
+        cuda-version:
+          - '11.8.0'
+    runs-on: ubuntu-latest
+    container: nvidia/cuda:${{ matrix.cuda-version }}-devel-ubuntu20.04
+    steps:
+      - run: |
+          apt-get update && apt-get install -y git build-essential curl
+          curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \
+            | tar -zx -C /usr --strip-components 1
+        env:
+          DEBIAN_FRONTEND: noninteractive
+      - uses: actions/checkout@v4
+      - uses: actions/setup-go@v4
+        with:
+          go-version: '1.21'
+          cache: true
+      - run: go get ./...
+      - run: |
+          git config --global --add safe.directory /__w/ollama/ollama
+          go generate -x ./...
+        env:
+          OLLAMA_SKIP_CPU_GENERATE: '1'
+      - uses: actions/upload-artifact@v4
+        with:
+          name: cuda-${{ matrix.cuda-version }}-libraries
+          path: llm/llama.cpp/build/**/lib/*
+  generate-rocm:
+    strategy:
+      matrix:
+        rocm-version:
+          - '5.7.1'
+          - '6.0'
+    runs-on: ubuntu-latest
+    container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
+    steps:
+      - run: |
+          apt-get update && apt-get install -y git build-essential curl rocm-libs
+          curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \
+            | tar -zx -C /usr --strip-components 1
+        env:
+          DEBIAN_FRONTEND: noninteractive
+      - uses: actions/checkout@v4
+      - uses: actions/setup-go@v4
+        with:
+          go-version: '1.21'
+          cache: true
+      - run: go get ./...
+      - run: |
+          git config --global --add safe.directory /__w/ollama/ollama
+          go generate -x ./...
+        env:
+          OLLAMA_SKIP_CPU_GENERATE: '1'
+      - uses: actions/upload-artifact@v4
+        with:
+          name: rocm-${{ matrix.rocm-version }}-libraries
+          path: llm/llama.cpp/build/**/lib/*
   lint:
   lint:
-    needs: generate
     strategy:
     strategy:
       matrix:
       matrix:
         os: [ubuntu-latest, macos-latest, windows-latest]
         os: [ubuntu-latest, macos-latest, windows-latest]
@@ -69,10 +112,19 @@ jobs:
         with:
         with:
           go-version: '1.21'
           go-version: '1.21'
           cache: false
           cache: false
-      - uses: actions/download-artifact@v4
-        with:
-          name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
-          path: llm/llama.cpp/build
+      - run: |
+          mkdir -p llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/
+          touch llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/stub.so
+        if: ${{ startsWith(matrix.os, 'ubuntu-') }}
+      - run: |
+          mkdir -p llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/
+          touch llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/stub.dylib
+          touch llm/llama.cpp/ggml-metal.metal
+        if: ${{ startsWith(matrix.os, 'macos-') }}
+      - run: |
+          mkdir -p llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/
+          touch llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/stub.dll
+        if: ${{ startsWith(matrix.os, 'windows-') }}
       - uses: golangci/golangci-lint-action@v3
       - uses: golangci/golangci-lint-action@v3
   test:
   test:
     needs: generate
     needs: generate
@@ -104,3 +156,7 @@ jobs:
           path: llm/llama.cpp/build
           path: llm/llama.cpp/build
       - run: go build
       - run: go build
       - run: go test -v ./...
       - run: go test -v ./...
+      - uses: actions/upload-artifact@v4
+        with:
+          name: ${{ matrix.os }}-binaries
+          path: ollama

+ 13 - 2
Dockerfile

@@ -109,17 +109,28 @@ ARG CGO_CFLAGS
 RUN go build .
 RUN go build .
 
 
 # Runtime stages
 # Runtime stages
-FROM --platform=linux/amd64 rocm/dev-centos-7:6.0-complete as runtime-amd64
+FROM --platform=linux/amd64 ubuntu:22.04 as runtime-amd64
+RUN apt-get update && apt-get install -y ca-certificates
 COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
 COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
 FROM --platform=linux/arm64 ubuntu:22.04 as runtime-arm64
 FROM --platform=linux/arm64 ubuntu:22.04 as runtime-arm64
 RUN apt-get update && apt-get install -y ca-certificates
 RUN apt-get update && apt-get install -y ca-certificates
 COPY --from=build-arm64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
 COPY --from=build-arm64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
 
 
+# Radeon images are much larger so we keep it distinct from the CPU/CUDA image
+FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete as runtime-rocm
+RUN update-pciids
+COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
+EXPOSE 11434
+ENV OLLAMA_HOST 0.0.0.0
+
+ENTRYPOINT ["/bin/ollama"]
+CMD ["serve"]
+
 FROM runtime-$TARGETARCH
 FROM runtime-$TARGETARCH
 EXPOSE 11434
 EXPOSE 11434
 ENV OLLAMA_HOST 0.0.0.0
 ENV OLLAMA_HOST 0.0.0.0
 ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
 ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
-ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/opt/rocm/lib:
+ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
 ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
 ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
 
 
 ENTRYPOINT ["/bin/ollama"]
 ENTRYPOINT ["/bin/ollama"]

+ 27 - 17
api/types.go

@@ -34,24 +34,26 @@ func (e StatusError) Error() string {
 type ImageData []byte
 type ImageData []byte
 
 
 type GenerateRequest struct {
 type GenerateRequest struct {
-	Model    string      `json:"model"`
-	Prompt   string      `json:"prompt"`
-	System   string      `json:"system"`
-	Template string      `json:"template"`
-	Context  []int       `json:"context,omitempty"`
-	Stream   *bool       `json:"stream,omitempty"`
-	Raw      bool        `json:"raw,omitempty"`
-	Format   string      `json:"format"`
-	Images   []ImageData `json:"images,omitempty"`
+	Model     string      `json:"model"`
+	Prompt    string      `json:"prompt"`
+	System    string      `json:"system"`
+	Template  string      `json:"template"`
+	Context   []int       `json:"context,omitempty"`
+	Stream    *bool       `json:"stream,omitempty"`
+	Raw       bool        `json:"raw,omitempty"`
+	Format    string      `json:"format"`
+	KeepAlive *Duration   `json:"keep_alive,omitempty"`
+	Images    []ImageData `json:"images,omitempty"`
 
 
 	Options map[string]interface{} `json:"options"`
 	Options map[string]interface{} `json:"options"`
 }
 }
 
 
 type ChatRequest struct {
 type ChatRequest struct {
-	Model    string    `json:"model"`
-	Messages []Message `json:"messages"`
-	Stream   *bool     `json:"stream,omitempty"`
-	Format   string    `json:"format"`
+	Model     string    `json:"model"`
+	Messages  []Message `json:"messages"`
+	Stream    *bool     `json:"stream,omitempty"`
+	Format    string    `json:"format"`
+	KeepAlive *Duration `json:"keep_alive,omitempty"`
 
 
 	Options map[string]interface{} `json:"options"`
 	Options map[string]interface{} `json:"options"`
 }
 }
@@ -126,8 +128,9 @@ type Runner struct {
 }
 }
 
 
 type EmbeddingRequest struct {
 type EmbeddingRequest struct {
-	Model  string `json:"model"`
-	Prompt string `json:"prompt"`
+	Model     string    `json:"model"`
+	Prompt    string    `json:"prompt"`
+	KeepAlive *Duration `json:"keep_alive,omitempty"`
 
 
 	Options map[string]interface{} `json:"options"`
 	Options map[string]interface{} `json:"options"`
 }
 }
@@ -171,6 +174,7 @@ type ShowResponse struct {
 	Template   string       `json:"template,omitempty"`
 	Template   string       `json:"template,omitempty"`
 	System     string       `json:"system,omitempty"`
 	System     string       `json:"system,omitempty"`
 	Details    ModelDetails `json:"details,omitempty"`
 	Details    ModelDetails `json:"details,omitempty"`
+	Messages   []Message    `json:"messages,omitempty"`
 }
 }
 
 
 type CopyRequest struct {
 type CopyRequest struct {
@@ -236,6 +240,7 @@ type GenerateResponse struct {
 }
 }
 
 
 type ModelDetails struct {
 type ModelDetails struct {
+	ParentModel       string   `json:"parent_model"`
 	Format            string   `json:"format"`
 	Format            string   `json:"format"`
 	Family            string   `json:"family"`
 	Family            string   `json:"family"`
 	Families          []string `json:"families"`
 	Families          []string `json:"families"`
@@ -411,14 +416,19 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
 	case float64:
 	case float64:
 		if t < 0 {
 		if t < 0 {
 			t = math.MaxFloat64
 			t = math.MaxFloat64
+			d.Duration = time.Duration(t)
+		} else {
+			d.Duration = time.Duration(t * float64(time.Second))
 		}
 		}
-
-		d.Duration = time.Duration(t)
 	case string:
 	case string:
 		d.Duration, err = time.ParseDuration(t)
 		d.Duration, err = time.ParseDuration(t)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
+		if d.Duration < 0 {
+			mf := math.MaxFloat64
+			d.Duration = time.Duration(mf)
+		}
 	}
 	}
 
 
 	return nil
 	return nil

+ 11 - 9
cmd/cmd.go

@@ -458,15 +458,17 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
 type generateContextKey string
 type generateContextKey string
 
 
 type runOptions struct {
 type runOptions struct {
-	Model    string
-	Prompt   string
-	Messages []api.Message
-	WordWrap bool
-	Format   string
-	System   string
-	Template string
-	Images   []api.ImageData
-	Options  map[string]interface{}
+	Model       string
+	ParentModel string
+	Prompt      string
+	Messages    []api.Message
+	WordWrap    bool
+	Format      string
+	System      string
+	Template    string
+	Images      []api.ImageData
+	Options     map[string]interface{}
+	MultiModal  bool
 }
 }
 
 
 type displayResponseState struct {
 type displayResponseState struct {

+ 127 - 24
cmd/interactive.go

@@ -7,12 +7,14 @@ import (
 	"net/http"
 	"net/http"
 	"os"
 	"os"
 	"regexp"
 	"regexp"
+	"sort"
 	"strings"
 	"strings"
 
 
 	"github.com/spf13/cobra"
 	"github.com/spf13/cobra"
 	"golang.org/x/exp/slices"
 	"golang.org/x/exp/slices"
 
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/progress"
 	"github.com/jmorganca/ollama/readline"
 	"github.com/jmorganca/ollama/readline"
 )
 )
 
 
@@ -25,43 +27,75 @@ const (
 	MultilineTemplate
 	MultilineTemplate
 )
 )
 
 
-func modelIsMultiModal(cmd *cobra.Command, name string) bool {
-	// get model details
+func loadModel(cmd *cobra.Command, opts *runOptions) error {
 	client, err := api.ClientFromEnvironment()
 	client, err := api.ClientFromEnvironment()
 	if err != nil {
 	if err != nil {
-		fmt.Println("error: couldn't connect to ollama server")
-		return false
+		return err
 	}
 	}
 
 
-	req := api.ShowRequest{Name: name}
-	resp, err := client.Show(cmd.Context(), &req)
+	p := progress.NewProgress(os.Stderr)
+	defer p.StopAndClear()
+
+	spinner := progress.NewSpinner("")
+	p.Add("", spinner)
+
+	showReq := api.ShowRequest{Name: opts.Model}
+	showResp, err := client.Show(cmd.Context(), &showReq)
 	if err != nil {
 	if err != nil {
-		return false
+		return err
 	}
 	}
+	opts.MultiModal = slices.Contains(showResp.Details.Families, "clip")
+	opts.ParentModel = showResp.Details.ParentModel
 
 
-	return slices.Contains(resp.Details.Families, "clip")
-}
-
-func generateInteractive(cmd *cobra.Command, opts runOptions) error {
-	multiModal := modelIsMultiModal(cmd, opts.Model)
+	if len(showResp.Messages) > 0 {
+		opts.Messages = append(opts.Messages, showResp.Messages...)
+	}
 
 
-	// load the model
-	loadOpts := runOptions{
+	chatReq := &api.ChatRequest{
 		Model:    opts.Model,
 		Model:    opts.Model,
-		Prompt:   "",
 		Messages: []api.Message{},
 		Messages: []api.Message{},
 	}
 	}
-	if _, err := chat(cmd, loadOpts); err != nil {
+	err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
+		p.StopAndClear()
+		if len(opts.Messages) > 0 {
+			for _, msg := range opts.Messages {
+				switch msg.Role {
+				case "user":
+					fmt.Printf(">>> %s\n", msg.Content)
+				case "assistant":
+					state := &displayResponseState{}
+					displayResponse(msg.Content, opts.WordWrap, state)
+					fmt.Println()
+					fmt.Println()
+				}
+			}
+		}
+		return nil
+	})
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func generateInteractive(cmd *cobra.Command, opts runOptions) error {
+	opts.Messages = make([]api.Message, 0)
+
+	err := loadModel(cmd, &opts)
+	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
 	usage := func() {
 	usage := func() {
 		fmt.Fprintln(os.Stderr, "Available Commands:")
 		fmt.Fprintln(os.Stderr, "Available Commands:")
-		fmt.Fprintln(os.Stderr, "  /set          Set session variables")
-		fmt.Fprintln(os.Stderr, "  /show         Show model information")
-		fmt.Fprintln(os.Stderr, "  /bye          Exit")
-		fmt.Fprintln(os.Stderr, "  /?, /help     Help for a command")
-		fmt.Fprintln(os.Stderr, "  /? shortcuts  Help for keyboard shortcuts")
+		fmt.Fprintln(os.Stderr, "  /set            Set session variables")
+		fmt.Fprintln(os.Stderr, "  /show           Show model information")
+		fmt.Fprintln(os.Stderr, "  /load <model>   Load a session or model")
+		fmt.Fprintln(os.Stderr, "  /save <model>   Save your current session")
+		fmt.Fprintln(os.Stderr, "  /bye            Exit")
+		fmt.Fprintln(os.Stderr, "  /?, /help       Help for a command")
+		fmt.Fprintln(os.Stderr, "  /? shortcuts    Help for keyboard shortcuts")
 		fmt.Fprintln(os.Stderr, "")
 		fmt.Fprintln(os.Stderr, "")
 		fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
 		fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
 		fmt.Fprintln(os.Stderr, "")
 		fmt.Fprintln(os.Stderr, "")
@@ -140,7 +174,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 
 
 	var sb strings.Builder
 	var sb strings.Builder
 	var multiline MultilineState
 	var multiline MultilineState
-	opts.Messages = make([]api.Message, 0)
 
 
 	for {
 	for {
 		line, err := scanner.Readline()
 		line, err := scanner.Readline()
@@ -203,6 +236,44 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 			if err := ListHandler(cmd, args[1:]); err != nil {
 			if err := ListHandler(cmd, args[1:]); err != nil {
 				return err
 				return err
 			}
 			}
+		case strings.HasPrefix(line, "/load"):
+			args := strings.Fields(line)
+			if len(args) != 2 {
+				fmt.Println("Usage:\n  /load <modelname>")
+				continue
+			}
+			opts.Model = args[1]
+			opts.Messages = []api.Message{}
+			fmt.Printf("Loading model '%s'\n", opts.Model)
+			if err := loadModel(cmd, &opts); err != nil {
+				return err
+			}
+			continue
+		case strings.HasPrefix(line, "/save"):
+			args := strings.Fields(line)
+			if len(args) != 2 {
+				fmt.Println("Usage:\n  /save <modelname>")
+				continue
+			}
+
+			client, err := api.ClientFromEnvironment()
+			if err != nil {
+				fmt.Println("error: couldn't connect to ollama server")
+				return err
+			}
+
+			req := &api.CreateRequest{
+				Name:      args[1],
+				Modelfile: buildModelfile(opts),
+			}
+			fn := func(resp api.ProgressResponse) error { return nil }
+			err = client.Create(cmd.Context(), req, fn)
+			if err != nil {
+				fmt.Println("error: couldn't save model")
+				return err
+			}
+			fmt.Printf("Created new model '%s'\n", args[1])
+			continue
 		case strings.HasPrefix(line, "/set"):
 		case strings.HasPrefix(line, "/set"):
 			args := strings.Fields(line)
 			args := strings.Fields(line)
 			if len(args) > 1 {
 			if len(args) > 1 {
@@ -389,7 +460,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 			args := strings.Fields(line)
 			args := strings.Fields(line)
 			isFile := false
 			isFile := false
 
 
-			if multiModal {
+			if opts.MultiModal {
 				for _, f := range extractFileNames(line) {
 				for _, f := range extractFileNames(line) {
 					if strings.HasPrefix(f, args[0]) {
 					if strings.HasPrefix(f, args[0]) {
 						isFile = true
 						isFile = true
@@ -411,7 +482,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 		if sb.Len() > 0 && multiline == MultilineNone {
 		if sb.Len() > 0 && multiline == MultilineNone {
 			newMessage := api.Message{Role: "user", Content: sb.String()}
 			newMessage := api.Message{Role: "user", Content: sb.String()}
 
 
-			if multiModal {
+			if opts.MultiModal {
 				msg, images, err := extractFileData(sb.String())
 				msg, images, err := extractFileData(sb.String())
 				if err != nil {
 				if err != nil {
 					return err
 					return err
@@ -454,6 +525,38 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 	}
 	}
 }
 }
 
 
+func buildModelfile(opts runOptions) string {
+	var mf strings.Builder
+	model := opts.ParentModel
+	if model == "" {
+		model = opts.Model
+	}
+	fmt.Fprintf(&mf, "FROM %s\n", model)
+	if opts.System != "" {
+		fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
+	}
+
+	if opts.Template != "" {
+		fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
+	}
+
+	keys := make([]string, 0)
+	for k := range opts.Options {
+		keys = append(keys, k)
+	}
+	sort.Strings(keys)
+	for _, k := range keys {
+		fmt.Fprintf(&mf, "PARAMETER %s %v\n", k, opts.Options[k])
+	}
+	fmt.Fprintln(&mf)
+
+	for _, msg := range opts.Messages {
+		fmt.Fprintf(&mf, "MESSAGE %s \"\"\"%s\"\"\"\n", msg.Role, msg.Content)
+	}
+
+	return mf.String()
+}
+
 func normalizeFilePath(fp string) string {
 func normalizeFilePath(fp string) string {
 	// Define a map of escaped characters and their replacements
 	// Define a map of escaped characters and their replacements
 	replacements := map[string]string{
 	replacements := map[string]string{

+ 65 - 0
cmd/interactive_test.go

@@ -1,9 +1,13 @@
 package cmd
 package cmd
 
 
 import (
 import (
+	"bytes"
 	"testing"
 	"testing"
+	"text/template"
 
 
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
+
+	"github.com/jmorganca/ollama/api"
 )
 )
 
 
 func TestExtractFilenames(t *testing.T) {
 func TestExtractFilenames(t *testing.T) {
@@ -49,3 +53,64 @@ d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
 	assert.Contains(t, res[9], "ten.svg")
 	assert.Contains(t, res[9], "ten.svg")
 	assert.Contains(t, res[9], "E:")
 	assert.Contains(t, res[9], "E:")
 }
 }
+
+func TestModelfileBuilder(t *testing.T) {
+	opts := runOptions{
+		Model:    "hork",
+		System:   "You are part horse and part shark, but all hork. Do horklike things",
+		Template: "This is a template.",
+		Messages: []api.Message{
+			{Role: "user", Content: "Hey there hork!"},
+			{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
+		},
+		Options: map[string]interface{}{},
+	}
+
+	opts.Options["temperature"] = 0.9
+	opts.Options["seed"] = 42
+	opts.Options["penalize_newline"] = false
+	opts.Options["stop"] = []string{"hi", "there"}
+
+	mf := buildModelfile(opts)
+	expectedModelfile := `FROM {{.Model}}
+SYSTEM """{{.System}}"""
+TEMPLATE """{{.Template}}"""
+PARAMETER penalize_newline false
+PARAMETER seed 42
+PARAMETER stop [hi there]
+PARAMETER temperature 0.9
+
+MESSAGE user """Hey there hork!"""
+MESSAGE assistant """Yes it is true, I am half horse, half shark."""
+`
+
+	tmpl, err := template.New("").Parse(expectedModelfile)
+	assert.Nil(t, err)
+
+	var buf bytes.Buffer
+	err = tmpl.Execute(&buf, opts)
+	assert.Nil(t, err)
+	assert.Equal(t, buf.String(), mf)
+
+	opts.ParentModel = "horseshark"
+	mf = buildModelfile(opts)
+	expectedModelfile = `FROM {{.ParentModel}}
+SYSTEM """{{.System}}"""
+TEMPLATE """{{.Template}}"""
+PARAMETER penalize_newline false
+PARAMETER seed 42
+PARAMETER stop [hi there]
+PARAMETER temperature 0.9
+
+MESSAGE user """Hey there hork!"""
+MESSAGE assistant """Yes it is true, I am half horse, half shark."""
+`
+
+	tmpl, err = template.New("").Parse(expectedModelfile)
+	assert.Nil(t, err)
+
+	var parentBuf bytes.Buffer
+	err = tmpl.Execute(&parentBuf, opts)
+	assert.Nil(t, err)
+	assert.Equal(t, parentBuf.String(), mf)
+}

+ 2 - 1
docs/development.md

@@ -50,7 +50,8 @@ development and runtime packages.
 Typically the build scripts will auto-detect CUDA, however, if your Linux distro
 Typically the build scripts will auto-detect CUDA, however, if your Linux distro
 or installation approach uses unusual paths, you can specify the location by
 or installation approach uses unusual paths, you can specify the location by
 specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
 specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
-libraries, and `CUDACXX` to the location of the nvcc compiler.
+libraries, and `CUDACXX` to the location of the nvcc compiler.  You can customize
+set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
 
 
 Then generate dependencies:
 Then generate dependencies:
 
 

+ 15 - 0
docs/modelfile.md

@@ -19,6 +19,7 @@ A model file is the blueprint to create and share models with Ollama.
   - [SYSTEM](#system)
   - [SYSTEM](#system)
   - [ADAPTER](#adapter)
   - [ADAPTER](#adapter)
   - [LICENSE](#license)
   - [LICENSE](#license)
+  - [MESSAGE](#message)
 - [Notes](#notes)
 - [Notes](#notes)
 
 
 ## Format
 ## Format
@@ -38,6 +39,7 @@ INSTRUCTION arguments
 | [`SYSTEM`](#system)                 | Specifies the system message that will be set in the template. |
 | [`SYSTEM`](#system)                 | Specifies the system message that will be set in the template. |
 | [`ADAPTER`](#adapter)               | Defines the (Q)LoRA adapters to apply to the model.            |
 | [`ADAPTER`](#adapter)               | Defines the (Q)LoRA adapters to apply to the model.            |
 | [`LICENSE`](#license)               | Specifies the legal license.                                   |
 | [`LICENSE`](#license)               | Specifies the legal license.                                   |
+| [`MESSAGE`](#message)               | Specify message history.                                       |
 
 
 ## Examples
 ## Examples
 
 
@@ -205,6 +207,19 @@ LICENSE """
 """
 """
 ```
 ```
 
 
+### MESSAGE
+
+The `MESSAGE` instruction allows you to specify a message history for the model to use when responding:
+
+```modelfile
+MESSAGE user Is Toronto in Canada?
+MESSAGE assistant yes
+MESSAGE user Is Sacramento in Canada?
+MESSAGE assistant no
+MESSAGE user Is Ontario in Canada?
+MESSAGE assistant yes
+```
+
 ## Notes
 ## Notes
 
 
 - the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.
 - the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.

+ 37 - 7
gpu/gpu.go

@@ -16,6 +16,7 @@ import (
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"runtime"
 	"runtime"
+	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"unsafe"
 	"unsafe"
@@ -29,8 +30,8 @@ type handles struct {
 var gpuMutex sync.Mutex
 var gpuMutex sync.Mutex
 var gpuHandles *handles = nil
 var gpuHandles *handles = nil
 
 
-// With our current CUDA compile flags, 5.2 and older will not work properly
-const CudaComputeMajorMin = 6
+// With our current CUDA compile flags, older than 5.0 will not work properly
+var CudaComputeMin = [2]C.int{5, 0}
 
 
 // Possible locations for the nvidia-ml library
 // Possible locations for the nvidia-ml library
 var CudaLinuxGlobs = []string{
 var CudaLinuxGlobs = []string{
@@ -121,9 +122,15 @@ func GetGPUInfo() GpuInfo {
 		initGPUHandles()
 		initGPUHandles()
 	}
 	}
 
 
+	// All our GPU builds have AVX enabled, so fallback to CPU if we don't detect at least AVX
+	cpuVariant := GetCPUVariant()
+	if cpuVariant == "" {
+		slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
+	}
+
 	var memInfo C.mem_info_t
 	var memInfo C.mem_info_t
 	resp := GpuInfo{}
 	resp := GpuInfo{}
-	if gpuHandles.cuda != nil {
+	if gpuHandles.cuda != nil && cpuVariant != "" {
 		C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
 		C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
 		if memInfo.err != nil {
 		if memInfo.err != nil {
 			slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err)))
 			slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err)))
@@ -135,19 +142,40 @@ func GetGPUInfo() GpuInfo {
 			if cc.err != nil {
 			if cc.err != nil {
 				slog.Info(fmt.Sprintf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err)))
 				slog.Info(fmt.Sprintf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err)))
 				C.free(unsafe.Pointer(cc.err))
 				C.free(unsafe.Pointer(cc.err))
-			} else if cc.major >= CudaComputeMajorMin {
+			} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
 				slog.Info(fmt.Sprintf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
 				slog.Info(fmt.Sprintf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
 				resp.Library = "cuda"
 				resp.Library = "cuda"
 			} else {
 			} else {
 				slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
 				slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
 			}
 			}
 		}
 		}
-	} else if gpuHandles.rocm != nil {
+	} else if gpuHandles.rocm != nil && cpuVariant != "" {
 		C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
 		C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
 		if memInfo.err != nil {
 		if memInfo.err != nil {
 			slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err)))
 			slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err)))
 			C.free(unsafe.Pointer(memInfo.err))
 			C.free(unsafe.Pointer(memInfo.err))
+		} else if memInfo.igpu_index >= 0 && memInfo.count == 1 {
+			// Only one GPU detected and it appears to be an integrated GPU - skip it
+			slog.Info("ROCm unsupported integrated GPU detected")
 		} else {
 		} else {
+			if memInfo.igpu_index >= 0 {
+				// We have multiple GPUs reported, and one of them is an integrated GPU
+				// so we have to set the env var to bypass it
+				// If the user has specified their own ROCR_VISIBLE_DEVICES, don't clobber it
+				val := os.Getenv("ROCR_VISIBLE_DEVICES")
+				if val == "" {
+					devices := []string{}
+					for i := 0; i < int(memInfo.count); i++ {
+						if i == int(memInfo.igpu_index) {
+							continue
+						}
+						devices = append(devices, strconv.Itoa(i))
+					}
+					val = strings.Join(devices, ",")
+					os.Setenv("ROCR_VISIBLE_DEVICES", val)
+				}
+				slog.Info(fmt.Sprintf("ROCm integrated GPU detected - ROCR_VISIBLE_DEVICES=%s", val))
+			}
 			resp.Library = "rocm"
 			resp.Library = "rocm"
 			var version C.rocm_version_resp_t
 			var version C.rocm_version_resp_t
 			C.rocm_get_version(*gpuHandles.rocm, &version)
 			C.rocm_get_version(*gpuHandles.rocm, &version)
@@ -163,7 +191,7 @@ func GetGPUInfo() GpuInfo {
 	if resp.Library == "" {
 	if resp.Library == "" {
 		C.cpu_check_ram(&memInfo)
 		C.cpu_check_ram(&memInfo)
 		resp.Library = "cpu"
 		resp.Library = "cpu"
-		resp.Variant = GetCPUVariant()
+		resp.Variant = cpuVariant
 	}
 	}
 	if memInfo.err != nil {
 	if memInfo.err != nil {
 		slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
 		slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
@@ -199,7 +227,9 @@ func CheckVRAM() (int64, error) {
 		if overhead < gpus*1024*1024*1024 {
 		if overhead < gpus*1024*1024*1024 {
 			overhead = gpus * 1024 * 1024 * 1024
 			overhead = gpus * 1024 * 1024 * 1024
 		}
 		}
-		return int64(gpuInfo.FreeMemory - overhead), nil
+		avail := int64(gpuInfo.FreeMemory - overhead)
+		slog.Debug(fmt.Sprintf("%s detected %d devices with %dM available memory", gpuInfo.Library, gpuInfo.DeviceCount, avail/1024/1024))
+		return avail, nil
 	}
 	}
 
 
 	return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
 	return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation

+ 1 - 0
gpu/gpu_info.h

@@ -42,6 +42,7 @@ typedef struct mem_info {
   uint64_t total;
   uint64_t total;
   uint64_t free;
   uint64_t free;
   unsigned int count;
   unsigned int count;
+  int igpu_index; // If >= 0, we detected an integrated GPU to ignore
   char *err;  // If non-nill, caller responsible for freeing
   char *err;  // If non-nill, caller responsible for freeing
 } mem_info_t;
 } mem_info_t;
 
 

+ 1 - 0
gpu/gpu_info_cuda.c

@@ -70,6 +70,7 @@ void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
     resp->ch.handle = NULL;
     resp->ch.handle = NULL;
     snprintf(buf, buflen, "nvml vram init failure: %d", ret);
     snprintf(buf, buflen, "nvml vram init failure: %d", ret);
     resp->err = strdup(buf);
     resp->err = strdup(buf);
+    return;
   }
   }
 
 
   // Report driver version if we're in verbose mode, ignore errors
   // Report driver version if we're in verbose mode, ignore errors

+ 11 - 4
gpu/gpu_info_rocm.c

@@ -77,6 +77,7 @@ void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
 
 
 void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
 void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
   resp->err = NULL;
   resp->err = NULL;
+  resp->igpu_index = -1;
   uint64_t totalMem = 0;
   uint64_t totalMem = 0;
   uint64_t usedMem = 0;
   uint64_t usedMem = 0;
   rsmi_status_t ret;
   rsmi_status_t ret;
@@ -162,8 +163,14 @@ void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
     }
     }
     LOG(h.verbose, "[%d] ROCm totalMem %ld\n", i, totalMem);
     LOG(h.verbose, "[%d] ROCm totalMem %ld\n", i, totalMem);
     LOG(h.verbose, "[%d] ROCm usedMem %ld\n", i, usedMem);
     LOG(h.verbose, "[%d] ROCm usedMem %ld\n", i, usedMem);
-    resp->total += totalMem;
-    resp->free += totalMem - usedMem;
+    if (totalMem < 1024 * 1024 * 1024) {
+      // Do not add up integrated GPU memory capacity, it's a bogus 512M, and actually uses system memory
+      LOG(h.verbose, "[%d] ROCm integrated GPU\n", i);
+      resp->igpu_index = i;
+    } else {
+      resp->total += totalMem;
+      resp->free += totalMem - usedMem;
+    }
   }
   }
 }
 }
 
 
@@ -171,7 +178,7 @@ void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) {
   const int buflen = 256;
   const int buflen = 256;
   char buf[buflen + 1];
   char buf[buflen + 1];
   if (h.handle == NULL) {
   if (h.handle == NULL) {
-    resp->str = strdup("nvml handle not initialized");
+    resp->str = strdup("rocm handle not initialized");
     resp->status = 1;
     resp->status = 1;
     return;
     return;
   }
   }
@@ -188,4 +195,4 @@ void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) {
   resp->str = strdup(buf);
   resp->str = strdup(buf);
 }
 }
 
 
-#endif  // __APPLE__
+#endif  // __APPLE__

+ 1 - 0
llm/dyn_ext_server.go

@@ -190,6 +190,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
 		"seed":              predict.Options.Seed,
 		"seed":              predict.Options.Seed,
 		"stop":              predict.Options.Stop,
 		"stop":              predict.Options.Stop,
 		"image_data":        imageData,
 		"image_data":        imageData,
+		"cache_prompt":      true,
 	}
 	}
 
 
 	if predict.Format == "json" {
 	if predict.Format == "json" {

+ 14 - 0
llm/generate/gen_common.sh

@@ -39,6 +39,9 @@ init_vars() {
     *)
     *)
         ;;
         ;;
     esac
     esac
+    if [ -z "${CMAKE_CUDA_ARCHITECTURES}" ] ; then 
+        CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
+    fi
 }
 }
 
 
 git_module_setup() {
 git_module_setup() {
@@ -61,6 +64,17 @@ apply_patches() {
     if ! grep ollama ${LLAMACPP_DIR}/examples/server/CMakeLists.txt; then
     if ! grep ollama ${LLAMACPP_DIR}/examples/server/CMakeLists.txt; then
         echo 'include (../../../ext_server/CMakeLists.txt) # ollama' >>${LLAMACPP_DIR}/examples/server/CMakeLists.txt
         echo 'include (../../../ext_server/CMakeLists.txt) # ollama' >>${LLAMACPP_DIR}/examples/server/CMakeLists.txt
     fi
     fi
+
+    # apply temporary patches until fix is upstream
+    for patch in ../patches/*.diff; do
+        for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do
+            (cd ${LLAMACPP_DIR}; git checkout ${file})
+        done
+    done
+    for patch in ../patches/*.diff; do
+        (cd ${LLAMACPP_DIR} && git apply ${patch})
+    done
+
     # Avoid duplicate main symbols when we link into the cgo binary
     # Avoid duplicate main symbols when we link into the cgo binary
     sed -e 's/int main(/int __main(/g' <${LLAMACPP_DIR}/examples/server/server.cpp >${LLAMACPP_DIR}/examples/server/server.cpp.tmp &&
     sed -e 's/int main(/int __main(/g' <${LLAMACPP_DIR}/examples/server/server.cpp >${LLAMACPP_DIR}/examples/server/server.cpp.tmp &&
         mv ${LLAMACPP_DIR}/examples/server/server.cpp.tmp ${LLAMACPP_DIR}/examples/server/server.cpp
         mv ${LLAMACPP_DIR}/examples/server/server.cpp.tmp ${LLAMACPP_DIR}/examples/server/server.cpp

+ 1 - 1
llm/generate/gen_linux.sh

@@ -140,7 +140,7 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
     if [ -n "${CUDA_MAJOR}" ]; then
     if [ -n "${CUDA_MAJOR}" ]; then
         CUDA_VARIANT=_v${CUDA_MAJOR}
         CUDA_VARIANT=_v${CUDA_MAJOR}
     fi
     fi
-    CMAKE_DEFS="-DLLAMA_CUBLAS=on ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
+    CMAKE_DEFS="-DLLAMA_CUBLAS=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
     BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cuda${CUDA_VARIANT}"
     BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cuda${CUDA_VARIANT}"
     EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
     EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
     build
     build

+ 29 - 1
llm/generate/gen_windows.ps1

@@ -25,6 +25,11 @@ function init_vars {
     }
     }
     $script:GZIP=(get-command -ea 'silentlycontinue' gzip).path
     $script:GZIP=(get-command -ea 'silentlycontinue' gzip).path
     $script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
     $script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
+    if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) {
+        $script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
+    } else {
+        $script:CMAKE_CUDA_ARCHITECTURES=$env:CMAKE_CUDA_ARCHITECTURES
+    }
 }
 }
 
 
 function git_module_setup {
 function git_module_setup {
@@ -40,6 +45,29 @@ function apply_patches {
     if (!(Select-String -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Pattern 'ollama')) {
     if (!(Select-String -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Pattern 'ollama')) {
         Add-Content -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Value 'include (../../../ext_server/CMakeLists.txt) # ollama'
         Add-Content -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Value 'include (../../../ext_server/CMakeLists.txt) # ollama'
     }
     }
+
+    # Apply temporary patches until fix is upstream
+    $patches = Get-ChildItem "../patches/*.diff"
+    foreach ($patch in $patches) {
+        # Extract file paths from the patch file
+        $filePaths = Get-Content $patch.FullName | Where-Object { $_ -match '^\+\+\+ ' } | ForEach-Object {
+            $parts = $_ -split ' '
+            ($parts[1] -split '/', 2)[1]
+        }
+
+        # Checkout each file
+        foreach ($file in $filePaths) {
+            Set-Location -Path ${script:llamacppDir}
+            git checkout $file
+        }
+    }
+
+    # Apply each patch
+    foreach ($patch in $patches) {
+        Set-Location -Path ${script:llamacppDir}
+        git apply $patch.FullName
+    }
+
     # Avoid duplicate main symbols when we link into the cgo binary
     # Avoid duplicate main symbols when we link into the cgo binary
     $content = Get-Content -Path "${script:llamacppDir}/examples/server/server.cpp"
     $content = Get-Content -Path "${script:llamacppDir}/examples/server/server.cpp"
     $content = $content -replace 'int main\(', 'int __main('
     $content = $content -replace 'int main\(', 'int __main('
@@ -128,7 +156,7 @@ if ($null -ne $script:CUDA_LIB_DIR) {
     }
     }
     init_vars
     init_vars
     $script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
     $script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
-    $script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on")
+    $script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
     build
     build
     install
     install
     cp "${script:CUDA_LIB_DIR}/cudart64_*.dll" "${script:buildDir}/lib"
     cp "${script:CUDA_LIB_DIR}/cudart64_*.dll" "${script:buildDir}/lib"

+ 61 - 54
llm/gguf.go

@@ -69,12 +69,65 @@ type tensor struct {
 	name   string
 	name   string
 	kind   uint32
 	kind   uint32
 	offset uint64
 	offset uint64
-	size   uint64
 
 
 	// shape is the number of elements in each dimension
 	// shape is the number of elements in each dimension
 	shape [4]uint64
 	shape [4]uint64
 }
 }
 
 
+func (t tensor) blockSize() uint64 {
+	switch {
+	case t.kind < 2:
+		return 1
+	case t.kind < 10:
+		return 32
+	default:
+		return 256
+	}
+}
+
+func (t tensor) typeSize() uint64 {
+	blockSize := t.blockSize()
+
+	switch t.kind {
+	case 0: // FP32
+		return 4
+	case 1: // FP16
+		return 2
+	case 2: // Q4_0
+		return 2 + blockSize/2
+	case 3: // Q4_1
+		return 2 + 2 + blockSize/2
+	case 6: // Q5_0
+		return 2 + 4 + blockSize/2
+	case 7: // Q5_1
+		return 2 + 2 + 4 + blockSize/2
+	case 8: // Q8_0
+		return 2 + blockSize
+	case 9: // Q8_1
+		return 4 + 4 + blockSize
+	case 10: // Q2_K
+		return blockSize/16 + blockSize/4 + 2 + 2
+	case 11: // Q3_K
+		return blockSize/8 + blockSize/4 + 12 + 2
+	case 12: // Q4_K
+		return 2 + 2 + 12 + blockSize/2
+	case 13: // Q5_K
+		return 2 + 2 + 12 + blockSize/8 + blockSize/2
+	case 14: // Q6_K
+		return blockSize/2 + blockSize/4 + blockSize/16 + 2
+	default:
+		return 0
+	}
+}
+
+func (t tensor) parameters() uint64 {
+	return t.shape[0] * t.shape[1] * t.shape[2] * t.shape[3]
+}
+
+func (t tensor) size() uint64 {
+	return t.parameters() * t.typeSize() / t.blockSize()
+}
+
 type ggufModel struct {
 type ggufModel struct {
 	*containerGGUF
 	*containerGGUF
 
 
@@ -201,61 +254,15 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
 			shape[i] = llm.readU64(rso)
 			shape[i] = llm.readU64(rso)
 		}
 		}
 
 
-		kind := llm.readU32(rso)
-		offset := llm.readU64(rso)
-
-		var blockSize uint64
-		switch {
-		case kind < 2:
-			blockSize = 1
-		case kind < 10:
-			blockSize = 32
-		default:
-			blockSize = 256
-		}
-
-		var typeSize uint64
-		switch kind {
-		case 0: // FP32
-			typeSize = 4
-		case 1: // FP16
-			typeSize = 2
-		case 2: // Q4_0
-			typeSize = 2 + blockSize/2
-		case 3: // Q4_1
-			typeSize = 2 + 2 + blockSize/2
-		case 6: // Q5_0
-			typeSize = 2 + 4 + blockSize/2
-		case 7: // Q5_1
-			typeSize = 2 + 2 + 4 + blockSize/2
-		case 8: // Q8_0
-			typeSize = 2 + blockSize
-		case 9: // Q8_1
-			typeSize = 4 + 4 + blockSize
-		case 10: // Q2_K
-			typeSize = blockSize/16 + blockSize/4 + 2 + 2
-		case 11: // Q3_K
-			typeSize = blockSize/8 + blockSize/4 + 12 + 2
-		case 12: // Q4_K
-			typeSize = 2 + 2 + 12 + blockSize/2
-		case 13: // Q5_K
-			typeSize = 2 + 2 + 12 + blockSize/8 + blockSize/2
-		case 14: // Q6_K
-			typeSize = blockSize/2 + blockSize/4 + blockSize/16 + 2
-		}
-
-		parameters := shape[0] * shape[1] * shape[2] * shape[3]
-		size := parameters * typeSize / blockSize
-
-		llm.tensors = append(llm.tensors, tensor{
+		tensor := tensor{
 			name:   name,
 			name:   name,
-			kind:   kind,
-			offset: offset,
-			size:   size,
+			kind:   llm.readU32(rso),
+			offset: llm.readU64(rso),
 			shape:  shape,
 			shape:  shape,
-		})
+		}
 
 
-		llm.parameters += parameters
+		llm.tensors = append(llm.tensors, tensor)
+		llm.parameters += tensor.parameters()
 	}
 	}
 
 
 	alignment, ok := llm.kv["general.alignment"].(uint32)
 	alignment, ok := llm.kv["general.alignment"].(uint32)
@@ -265,7 +272,7 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
 
 
 	rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent)
 	rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent)
 	for _, tensor := range llm.tensors {
 	for _, tensor := range llm.tensors {
-		padded := (int64(tensor.size) + int64(alignment) - 1) & ^(int64(alignment) - 1)
+		padded := (int64(tensor.size()) + int64(alignment) - 1) & ^(int64(alignment) - 1)
 		rso.Seek(padded, io.SeekCurrent)
 		rso.Seek(padded, io.SeekCurrent)
 	}
 	}
 
 

+ 1 - 1
llm/llama.cpp

@@ -1 +1 @@
-Subproject commit 011e8ec577fd135cbc02993d3ea9840c516d6a1c
+Subproject commit cd4fddb29f81d6a1f6d51a0c016bc6b486d68def

+ 30 - 0
llm/patches/01-cache.diff

@@ -0,0 +1,30 @@
+diff --git a/examples/server/server.cpp b/examples/server/server.cpp
+index 0462fbd2..4fa7b57f 100644
+--- a/examples/server/server.cpp
++++ b/examples/server/server.cpp
+@@ -1857,12 +1857,6 @@ struct llama_server_context
+                         LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
+                     }
+ 
+-                    LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
+-
+-                    llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
+-
+-                    slot.cache_tokens = prompt_tokens;
+-
+                     if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
+                     {
+                         // we have to evaluate at least 1 token to generate logits.
+@@ -1870,6 +1864,12 @@ struct llama_server_context
+                         slot.n_past--;
+                     }
+ 
++                    LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
++
++                    llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
++
++                    slot.cache_tokens = prompt_tokens;
++
+                     LOG_VERBOSE("prompt ingested", {
+                                                     {"n_past", slot.n_past},
+                                                     {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},

+ 11 - 0
parser/parser.go

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"log/slog"
 	"log/slog"
+	"slices"
 )
 )
 
 
 type Command struct {
 type Command struct {
@@ -56,6 +57,16 @@ func Parse(reader io.Reader) ([]Command, error) {
 			command.Args = string(bytes.TrimSpace(fields[1]))
 			command.Args = string(bytes.TrimSpace(fields[1]))
 		case "EMBED":
 		case "EMBED":
 			return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
 			return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
+		case "MESSAGE":
+			command.Name = string(bytes.ToLower(fields[0]))
+			fields = bytes.SplitN(fields[1], []byte(" "), 2)
+			if len(fields) < 2 {
+				return nil, fmt.Errorf("should be in the format <role> <message>")
+			}
+			if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
+				return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
+			}
+			command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
 		default:
 		default:
 			if !bytes.HasPrefix(fields[0], []byte("#")) {
 			if !bytes.HasPrefix(fields[0], []byte("#")) {
 				// log a warning for unknown commands
 				// log a warning for unknown commands

+ 35 - 0
parser/parser_test.go

@@ -61,3 +61,38 @@ PARAMETER param1
 	assert.ErrorContains(t, err, "missing value for [param1]")
 	assert.ErrorContains(t, err, "missing value for [param1]")
 
 
 }
 }
+
+func Test_Parser_Messages(t *testing.T) {
+
+	input := `
+FROM foo
+MESSAGE system You are a Parser. Always Parse things.
+MESSAGE user Hey there!
+MESSAGE assistant Hello, I want to parse all the things!
+`
+
+	reader := strings.NewReader(input)
+	commands, err := Parse(reader)
+	assert.Nil(t, err)
+
+	expectedCommands := []Command{
+		{Name: "model", Args: "foo"},
+		{Name: "message", Args: "system: You are a Parser. Always Parse things."},
+		{Name: "message", Args: "user: Hey there!"},
+		{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
+	}
+
+	assert.Equal(t, expectedCommands, commands)
+}
+
+func Test_Parser_Messages_BadRole(t *testing.T) {
+
+	input := `
+FROM foo
+MESSAGE badguy I'm a bad guy!
+`
+
+	reader := strings.NewReader(input)
+	_, err := Parse(reader)
+	assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
+}

+ 10 - 0
scripts/build_docker.sh

@@ -13,3 +13,13 @@ docker build \
     -f Dockerfile \
     -f Dockerfile \
     -t ollama/ollama:$VERSION \
     -t ollama/ollama:$VERSION \
     .
     .
+
+docker build \
+    --load \
+    --platform=linux/amd64 \
+    --build-arg=VERSION \
+    --build-arg=GOFLAGS \
+    --target runtime-rocm \
+    -f Dockerfile \
+    -t ollama/ollama:$VERSION-rocm \
+    .

+ 74 - 43
server/download.go

@@ -25,6 +25,11 @@ import (
 	"github.com/jmorganca/ollama/format"
 	"github.com/jmorganca/ollama/format"
 )
 )
 
 
+const maxRetries = 6
+
+var errMaxRetriesExceeded = errors.New("max retries exceeded")
+var errPartStalled = errors.New("part stalled")
+
 var blobDownloadManager sync.Map
 var blobDownloadManager sync.Map
 
 
 type blobDownload struct {
 type blobDownload struct {
@@ -44,10 +49,11 @@ type blobDownload struct {
 }
 }
 
 
 type blobDownloadPart struct {
 type blobDownloadPart struct {
-	N         int
-	Offset    int64
-	Size      int64
-	Completed int64
+	N           int
+	Offset      int64
+	Size        int64
+	Completed   int64
+	lastUpdated time.Time
 
 
 	*blobDownload `json:"-"`
 	*blobDownload `json:"-"`
 }
 }
@@ -72,6 +78,13 @@ func (p *blobDownloadPart) StopsAt() int64 {
 	return p.Offset + p.Size
 	return p.Offset + p.Size
 }
 }
 
 
+func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
+	n = len(b)
+	p.blobDownload.Completed.Add(int64(n))
+	p.lastUpdated = time.Now()
+	return n, nil
+}
+
 func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
 func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
 	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
 	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
 	if err != nil {
 	if err != nil {
@@ -157,6 +170,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
 				case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
 				case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
 					// return immediately if the context is canceled or the device is out of space
 					// return immediately if the context is canceled or the device is out of space
 					return err
 					return err
+				case errors.Is(err, errPartStalled):
+					try--
+					continue
 				case err != nil:
 				case err != nil:
 					sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
 					sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
 					slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
 					slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
@@ -195,28 +211,54 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
 }
 }
 
 
 func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
 func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
-	headers := make(http.Header)
-	headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
-	resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
-	if err != nil {
-		return err
-	}
-	defer resp.Body.Close()
+	g, ctx := errgroup.WithContext(ctx)
+	g.Go(func() error {
+		headers := make(http.Header)
+		headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
+		resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
+		if err != nil {
+			return err
+		}
+		defer resp.Body.Close()
 
 
-	n, err := io.Copy(w, io.TeeReader(resp.Body, b))
-	if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
-		// rollback progress
-		b.Completed.Add(-n)
-		return err
-	}
+		n, err := io.Copy(w, io.TeeReader(resp.Body, part))
+		if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
+			// rollback progress
+			b.Completed.Add(-n)
+			return err
+		}
 
 
-	part.Completed += n
-	if err := b.writePart(part.Name(), part); err != nil {
+		part.Completed += n
+		if err := b.writePart(part.Name(), part); err != nil {
+			return err
+		}
+
+		// return nil or context.Canceled or UnexpectedEOF (resumable)
 		return err
 		return err
-	}
+	})
+
+	g.Go(func() error {
+		ticker := time.NewTicker(time.Second)
+		for {
+			select {
+			case <-ticker.C:
+				if part.Completed >= part.Size {
+					return nil
+				}
+
+				if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
+					slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
+					// reset last updated
+					part.lastUpdated = time.Time{}
+					return errPartStalled
+				}
+			case <-ctx.Done():
+				return ctx.Err()
+			}
+		}
+	})
 
 
-	// return nil or context.Canceled or UnexpectedEOF (resumable)
-	return err
+	return g.Wait()
 }
 }
 
 
 func (b *blobDownload) newPart(offset, size int64) error {
 func (b *blobDownload) newPart(offset, size int64) error {
@@ -255,12 +297,6 @@ func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error
 	return json.NewEncoder(partFile).Encode(part)
 	return json.NewEncoder(partFile).Encode(part)
 }
 }
 
 
-func (b *blobDownload) Write(p []byte) (n int, err error) {
-	n = len(p)
-	b.Completed.Add(int64(n))
-	return n, nil
-}
-
 func (b *blobDownload) acquire() {
 func (b *blobDownload) acquire() {
 	b.references.Add(1)
 	b.references.Add(1)
 }
 }
@@ -279,20 +315,19 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
 	for {
 	for {
 		select {
 		select {
 		case <-ticker.C:
 		case <-ticker.C:
+			fn(api.ProgressResponse{
+				Status:    fmt.Sprintf("pulling %s", b.Digest[7:19]),
+				Digest:    b.Digest,
+				Total:     b.Total,
+				Completed: b.Completed.Load(),
+			})
+
+			if b.done || b.err != nil {
+				return b.err
+			}
 		case <-ctx.Done():
 		case <-ctx.Done():
 			return ctx.Err()
 			return ctx.Err()
 		}
 		}
-
-		fn(api.ProgressResponse{
-			Status:    fmt.Sprintf("pulling %s", b.Digest[7:19]),
-			Digest:    b.Digest,
-			Total:     b.Total,
-			Completed: b.Completed.Load(),
-		})
-
-		if b.done || b.err != nil {
-			return b.err
-		}
 	}
 	}
 }
 }
 
 
@@ -303,10 +338,6 @@ type downloadOpts struct {
 	fn      func(api.ProgressResponse)
 	fn      func(api.ProgressResponse)
 }
 }
 
 
-const maxRetries = 6
-
-var errMaxRetriesExceeded = errors.New("max retries exceeded")
-
 // downloadBlob downloads a blob from the registry and stores it in the blobs directory
 // downloadBlob downloads a blob from the registry and stores it in the blobs directory
 func downloadBlob(ctx context.Context, opts downloadOpts) error {
 func downloadBlob(ctx context.Context, opts downloadOpts) error {
 	fp, err := GetBlobsPath(opts.digest)
 	fp, err := GetBlobsPath(opts.digest)

+ 47 - 5
server/images.go

@@ -41,7 +41,7 @@ type Model struct {
 	Config         ConfigV2
 	Config         ConfigV2
 	ShortName      string
 	ShortName      string
 	ModelPath      string
 	ModelPath      string
-	OriginalModel  string
+	ParentModel    string
 	AdapterPaths   []string
 	AdapterPaths   []string
 	ProjectorPaths []string
 	ProjectorPaths []string
 	Template       string
 	Template       string
@@ -50,6 +50,12 @@ type Model struct {
 	Digest         string
 	Digest         string
 	Size           int64
 	Size           int64
 	Options        map[string]interface{}
 	Options        map[string]interface{}
+	Messages       []Message
+}
+
+type Message struct {
+	Role    string `json:"role"`
+	Content string `json:"content"`
 }
 }
 
 
 type PromptVars struct {
 type PromptVars struct {
@@ -333,7 +339,7 @@ func GetModel(name string) (*Model, error) {
 		switch layer.MediaType {
 		switch layer.MediaType {
 		case "application/vnd.ollama.image.model":
 		case "application/vnd.ollama.image.model":
 			model.ModelPath = filename
 			model.ModelPath = filename
-			model.OriginalModel = layer.From
+			model.ParentModel = layer.From
 		case "application/vnd.ollama.image.embed":
 		case "application/vnd.ollama.image.embed":
 			// Deprecated in versions  > 0.1.2
 			// Deprecated in versions  > 0.1.2
 			// TODO: remove this warning in a future version
 			// TODO: remove this warning in a future version
@@ -374,6 +380,16 @@ func GetModel(name string) (*Model, error) {
 			if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
 			if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
+		case "application/vnd.ollama.image.messages":
+			msgs, err := os.Open(filename)
+			if err != nil {
+				return nil, err
+			}
+			defer msgs.Close()
+
+			if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
+				return nil, err
+			}
 		case "application/vnd.ollama.image.license":
 		case "application/vnd.ollama.image.license":
 			bts, err := os.ReadFile(filename)
 			bts, err := os.ReadFile(filename)
 			if err != nil {
 			if err != nil {
@@ -428,12 +444,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 	}
 	}
 
 
 	var layers Layers
 	var layers Layers
+	messages := []string{}
 
 
 	params := make(map[string][]string)
 	params := make(map[string][]string)
 	fromParams := make(map[string]any)
 	fromParams := make(map[string]any)
 
 
 	for _, c := range commands {
 	for _, c := range commands {
-		slog.Info(fmt.Sprintf("[%s] - %s", c.Name, c.Args))
 		mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
 		mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
 
 
 		switch c.Name {
 		switch c.Name {
@@ -607,11 +623,37 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 			}
 			}
 
 
 			layers.Replace(layer)
 			layers.Replace(layer)
+		case "message":
+			messages = append(messages, c.Args)
 		default:
 		default:
 			params[c.Name] = append(params[c.Name], c.Args)
 			params[c.Name] = append(params[c.Name], c.Args)
 		}
 		}
 	}
 	}
 
 
+	if len(messages) > 0 {
+		fn(api.ProgressResponse{Status: "creating parameters layer"})
+
+		msgs := make([]api.Message, 0)
+
+		for _, m := range messages {
+			// todo: handle images
+			msg := strings.SplitN(m, ": ", 2)
+			msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]})
+		}
+
+		var b bytes.Buffer
+		if err := json.NewEncoder(&b).Encode(msgs); err != nil {
+			return err
+		}
+
+		layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
+		if err != nil {
+			return err
+		}
+
+		layers.Replace(layer)
+	}
+
 	if len(params) > 0 {
 	if len(params) > 0 {
 		fn(api.ProgressResponse{Status: "creating parameters layer"})
 		fn(api.ProgressResponse{Status: "creating parameters layer"})
 
 
@@ -908,8 +950,8 @@ func ShowModelfile(model *Model) (string, error) {
 	mt.Model = model
 	mt.Model = model
 	mt.From = model.ModelPath
 	mt.From = model.ModelPath
 
 
-	if model.OriginalModel != "" {
-		mt.From = model.OriginalModel
+	if model.ParentModel != "" {
+		mt.From = model.ParentModel
 	}
 	}
 
 
 	modelFile := `# Modelfile generated by "ollama show"
 	modelFile := `# Modelfile generated by "ollama show"

+ 37 - 4
server/routes.go

@@ -186,7 +186,13 @@ func GenerateHandler(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
-	sessionDuration := defaultSessionDuration
+	var sessionDuration time.Duration
+	if req.KeepAlive == nil {
+		sessionDuration = defaultSessionDuration
+	} else {
+		sessionDuration = req.KeepAlive.Duration
+	}
+
 	if err := load(c, model, opts, sessionDuration); err != nil {
 	if err := load(c, model, opts, sessionDuration); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return
@@ -378,7 +384,14 @@ func EmbeddingHandler(c *gin.Context) {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
-	sessionDuration := defaultSessionDuration
+
+	var sessionDuration time.Duration
+	if req.KeepAlive == nil {
+		sessionDuration = defaultSessionDuration
+	} else {
+		sessionDuration = req.KeepAlive.Duration
+	}
+
 	if err := load(c, model, opts, sessionDuration); err != nil {
 	if err := load(c, model, opts, sessionDuration); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return
@@ -659,6 +672,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 	}
 	}
 
 
 	modelDetails := api.ModelDetails{
 	modelDetails := api.ModelDetails{
+		ParentModel:       model.ParentModel,
 		Format:            model.Config.ModelFormat,
 		Format:            model.Config.ModelFormat,
 		Family:            model.Config.ModelFamily,
 		Family:            model.Config.ModelFamily,
 		Families:          model.Config.ModelFamilies,
 		Families:          model.Config.ModelFamilies,
@@ -674,11 +688,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 		model.Template = req.Template
 		model.Template = req.Template
 	}
 	}
 
 
+	msgs := make([]api.Message, 0)
+	for _, msg := range model.Messages {
+		msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
+	}
+
 	resp := &api.ShowResponse{
 	resp := &api.ShowResponse{
 		License:  strings.Join(model.License, "\n"),
 		License:  strings.Join(model.License, "\n"),
 		System:   model.System,
 		System:   model.System,
 		Template: model.Template,
 		Template: model.Template,
 		Details:  modelDetails,
 		Details:  modelDetails,
+		Messages: msgs,
 	}
 	}
 
 
 	var params []string
 	var params []string
@@ -1067,7 +1087,14 @@ func ChatHandler(c *gin.Context) {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return
 	}
 	}
-	sessionDuration := defaultSessionDuration
+
+	var sessionDuration time.Duration
+	if req.KeepAlive == nil {
+		sessionDuration = defaultSessionDuration
+	} else {
+		sessionDuration = req.KeepAlive.Duration
+	}
+
 	if err := load(c, model, opts, sessionDuration); err != nil {
 	if err := load(c, model, opts, sessionDuration); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 		return
@@ -1075,7 +1102,13 @@ func ChatHandler(c *gin.Context) {
 
 
 	// an empty request loads the model
 	// an empty request loads the model
 	if len(req.Messages) == 0 {
 	if len(req.Messages) == 0 {
-		c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true, Message: api.Message{Role: "assistant"}})
+		resp := api.ChatResponse{
+			CreatedAt: time.Now().UTC(),
+			Model:     req.Model,
+			Done:      true,
+			Message:   api.Message{Role: "assistant"},
+		}
+		c.JSON(http.StatusOK, resp)
 		return
 		return
 	}
 	}