Forráskód Böngészése

Merge branch 'ollama:main' into main

mraiser 1 éve
szülő
commit
4c4c730a0a

+ 77 - 21
.github/workflows/test.yaml

@@ -23,29 +23,72 @@ jobs:
         with:
           go-version: '1.21'
           cache: true
-      - if: ${{ startsWith(matrix.os, 'windows-') }}
-        shell: pwsh
-        run: |
-          $path = vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath
-          if ($path) {
-              $path = join-path $path 'Common7\Tools\vsdevcmd.bat'
-              if (test-path $path) {
-                  cmd /s /c """$path"" $args && set" | where { $_ -match '(\w+)=(.*)' } | foreach {
-                      echo "$($Matches[1])=$($Matches[2])" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
-                  }
-              }
-          }
-
-          echo "C:\Program Files\Git\usr\bin" | Out-File -FilePath $Env:GITHUB_PATH -Encoding utf8 -Append
       - run: go get ./...
       - run: go generate -x ./...
       - uses: actions/upload-artifact@v4
         with:
           name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
-          path: |
-            llm/llama.cpp/build/**/lib/*
+          path: llm/llama.cpp/build/**/lib/*
+  generate-cuda:
+    strategy:
+      matrix:
+        cuda-version:
+          - '11.8.0'
+    runs-on: ubuntu-latest
+    container: nvidia/cuda:${{ matrix.cuda-version }}-devel-ubuntu20.04
+    steps:
+      - run: |
+          apt-get update && apt-get install -y git build-essential curl
+          curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \
+            | tar -zx -C /usr --strip-components 1
+        env:
+          DEBIAN_FRONTEND: noninteractive
+      - uses: actions/checkout@v4
+      - uses: actions/setup-go@v4
+        with:
+          go-version: '1.21'
+          cache: true
+      - run: go get ./...
+      - run: |
+          git config --global --add safe.directory /__w/ollama/ollama
+          go generate -x ./...
+        env:
+          OLLAMA_SKIP_CPU_GENERATE: '1'
+      - uses: actions/upload-artifact@v4
+        with:
+          name: cuda-${{ matrix.cuda-version }}-libraries
+          path: llm/llama.cpp/build/**/lib/*
+  generate-rocm:
+    strategy:
+      matrix:
+        rocm-version:
+          - '5.7.1'
+          - '6.0'
+    runs-on: ubuntu-latest
+    container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
+    steps:
+      - run: |
+          apt-get update && apt-get install -y git build-essential curl rocm-libs
+          curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \
+            | tar -zx -C /usr --strip-components 1
+        env:
+          DEBIAN_FRONTEND: noninteractive
+      - uses: actions/checkout@v4
+      - uses: actions/setup-go@v4
+        with:
+          go-version: '1.21'
+          cache: true
+      - run: go get ./...
+      - run: |
+          git config --global --add safe.directory /__w/ollama/ollama
+          go generate -x ./...
+        env:
+          OLLAMA_SKIP_CPU_GENERATE: '1'
+      - uses: actions/upload-artifact@v4
+        with:
+          name: rocm-${{ matrix.rocm-version }}-libraries
+          path: llm/llama.cpp/build/**/lib/*
   lint:
-    needs: generate
     strategy:
       matrix:
         os: [ubuntu-latest, macos-latest, windows-latest]
@@ -69,10 +112,19 @@ jobs:
         with:
           go-version: '1.21'
           cache: false
-      - uses: actions/download-artifact@v4
-        with:
-          name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
-          path: llm/llama.cpp/build
+      - run: |
+          mkdir -p llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/
+          touch llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/stub.so
+        if: ${{ startsWith(matrix.os, 'ubuntu-') }}
+      - run: |
+          mkdir -p llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/
+          touch llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/stub.dylib
+          touch llm/llama.cpp/ggml-metal.metal
+        if: ${{ startsWith(matrix.os, 'macos-') }}
+      - run: |
+          mkdir -p llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/
+          touch llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/stub.dll
+        if: ${{ startsWith(matrix.os, 'windows-') }}
       - uses: golangci/golangci-lint-action@v3
   test:
     needs: generate
@@ -104,3 +156,7 @@ jobs:
           path: llm/llama.cpp/build
       - run: go build
       - run: go test -v ./...
+      - uses: actions/upload-artifact@v4
+        with:
+          name: ${{ matrix.os }}-binaries
+          path: ollama

+ 13 - 2
Dockerfile

@@ -109,17 +109,28 @@ ARG CGO_CFLAGS
 RUN go build .
 
 # Runtime stages
-FROM --platform=linux/amd64 rocm/dev-centos-7:6.0-complete as runtime-amd64
+FROM --platform=linux/amd64 ubuntu:22.04 as runtime-amd64
+RUN apt-get update && apt-get install -y ca-certificates
 COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
 FROM --platform=linux/arm64 ubuntu:22.04 as runtime-arm64
 RUN apt-get update && apt-get install -y ca-certificates
 COPY --from=build-arm64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
 
+# Radeon images are much larger so we keep it distinct from the CPU/CUDA image
+FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete as runtime-rocm
+RUN update-pciids
+COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
+EXPOSE 11434
+ENV OLLAMA_HOST 0.0.0.0
+
+ENTRYPOINT ["/bin/ollama"]
+CMD ["serve"]
+
 FROM runtime-$TARGETARCH
 EXPOSE 11434
 ENV OLLAMA_HOST 0.0.0.0
 ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
-ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/opt/rocm/lib:
+ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
 ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
 
 ENTRYPOINT ["/bin/ollama"]

+ 27 - 17
api/types.go

@@ -34,24 +34,26 @@ func (e StatusError) Error() string {
 type ImageData []byte
 
 type GenerateRequest struct {
-	Model    string      `json:"model"`
-	Prompt   string      `json:"prompt"`
-	System   string      `json:"system"`
-	Template string      `json:"template"`
-	Context  []int       `json:"context,omitempty"`
-	Stream   *bool       `json:"stream,omitempty"`
-	Raw      bool        `json:"raw,omitempty"`
-	Format   string      `json:"format"`
-	Images   []ImageData `json:"images,omitempty"`
+	Model     string      `json:"model"`
+	Prompt    string      `json:"prompt"`
+	System    string      `json:"system"`
+	Template  string      `json:"template"`
+	Context   []int       `json:"context,omitempty"`
+	Stream    *bool       `json:"stream,omitempty"`
+	Raw       bool        `json:"raw,omitempty"`
+	Format    string      `json:"format"`
+	KeepAlive *Duration   `json:"keep_alive,omitempty"`
+	Images    []ImageData `json:"images,omitempty"`
 
 	Options map[string]interface{} `json:"options"`
 }
 
 type ChatRequest struct {
-	Model    string    `json:"model"`
-	Messages []Message `json:"messages"`
-	Stream   *bool     `json:"stream,omitempty"`
-	Format   string    `json:"format"`
+	Model     string    `json:"model"`
+	Messages  []Message `json:"messages"`
+	Stream    *bool     `json:"stream,omitempty"`
+	Format    string    `json:"format"`
+	KeepAlive *Duration `json:"keep_alive,omitempty"`
 
 	Options map[string]interface{} `json:"options"`
 }
@@ -126,8 +128,9 @@ type Runner struct {
 }
 
 type EmbeddingRequest struct {
-	Model  string `json:"model"`
-	Prompt string `json:"prompt"`
+	Model     string    `json:"model"`
+	Prompt    string    `json:"prompt"`
+	KeepAlive *Duration `json:"keep_alive,omitempty"`
 
 	Options map[string]interface{} `json:"options"`
 }
@@ -171,6 +174,7 @@ type ShowResponse struct {
 	Template   string       `json:"template,omitempty"`
 	System     string       `json:"system,omitempty"`
 	Details    ModelDetails `json:"details,omitempty"`
+	Messages   []Message    `json:"messages,omitempty"`
 }
 
 type CopyRequest struct {
@@ -236,6 +240,7 @@ type GenerateResponse struct {
 }
 
 type ModelDetails struct {
+	ParentModel       string   `json:"parent_model"`
 	Format            string   `json:"format"`
 	Family            string   `json:"family"`
 	Families          []string `json:"families"`
@@ -411,14 +416,19 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
 	case float64:
 		if t < 0 {
 			t = math.MaxFloat64
+			d.Duration = time.Duration(t)
+		} else {
+			d.Duration = time.Duration(t * float64(time.Second))
 		}
-
-		d.Duration = time.Duration(t)
 	case string:
 		d.Duration, err = time.ParseDuration(t)
 		if err != nil {
 			return err
 		}
+		if d.Duration < 0 {
+			mf := math.MaxFloat64
+			d.Duration = time.Duration(mf)
+		}
 	}
 
 	return nil

+ 11 - 9
cmd/cmd.go

@@ -458,15 +458,17 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
 type generateContextKey string
 
 type runOptions struct {
-	Model    string
-	Prompt   string
-	Messages []api.Message
-	WordWrap bool
-	Format   string
-	System   string
-	Template string
-	Images   []api.ImageData
-	Options  map[string]interface{}
+	Model       string
+	ParentModel string
+	Prompt      string
+	Messages    []api.Message
+	WordWrap    bool
+	Format      string
+	System      string
+	Template    string
+	Images      []api.ImageData
+	Options     map[string]interface{}
+	MultiModal  bool
 }
 
 type displayResponseState struct {

+ 127 - 24
cmd/interactive.go

@@ -7,12 +7,14 @@ import (
 	"net/http"
 	"os"
 	"regexp"
+	"sort"
 	"strings"
 
 	"github.com/spf13/cobra"
 	"golang.org/x/exp/slices"
 
 	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/progress"
 	"github.com/jmorganca/ollama/readline"
 )
 
@@ -25,43 +27,75 @@ const (
 	MultilineTemplate
 )
 
-func modelIsMultiModal(cmd *cobra.Command, name string) bool {
-	// get model details
+func loadModel(cmd *cobra.Command, opts *runOptions) error {
 	client, err := api.ClientFromEnvironment()
 	if err != nil {
-		fmt.Println("error: couldn't connect to ollama server")
-		return false
+		return err
 	}
 
-	req := api.ShowRequest{Name: name}
-	resp, err := client.Show(cmd.Context(), &req)
+	p := progress.NewProgress(os.Stderr)
+	defer p.StopAndClear()
+
+	spinner := progress.NewSpinner("")
+	p.Add("", spinner)
+
+	showReq := api.ShowRequest{Name: opts.Model}
+	showResp, err := client.Show(cmd.Context(), &showReq)
 	if err != nil {
-		return false
+		return err
 	}
+	opts.MultiModal = slices.Contains(showResp.Details.Families, "clip")
+	opts.ParentModel = showResp.Details.ParentModel
 
-	return slices.Contains(resp.Details.Families, "clip")
-}
-
-func generateInteractive(cmd *cobra.Command, opts runOptions) error {
-	multiModal := modelIsMultiModal(cmd, opts.Model)
+	if len(showResp.Messages) > 0 {
+		opts.Messages = append(opts.Messages, showResp.Messages...)
+	}
 
-	// load the model
-	loadOpts := runOptions{
+	chatReq := &api.ChatRequest{
 		Model:    opts.Model,
-		Prompt:   "",
 		Messages: []api.Message{},
 	}
-	if _, err := chat(cmd, loadOpts); err != nil {
+	err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
+		p.StopAndClear()
+		if len(opts.Messages) > 0 {
+			for _, msg := range opts.Messages {
+				switch msg.Role {
+				case "user":
+					fmt.Printf(">>> %s\n", msg.Content)
+				case "assistant":
+					state := &displayResponseState{}
+					displayResponse(msg.Content, opts.WordWrap, state)
+					fmt.Println()
+					fmt.Println()
+				}
+			}
+		}
+		return nil
+	})
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func generateInteractive(cmd *cobra.Command, opts runOptions) error {
+	opts.Messages = make([]api.Message, 0)
+
+	err := loadModel(cmd, &opts)
+	if err != nil {
 		return err
 	}
 
 	usage := func() {
 		fmt.Fprintln(os.Stderr, "Available Commands:")
-		fmt.Fprintln(os.Stderr, "  /set          Set session variables")
-		fmt.Fprintln(os.Stderr, "  /show         Show model information")
-		fmt.Fprintln(os.Stderr, "  /bye          Exit")
-		fmt.Fprintln(os.Stderr, "  /?, /help     Help for a command")
-		fmt.Fprintln(os.Stderr, "  /? shortcuts  Help for keyboard shortcuts")
+		fmt.Fprintln(os.Stderr, "  /set            Set session variables")
+		fmt.Fprintln(os.Stderr, "  /show           Show model information")
+		fmt.Fprintln(os.Stderr, "  /load <model>   Load a session or model")
+		fmt.Fprintln(os.Stderr, "  /save <model>   Save your current session")
+		fmt.Fprintln(os.Stderr, "  /bye            Exit")
+		fmt.Fprintln(os.Stderr, "  /?, /help       Help for a command")
+		fmt.Fprintln(os.Stderr, "  /? shortcuts    Help for keyboard shortcuts")
 		fmt.Fprintln(os.Stderr, "")
 		fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
 		fmt.Fprintln(os.Stderr, "")
@@ -140,7 +174,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 
 	var sb strings.Builder
 	var multiline MultilineState
-	opts.Messages = make([]api.Message, 0)
 
 	for {
 		line, err := scanner.Readline()
@@ -203,6 +236,44 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 			if err := ListHandler(cmd, args[1:]); err != nil {
 				return err
 			}
+		case strings.HasPrefix(line, "/load"):
+			args := strings.Fields(line)
+			if len(args) != 2 {
+				fmt.Println("Usage:\n  /load <modelname>")
+				continue
+			}
+			opts.Model = args[1]
+			opts.Messages = []api.Message{}
+			fmt.Printf("Loading model '%s'\n", opts.Model)
+			if err := loadModel(cmd, &opts); err != nil {
+				return err
+			}
+			continue
+		case strings.HasPrefix(line, "/save"):
+			args := strings.Fields(line)
+			if len(args) != 2 {
+				fmt.Println("Usage:\n  /save <modelname>")
+				continue
+			}
+
+			client, err := api.ClientFromEnvironment()
+			if err != nil {
+				fmt.Println("error: couldn't connect to ollama server")
+				return err
+			}
+
+			req := &api.CreateRequest{
+				Name:      args[1],
+				Modelfile: buildModelfile(opts),
+			}
+			fn := func(resp api.ProgressResponse) error { return nil }
+			err = client.Create(cmd.Context(), req, fn)
+			if err != nil {
+				fmt.Println("error: couldn't save model")
+				return err
+			}
+			fmt.Printf("Created new model '%s'\n", args[1])
+			continue
 		case strings.HasPrefix(line, "/set"):
 			args := strings.Fields(line)
 			if len(args) > 1 {
@@ -389,7 +460,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 			args := strings.Fields(line)
 			isFile := false
 
-			if multiModal {
+			if opts.MultiModal {
 				for _, f := range extractFileNames(line) {
 					if strings.HasPrefix(f, args[0]) {
 						isFile = true
@@ -411,7 +482,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 		if sb.Len() > 0 && multiline == MultilineNone {
 			newMessage := api.Message{Role: "user", Content: sb.String()}
 
-			if multiModal {
+			if opts.MultiModal {
 				msg, images, err := extractFileData(sb.String())
 				if err != nil {
 					return err
@@ -454,6 +525,38 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 	}
 }
 
+func buildModelfile(opts runOptions) string {
+	var mf strings.Builder
+	model := opts.ParentModel
+	if model == "" {
+		model = opts.Model
+	}
+	fmt.Fprintf(&mf, "FROM %s\n", model)
+	if opts.System != "" {
+		fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
+	}
+
+	if opts.Template != "" {
+		fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
+	}
+
+	keys := make([]string, 0)
+	for k := range opts.Options {
+		keys = append(keys, k)
+	}
+	sort.Strings(keys)
+	for _, k := range keys {
+		fmt.Fprintf(&mf, "PARAMETER %s %v\n", k, opts.Options[k])
+	}
+	fmt.Fprintln(&mf)
+
+	for _, msg := range opts.Messages {
+		fmt.Fprintf(&mf, "MESSAGE %s \"\"\"%s\"\"\"\n", msg.Role, msg.Content)
+	}
+
+	return mf.String()
+}
+
 func normalizeFilePath(fp string) string {
 	// Define a map of escaped characters and their replacements
 	replacements := map[string]string{

+ 65 - 0
cmd/interactive_test.go

@@ -1,9 +1,13 @@
 package cmd
 
 import (
+	"bytes"
 	"testing"
+	"text/template"
 
 	"github.com/stretchr/testify/assert"
+
+	"github.com/jmorganca/ollama/api"
 )
 
 func TestExtractFilenames(t *testing.T) {
@@ -49,3 +53,64 @@ d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
 	assert.Contains(t, res[9], "ten.svg")
 	assert.Contains(t, res[9], "E:")
 }
+
+func TestModelfileBuilder(t *testing.T) {
+	opts := runOptions{
+		Model:    "hork",
+		System:   "You are part horse and part shark, but all hork. Do horklike things",
+		Template: "This is a template.",
+		Messages: []api.Message{
+			{Role: "user", Content: "Hey there hork!"},
+			{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
+		},
+		Options: map[string]interface{}{},
+	}
+
+	opts.Options["temperature"] = 0.9
+	opts.Options["seed"] = 42
+	opts.Options["penalize_newline"] = false
+	opts.Options["stop"] = []string{"hi", "there"}
+
+	mf := buildModelfile(opts)
+	expectedModelfile := `FROM {{.Model}}
+SYSTEM """{{.System}}"""
+TEMPLATE """{{.Template}}"""
+PARAMETER penalize_newline false
+PARAMETER seed 42
+PARAMETER stop [hi there]
+PARAMETER temperature 0.9
+
+MESSAGE user """Hey there hork!"""
+MESSAGE assistant """Yes it is true, I am half horse, half shark."""
+`
+
+	tmpl, err := template.New("").Parse(expectedModelfile)
+	assert.Nil(t, err)
+
+	var buf bytes.Buffer
+	err = tmpl.Execute(&buf, opts)
+	assert.Nil(t, err)
+	assert.Equal(t, buf.String(), mf)
+
+	opts.ParentModel = "horseshark"
+	mf = buildModelfile(opts)
+	expectedModelfile = `FROM {{.ParentModel}}
+SYSTEM """{{.System}}"""
+TEMPLATE """{{.Template}}"""
+PARAMETER penalize_newline false
+PARAMETER seed 42
+PARAMETER stop [hi there]
+PARAMETER temperature 0.9
+
+MESSAGE user """Hey there hork!"""
+MESSAGE assistant """Yes it is true, I am half horse, half shark."""
+`
+
+	tmpl, err = template.New("").Parse(expectedModelfile)
+	assert.Nil(t, err)
+
+	var parentBuf bytes.Buffer
+	err = tmpl.Execute(&parentBuf, opts)
+	assert.Nil(t, err)
+	assert.Equal(t, parentBuf.String(), mf)
+}

+ 2 - 1
docs/development.md

@@ -50,7 +50,8 @@ development and runtime packages.
 Typically the build scripts will auto-detect CUDA, however, if your Linux distro
 or installation approach uses unusual paths, you can specify the location by
 specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
-libraries, and `CUDACXX` to the location of the nvcc compiler.
+libraries, and `CUDACXX` to the location of the nvcc compiler.  You can customize
+set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
 
 Then generate dependencies:
 

+ 15 - 0
docs/modelfile.md

@@ -19,6 +19,7 @@ A model file is the blueprint to create and share models with Ollama.
   - [SYSTEM](#system)
   - [ADAPTER](#adapter)
   - [LICENSE](#license)
+  - [MESSAGE](#message)
 - [Notes](#notes)
 
 ## Format
@@ -38,6 +39,7 @@ INSTRUCTION arguments
 | [`SYSTEM`](#system)                 | Specifies the system message that will be set in the template. |
 | [`ADAPTER`](#adapter)               | Defines the (Q)LoRA adapters to apply to the model.            |
 | [`LICENSE`](#license)               | Specifies the legal license.                                   |
+| [`MESSAGE`](#message)               | Specify message history.                                       |
 
 ## Examples
 
@@ -205,6 +207,19 @@ LICENSE """
 """
 ```
 
+### MESSAGE
+
+The `MESSAGE` instruction allows you to specify a message history for the model to use when responding:
+
+```modelfile
+MESSAGE user Is Toronto in Canada?
+MESSAGE assistant yes
+MESSAGE user Is Sacramento in Canada?
+MESSAGE assistant no
+MESSAGE user Is Ontario in Canada?
+MESSAGE assistant yes
+```
+
 ## Notes
 
 - the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.

+ 37 - 7
gpu/gpu.go

@@ -16,6 +16,7 @@ import (
 	"os"
 	"path/filepath"
 	"runtime"
+	"strconv"
 	"strings"
 	"sync"
 	"unsafe"
@@ -29,8 +30,8 @@ type handles struct {
 var gpuMutex sync.Mutex
 var gpuHandles *handles = nil
 
-// With our current CUDA compile flags, 5.2 and older will not work properly
-const CudaComputeMajorMin = 6
+// With our current CUDA compile flags, older than 5.0 will not work properly
+var CudaComputeMin = [2]C.int{5, 0}
 
 // Possible locations for the nvidia-ml library
 var CudaLinuxGlobs = []string{
@@ -121,9 +122,15 @@ func GetGPUInfo() GpuInfo {
 		initGPUHandles()
 	}
 
+	// All our GPU builds have AVX enabled, so fallback to CPU if we don't detect at least AVX
+	cpuVariant := GetCPUVariant()
+	if cpuVariant == "" {
+		slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
+	}
+
 	var memInfo C.mem_info_t
 	resp := GpuInfo{}
-	if gpuHandles.cuda != nil {
+	if gpuHandles.cuda != nil && cpuVariant != "" {
 		C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
 		if memInfo.err != nil {
 			slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err)))
@@ -135,19 +142,40 @@ func GetGPUInfo() GpuInfo {
 			if cc.err != nil {
 				slog.Info(fmt.Sprintf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err)))
 				C.free(unsafe.Pointer(cc.err))
-			} else if cc.major >= CudaComputeMajorMin {
+			} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
 				slog.Info(fmt.Sprintf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
 				resp.Library = "cuda"
 			} else {
 				slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
 			}
 		}
-	} else if gpuHandles.rocm != nil {
+	} else if gpuHandles.rocm != nil && cpuVariant != "" {
 		C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
 		if memInfo.err != nil {
 			slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err)))
 			C.free(unsafe.Pointer(memInfo.err))
+		} else if memInfo.igpu_index >= 0 && memInfo.count == 1 {
+			// Only one GPU detected and it appears to be an integrated GPU - skip it
+			slog.Info("ROCm unsupported integrated GPU detected")
 		} else {
+			if memInfo.igpu_index >= 0 {
+				// We have multiple GPUs reported, and one of them is an integrated GPU
+				// so we have to set the env var to bypass it
+				// If the user has specified their own ROCR_VISIBLE_DEVICES, don't clobber it
+				val := os.Getenv("ROCR_VISIBLE_DEVICES")
+				if val == "" {
+					devices := []string{}
+					for i := 0; i < int(memInfo.count); i++ {
+						if i == int(memInfo.igpu_index) {
+							continue
+						}
+						devices = append(devices, strconv.Itoa(i))
+					}
+					val = strings.Join(devices, ",")
+					os.Setenv("ROCR_VISIBLE_DEVICES", val)
+				}
+				slog.Info(fmt.Sprintf("ROCm integrated GPU detected - ROCR_VISIBLE_DEVICES=%s", val))
+			}
 			resp.Library = "rocm"
 			var version C.rocm_version_resp_t
 			C.rocm_get_version(*gpuHandles.rocm, &version)
@@ -163,7 +191,7 @@ func GetGPUInfo() GpuInfo {
 	if resp.Library == "" {
 		C.cpu_check_ram(&memInfo)
 		resp.Library = "cpu"
-		resp.Variant = GetCPUVariant()
+		resp.Variant = cpuVariant
 	}
 	if memInfo.err != nil {
 		slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
@@ -199,7 +227,9 @@ func CheckVRAM() (int64, error) {
 		if overhead < gpus*1024*1024*1024 {
 			overhead = gpus * 1024 * 1024 * 1024
 		}
-		return int64(gpuInfo.FreeMemory - overhead), nil
+		avail := int64(gpuInfo.FreeMemory - overhead)
+		slog.Debug(fmt.Sprintf("%s detected %d devices with %dM available memory", gpuInfo.Library, gpuInfo.DeviceCount, avail/1024/1024))
+		return avail, nil
 	}
 
 	return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation

+ 1 - 0
gpu/gpu_info.h

@@ -42,6 +42,7 @@ typedef struct mem_info {
   uint64_t total;
   uint64_t free;
   unsigned int count;
+  int igpu_index; // If >= 0, we detected an integrated GPU to ignore
   char *err;  // If non-nill, caller responsible for freeing
 } mem_info_t;
 

+ 1 - 0
gpu/gpu_info_cuda.c

@@ -70,6 +70,7 @@ void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
     resp->ch.handle = NULL;
     snprintf(buf, buflen, "nvml vram init failure: %d", ret);
     resp->err = strdup(buf);
+    return;
   }
 
   // Report driver version if we're in verbose mode, ignore errors

+ 11 - 4
gpu/gpu_info_rocm.c

@@ -77,6 +77,7 @@ void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
 
 void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
   resp->err = NULL;
+  resp->igpu_index = -1;
   uint64_t totalMem = 0;
   uint64_t usedMem = 0;
   rsmi_status_t ret;
@@ -162,8 +163,14 @@ void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
     }
     LOG(h.verbose, "[%d] ROCm totalMem %ld\n", i, totalMem);
     LOG(h.verbose, "[%d] ROCm usedMem %ld\n", i, usedMem);
-    resp->total += totalMem;
-    resp->free += totalMem - usedMem;
+    if (totalMem < 1024 * 1024 * 1024) {
+      // Do not add up integrated GPU memory capacity, it's a bogus 512M, and actually uses system memory
+      LOG(h.verbose, "[%d] ROCm integrated GPU\n", i);
+      resp->igpu_index = i;
+    } else {
+      resp->total += totalMem;
+      resp->free += totalMem - usedMem;
+    }
   }
 }
 
@@ -171,7 +178,7 @@ void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) {
   const int buflen = 256;
   char buf[buflen + 1];
   if (h.handle == NULL) {
-    resp->str = strdup("nvml handle not initialized");
+    resp->str = strdup("rocm handle not initialized");
     resp->status = 1;
     return;
   }
@@ -188,4 +195,4 @@ void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) {
   resp->str = strdup(buf);
 }
 
-#endif  // __APPLE__
+#endif  // __APPLE__

+ 1 - 0
llm/dyn_ext_server.go

@@ -190,6 +190,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
 		"seed":              predict.Options.Seed,
 		"stop":              predict.Options.Stop,
 		"image_data":        imageData,
+		"cache_prompt":      true,
 	}
 
 	if predict.Format == "json" {

+ 14 - 0
llm/generate/gen_common.sh

@@ -39,6 +39,9 @@ init_vars() {
     *)
         ;;
     esac
+    if [ -z "${CMAKE_CUDA_ARCHITECTURES}" ] ; then 
+        CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
+    fi
 }
 
 git_module_setup() {
@@ -61,6 +64,17 @@ apply_patches() {
     if ! grep ollama ${LLAMACPP_DIR}/examples/server/CMakeLists.txt; then
         echo 'include (../../../ext_server/CMakeLists.txt) # ollama' >>${LLAMACPP_DIR}/examples/server/CMakeLists.txt
     fi
+
+    # apply temporary patches until fix is upstream
+    for patch in ../patches/*.diff; do
+        for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do
+            (cd ${LLAMACPP_DIR}; git checkout ${file})
+        done
+    done
+    for patch in ../patches/*.diff; do
+        (cd ${LLAMACPP_DIR} && git apply ${patch})
+    done
+
     # Avoid duplicate main symbols when we link into the cgo binary
     sed -e 's/int main(/int __main(/g' <${LLAMACPP_DIR}/examples/server/server.cpp >${LLAMACPP_DIR}/examples/server/server.cpp.tmp &&
         mv ${LLAMACPP_DIR}/examples/server/server.cpp.tmp ${LLAMACPP_DIR}/examples/server/server.cpp

+ 1 - 1
llm/generate/gen_linux.sh

@@ -140,7 +140,7 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
     if [ -n "${CUDA_MAJOR}" ]; then
         CUDA_VARIANT=_v${CUDA_MAJOR}
     fi
-    CMAKE_DEFS="-DLLAMA_CUBLAS=on ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
+    CMAKE_DEFS="-DLLAMA_CUBLAS=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
     BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cuda${CUDA_VARIANT}"
     EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
     build

+ 29 - 1
llm/generate/gen_windows.ps1

@@ -25,6 +25,11 @@ function init_vars {
     }
     $script:GZIP=(get-command -ea 'silentlycontinue' gzip).path
     $script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
+    if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) {
+        $script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
+    } else {
+        $script:CMAKE_CUDA_ARCHITECTURES=$env:CMAKE_CUDA_ARCHITECTURES
+    }
 }
 
 function git_module_setup {
@@ -40,6 +45,29 @@ function apply_patches {
     if (!(Select-String -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Pattern 'ollama')) {
         Add-Content -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Value 'include (../../../ext_server/CMakeLists.txt) # ollama'
     }
+
+    # Apply temporary patches until fix is upstream
+    $patches = Get-ChildItem "../patches/*.diff"
+    foreach ($patch in $patches) {
+        # Extract file paths from the patch file
+        $filePaths = Get-Content $patch.FullName | Where-Object { $_ -match '^\+\+\+ ' } | ForEach-Object {
+            $parts = $_ -split ' '
+            ($parts[1] -split '/', 2)[1]
+        }
+
+        # Checkout each file
+        foreach ($file in $filePaths) {
+            Set-Location -Path ${script:llamacppDir}
+            git checkout $file
+        }
+    }
+
+    # Apply each patch
+    foreach ($patch in $patches) {
+        Set-Location -Path ${script:llamacppDir}
+        git apply $patch.FullName
+    }
+
     # Avoid duplicate main symbols when we link into the cgo binary
     $content = Get-Content -Path "${script:llamacppDir}/examples/server/server.cpp"
     $content = $content -replace 'int main\(', 'int __main('
@@ -128,7 +156,7 @@ if ($null -ne $script:CUDA_LIB_DIR) {
     }
     init_vars
     $script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
-    $script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on")
+    $script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
     build
     install
     cp "${script:CUDA_LIB_DIR}/cudart64_*.dll" "${script:buildDir}/lib"

+ 61 - 54
llm/gguf.go

@@ -69,12 +69,65 @@ type tensor struct {
 	name   string
 	kind   uint32
 	offset uint64
-	size   uint64
 
 	// shape is the number of elements in each dimension
 	shape [4]uint64
 }
 
+func (t tensor) blockSize() uint64 {
+	switch {
+	case t.kind < 2:
+		return 1
+	case t.kind < 10:
+		return 32
+	default:
+		return 256
+	}
+}
+
+func (t tensor) typeSize() uint64 {
+	blockSize := t.blockSize()
+
+	switch t.kind {
+	case 0: // FP32
+		return 4
+	case 1: // FP16
+		return 2
+	case 2: // Q4_0
+		return 2 + blockSize/2
+	case 3: // Q4_1
+		return 2 + 2 + blockSize/2
+	case 6: // Q5_0
+		return 2 + 4 + blockSize/2
+	case 7: // Q5_1
+		return 2 + 2 + 4 + blockSize/2
+	case 8: // Q8_0
+		return 2 + blockSize
+	case 9: // Q8_1
+		return 4 + 4 + blockSize
+	case 10: // Q2_K
+		return blockSize/16 + blockSize/4 + 2 + 2
+	case 11: // Q3_K
+		return blockSize/8 + blockSize/4 + 12 + 2
+	case 12: // Q4_K
+		return 2 + 2 + 12 + blockSize/2
+	case 13: // Q5_K
+		return 2 + 2 + 12 + blockSize/8 + blockSize/2
+	case 14: // Q6_K
+		return blockSize/2 + blockSize/4 + blockSize/16 + 2
+	default:
+		return 0
+	}
+}
+
+func (t tensor) parameters() uint64 {
+	return t.shape[0] * t.shape[1] * t.shape[2] * t.shape[3]
+}
+
+func (t tensor) size() uint64 {
+	return t.parameters() * t.typeSize() / t.blockSize()
+}
+
 type ggufModel struct {
 	*containerGGUF
 
@@ -201,61 +254,15 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
 			shape[i] = llm.readU64(rso)
 		}
 
-		kind := llm.readU32(rso)
-		offset := llm.readU64(rso)
-
-		var blockSize uint64
-		switch {
-		case kind < 2:
-			blockSize = 1
-		case kind < 10:
-			blockSize = 32
-		default:
-			blockSize = 256
-		}
-
-		var typeSize uint64
-		switch kind {
-		case 0: // FP32
-			typeSize = 4
-		case 1: // FP16
-			typeSize = 2
-		case 2: // Q4_0
-			typeSize = 2 + blockSize/2
-		case 3: // Q4_1
-			typeSize = 2 + 2 + blockSize/2
-		case 6: // Q5_0
-			typeSize = 2 + 4 + blockSize/2
-		case 7: // Q5_1
-			typeSize = 2 + 2 + 4 + blockSize/2
-		case 8: // Q8_0
-			typeSize = 2 + blockSize
-		case 9: // Q8_1
-			typeSize = 4 + 4 + blockSize
-		case 10: // Q2_K
-			typeSize = blockSize/16 + blockSize/4 + 2 + 2
-		case 11: // Q3_K
-			typeSize = blockSize/8 + blockSize/4 + 12 + 2
-		case 12: // Q4_K
-			typeSize = 2 + 2 + 12 + blockSize/2
-		case 13: // Q5_K
-			typeSize = 2 + 2 + 12 + blockSize/8 + blockSize/2
-		case 14: // Q6_K
-			typeSize = blockSize/2 + blockSize/4 + blockSize/16 + 2
-		}
-
-		parameters := shape[0] * shape[1] * shape[2] * shape[3]
-		size := parameters * typeSize / blockSize
-
-		llm.tensors = append(llm.tensors, tensor{
+		tensor := tensor{
 			name:   name,
-			kind:   kind,
-			offset: offset,
-			size:   size,
+			kind:   llm.readU32(rso),
+			offset: llm.readU64(rso),
 			shape:  shape,
-		})
+		}
 
-		llm.parameters += parameters
+		llm.tensors = append(llm.tensors, tensor)
+		llm.parameters += tensor.parameters()
 	}
 
 	alignment, ok := llm.kv["general.alignment"].(uint32)
@@ -265,7 +272,7 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
 
 	rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent)
 	for _, tensor := range llm.tensors {
-		padded := (int64(tensor.size) + int64(alignment) - 1) & ^(int64(alignment) - 1)
+		padded := (int64(tensor.size()) + int64(alignment) - 1) & ^(int64(alignment) - 1)
 		rso.Seek(padded, io.SeekCurrent)
 	}
 

+ 1 - 1
llm/llama.cpp

@@ -1 +1 @@
-Subproject commit 011e8ec577fd135cbc02993d3ea9840c516d6a1c
+Subproject commit cd4fddb29f81d6a1f6d51a0c016bc6b486d68def

+ 30 - 0
llm/patches/01-cache.diff

@@ -0,0 +1,30 @@
+diff --git a/examples/server/server.cpp b/examples/server/server.cpp
+index 0462fbd2..4fa7b57f 100644
+--- a/examples/server/server.cpp
++++ b/examples/server/server.cpp
+@@ -1857,12 +1857,6 @@ struct llama_server_context
+                         LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
+                     }
+ 
+-                    LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
+-
+-                    llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
+-
+-                    slot.cache_tokens = prompt_tokens;
+-
+                     if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
+                     {
+                         // we have to evaluate at least 1 token to generate logits.
+@@ -1870,6 +1864,12 @@ struct llama_server_context
+                         slot.n_past--;
+                     }
+ 
++                    LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
++
++                    llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
++
++                    slot.cache_tokens = prompt_tokens;
++
+                     LOG_VERBOSE("prompt ingested", {
+                                                     {"n_past", slot.n_past},
+                                                     {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},

+ 11 - 0
parser/parser.go

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"io"
 	"log/slog"
+	"slices"
 )
 
 type Command struct {
@@ -56,6 +57,16 @@ func Parse(reader io.Reader) ([]Command, error) {
 			command.Args = string(bytes.TrimSpace(fields[1]))
 		case "EMBED":
 			return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
+		case "MESSAGE":
+			command.Name = string(bytes.ToLower(fields[0]))
+			fields = bytes.SplitN(fields[1], []byte(" "), 2)
+			if len(fields) < 2 {
+				return nil, fmt.Errorf("should be in the format <role> <message>")
+			}
+			if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
+				return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
+			}
+			command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
 		default:
 			if !bytes.HasPrefix(fields[0], []byte("#")) {
 				// log a warning for unknown commands

+ 35 - 0
parser/parser_test.go

@@ -61,3 +61,38 @@ PARAMETER param1
 	assert.ErrorContains(t, err, "missing value for [param1]")
 
 }
+
+func Test_Parser_Messages(t *testing.T) {
+
+	input := `
+FROM foo
+MESSAGE system You are a Parser. Always Parse things.
+MESSAGE user Hey there!
+MESSAGE assistant Hello, I want to parse all the things!
+`
+
+	reader := strings.NewReader(input)
+	commands, err := Parse(reader)
+	assert.Nil(t, err)
+
+	expectedCommands := []Command{
+		{Name: "model", Args: "foo"},
+		{Name: "message", Args: "system: You are a Parser. Always Parse things."},
+		{Name: "message", Args: "user: Hey there!"},
+		{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
+	}
+
+	assert.Equal(t, expectedCommands, commands)
+}
+
+func Test_Parser_Messages_BadRole(t *testing.T) {
+
+	input := `
+FROM foo
+MESSAGE badguy I'm a bad guy!
+`
+
+	reader := strings.NewReader(input)
+	_, err := Parse(reader)
+	assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
+}

+ 10 - 0
scripts/build_docker.sh

@@ -13,3 +13,13 @@ docker build \
     -f Dockerfile \
     -t ollama/ollama:$VERSION \
     .
+
+docker build \
+    --load \
+    --platform=linux/amd64 \
+    --build-arg=VERSION \
+    --build-arg=GOFLAGS \
+    --target runtime-rocm \
+    -f Dockerfile \
+    -t ollama/ollama:$VERSION-rocm \
+    .

+ 74 - 43
server/download.go

@@ -25,6 +25,11 @@ import (
 	"github.com/jmorganca/ollama/format"
 )
 
+const maxRetries = 6
+
+var errMaxRetriesExceeded = errors.New("max retries exceeded")
+var errPartStalled = errors.New("part stalled")
+
 var blobDownloadManager sync.Map
 
 type blobDownload struct {
@@ -44,10 +49,11 @@ type blobDownload struct {
 }
 
 type blobDownloadPart struct {
-	N         int
-	Offset    int64
-	Size      int64
-	Completed int64
+	N           int
+	Offset      int64
+	Size        int64
+	Completed   int64
+	lastUpdated time.Time
 
 	*blobDownload `json:"-"`
 }
@@ -72,6 +78,13 @@ func (p *blobDownloadPart) StopsAt() int64 {
 	return p.Offset + p.Size
 }
 
+func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
+	n = len(b)
+	p.blobDownload.Completed.Add(int64(n))
+	p.lastUpdated = time.Now()
+	return n, nil
+}
+
 func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
 	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
 	if err != nil {
@@ -157,6 +170,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
 				case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
 					// return immediately if the context is canceled or the device is out of space
 					return err
+				case errors.Is(err, errPartStalled):
+					try--
+					continue
 				case err != nil:
 					sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
 					slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
@@ -195,28 +211,54 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
 }
 
 func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
-	headers := make(http.Header)
-	headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
-	resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
-	if err != nil {
-		return err
-	}
-	defer resp.Body.Close()
+	g, ctx := errgroup.WithContext(ctx)
+	g.Go(func() error {
+		headers := make(http.Header)
+		headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
+		resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
+		if err != nil {
+			return err
+		}
+		defer resp.Body.Close()
 
-	n, err := io.Copy(w, io.TeeReader(resp.Body, b))
-	if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
-		// rollback progress
-		b.Completed.Add(-n)
-		return err
-	}
+		n, err := io.Copy(w, io.TeeReader(resp.Body, part))
+		if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
+			// rollback progress
+			b.Completed.Add(-n)
+			return err
+		}
 
-	part.Completed += n
-	if err := b.writePart(part.Name(), part); err != nil {
+		part.Completed += n
+		if err := b.writePart(part.Name(), part); err != nil {
+			return err
+		}
+
+		// return nil or context.Canceled or UnexpectedEOF (resumable)
 		return err
-	}
+	})
+
+	g.Go(func() error {
+		ticker := time.NewTicker(time.Second)
+		for {
+			select {
+			case <-ticker.C:
+				if part.Completed >= part.Size {
+					return nil
+				}
+
+				if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
+					slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
+					// reset last updated
+					part.lastUpdated = time.Time{}
+					return errPartStalled
+				}
+			case <-ctx.Done():
+				return ctx.Err()
+			}
+		}
+	})
 
-	// return nil or context.Canceled or UnexpectedEOF (resumable)
-	return err
+	return g.Wait()
 }
 
 func (b *blobDownload) newPart(offset, size int64) error {
@@ -255,12 +297,6 @@ func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error
 	return json.NewEncoder(partFile).Encode(part)
 }
 
-func (b *blobDownload) Write(p []byte) (n int, err error) {
-	n = len(p)
-	b.Completed.Add(int64(n))
-	return n, nil
-}
-
 func (b *blobDownload) acquire() {
 	b.references.Add(1)
 }
@@ -279,20 +315,19 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
 	for {
 		select {
 		case <-ticker.C:
+			fn(api.ProgressResponse{
+				Status:    fmt.Sprintf("pulling %s", b.Digest[7:19]),
+				Digest:    b.Digest,
+				Total:     b.Total,
+				Completed: b.Completed.Load(),
+			})
+
+			if b.done || b.err != nil {
+				return b.err
+			}
 		case <-ctx.Done():
 			return ctx.Err()
 		}
-
-		fn(api.ProgressResponse{
-			Status:    fmt.Sprintf("pulling %s", b.Digest[7:19]),
-			Digest:    b.Digest,
-			Total:     b.Total,
-			Completed: b.Completed.Load(),
-		})
-
-		if b.done || b.err != nil {
-			return b.err
-		}
 	}
 }
 
@@ -303,10 +338,6 @@ type downloadOpts struct {
 	fn      func(api.ProgressResponse)
 }
 
-const maxRetries = 6
-
-var errMaxRetriesExceeded = errors.New("max retries exceeded")
-
 // downloadBlob downloads a blob from the registry and stores it in the blobs directory
 func downloadBlob(ctx context.Context, opts downloadOpts) error {
 	fp, err := GetBlobsPath(opts.digest)

+ 47 - 5
server/images.go

@@ -41,7 +41,7 @@ type Model struct {
 	Config         ConfigV2
 	ShortName      string
 	ModelPath      string
-	OriginalModel  string
+	ParentModel    string
 	AdapterPaths   []string
 	ProjectorPaths []string
 	Template       string
@@ -50,6 +50,12 @@ type Model struct {
 	Digest         string
 	Size           int64
 	Options        map[string]interface{}
+	Messages       []Message
+}
+
+type Message struct {
+	Role    string `json:"role"`
+	Content string `json:"content"`
 }
 
 type PromptVars struct {
@@ -333,7 +339,7 @@ func GetModel(name string) (*Model, error) {
 		switch layer.MediaType {
 		case "application/vnd.ollama.image.model":
 			model.ModelPath = filename
-			model.OriginalModel = layer.From
+			model.ParentModel = layer.From
 		case "application/vnd.ollama.image.embed":
 			// Deprecated in versions  > 0.1.2
 			// TODO: remove this warning in a future version
@@ -374,6 +380,16 @@ func GetModel(name string) (*Model, error) {
 			if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
 				return nil, err
 			}
+		case "application/vnd.ollama.image.messages":
+			msgs, err := os.Open(filename)
+			if err != nil {
+				return nil, err
+			}
+			defer msgs.Close()
+
+			if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
+				return nil, err
+			}
 		case "application/vnd.ollama.image.license":
 			bts, err := os.ReadFile(filename)
 			if err != nil {
@@ -428,12 +444,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 	}
 
 	var layers Layers
+	messages := []string{}
 
 	params := make(map[string][]string)
 	fromParams := make(map[string]any)
 
 	for _, c := range commands {
-		slog.Info(fmt.Sprintf("[%s] - %s", c.Name, c.Args))
 		mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
 
 		switch c.Name {
@@ -607,11 +623,37 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 			}
 
 			layers.Replace(layer)
+		case "message":
+			messages = append(messages, c.Args)
 		default:
 			params[c.Name] = append(params[c.Name], c.Args)
 		}
 	}
 
+	if len(messages) > 0 {
+		fn(api.ProgressResponse{Status: "creating parameters layer"})
+
+		msgs := make([]api.Message, 0)
+
+		for _, m := range messages {
+			// todo: handle images
+			msg := strings.SplitN(m, ": ", 2)
+			msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]})
+		}
+
+		var b bytes.Buffer
+		if err := json.NewEncoder(&b).Encode(msgs); err != nil {
+			return err
+		}
+
+		layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
+		if err != nil {
+			return err
+		}
+
+		layers.Replace(layer)
+	}
+
 	if len(params) > 0 {
 		fn(api.ProgressResponse{Status: "creating parameters layer"})
 
@@ -908,8 +950,8 @@ func ShowModelfile(model *Model) (string, error) {
 	mt.Model = model
 	mt.From = model.ModelPath
 
-	if model.OriginalModel != "" {
-		mt.From = model.OriginalModel
+	if model.ParentModel != "" {
+		mt.From = model.ParentModel
 	}
 
 	modelFile := `# Modelfile generated by "ollama show"

+ 37 - 4
server/routes.go

@@ -186,7 +186,13 @@ func GenerateHandler(c *gin.Context) {
 		return
 	}
 
-	sessionDuration := defaultSessionDuration
+	var sessionDuration time.Duration
+	if req.KeepAlive == nil {
+		sessionDuration = defaultSessionDuration
+	} else {
+		sessionDuration = req.KeepAlive.Duration
+	}
+
 	if err := load(c, model, opts, sessionDuration); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
@@ -378,7 +384,14 @@ func EmbeddingHandler(c *gin.Context) {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
-	sessionDuration := defaultSessionDuration
+
+	var sessionDuration time.Duration
+	if req.KeepAlive == nil {
+		sessionDuration = defaultSessionDuration
+	} else {
+		sessionDuration = req.KeepAlive.Duration
+	}
+
 	if err := load(c, model, opts, sessionDuration); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
@@ -659,6 +672,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 	}
 
 	modelDetails := api.ModelDetails{
+		ParentModel:       model.ParentModel,
 		Format:            model.Config.ModelFormat,
 		Family:            model.Config.ModelFamily,
 		Families:          model.Config.ModelFamilies,
@@ -674,11 +688,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 		model.Template = req.Template
 	}
 
+	msgs := make([]api.Message, 0)
+	for _, msg := range model.Messages {
+		msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
+	}
+
 	resp := &api.ShowResponse{
 		License:  strings.Join(model.License, "\n"),
 		System:   model.System,
 		Template: model.Template,
 		Details:  modelDetails,
+		Messages: msgs,
 	}
 
 	var params []string
@@ -1067,7 +1087,14 @@ func ChatHandler(c *gin.Context) {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
-	sessionDuration := defaultSessionDuration
+
+	var sessionDuration time.Duration
+	if req.KeepAlive == nil {
+		sessionDuration = defaultSessionDuration
+	} else {
+		sessionDuration = req.KeepAlive.Duration
+	}
+
 	if err := load(c, model, opts, sessionDuration); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
@@ -1075,7 +1102,13 @@ func ChatHandler(c *gin.Context) {
 
 	// an empty request loads the model
 	if len(req.Messages) == 0 {
-		c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true, Message: api.Message{Role: "assistant"}})
+		resp := api.ChatResponse{
+			CreatedAt: time.Now().UTC(),
+			Model:     req.Model,
+			Done:      true,
+			Message:   api.Message{Role: "assistant"},
+		}
+		c.JSON(http.StatusOK, resp)
 		return
 	}