Browse Source

Merge remote-tracking branch 'upstream/main' into pr3702

Daniel Hiltgen 1 year ago
parent
commit
920a4b0794
100 changed files with 3810 additions and 2222 deletions
  1. 10 10
      .github/workflows/release.yaml
  2. 20 14
      .github/workflows/test.yaml
  3. 2 1
      .gitignore
  4. 7 7
      Dockerfile
  5. 41 28
      README.md
  6. 76 10
      api/client.go
  7. 43 1
      api/client_test.go
  8. 82 15
      api/types.go
  9. 57 0
      api/types_test.go
  10. 3 1
      app/lifecycle/logging.go
  11. 37 27
      app/lifecycle/server.go
  12. 1 4
      app/lifecycle/updater_windows.go
  13. 5 8
      app/ollama.iss
  14. 71 71
      app/tray/wintray/menus.go
  15. 33 3
      auth/auth.go
  16. 193 103
      cmd/cmd.go
  17. 6 1
      cmd/interactive.go
  18. 19 14
      convert/convert.go
  19. 2 13
      convert/gemma.go
  20. 2 16
      convert/llama.go
  21. 2 13
      convert/mistral.go
  22. 85 0
      convert/mixtral.go
  23. 27 14
      convert/safetensors.go
  24. 1 1
      convert/torch.go
  25. 31 31
      docs/api.md
  26. 2 2
      docs/development.md
  27. 16 6
      docs/faq.md
  28. 3 1
      docs/import.md
  29. 1 1
      docs/linux.md
  30. 17 25
      docs/modelfile.md
  31. 5 5
      docs/openai.md
  32. 3 3
      docs/tutorials/langchainjs.md
  33. 7 5
      docs/tutorials/langchainpy.md
  34. 6 29
      docs/tutorials/nvidia-jetson.md
  35. 61 47
      docs/windows.md
  36. 1 1
      examples/bash-comparemodels/README.md
  37. 1 0
      examples/flyio/.gitignore
  38. 67 0
      examples/flyio/README.md
  39. 1 1
      examples/go-chat/main.go
  40. 1 1
      examples/go-http-generate/main.go
  41. 14 12
      examples/kubernetes/README.md
  42. 5 5
      examples/langchain-python-rag-document/main.py
  43. 4 4
      examples/langchain-python-rag-websummary/main.py
  44. 2 3
      examples/langchain-python-simple/README.md
  45. 1 1
      examples/langchain-python-simple/main.py
  46. 1 1
      examples/modelfile-mario/Modelfile
  47. 3 3
      examples/modelfile-mario/readme.md
  48. 7 7
      examples/python-json-datagenerator/predefinedschema.py
  49. 1 1
      examples/python-json-datagenerator/randomaddresses.py
  50. 2 2
      examples/python-json-datagenerator/readme.md
  51. 1 1
      examples/python-simplechat/client.py
  52. 2 2
      examples/python-simplechat/readme.md
  53. 2 2
      examples/typescript-mentors/README.md
  54. 2 2
      examples/typescript-mentors/character-generator.ts
  55. 2 2
      examples/typescript-simplechat/client.ts
  56. 3 0
      format/bytes.go
  57. 14 6
      format/format.go
  58. 34 0
      format/format_test.go
  59. 58 14
      gpu/amd_common.go
  60. 2 2
      gpu/amd_hip_windows.go
  61. 176 287
      gpu/amd_linux.go
  62. 85 76
      gpu/amd_windows.go
  63. 15 4
      gpu/assets.go
  64. 22 0
      gpu/cuda_common.go
  65. 154 142
      gpu/gpu.go
  66. 28 33
      gpu/gpu_darwin.go
  67. 9 4
      gpu/gpu_info.h
  68. 6 2
      gpu/gpu_info_cpu.c
  69. 63 82
      gpu/gpu_info_cudart.c
  70. 97 11
      gpu/gpu_info_cudart.h
  71. 203 0
      gpu/gpu_info_nvcuda.c
  72. 71 0
      gpu/gpu_info_nvcuda.h
  73. 0 221
      gpu/gpu_info_nvml.c
  74. 0 57
      gpu/gpu_info_nvml.h
  75. 6 13
      gpu/gpu_test.go
  76. 43 6
      gpu/types.go
  77. 44 2
      integration/basic_test.go
  78. 225 0
      integration/concurrency_test.go
  79. 1 2
      integration/context_test.go
  80. 3 3
      integration/llm_image_test.go
  81. 1 23
      integration/llm_test.go
  82. 117 0
      integration/max_queue_test.go
  83. 172 99
      integration/utils_test.go
  84. 9 8
      llm/ext_server/server.cpp
  85. 140 0
      llm/filetype.go
  86. 1 1
      llm/generate/gen_common.sh
  87. 30 16
      llm/generate/gen_linux.sh
  88. 190 112
      llm/generate/gen_windows.ps1
  89. 28 79
      llm/ggml.go
  90. 10 6
      llm/gguf.go
  91. 1 1
      llm/llama.cpp
  92. 5 52
      llm/llm.go
  93. 1 1
      llm/llm_windows.go
  94. 185 0
      llm/memory.go
  95. 12 0
      llm/patches/02-clip-log.diff
  96. 45 0
      llm/patches/04-metal.diff
  97. 24 0
      llm/patches/05-clip-fix.diff
  98. 27 7
      llm/payload.go
  99. 352 268
      llm/server.go
  100. 1 1
      macapp/src/app.tsx

+ 10 - 10
.github/workflows/release.yaml

@@ -103,6 +103,7 @@ jobs:
           path: |
           path: |
             llm/build/**/bin/*
             llm/build/**/bin/*
             llm/build/**/*.a
             llm/build/**/*.a
+            dist/windows-amd64/**
 
 
   # ROCm generation step
   # ROCm generation step
   generate-windows-rocm:
   generate-windows-rocm:
@@ -173,7 +174,9 @@ jobs:
       - uses: actions/upload-artifact@v4
       - uses: actions/upload-artifact@v4
         with:
         with:
           name: generate-windows-rocm
           name: generate-windows-rocm
-          path: llm/build/**/bin/*
+          path: |
+            llm/build/**/bin/*
+            dist/windows-amd64/**
       - uses: actions/upload-artifact@v4
       - uses: actions/upload-artifact@v4
         with:
         with:
           name: windows-rocm-deps
           name: windows-rocm-deps
@@ -253,7 +256,9 @@ jobs:
       - uses: actions/upload-artifact@v4
       - uses: actions/upload-artifact@v4
         with:
         with:
           name: generate-windows-cuda
           name: generate-windows-cuda
-          path: llm/build/**/bin/*
+          path: |
+            llm/build/**/bin/*
+            dist/windows-amd64/**
       - uses: actions/upload-artifact@v4
       - uses: actions/upload-artifact@v4
         with:
         with:
           name: windows-cuda-deps
           name: windows-cuda-deps
@@ -306,23 +311,18 @@ jobs:
       - uses: actions/download-artifact@v4
       - uses: actions/download-artifact@v4
         with:
         with:
           name: generate-windows-cpu
           name: generate-windows-cpu
-          path: llm/build
       - uses: actions/download-artifact@v4
       - uses: actions/download-artifact@v4
         with:
         with:
           name: generate-windows-cuda
           name: generate-windows-cuda
-          path: llm/build
       - uses: actions/download-artifact@v4
       - uses: actions/download-artifact@v4
         with:
         with:
           name: windows-cuda-deps
           name: windows-cuda-deps
-          path: dist/deps
       - uses: actions/download-artifact@v4
       - uses: actions/download-artifact@v4
         with:
         with:
           name: windows-rocm-deps
           name: windows-rocm-deps
-          path: dist/deps
       - uses: actions/download-artifact@v4
       - uses: actions/download-artifact@v4
         with:
         with:
           name: generate-windows-rocm
           name: generate-windows-rocm
-          path: llm/build
       - run: dir llm/build
       - run: dir llm/build
       - run: |
       - run: |
           $gopath=(get-command go).source | split-path -parent
           $gopath=(get-command go).source | split-path -parent
@@ -331,13 +331,13 @@ jobs:
           $env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
           $env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
           $env:PATH="$gopath;$env:PATH"
           $env:PATH="$gopath;$env:PATH"
           $env:OLLAMA_SKIP_GENERATE="1"
           $env:OLLAMA_SKIP_GENERATE="1"
-          $env:NVIDIA_DIR=$(resolve-path ".\dist\deps")
-          $env:HIP_PATH=$(resolve-path ".\dist\deps")
           & .\scripts\build_windows.ps1
           & .\scripts\build_windows.ps1
       - uses: actions/upload-artifact@v4
       - uses: actions/upload-artifact@v4
         with:
         with:
           name: dist-windows
           name: dist-windows
-          path: dist/*.exe
+          path: |
+            dist/OllamaSetup.exe
+            dist/ollama-windows-*.zip
 
 
   # Linux x86 assets built using the container based build
   # Linux x86 assets built using the container based build
   build-linux-amd64:
   build-linux-amd64:

+ 20 - 14
.github/workflows/test.yaml

@@ -1,5 +1,15 @@
 name: test
 name: test
 
 
+concurrency:
+  # For PRs, later CI runs preempt previous ones. e.g. a force push on a PR
+  # cancels running CI jobs and starts all new ones.
+  #
+  # For non-PR pushes, concurrency.group needs to be unique for every distinct
+  # CI run we want to have happen. Use run_id, which in practice means all
+  # non-PR CI runs will be allowed to run without preempting each other.
+  group: ${{ github.workflow }}-$${{ github.pull_request.number || github.run_id }}
+  cancel-in-progress: true
+
 on:
 on:
   pull_request:
   pull_request:
     paths:
     paths:
@@ -21,7 +31,9 @@ jobs:
       - id: changes
       - id: changes
         run: |
         run: |
           changed() {
           changed() {
-            git diff-tree -r --no-commit-id --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} \
+            git diff-tree -r --no-commit-id --name-only \
+              $(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \
+              ${{ github.event.pull_request.head.sha }} \
               | xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))"
               | xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))"
           }
           }
 
 
@@ -103,7 +115,9 @@ jobs:
       - uses: actions/upload-artifact@v4
       - uses: actions/upload-artifact@v4
         with:
         with:
           name: cuda-${{ matrix.cuda-version }}-libraries
           name: cuda-${{ matrix.cuda-version }}-libraries
-          path: llm/build/**/bin/*
+          path: |
+            llm/build/**/bin/*
+            dist/windows-amd64/**
   generate-rocm:
   generate-rocm:
     needs: [changes]
     needs: [changes]
     if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }}
     if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }}
@@ -134,7 +148,9 @@ jobs:
       - uses: actions/upload-artifact@v4
       - uses: actions/upload-artifact@v4
         with:
         with:
           name: rocm-${{ matrix.rocm-version }}-libraries
           name: rocm-${{ matrix.rocm-version }}-libraries
-          path: llm/build/**/bin/*
+          path: |
+            llm/build/**/bin/*
+            dist/windows-amd64/**
 
 
   # ROCm generation step
   # ROCm generation step
   generate-windows-rocm:
   generate-windows-rocm:
@@ -253,14 +269,9 @@ jobs:
           mkdir -p llm/build/darwin/$ARCH/stub/bin
           mkdir -p llm/build/darwin/$ARCH/stub/bin
           touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
           touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
         if: ${{ startsWith(matrix.os, 'macos-') }}
         if: ${{ startsWith(matrix.os, 'macos-') }}
-      - run: |
-          mkdir -p llm/build/windows/$ARCH/stub/bin
-          touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
-        if: ${{ startsWith(matrix.os, 'windows-') }}
-        shell: bash
       - uses: golangci/golangci-lint-action@v4
       - uses: golangci/golangci-lint-action@v4
         with:
         with:
-          args: --timeout 8m0s
+          args: --timeout 8m0s -v
   test:
   test:
     strategy:
     strategy:
       matrix:
       matrix:
@@ -284,7 +295,6 @@ jobs:
         with:
         with:
           go-version-file: go.mod
           go-version-file: go.mod
           cache: true
           cache: true
-      - run: go get
       - run: |
       - run: |
           case ${{ matrix.arch }} in
           case ${{ matrix.arch }} in
             amd64) echo ARCH=x86_64 ;;
             amd64) echo ARCH=x86_64 ;;
@@ -299,10 +309,6 @@ jobs:
           mkdir -p llm/build/darwin/$ARCH/stub/bin
           mkdir -p llm/build/darwin/$ARCH/stub/bin
           touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
           touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
         if: ${{ startsWith(matrix.os, 'macos-') }}
         if: ${{ startsWith(matrix.os, 'macos-') }}
-      - run: |
-          mkdir -p llm/build/windows/$ARCH/stub/bin
-          touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
-        if: ${{ startsWith(matrix.os, 'windows-') }}
         shell: bash
         shell: bash
       - run: go generate ./...
       - run: go generate ./...
       - run: go build
       - run: go build

+ 2 - 1
.gitignore

@@ -11,4 +11,5 @@ ggml-metal.metal
 .idea
 .idea
 test_data
 test_data
 *.crt
 *.crt
-llm/build
+llm/build
+__debug_bin*

+ 7 - 7
Dockerfile

@@ -18,7 +18,7 @@ ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
 COPY --from=llm-code / /go/src/github.com/ollama/ollama/
 COPY --from=llm-code / /go/src/github.com/ollama/ollama/
 WORKDIR /go/src/github.com/ollama/ollama/llm/generate
 WORKDIR /go/src/github.com/ollama/ollama/llm/generate
 ARG CGO_CFLAGS
 ARG CGO_CFLAGS
-RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
+RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
 
 
 FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
 FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
 ARG CMAKE_VERSION
 ARG CMAKE_VERSION
@@ -28,7 +28,7 @@ ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
 COPY --from=llm-code / /go/src/github.com/ollama/ollama/
 COPY --from=llm-code / /go/src/github.com/ollama/ollama/
 WORKDIR /go/src/github.com/ollama/ollama/llm/generate
 WORKDIR /go/src/github.com/ollama/ollama/llm/generate
 ARG CGO_CFLAGS
 ARG CGO_CFLAGS
-RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
+RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
 
 
 FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS rocm-build-amd64
 FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS rocm-build-amd64
 ARG CMAKE_VERSION
 ARG CMAKE_VERSION
@@ -40,7 +40,7 @@ COPY --from=llm-code / /go/src/github.com/ollama/ollama/
 WORKDIR /go/src/github.com/ollama/ollama/llm/generate
 WORKDIR /go/src/github.com/ollama/ollama/llm/generate
 ARG CGO_CFLAGS
 ARG CGO_CFLAGS
 ARG AMDGPU_TARGETS
 ARG AMDGPU_TARGETS
-RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
+RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
 RUN mkdir /tmp/scratch && \
 RUN mkdir /tmp/scratch && \
     for dep in $(zcat /go/src/github.com/ollama/ollama/llm/build/linux/x86_64/rocm*/bin/deps.txt.gz) ; do \
     for dep in $(zcat /go/src/github.com/ollama/ollama/llm/build/linux/x86_64/rocm*/bin/deps.txt.gz) ; do \
         cp ${dep} /tmp/scratch/ || exit 1 ; \
         cp ${dep} /tmp/scratch/ || exit 1 ; \
@@ -64,11 +64,11 @@ WORKDIR /go/src/github.com/ollama/ollama/llm/generate
 FROM --platform=linux/amd64 cpu-builder-amd64 AS static-build-amd64
 FROM --platform=linux/amd64 cpu-builder-amd64 AS static-build-amd64
 RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
 RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
-RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
+RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64
 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64
-RUN OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
+RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
-RUN OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
+RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
 
 
 FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
 FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
 ARG CMAKE_VERSION
 ARG CMAKE_VERSION
@@ -84,7 +84,7 @@ WORKDIR /go/src/github.com/ollama/ollama/llm/generate
 FROM --platform=linux/arm64 cpu-builder-arm64 AS static-build-arm64
 FROM --platform=linux/arm64 cpu-builder-arm64 AS static-build-arm64
 RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
 RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
 FROM --platform=linux/arm64 cpu-builder-arm64 AS cpu-build-arm64
 FROM --platform=linux/arm64 cpu-builder-arm64 AS cpu-build-arm64
-RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
+RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
 
 
 
 
 # Intermediate stage used for ./scripts/build_linux.sh
 # Intermediate stage used for ./scripts/build_linux.sh

+ 41 - 28
README.md

@@ -1,5 +1,5 @@
 <div align="center">
 <div align="center">
-  <img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
+ <img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
 </div>
 </div>
 
 
 # Ollama
 # Ollama
@@ -35,10 +35,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
 
 
 ## Quickstart
 ## Quickstart
 
 
-To run and chat with [Llama 2](https://ollama.com/library/llama2):
+To run and chat with [Llama 3](https://ollama.com/library/llama3):
 
 
 ```
 ```
-ollama run llama2
+ollama run llama3
 ```
 ```
 
 
 ## Model library
 ## Model library
@@ -49,17 +49,14 @@ Here are some example models that can be downloaded:
 
 
 | Model              | Parameters | Size  | Download                       |
 | Model              | Parameters | Size  | Download                       |
 | ------------------ | ---------- | ----- | ------------------------------ |
 | ------------------ | ---------- | ----- | ------------------------------ |
-| Llama 2            | 7B         | 3.8GB | `ollama run llama2`            |
+| Llama 3            | 8B         | 4.7GB | `ollama run llama3`            |
+| Llama 3            | 70B        | 40GB  | `ollama run llama3:70b`        |
+| Phi-3              | 3.8B       | 2.3GB | `ollama run phi3`              |
 | Mistral            | 7B         | 4.1GB | `ollama run mistral`           |
 | Mistral            | 7B         | 4.1GB | `ollama run mistral`           |
-| Dolphin Phi        | 2.7B       | 1.6GB | `ollama run dolphin-phi`       |
-| Phi-2              | 2.7B       | 1.7GB | `ollama run phi`               |
 | Neural Chat        | 7B         | 4.1GB | `ollama run neural-chat`       |
 | Neural Chat        | 7B         | 4.1GB | `ollama run neural-chat`       |
 | Starling           | 7B         | 4.1GB | `ollama run starling-lm`       |
 | Starling           | 7B         | 4.1GB | `ollama run starling-lm`       |
 | Code Llama         | 7B         | 3.8GB | `ollama run codellama`         |
 | Code Llama         | 7B         | 3.8GB | `ollama run codellama`         |
 | Llama 2 Uncensored | 7B         | 3.8GB | `ollama run llama2-uncensored` |
 | Llama 2 Uncensored | 7B         | 3.8GB | `ollama run llama2-uncensored` |
-| Llama 2 13B        | 13B        | 7.3GB | `ollama run llama2:13b`        |
-| Llama 2 70B        | 70B        | 39GB  | `ollama run llama2:70b`        |
-| Orca Mini          | 3B         | 1.9GB | `ollama run orca-mini`         |
 | LLaVA              | 7B         | 4.5GB | `ollama run llava`             |
 | LLaVA              | 7B         | 4.5GB | `ollama run llava`             |
 | Gemma              | 2B         | 1.4GB | `ollama run gemma:2b`          |
 | Gemma              | 2B         | 1.4GB | `ollama run gemma:2b`          |
 | Gemma              | 7B         | 4.8GB | `ollama run gemma:7b`          |
 | Gemma              | 7B         | 4.8GB | `ollama run gemma:7b`          |
@@ -97,16 +94,16 @@ See the [guide](docs/import.md) on importing models for more information.
 
 
 ### Customize a prompt
 ### Customize a prompt
 
 
-Models from the Ollama library can be customized with a prompt. For example, to customize the `llama2` model:
+Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3` model:
 
 
 ```
 ```
-ollama pull llama2
+ollama pull llama3
 ```
 ```
 
 
 Create a `Modelfile`:
 Create a `Modelfile`:
 
 
 ```
 ```
-FROM llama2
+FROM llama3
 
 
 # set the temperature to 1 [higher is more creative, lower is more coherent]
 # set the temperature to 1 [higher is more creative, lower is more coherent]
 PARAMETER temperature 1
 PARAMETER temperature 1
@@ -141,7 +138,7 @@ ollama create mymodel -f ./Modelfile
 ### Pull a model
 ### Pull a model
 
 
 ```
 ```
-ollama pull llama2
+ollama pull llama3
 ```
 ```
 
 
 > This command can also be used to update a local model. Only the diff will be pulled.
 > This command can also be used to update a local model. Only the diff will be pulled.
@@ -149,13 +146,13 @@ ollama pull llama2
 ### Remove a model
 ### Remove a model
 
 
 ```
 ```
-ollama rm llama2
+ollama rm llama3
 ```
 ```
 
 
 ### Copy a model
 ### Copy a model
 
 
 ```
 ```
-ollama cp llama2 my-llama2
+ollama cp llama3 my-model
 ```
 ```
 
 
 ### Multiline input
 ### Multiline input
@@ -176,10 +173,10 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol
 The image features a yellow smiley face, which is likely the central focus of the picture.
 The image features a yellow smiley face, which is likely the central focus of the picture.
 ```
 ```
 
 
-### Pass in prompt as arguments
+### Pass the prompt as an argument
 
 
 ```
 ```
-$ ollama run llama2 "Summarize this file: $(cat README.md)"
+$ ollama run llama3 "Summarize this file: $(cat README.md)"
  Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
  Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
 ```
 ```
 
 
@@ -226,7 +223,7 @@ Next, start the server:
 Finally, in a separate shell, run a model:
 Finally, in a separate shell, run a model:
 
 
 ```
 ```
-./ollama run llama2
+./ollama run llama3
 ```
 ```
 
 
 ## REST API
 ## REST API
@@ -237,7 +234,7 @@ Ollama has a REST API for running and managing models.
 
 
 ```
 ```
 curl http://localhost:11434/api/generate -d '{
 curl http://localhost:11434/api/generate -d '{
-  "model": "llama2",
+  "model": "llama3",
   "prompt":"Why is the sky blue?"
   "prompt":"Why is the sky blue?"
 }'
 }'
 ```
 ```
@@ -246,7 +243,7 @@ curl http://localhost:11434/api/generate -d '{
 
 
 ```
 ```
 curl http://localhost:11434/api/chat -d '{
 curl http://localhost:11434/api/chat -d '{
-  "model": "mistral",
+  "model": "llama3",
   "messages": [
   "messages": [
     { "role": "user", "content": "why is the sky blue?" }
     { "role": "user", "content": "why is the sky blue?" }
   ]
   ]
@@ -259,16 +256,18 @@ See the [API documentation](./docs/api.md) for all endpoints.
 
 
 ### Web & Desktop
 ### Web & Desktop
 
 
+- [Open WebUI](https://github.com/open-webui/open-webui)
+- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
+- [Hollama](https://github.com/fmaclen/hollama)
 - [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
 - [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
 - [LibreChat](https://github.com/danny-avila/LibreChat)
 - [LibreChat](https://github.com/danny-avila/LibreChat)
 - [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
 - [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
-- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
 - [HTML UI](https://github.com/rtcfirefly/ollama-ui)
 - [HTML UI](https://github.com/rtcfirefly/ollama-ui)
 - [Saddle](https://github.com/jikkuatwork/saddle)
 - [Saddle](https://github.com/jikkuatwork/saddle)
 - [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
 - [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
+- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
 - [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
 - [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
 - [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
 - [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
-- [Open WebUI](https://github.com/open-webui/open-webui)
 - [Ollamac](https://github.com/kevinhermawan/Ollamac)
 - [Ollamac](https://github.com/kevinhermawan/Ollamac)
 - [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
 - [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
 - [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
 - [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
@@ -286,13 +285,20 @@ See the [API documentation](./docs/api.md) for all endpoints.
 - [OllamaGUI](https://github.com/enoch1118/ollamaGUI)
 - [OllamaGUI](https://github.com/enoch1118/ollamaGUI)
 - [OpenAOE](https://github.com/InternLM/OpenAOE)
 - [OpenAOE](https://github.com/InternLM/OpenAOE)
 - [Odin Runes](https://github.com/leonid20000/OdinRunes)
 - [Odin Runes](https://github.com/leonid20000/OdinRunes)
-- [LLM-X: Progressive Web App](https://github.com/mrdjohnson/llm-x)
+- [LLM-X](https://github.com/mrdjohnson/llm-x) (Progressive Web App)
 - [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
 - [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
 - [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
 - [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
 - [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
 - [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
-- [ChatOllama: Open Source Chatbot based on Ollama with Knowledge Bases](https://github.com/sugarforever/chat-ollama)
-- [CRAG Ollama Chat: Simple Web Search with Corrective RAG](https://github.com/Nagi-ovo/CRAG-Ollama-Chat)
-- [RAGFlow: Open-source Retrieval-Augmented Generation engine based on deep document understanding](https://github.com/infiniflow/ragflow)
+- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Chat with Code Repository)
+- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
+- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
+- [RAGFlow](https://github.com/infiniflow/ragflow) (Open-source Retrieval-Augmented Generation engine based on deep document understanding)
+- [StreamDeploy](https://github.com/StreamDeploy-DevRel/streamdeploy-llm-app-scaffold) (LLM Application Scaffold)
+- [chat](https://github.com/swuecho/chat) (chat web app for teams)
+- [Lobe Chat](https://github.com/lobehub/lobe-chat) with [Integrating Doc](https://lobehub.com/docs/self-hosting/examples/ollama)
+- [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG)
+- [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation)
+- [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends)
 
 
 ### Terminal
 ### Terminal
 
 
@@ -308,11 +314,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
 - [Oatmeal](https://github.com/dustinblackman/oatmeal)
 - [Oatmeal](https://github.com/dustinblackman/oatmeal)
 - [cmdh](https://github.com/pgibler/cmdh)
 - [cmdh](https://github.com/pgibler/cmdh)
 - [ooo](https://github.com/npahlfer/ooo)
 - [ooo](https://github.com/npahlfer/ooo)
+- [shell-pilot](https://github.com/reid41/shell-pilot)
 - [tenere](https://github.com/pythops/tenere)
 - [tenere](https://github.com/pythops/tenere)
 - [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/).
 - [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/).
 - [typechat-cli](https://github.com/anaisbetts/typechat-cli)
 - [typechat-cli](https://github.com/anaisbetts/typechat-cli)
 - [ShellOracle](https://github.com/djcopley/ShellOracle)
 - [ShellOracle](https://github.com/djcopley/ShellOracle)
 - [tlm](https://github.com/yusufcanb/tlm)
 - [tlm](https://github.com/yusufcanb/tlm)
+- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
 
 
 ### Database
 ### Database
 
 
@@ -344,9 +352,11 @@ See the [API documentation](./docs/api.md) for all endpoints.
 - [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md)
 - [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md)
 - [Elixir LangChain](https://github.com/brainlid/langchain)
 - [Elixir LangChain](https://github.com/brainlid/langchain)
 - [Ollama for R - rollama](https://github.com/JBGruber/rollama)
 - [Ollama for R - rollama](https://github.com/JBGruber/rollama)
+- [Ollama for R - ollama-r](https://github.com/hauselin/ollama-r)
 - [Ollama-ex for Elixir](https://github.com/lebrunel/ollama-ex)
 - [Ollama-ex for Elixir](https://github.com/lebrunel/ollama-ex)
 - [Ollama Connector for SAP ABAP](https://github.com/b-tocs/abap_btocs_ollama)
 - [Ollama Connector for SAP ABAP](https://github.com/b-tocs/abap_btocs_ollama)
 - [Testcontainers](https://testcontainers.com/modules/ollama/)
 - [Testcontainers](https://testcontainers.com/modules/ollama/)
+- [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama)
 
 
 ### Mobile
 ### Mobile
 
 
@@ -366,17 +376,20 @@ See the [API documentation](./docs/api.md) for all endpoints.
 - [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram)
 - [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram)
 - [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation)
 - [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation)
 - [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
 - [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
-- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
 - [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot)
 - [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot)
 - [Cliobot](https://github.com/herval/cliobot) (Telegram bot with Ollama support)
 - [Cliobot](https://github.com/herval/cliobot) (Telegram bot with Ollama support)
 - [Copilot for Obsidian plugin](https://github.com/logancyang/obsidian-copilot)
 - [Copilot for Obsidian plugin](https://github.com/logancyang/obsidian-copilot)
 - [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
 - [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
 - [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
 - [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
+- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
+- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot)
 - [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
 - [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
 - [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)
 - [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)
 - [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
 - [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
 - [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend)
 - [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend)
 - [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
 - [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
+- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
 
 
 ### Supported backends 
 ### Supported backends 
-- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov. 
+- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
+

+ 76 - 10
api/client.go

@@ -1,9 +1,16 @@
 // Package api implements the client-side API for code wishing to interact
 // Package api implements the client-side API for code wishing to interact
 // with the ollama service. The methods of the [Client] type correspond to
 // with the ollama service. The methods of the [Client] type correspond to
-// the ollama REST API as described in https://github.com/ollama/ollama/blob/main/docs/api.md
-//
+// the ollama REST API as described in [the API documentation].
 // The ollama command-line client itself uses this package to interact with
 // The ollama command-line client itself uses this package to interact with
 // the backend service.
 // the backend service.
+//
+// # Examples
+//
+// Several examples of using this package are available [in the GitHub
+// repository].
+//
+// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
+// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
 package api
 package api
 
 
 import (
 import (
@@ -18,6 +25,7 @@ import (
 	"net/url"
 	"net/url"
 	"os"
 	"os"
 	"runtime"
 	"runtime"
+	"strconv"
 	"strings"
 	"strings"
 
 
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/format"
@@ -57,12 +65,36 @@ func checkError(resp *http.Response, body []byte) error {
 // If the variable is not specified, a default ollama host and port will be
 // If the variable is not specified, a default ollama host and port will be
 // used.
 // used.
 func ClientFromEnvironment() (*Client, error) {
 func ClientFromEnvironment() (*Client, error) {
+	ollamaHost, err := GetOllamaHost()
+	if err != nil {
+		return nil, err
+	}
+
+	return &Client{
+		base: &url.URL{
+			Scheme: ollamaHost.Scheme,
+			Host:   net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
+		},
+		http: http.DefaultClient,
+	}, nil
+}
+
+type OllamaHost struct {
+	Scheme string
+	Host   string
+	Port   string
+}
+
+func GetOllamaHost() (OllamaHost, error) {
 	defaultPort := "11434"
 	defaultPort := "11434"
 
 
-	scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
+	hostVar := os.Getenv("OLLAMA_HOST")
+	hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
+
+	scheme, hostport, ok := strings.Cut(hostVar, "://")
 	switch {
 	switch {
 	case !ok:
 	case !ok:
-		scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
+		scheme, hostport = "http", hostVar
 	case scheme == "http":
 	case scheme == "http":
 		defaultPort = "80"
 		defaultPort = "80"
 	case scheme == "https":
 	case scheme == "https":
@@ -82,15 +114,24 @@ func ClientFromEnvironment() (*Client, error) {
 		}
 		}
 	}
 	}
 
 
-	return &Client{
-		base: &url.URL{
-			Scheme: scheme,
-			Host:   net.JoinHostPort(host, port),
-		},
-		http: http.DefaultClient,
+	if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
+		return OllamaHost{}, ErrInvalidHostPort
+	}
+
+	return OllamaHost{
+		Scheme: scheme,
+		Host:   host,
+		Port:   port,
 	}, nil
 	}, nil
 }
 }
 
 
+func NewClient(base *url.URL, http *http.Client) *Client {
+	return &Client{
+		base: base,
+		http: http,
+	}
+}
+
 func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
 func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
 	var reqBody io.Reader
 	var reqBody io.Reader
 	var data []byte
 	var data []byte
@@ -265,8 +306,14 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc
 	})
 	})
 }
 }
 
 
+// PushProgressFunc is a function that [Client.Push] invokes when progress is
+// made.
+// It's similar to other progress function types like [PullProgressFunc].
 type PushProgressFunc func(ProgressResponse) error
 type PushProgressFunc func(ProgressResponse) error
 
 
+// Push uploads a model to the model library; requires registering for ollama.ai
+// and adding a public key first. fn is called each time progress is made on
+// the request and can be used to display a progress bar, etc.
 func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
 func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
 	return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
 	return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
 		var resp ProgressResponse
 		var resp ProgressResponse
@@ -278,8 +325,15 @@ func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc
 	})
 	})
 }
 }
 
 
+// CreateProgressFunc is a function that [Client.Create] invokes when progress
+// is made.
+// It's similar to other progress function types like [PullProgressFunc].
 type CreateProgressFunc func(ProgressResponse) error
 type CreateProgressFunc func(ProgressResponse) error
 
 
+// Create creates a model from a [Modelfile]. fn is a progress function that
+// behaves similarly to other methods (see [Client.Pull]).
+//
+// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
 func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
 func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
 	return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
 	return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
 		var resp ProgressResponse
 		var resp ProgressResponse
@@ -291,6 +345,7 @@ func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgre
 	})
 	})
 }
 }
 
 
+// List lists models that are available locally.
 func (c *Client) List(ctx context.Context) (*ListResponse, error) {
 func (c *Client) List(ctx context.Context) (*ListResponse, error) {
 	var lr ListResponse
 	var lr ListResponse
 	if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
 	if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
@@ -299,6 +354,8 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
 	return &lr, nil
 	return &lr, nil
 }
 }
 
 
+// Copy copies a model - creating a model with another name from an existing
+// model.
 func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
 func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
 	if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
 	if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
 		return err
 		return err
@@ -306,6 +363,7 @@ func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
 	return nil
 	return nil
 }
 }
 
 
+// Delete deletes a model and its data.
 func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
 func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
 	if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
 	if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
 		return err
 		return err
@@ -313,6 +371,7 @@ func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
 	return nil
 	return nil
 }
 }
 
 
+// Show obtains model information, including details, modelfile, license etc.
 func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) {
 func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) {
 	var resp ShowResponse
 	var resp ShowResponse
 	if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil {
 	if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil {
@@ -321,12 +380,16 @@ func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, err
 	return &resp, nil
 	return &resp, nil
 }
 }
 
 
+// Hearbeat checks if the server has started and is responsive; if yes, it
+// returns nil, otherwise an error.
 func (c *Client) Heartbeat(ctx context.Context) error {
 func (c *Client) Heartbeat(ctx context.Context) error {
 	if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
 	if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
 		return err
 		return err
 	}
 	}
 	return nil
 	return nil
 }
 }
+
+// Embeddings generates embeddings from a model.
 func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
 func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
 	var resp EmbeddingResponse
 	var resp EmbeddingResponse
 	if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
 	if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
@@ -335,10 +398,13 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
 	return &resp, nil
 	return &resp, nil
 }
 }
 
 
+// CreateBlob creates a blob from a file on the server. digest is the
+// expected SHA256 digest of the file, and r represents the file.
 func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
 func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
 	return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
 	return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
 }
 }
 
 
+// Version returns the Ollama server version as a string.
 func (c *Client) Version(ctx context.Context) (string, error) {
 func (c *Client) Version(ctx context.Context) (string, error) {
 	var version struct {
 	var version struct {
 		Version string `json:"version"`
 		Version string `json:"version"`

+ 43 - 1
api/client_test.go

@@ -1,6 +1,12 @@
 package api
 package api
 
 
-import "testing"
+import (
+	"fmt"
+	"net"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
 
 
 func TestClientFromEnvironment(t *testing.T) {
 func TestClientFromEnvironment(t *testing.T) {
 	type testCase struct {
 	type testCase struct {
@@ -40,4 +46,40 @@ func TestClientFromEnvironment(t *testing.T) {
 			}
 			}
 		})
 		})
 	}
 	}
+
+	hostTestCases := map[string]*testCase{
+		"empty":               {value: "", expect: "127.0.0.1:11434"},
+		"only address":        {value: "1.2.3.4", expect: "1.2.3.4:11434"},
+		"only port":           {value: ":1234", expect: ":1234"},
+		"address and port":    {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
+		"hostname":            {value: "example.com", expect: "example.com:11434"},
+		"hostname and port":   {value: "example.com:1234", expect: "example.com:1234"},
+		"zero port":           {value: ":0", expect: ":0"},
+		"too large port":      {value: ":66000", err: ErrInvalidHostPort},
+		"too small port":      {value: ":-1", err: ErrInvalidHostPort},
+		"ipv6 localhost":      {value: "[::1]", expect: "[::1]:11434"},
+		"ipv6 world open":     {value: "[::]", expect: "[::]:11434"},
+		"ipv6 no brackets":    {value: "::1", expect: "[::1]:11434"},
+		"ipv6 + port":         {value: "[::1]:1337", expect: "[::1]:1337"},
+		"extra space":         {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
+		"extra quotes":        {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
+		"extra space+quotes":  {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
+		"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
+	}
+
+	for k, v := range hostTestCases {
+		t.Run(k, func(t *testing.T) {
+			t.Setenv("OLLAMA_HOST", v.value)
+
+			oh, err := GetOllamaHost()
+			if err != v.err {
+				t.Fatalf("expected %s, got %s", v.err, err)
+			}
+
+			if err == nil {
+				host := net.JoinHostPort(oh.Host, oh.Port)
+				assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
+			}
+		})
+	}
 }
 }

+ 82 - 15
api/types.go

@@ -2,6 +2,7 @@ package api
 
 
 import (
 import (
 	"encoding/json"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"fmt"
 	"math"
 	"math"
 	"os"
 	"os"
@@ -11,6 +12,7 @@ import (
 	"time"
 	"time"
 )
 )
 
 
+// StatusError is an error with and HTTP status code.
 type StatusError struct {
 type StatusError struct {
 	StatusCode   int
 	StatusCode   int
 	Status       string
 	Status       string
@@ -31,6 +33,7 @@ func (e StatusError) Error() string {
 	}
 	}
 }
 }
 
 
+// ImageData represents the raw binary data of an image file.
 type ImageData []byte
 type ImageData []byte
 
 
 // GenerateRequest describes a request sent by [Client.Generate]. While you
 // GenerateRequest describes a request sent by [Client.Generate]. While you
@@ -76,22 +79,39 @@ type GenerateRequest struct {
 	Options map[string]interface{} `json:"options"`
 	Options map[string]interface{} `json:"options"`
 }
 }
 
 
+// ChatRequest describes a request sent by [Client.Chat].
 type ChatRequest struct {
 type ChatRequest struct {
-	Model     string    `json:"model"`
-	Messages  []Message `json:"messages"`
-	Stream    *bool     `json:"stream,omitempty"`
-	Format    string    `json:"format"`
+	// Model is the model name, as in [GenerateRequest].
+	Model string `json:"model"`
+
+	// Messages is the messages of the chat - can be used to keep a chat memory.
+	Messages []Message `json:"messages"`
+
+	// Stream enable streaming of returned response; true by default.
+	Stream *bool `json:"stream,omitempty"`
+
+	// Format is the format to return the response in (e.g. "json").
+	Format string `json:"format"`
+
+	// KeepAlive controls how long the model will stay loaded into memory
+	// followin the request.
 	KeepAlive *Duration `json:"keep_alive,omitempty"`
 	KeepAlive *Duration `json:"keep_alive,omitempty"`
 
 
+	// Options lists model-specific options.
 	Options map[string]interface{} `json:"options"`
 	Options map[string]interface{} `json:"options"`
 }
 }
 
 
+// Message is a single message in a chat sequence. The message contains the
+// role ("system", "user", or "assistant"), the content and an optional list
+// of images.
 type Message struct {
 type Message struct {
-	Role    string      `json:"role"` // one of ["system", "user", "assistant"]
+	Role    string      `json:"role"`
 	Content string      `json:"content"`
 	Content string      `json:"content"`
 	Images  []ImageData `json:"images,omitempty"`
 	Images  []ImageData `json:"images,omitempty"`
 }
 }
 
 
+// ChatResponse is the response returned by [Client.Chat]. Its fields are
+// similar to [GenerateResponse].
 type ChatResponse struct {
 type ChatResponse struct {
 	Model     string    `json:"model"`
 	Model     string    `json:"model"`
 	CreatedAt time.Time `json:"created_at"`
 	CreatedAt time.Time `json:"created_at"`
@@ -111,7 +131,8 @@ type Metrics struct {
 	EvalDuration       time.Duration `json:"eval_duration,omitempty"`
 	EvalDuration       time.Duration `json:"eval_duration,omitempty"`
 }
 }
 
 
-// Options specified in GenerateRequest, if you add a new option here add it to the API docs also
+// Options specified in [GenerateRequest], if you add a new option here add it
+// to the API docs also.
 type Options struct {
 type Options struct {
 	Runner
 	Runner
 
 
@@ -157,18 +178,28 @@ type Runner struct {
 	RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
 	RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
 }
 }
 
 
+// EmbeddingRequest is the request passed to [Client.Embeddings].
 type EmbeddingRequest struct {
 type EmbeddingRequest struct {
-	Model     string    `json:"model"`
-	Prompt    string    `json:"prompt"`
+	// Model is the model name.
+	Model string `json:"model"`
+
+	// Prompt is the textual prompt to embed.
+	Prompt string `json:"prompt"`
+
+	// KeepAlive controls how long the model will stay loaded in memory following
+	// this request.
 	KeepAlive *Duration `json:"keep_alive,omitempty"`
 	KeepAlive *Duration `json:"keep_alive,omitempty"`
 
 
+	// Options lists model-specific options.
 	Options map[string]interface{} `json:"options"`
 	Options map[string]interface{} `json:"options"`
 }
 }
 
 
+// EmbeddingResponse is the response from [Client.Embeddings].
 type EmbeddingResponse struct {
 type EmbeddingResponse struct {
 	Embedding []float64 `json:"embedding"`
 	Embedding []float64 `json:"embedding"`
 }
 }
 
 
+// CreateRequest is the request passed to [Client.Create].
 type CreateRequest struct {
 type CreateRequest struct {
 	Model        string `json:"model"`
 	Model        string `json:"model"`
 	Path         string `json:"path"`
 	Path         string `json:"path"`
@@ -180,6 +211,7 @@ type CreateRequest struct {
 	Name string `json:"name"`
 	Name string `json:"name"`
 }
 }
 
 
+// DeleteRequest is the request passed to [Client.Delete].
 type DeleteRequest struct {
 type DeleteRequest struct {
 	Model string `json:"model"`
 	Model string `json:"model"`
 
 
@@ -187,6 +219,7 @@ type DeleteRequest struct {
 	Name string `json:"name"`
 	Name string `json:"name"`
 }
 }
 
 
+// ShowRequest is the request passed to [Client.Show].
 type ShowRequest struct {
 type ShowRequest struct {
 	Model    string `json:"model"`
 	Model    string `json:"model"`
 	System   string `json:"system"`
 	System   string `json:"system"`
@@ -198,6 +231,7 @@ type ShowRequest struct {
 	Name string `json:"name"`
 	Name string `json:"name"`
 }
 }
 
 
+// ShowResponse is the response returned from [Client.Show].
 type ShowResponse struct {
 type ShowResponse struct {
 	License    string       `json:"license,omitempty"`
 	License    string       `json:"license,omitempty"`
 	Modelfile  string       `json:"modelfile,omitempty"`
 	Modelfile  string       `json:"modelfile,omitempty"`
@@ -208,11 +242,13 @@ type ShowResponse struct {
 	Messages   []Message    `json:"messages,omitempty"`
 	Messages   []Message    `json:"messages,omitempty"`
 }
 }
 
 
+// CopyRequest is the request passed to [Client.Copy].
 type CopyRequest struct {
 type CopyRequest struct {
 	Source      string `json:"source"`
 	Source      string `json:"source"`
 	Destination string `json:"destination"`
 	Destination string `json:"destination"`
 }
 }
 
 
+// PullRequest is the request passed to [Client.Pull].
 type PullRequest struct {
 type PullRequest struct {
 	Model    string `json:"model"`
 	Model    string `json:"model"`
 	Insecure bool   `json:"insecure,omitempty"`
 	Insecure bool   `json:"insecure,omitempty"`
@@ -224,6 +260,8 @@ type PullRequest struct {
 	Name string `json:"name"`
 	Name string `json:"name"`
 }
 }
 
 
+// ProgressResponse is the response passed to progress functions like
+// [PullProgressFunc] and [PushProgressFunc].
 type ProgressResponse struct {
 type ProgressResponse struct {
 	Status    string `json:"status"`
 	Status    string `json:"status"`
 	Digest    string `json:"digest,omitempty"`
 	Digest    string `json:"digest,omitempty"`
@@ -231,6 +269,7 @@ type ProgressResponse struct {
 	Completed int64  `json:"completed,omitempty"`
 	Completed int64  `json:"completed,omitempty"`
 }
 }
 
 
+// PushRequest is the request passed to [Client.Push].
 type PushRequest struct {
 type PushRequest struct {
 	Model    string `json:"model"`
 	Model    string `json:"model"`
 	Insecure bool   `json:"insecure,omitempty"`
 	Insecure bool   `json:"insecure,omitempty"`
@@ -242,10 +281,12 @@ type PushRequest struct {
 	Name string `json:"name"`
 	Name string `json:"name"`
 }
 }
 
 
+// ListResponse is the response from [Client.List].
 type ListResponse struct {
 type ListResponse struct {
 	Models []ModelResponse `json:"models"`
 	Models []ModelResponse `json:"models"`
 }
 }
 
 
+// ModelResponse is a single model description in [ListResponse].
 type ModelResponse struct {
 type ModelResponse struct {
 	Name       string       `json:"name"`
 	Name       string       `json:"name"`
 	Model      string       `json:"model"`
 	Model      string       `json:"model"`
@@ -259,17 +300,28 @@ type TokenResponse struct {
 	Token string `json:"token"`
 	Token string `json:"token"`
 }
 }
 
 
+// GenerateResponse is the response passed into [GenerateResponseFunc].
 type GenerateResponse struct {
 type GenerateResponse struct {
-	Model     string    `json:"model"`
+	// Model is the model name that generated the response.
+	Model string `json:"model"`
+
+	//CreatedAt is the timestamp of the response.
 	CreatedAt time.Time `json:"created_at"`
 	CreatedAt time.Time `json:"created_at"`
-	Response  string    `json:"response"`
 
 
-	Done    bool  `json:"done"`
+	// Response is the textual response itself.
+	Response string `json:"response"`
+
+	// Done specifies if the response is complete.
+	Done bool `json:"done"`
+
+	// Context is an encoding of the conversation used in this response; this
+	// can be sent in the next request to keep a conversational memory.
 	Context []int `json:"context,omitempty"`
 	Context []int `json:"context,omitempty"`
 
 
 	Metrics
 	Metrics
 }
 }
 
 
+// ModelDetails provides details about a model.
 type ModelDetails struct {
 type ModelDetails struct {
 	ParentModel       string   `json:"parent_model"`
 	ParentModel       string   `json:"parent_model"`
 	Format            string   `json:"format"`
 	Format            string   `json:"format"`
@@ -307,7 +359,9 @@ func (m *Metrics) Summary() {
 	}
 	}
 }
 }
 
 
-var ErrInvalidOpts = fmt.Errorf("invalid options")
+// ErrInvalidOpts is returned when invalid options are passed to the client.
+var ErrInvalidOpts = errors.New("invalid options")
+var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
 
 
 func (opts *Options) FromMap(m map[string]interface{}) error {
 func (opts *Options) FromMap(m map[string]interface{}) error {
 	valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
 	valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
@@ -392,11 +446,15 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
 	return nil
 	return nil
 }
 }
 
 
+// DefaultOptions is the default set of options for [GenerateRequest]; these
+// values are used unless the user specifies other values explicitly.
 func DefaultOptions() Options {
 func DefaultOptions() Options {
 	return Options{
 	return Options{
 		// options set on request to runner
 		// options set on request to runner
-		NumPredict:       -1,
-		NumKeep:          0,
+		NumPredict: -1,
+
+		// set a minimal num_keep to avoid issues on context shifts
+		NumKeep:          4,
 		Temperature:      0.8,
 		Temperature:      0.8,
 		TopK:             40,
 		TopK:             40,
 		TopP:             0.9,
 		TopP:             0.9,
@@ -432,6 +490,13 @@ type Duration struct {
 	time.Duration
 	time.Duration
 }
 }
 
 
+func (d Duration) MarshalJSON() ([]byte, error) {
+	if d.Duration < 0 {
+		return []byte("-1"), nil
+	}
+	return []byte("\"" + d.Duration.String() + "\""), nil
+}
+
 func (d *Duration) UnmarshalJSON(b []byte) (err error) {
 func (d *Duration) UnmarshalJSON(b []byte) (err error) {
 	var v any
 	var v any
 	if err := json.Unmarshal(b, &v); err != nil {
 	if err := json.Unmarshal(b, &v); err != nil {
@@ -445,7 +510,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
 		if t < 0 {
 		if t < 0 {
 			d.Duration = time.Duration(math.MaxInt64)
 			d.Duration = time.Duration(math.MaxInt64)
 		} else {
 		} else {
-			d.Duration = time.Duration(t * float64(time.Second))
+			d.Duration = time.Duration(int(t) * int(time.Second))
 		}
 		}
 	case string:
 	case string:
 		d.Duration, err = time.ParseDuration(t)
 		d.Duration, err = time.ParseDuration(t)
@@ -455,6 +520,8 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
 		if d.Duration < 0 {
 		if d.Duration < 0 {
 			d.Duration = time.Duration(math.MaxInt64)
 			d.Duration = time.Duration(math.MaxInt64)
 		}
 		}
+	default:
+		return fmt.Errorf("Unsupported type: '%s'", reflect.TypeOf(v))
 	}
 	}
 
 
 	return nil
 	return nil

+ 57 - 0
api/types_test.go

@@ -21,6 +21,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
 			req:  `{ "keep_alive": 42 }`,
 			req:  `{ "keep_alive": 42 }`,
 			exp:  &Duration{42 * time.Second},
 			exp:  &Duration{42 * time.Second},
 		},
 		},
+		{
+			name: "Positive Float",
+			req:  `{ "keep_alive": 42.5 }`,
+			exp:  &Duration{42 * time.Second},
+		},
 		{
 		{
 			name: "Positive Integer String",
 			name: "Positive Integer String",
 			req:  `{ "keep_alive": "42m" }`,
 			req:  `{ "keep_alive": "42m" }`,
@@ -31,6 +36,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
 			req:  `{ "keep_alive": -1 }`,
 			req:  `{ "keep_alive": -1 }`,
 			exp:  &Duration{math.MaxInt64},
 			exp:  &Duration{math.MaxInt64},
 		},
 		},
+		{
+			name: "Negative Float",
+			req:  `{ "keep_alive": -3.14 }`,
+			exp:  &Duration{math.MaxInt64},
+		},
 		{
 		{
 			name: "Negative Integer String",
 			name: "Negative Integer String",
 			req:  `{ "keep_alive": "-1m" }`,
 			req:  `{ "keep_alive": "-1m" }`,
@@ -48,3 +58,50 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
 		})
 		})
 	}
 	}
 }
 }
+
+func TestDurationMarshalUnmarshal(t *testing.T) {
+	tests := []struct {
+		name     string
+		input    time.Duration
+		expected time.Duration
+	}{
+		{
+			"negative duration",
+			time.Duration(-1),
+			time.Duration(math.MaxInt64),
+		},
+		{
+			"positive duration",
+			time.Duration(42 * time.Second),
+			time.Duration(42 * time.Second),
+		},
+		{
+			"another positive duration",
+			time.Duration(42 * time.Minute),
+			time.Duration(42 * time.Minute),
+		},
+		{
+			"zero duration",
+			time.Duration(0),
+			time.Duration(0),
+		},
+		{
+			"max duration",
+			time.Duration(math.MaxInt64),
+			time.Duration(math.MaxInt64),
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			b, err := json.Marshal(Duration{test.input})
+			require.NoError(t, err)
+
+			var d Duration
+			err = json.Unmarshal(b, &d)
+			require.NoError(t, err)
+
+			assert.Equal(t, test.expected, d.Duration, "input %v, marshalled %v, got %v", test.input, string(b), d.Duration)
+		})
+	}
+}

+ 3 - 1
app/lifecycle/logging.go

@@ -5,12 +5,14 @@ import (
 	"log/slog"
 	"log/slog"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
+
+	"github.com/ollama/ollama/server/envconfig"
 )
 )
 
 
 func InitLogging() {
 func InitLogging() {
 	level := slog.LevelInfo
 	level := slog.LevelInfo
 
 
-	if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
+	if envconfig.Debug {
 		level = slog.LevelDebug
 		level = slog.LevelDebug
 	}
 	}
 
 

+ 37 - 27
app/lifecycle/server.go

@@ -43,37 +43,36 @@ func getCLIFullPath(command string) string {
 	return command
 	return command
 }
 }
 
 
-func SpawnServer(ctx context.Context, command string) (chan int, error) {
-	done := make(chan int)
-
-	logDir := filepath.Dir(ServerLogFile)
-	_, err := os.Stat(logDir)
-	if errors.Is(err, os.ErrNotExist) {
-		if err := os.MkdirAll(logDir, 0o755); err != nil {
-			return done, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
-		}
-	}
-
+func start(ctx context.Context, command string) (*exec.Cmd, error) {
 	cmd := getCmd(ctx, getCLIFullPath(command))
 	cmd := getCmd(ctx, getCLIFullPath(command))
-	// send stdout and stderr to a file
 	stdout, err := cmd.StdoutPipe()
 	stdout, err := cmd.StdoutPipe()
 	if err != nil {
 	if err != nil {
-		return done, fmt.Errorf("failed to spawn server stdout pipe %s", err)
+		return nil, fmt.Errorf("failed to spawn server stdout pipe: %w", err)
 	}
 	}
 	stderr, err := cmd.StderrPipe()
 	stderr, err := cmd.StderrPipe()
 	if err != nil {
 	if err != nil {
-		return done, fmt.Errorf("failed to spawn server stderr pipe %s", err)
-	}
-	stdin, err := cmd.StdinPipe()
-	if err != nil {
-		return done, fmt.Errorf("failed to spawn server stdin pipe %s", err)
+		return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
 	}
 	}
 
 
 	// TODO - rotation
 	// TODO - rotation
 	logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
 	logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
 	if err != nil {
 	if err != nil {
-		return done, fmt.Errorf("failed to create server log %w", err)
+		return nil, fmt.Errorf("failed to create server log: %w", err)
 	}
 	}
+
+	logDir := filepath.Dir(ServerLogFile)
+	_, err = os.Stat(logDir)
+	if err != nil {
+		if !errors.Is(err, os.ErrNotExist) {
+			return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
+
+		}
+
+		if err := os.MkdirAll(logDir, 0o755); err != nil {
+			return nil, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
+		}
+	}
+
 	go func() {
 	go func() {
 		defer logFile.Close()
 		defer logFile.Close()
 		io.Copy(logFile, stdout) //nolint:errcheck
 		io.Copy(logFile, stdout) //nolint:errcheck
@@ -117,19 +116,33 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
 
 
 	// run the command and wait for it to finish
 	// run the command and wait for it to finish
 	if err := cmd.Start(); err != nil {
 	if err := cmd.Start(); err != nil {
-		return done, fmt.Errorf("failed to start server %w", err)
+		return nil, fmt.Errorf("failed to start server %w", err)
 	}
 	}
 	if cmd.Process != nil {
 	if cmd.Process != nil {
 		slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
 		slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
 	}
 	}
 	slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))
 	slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))
 
 
+	return cmd, nil
+}
+
+func SpawnServer(ctx context.Context, command string) (chan int, error) {
+	done := make(chan int)
+
 	go func() {
 	go func() {
 		// Keep the server running unless we're shuttind down the app
 		// Keep the server running unless we're shuttind down the app
 		crashCount := 0
 		crashCount := 0
 		for {
 		for {
+			slog.Info("starting server...")
+			cmd, err := start(ctx, command)
+			if err != nil {
+				crashCount++
+				slog.Error(fmt.Sprintf("failed to start server %s", err))
+				time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
+				continue
+			}
+
 			cmd.Wait() //nolint:errcheck
 			cmd.Wait() //nolint:errcheck
-			stdin.Close()
 			var code int
 			var code int
 			if cmd.ProcessState != nil {
 			if cmd.ProcessState != nil {
 				code = cmd.ProcessState.ExitCode()
 				code = cmd.ProcessState.ExitCode()
@@ -143,15 +156,12 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
 			default:
 			default:
 				crashCount++
 				crashCount++
 				slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
 				slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
-				time.Sleep(500 * time.Millisecond)
-				if err := cmd.Start(); err != nil {
-					slog.Error(fmt.Sprintf("failed to restart server %s", err))
-					// Keep trying, but back off if we keep failing
-					time.Sleep(time.Duration(crashCount) * time.Second)
-				}
+				time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
+				break
 			}
 			}
 		}
 		}
 	}()
 	}()
+
 	return done, nil
 	return done, nil
 }
 }
 
 

+ 1 - 4
app/lifecycle/updater_windows.go

@@ -31,16 +31,13 @@ func DoUpgrade(cancel context.CancelFunc, done chan int) error {
 		"/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd
 		"/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd
 		"/FORCECLOSEAPPLICATIONS",               // Force close the tray app - might be needed
 		"/FORCECLOSEAPPLICATIONS",               // Force close the tray app - might be needed
 	}
 	}
-	// When we're not in debug mode, make the upgrade as quiet as possible (no GUI, no prompts)
-	// TODO - temporarily disable since we're pinning in debug mode for the preview
-	// if debug := os.Getenv("OLLAMA_DEBUG"); debug == "" {
+	// make the upgrade as quiet as possible (no GUI, no prompts)
 	installArgs = append(installArgs,
 	installArgs = append(installArgs,
 		"/SP", // Skip the "This will install... Do you wish to continue" prompt
 		"/SP", // Skip the "This will install... Do you wish to continue" prompt
 		"/SUPPRESSMSGBOXES",
 		"/SUPPRESSMSGBOXES",
 		"/SILENT",
 		"/SILENT",
 		"/VERYSILENT",
 		"/VERYSILENT",
 	)
 	)
-	// }
 
 
 	// Safeguard in case we have requests in flight that need to drain...
 	// Safeguard in case we have requests in flight that need to drain...
 	slog.Info("Waiting for server to shutdown")
 	slog.Info("Waiting for server to shutdown")

+ 5 - 8
app/ollama.iss

@@ -88,15 +88,12 @@ DialogFontSize=12
 [Files]
 [Files]
 Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
 Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
 Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
 Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
-Source: "..\dist\windeps\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
+Source: "..\dist\windows-{#ARCH}\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
+Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
 Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
 Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
 Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
 Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
-; Assumes v5.7, may need adjustments for v6
-#if GetEnv("HIP_PATH") != ""
-  Source: "{#GetEnv('HIP_PATH')}\bin\hipblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion
-  Source: "{#GetEnv('HIP_PATH')}\bin\rocblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion
-  ; amdhip64.dll dependency comes from the driver and must be installed already
-  Source: "{#GetEnv('HIP_PATH')}\bin\rocblas\library\*"; DestDir: "{app}\rocm\rocblas\library\"; Flags: ignoreversion
+#if DirExists("..\dist\windows-amd64\rocm")
+  Source: "..\dist\windows-amd64\rocm\*"; DestDir: "{app}\rocm\"; Flags: ignoreversion recursesubdirs
 #endif
 #endif
 
 
 
 
@@ -132,7 +129,7 @@ SetupAppRunningError=Another Ollama installer is running.%n%nPlease cancel or fi
 
 
 
 
 ;FinishedHeadingLabel=Run your first model
 ;FinishedHeadingLabel=Run your first model
-;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n    ollama run llama2
+;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n    ollama run llama3
 ;ClickFinish=%n
 ;ClickFinish=%n
 
 
 [Registry]
 [Registry]

+ 71 - 71
app/tray/wintray/menus.go

@@ -1,71 +1,71 @@
-//go:build windows
-
-package wintray
-
-import (
-	"fmt"
-	"log/slog"
-	"unsafe"
-
-	"golang.org/x/sys/windows"
-)
-
-const (
-	updatAvailableMenuID = 1
-	updateMenuID         = updatAvailableMenuID + 1
-	separatorMenuID      = updateMenuID + 1
-	diagLogsMenuID       = separatorMenuID + 1
-	diagSeparatorMenuID  = diagLogsMenuID + 1
-	quitMenuID           = diagSeparatorMenuID + 1
-)
-
-func (t *winTray) initMenus() error {
-	if err := t.addOrUpdateMenuItem(diagLogsMenuID, 0, diagLogsMenuTitle, false); err != nil {
-		return fmt.Errorf("unable to create menu entries %w\n", err)
-	}
-	if err := t.addSeparatorMenuItem(diagSeparatorMenuID, 0); err != nil {
-		return fmt.Errorf("unable to create menu entries %w", err)
-	}
-	if err := t.addOrUpdateMenuItem(quitMenuID, 0, quitMenuTitle, false); err != nil {
-		return fmt.Errorf("unable to create menu entries %w\n", err)
-	}
-	return nil
-}
-
-func (t *winTray) UpdateAvailable(ver string) error {
-	if !t.updateNotified {
-		slog.Debug("updating menu and sending notification for new update")
-		if err := t.addOrUpdateMenuItem(updatAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil {
-			return fmt.Errorf("unable to create menu entries %w", err)
-		}
-		if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil {
-			return fmt.Errorf("unable to create menu entries %w", err)
-		}
-		if err := t.addSeparatorMenuItem(separatorMenuID, 0); err != nil {
-			return fmt.Errorf("unable to create menu entries %w", err)
-		}
-		iconFilePath, err := iconBytesToFilePath(wt.updateIcon)
-		if err != nil {
-			return fmt.Errorf("unable to write icon data to temp file: %w", err)
-		}
-		if err := wt.setIcon(iconFilePath); err != nil {
-			return fmt.Errorf("unable to set icon: %w", err)
-		}
-		t.updateNotified = true
-
-		t.pendingUpdate = true
-		// Now pop up the notification
-		t.muNID.Lock()
-		defer t.muNID.Unlock()
-		copy(t.nid.InfoTitle[:], windows.StringToUTF16(updateTitle))
-		copy(t.nid.Info[:], windows.StringToUTF16(fmt.Sprintf(updateMessage, ver)))
-		t.nid.Flags |= NIF_INFO
-		t.nid.Timeout = 10
-		t.nid.Size = uint32(unsafe.Sizeof(*wt.nid))
-		err = t.nid.modify()
-		if err != nil {
-			return err
-		}
-	}
-	return nil
-}
+//go:build windows
+
+package wintray
+
+import (
+	"fmt"
+	"log/slog"
+	"unsafe"
+
+	"golang.org/x/sys/windows"
+)
+
+const (
+	updatAvailableMenuID = 1
+	updateMenuID         = updatAvailableMenuID + 1
+	separatorMenuID      = updateMenuID + 1
+	diagLogsMenuID       = separatorMenuID + 1
+	diagSeparatorMenuID  = diagLogsMenuID + 1
+	quitMenuID           = diagSeparatorMenuID + 1
+)
+
+func (t *winTray) initMenus() error {
+	if err := t.addOrUpdateMenuItem(diagLogsMenuID, 0, diagLogsMenuTitle, false); err != nil {
+		return fmt.Errorf("unable to create menu entries %w\n", err)
+	}
+	if err := t.addSeparatorMenuItem(diagSeparatorMenuID, 0); err != nil {
+		return fmt.Errorf("unable to create menu entries %w", err)
+	}
+	if err := t.addOrUpdateMenuItem(quitMenuID, 0, quitMenuTitle, false); err != nil {
+		return fmt.Errorf("unable to create menu entries %w\n", err)
+	}
+	return nil
+}
+
+func (t *winTray) UpdateAvailable(ver string) error {
+	if !t.updateNotified {
+		slog.Debug("updating menu and sending notification for new update")
+		if err := t.addOrUpdateMenuItem(updatAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil {
+			return fmt.Errorf("unable to create menu entries %w", err)
+		}
+		if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil {
+			return fmt.Errorf("unable to create menu entries %w", err)
+		}
+		if err := t.addSeparatorMenuItem(separatorMenuID, 0); err != nil {
+			return fmt.Errorf("unable to create menu entries %w", err)
+		}
+		iconFilePath, err := iconBytesToFilePath(wt.updateIcon)
+		if err != nil {
+			return fmt.Errorf("unable to write icon data to temp file: %w", err)
+		}
+		if err := wt.setIcon(iconFilePath); err != nil {
+			return fmt.Errorf("unable to set icon: %w", err)
+		}
+		t.updateNotified = true
+
+		t.pendingUpdate = true
+		// Now pop up the notification
+		t.muNID.Lock()
+		defer t.muNID.Unlock()
+		copy(t.nid.InfoTitle[:], windows.StringToUTF16(updateTitle))
+		copy(t.nid.Info[:], windows.StringToUTF16(fmt.Sprintf(updateMessage, ver)))
+		t.nid.Flags |= NIF_INFO
+		t.nid.Timeout = 10
+		t.nid.Size = uint32(unsafe.Sizeof(*wt.nid))
+		err = t.nid.modify()
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}

+ 33 - 3
auth/auth.go

@@ -10,12 +10,44 @@ import (
 	"log/slog"
 	"log/slog"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
+	"strings"
 
 
 	"golang.org/x/crypto/ssh"
 	"golang.org/x/crypto/ssh"
 )
 )
 
 
 const defaultPrivateKey = "id_ed25519"
 const defaultPrivateKey = "id_ed25519"
 
 
+func keyPath() (string, error) {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return "", err
+	}
+
+	return filepath.Join(home, ".ollama", defaultPrivateKey), nil
+}
+
+func GetPublicKey() (string, error) {
+	keyPath, err := keyPath()
+	if err != nil {
+		return "", err
+	}
+
+	privateKeyFile, err := os.ReadFile(keyPath)
+	if err != nil {
+		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
+		return "", err
+	}
+
+	privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
+	if err != nil {
+		return "", err
+	}
+
+	publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
+
+	return strings.TrimSpace(string(publicKey)), nil
+}
+
 func NewNonce(r io.Reader, length int) (string, error) {
 func NewNonce(r io.Reader, length int) (string, error) {
 	nonce := make([]byte, length)
 	nonce := make([]byte, length)
 	if _, err := io.ReadFull(r, nonce); err != nil {
 	if _, err := io.ReadFull(r, nonce); err != nil {
@@ -26,13 +58,11 @@ func NewNonce(r io.Reader, length int) (string, error) {
 }
 }
 
 
 func Sign(ctx context.Context, bts []byte) (string, error) {
 func Sign(ctx context.Context, bts []byte) (string, error) {
-	home, err := os.UserHomeDir()
+	keyPath, err := keyPath()
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
 
 
-	keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
-
 	privateKeyFile, err := os.ReadFile(keyPath)
 	privateKeyFile, err := os.ReadFile(keyPath)
 	if err != nil {
 	if err != nil {
 		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
 		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))

+ 193 - 103
cmd/cmd.go

@@ -17,6 +17,7 @@ import (
 	"os"
 	"os"
 	"os/signal"
 	"os/signal"
 	"path/filepath"
 	"path/filepath"
+	"regexp"
 	"runtime"
 	"runtime"
 	"strings"
 	"strings"
 	"syscall"
 	"syscall"
@@ -31,10 +32,12 @@ import (
 	"golang.org/x/term"
 	"golang.org/x/term"
 
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/auth"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/format"
-	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/progress"
 	"github.com/ollama/ollama/progress"
 	"github.com/ollama/ollama/server"
 	"github.com/ollama/ollama/server"
+	"github.com/ollama/ollama/types/errtypes"
+	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
 	"github.com/ollama/ollama/version"
 )
 )
 
 
@@ -53,14 +56,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	p := progress.NewProgress(os.Stderr)
 	p := progress.NewProgress(os.Stderr)
 	defer p.Stop()
 	defer p.Stop()
 
 
-	bars := make(map[string]*progress.Bar)
-
-	modelfile, err := os.ReadFile(filename)
+	f, err := os.Open(filename)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
+	defer f.Close()
 
 
-	commands, err := parser.Parse(bytes.NewReader(modelfile))
+	modelfile, err := model.ParseFile(f)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -74,10 +76,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	spinner := progress.NewSpinner(status)
 	spinner := progress.NewSpinner(status)
 	p.Add(status, spinner)
 	p.Add(status, spinner)
 
 
-	for _, c := range commands {
-		switch c.Name {
+	for i := range modelfile.Commands {
+		switch modelfile.Commands[i].Name {
 		case "model", "adapter":
 		case "model", "adapter":
-			path := c.Args
+			path := modelfile.Commands[i].Args
 			if path == "~" {
 			if path == "~" {
 				path = home
 				path = home
 			} else if strings.HasPrefix(path, "~/") {
 			} else if strings.HasPrefix(path, "~/") {
@@ -89,101 +91,22 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 			}
 			}
 
 
 			fi, err := os.Stat(path)
 			fi, err := os.Stat(path)
-			if errors.Is(err, os.ErrNotExist) && c.Name == "model" {
+			if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
 				continue
 				continue
 			} else if err != nil {
 			} else if err != nil {
 				return err
 				return err
 			}
 			}
 
 
-			// TODO make this work w/ adapters
 			if fi.IsDir() {
 			if fi.IsDir() {
-				tf, err := os.CreateTemp("", "ollama-tf")
-				if err != nil {
-					return err
-				}
-				defer os.RemoveAll(tf.Name())
-
-				zf := zip.NewWriter(tf)
-
-				files := []string{}
-
-				tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
+				// this is likely a safetensors or pytorch directory
+				// TODO make this work w/ adapters
+				tempfile, err := tempZipFiles(path)
 				if err != nil {
 				if err != nil {
 					return err
 					return err
-				} else if len(tfiles) == 0 {
-					tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
-					if err != nil {
-						return err
-					}
-				}
-
-				files = append(files, tfiles...)
-
-				if len(files) == 0 {
-					return fmt.Errorf("no models were found in '%s'", path)
 				}
 				}
+				defer os.RemoveAll(tempfile)
 
 
-				// add the safetensor/torch config file + tokenizer
-				files = append(files, filepath.Join(path, "config.json"))
-				files = append(files, filepath.Join(path, "params.json"))
-				files = append(files, filepath.Join(path, "added_tokens.json"))
-				files = append(files, filepath.Join(path, "tokenizer.model"))
-
-				for _, fn := range files {
-					f, err := os.Open(fn)
-
-					// just skip whatever files aren't there
-					if os.IsNotExist(err) {
-						if strings.HasSuffix(fn, "tokenizer.model") {
-							// try the parent dir before giving up
-							parentDir := filepath.Dir(path)
-							newFn := filepath.Join(parentDir, "tokenizer.model")
-							f, err = os.Open(newFn)
-							if os.IsNotExist(err) {
-								continue
-							} else if err != nil {
-								return err
-							}
-						} else {
-							continue
-						}
-					} else if err != nil {
-						return err
-					}
-
-					fi, err := f.Stat()
-					if err != nil {
-						return err
-					}
-
-					h, err := zip.FileInfoHeader(fi)
-					if err != nil {
-						return err
-					}
-
-					h.Name = filepath.Base(fn)
-					h.Method = zip.Store
-
-					w, err := zf.CreateHeader(h)
-					if err != nil {
-						return err
-					}
-
-					_, err = io.Copy(w, f)
-					if err != nil {
-						return err
-					}
-
-				}
-
-				if err := zf.Close(); err != nil {
-					return err
-				}
-
-				if err := tf.Close(); err != nil {
-					return err
-				}
-				path = tf.Name()
+				path = tempfile
 			}
 			}
 
 
 			digest, err := createBlob(cmd, client, path)
 			digest, err := createBlob(cmd, client, path)
@@ -191,10 +114,11 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 				return err
 				return err
 			}
 			}
 
 
-			modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte("@"+digest))
+			modelfile.Commands[i].Args = "@" + digest
 		}
 		}
 	}
 	}
 
 
+	bars := make(map[string]*progress.Bar)
 	fn := func(resp api.ProgressResponse) error {
 	fn := func(resp api.ProgressResponse) error {
 		if resp.Digest != "" {
 		if resp.Digest != "" {
 			spinner.Stop()
 			spinner.Stop()
@@ -220,7 +144,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 
 
 	quantization, _ := cmd.Flags().GetString("quantization")
 	quantization, _ := cmd.Flags().GetString("quantization")
 
 
-	request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization}
+	request := api.CreateRequest{Name: args[0], Modelfile: modelfile.String(), Quantization: quantization}
 	if err := client.Create(cmd.Context(), &request, fn); err != nil {
 	if err := client.Create(cmd.Context(), &request, fn); err != nil {
 		return err
 		return err
 	}
 	}
@@ -228,6 +152,114 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	return nil
 	return nil
 }
 }
 
 
+func tempZipFiles(path string) (string, error) {
+	tempfile, err := os.CreateTemp("", "ollama-tf")
+	if err != nil {
+		return "", err
+	}
+	defer tempfile.Close()
+
+	zipfile := zip.NewWriter(tempfile)
+	defer zipfile.Close()
+
+	detectContentType := func(path string) (string, error) {
+		f, err := os.Open(path)
+		if err != nil {
+			return "", err
+		}
+		defer f.Close()
+
+		var b bytes.Buffer
+		b.Grow(512)
+
+		if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
+			return "", err
+		}
+
+		contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
+		return contentType, nil
+	}
+
+	glob := func(pattern, contentType string) ([]string, error) {
+		matches, err := filepath.Glob(pattern)
+		if err != nil {
+			return nil, err
+		}
+
+		for _, safetensor := range matches {
+			if ct, err := detectContentType(safetensor); err != nil {
+				return nil, err
+			} else if ct != contentType {
+				return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
+			}
+		}
+
+		return matches, nil
+	}
+
+	var files []string
+	if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
+		// safetensors files might be unresolved git lfs references; skip if they are
+		// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
+		files = append(files, st...)
+	} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
+		// pytorch files might also be unresolved git lfs references; skip if they are
+		// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
+		files = append(files, pt...)
+	} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/octet-stream"); len(pt) > 0 {
+		// pytorch files might also be unresolved git lfs references; skip if they are
+		// covers consolidated.x.pth, consolidated.pth
+		files = append(files, pt...)
+	} else {
+		return "", errors.New("no safetensors or torch files found")
+	}
+
+	// add configuration files, json files are detected as text/plain
+	js, err := glob(filepath.Join(path, "*.json"), "text/plain")
+	if err != nil {
+		return "", err
+	}
+	files = append(files, js...)
+
+	if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
+		// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
+		// tokenizer.model might be a unresolved git lfs reference; error if it is
+		files = append(files, tks...)
+	} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
+		// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
+		files = append(files, tks...)
+	}
+
+	for _, file := range files {
+		f, err := os.Open(file)
+		if err != nil {
+			return "", err
+		}
+		defer f.Close()
+
+		fi, err := f.Stat()
+		if err != nil {
+			return "", err
+		}
+
+		zfi, err := zip.FileInfoHeader(fi)
+		if err != nil {
+			return "", err
+		}
+
+		zf, err := zipfile.CreateHeader(zfi)
+		if err != nil {
+			return "", err
+		}
+
+		if _, err := io.Copy(zf, f); err != nil {
+			return "", err
+		}
+	}
+
+	return tempfile.Name(), nil
+}
+
 func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
 func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
 	bin, err := os.Open(path)
 	bin, err := os.Open(path)
 	if err != nil {
 	if err != nil {
@@ -322,6 +354,47 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 	return generateInteractive(cmd, opts)
 	return generateInteractive(cmd, opts)
 }
 }
 
 
+func errFromUnknownKey(unknownKeyErr error) error {
+	// find SSH public key in the error message
+	sshKeyPattern := `ssh-\w+ [^\s"]+`
+	re := regexp.MustCompile(sshKeyPattern)
+	matches := re.FindStringSubmatch(unknownKeyErr.Error())
+
+	if len(matches) > 0 {
+		serverPubKey := matches[0]
+
+		localPubKey, err := auth.GetPublicKey()
+		if err != nil {
+			return unknownKeyErr
+		}
+
+		if runtime.GOOS == "linux" && serverPubKey != localPubKey {
+			// try the ollama service public key
+			svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
+			if err != nil {
+				return unknownKeyErr
+			}
+			localPubKey = strings.TrimSpace(string(svcPubKey))
+		}
+
+		// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
+		if serverPubKey != localPubKey {
+			return unknownKeyErr
+		}
+
+		var msg strings.Builder
+		msg.WriteString(unknownKeyErr.Error())
+		msg.WriteString("\n\nYour ollama key is:\n")
+		msg.WriteString(localPubKey)
+		msg.WriteString("\nAdd your key at:\n")
+		msg.WriteString("https://ollama.com/settings/keys")
+
+		return errors.New(msg.String())
+	}
+
+	return unknownKeyErr
+}
+
 func PushHandler(cmd *cobra.Command, args []string) error {
 func PushHandler(cmd *cobra.Command, args []string) error {
 	client, err := api.ClientFromEnvironment()
 	client, err := api.ClientFromEnvironment()
 	if err != nil {
 	if err != nil {
@@ -369,6 +442,20 @@ func PushHandler(cmd *cobra.Command, args []string) error {
 
 
 	request := api.PushRequest{Name: args[0], Insecure: insecure}
 	request := api.PushRequest{Name: args[0], Insecure: insecure}
 	if err := client.Push(cmd.Context(), &request, fn); err != nil {
 	if err := client.Push(cmd.Context(), &request, fn); err != nil {
+		if spinner != nil {
+			spinner.Stop()
+		}
+		if strings.Contains(err.Error(), "access denied") {
+			return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
+		}
+		host := model.ParseName(args[0]).Host
+		isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
+		if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
+			// the user has not added their ollama key to ollama.com
+			// re-throw an error with a more user-friendly message
+			return errFromUnknownKey(err)
+		}
+
 		return err
 		return err
 	}
 	}
 
 
@@ -796,24 +883,27 @@ func generate(cmd *cobra.Command, opts runOptions) error {
 }
 }
 
 
 func RunServer(cmd *cobra.Command, _ []string) error {
 func RunServer(cmd *cobra.Command, _ []string) error {
-	host, port, err := net.SplitHostPort(strings.Trim(os.Getenv("OLLAMA_HOST"), "\"'"))
+	// retrieve the OLLAMA_HOST environment variable
+	ollamaHost, err := api.GetOllamaHost()
 	if err != nil {
 	if err != nil {
-		host, port = "127.0.0.1", "11434"
-		if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
-			host = ip.String()
-		}
+		return err
 	}
 	}
 
 
 	if err := initializeKeypair(); err != nil {
 	if err := initializeKeypair(); err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	ln, err := net.Listen("tcp", net.JoinHostPort(host, port))
+	ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	return server.Serve(ln)
+	err = server.Serve(ln)
+	if errors.Is(err, http.ErrServerClosed) {
+		return nil
+	}
+
+	return err
 }
 }
 
 
 func initializeKeypair() error {
 func initializeKeypair() error {
@@ -1034,7 +1124,7 @@ Environment Variables:
 		RunE:    ListHandler,
 		RunE:    ListHandler,
 	}
 	}
 	copyCmd := &cobra.Command{
 	copyCmd := &cobra.Command{
-		Use:     "cp SOURCE TARGET",
+		Use:     "cp SOURCE DESTINATION",
 		Short:   "Copy a model",
 		Short:   "Copy a model",
 		Args:    cobra.ExactArgs(2),
 		Args:    cobra.ExactArgs(2),
 		PreRunE: checkServerHeartbeat,
 		PreRunE: checkServerHeartbeat,

+ 6 - 1
cmd/interactive.go

@@ -94,6 +94,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 		fmt.Fprintln(os.Stderr, "  /show           Show model information")
 		fmt.Fprintln(os.Stderr, "  /show           Show model information")
 		fmt.Fprintln(os.Stderr, "  /load <model>   Load a session or model")
 		fmt.Fprintln(os.Stderr, "  /load <model>   Load a session or model")
 		fmt.Fprintln(os.Stderr, "  /save <model>   Save your current session")
 		fmt.Fprintln(os.Stderr, "  /save <model>   Save your current session")
+		fmt.Fprintln(os.Stderr, "  /clear          Clear session context")
 		fmt.Fprintln(os.Stderr, "  /bye            Exit")
 		fmt.Fprintln(os.Stderr, "  /bye            Exit")
 		fmt.Fprintln(os.Stderr, "  /?, /help       Help for a command")
 		fmt.Fprintln(os.Stderr, "  /?, /help       Help for a command")
 		fmt.Fprintln(os.Stderr, "  /? shortcuts    Help for keyboard shortcuts")
 		fmt.Fprintln(os.Stderr, "  /? shortcuts    Help for keyboard shortcuts")
@@ -161,7 +162,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 		fmt.Fprintln(os.Stderr, "  /set parameter repeat_penalty <float> How strongly to penalize repetitions")
 		fmt.Fprintln(os.Stderr, "  /set parameter repeat_penalty <float> How strongly to penalize repetitions")
 		fmt.Fprintln(os.Stderr, "  /set parameter repeat_last_n <int>    Set how far back to look for repetitions")
 		fmt.Fprintln(os.Stderr, "  /set parameter repeat_last_n <int>    Set how far back to look for repetitions")
 		fmt.Fprintln(os.Stderr, "  /set parameter num_gpu <int>          The number of layers to send to the GPU")
 		fmt.Fprintln(os.Stderr, "  /set parameter num_gpu <int>          The number of layers to send to the GPU")
-		fmt.Fprintln(os.Stderr, "  /set parameter stop \"<string>\", ...   Set the stop parameters")
+		fmt.Fprintln(os.Stderr, "  /set parameter stop <string> <string> ...   Set the stop parameters")
 		fmt.Fprintln(os.Stderr, "")
 		fmt.Fprintln(os.Stderr, "")
 	}
 	}
 
 
@@ -280,6 +281,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 			}
 			}
 			fmt.Printf("Created new model '%s'\n", args[1])
 			fmt.Printf("Created new model '%s'\n", args[1])
 			continue
 			continue
+		case strings.HasPrefix(line, "/clear"):
+			opts.Messages = []api.Message{}
+			fmt.Println("Cleared session context")
+			continue
 		case strings.HasPrefix(line, "/set"):
 		case strings.HasPrefix(line, "/set"):
 			args := strings.Fields(line)
 			args := strings.Fields(line)
 			if len(args) > 1 {
 			if len(args) > 1 {

+ 19 - 14
convert/convert.go

@@ -5,6 +5,7 @@ import (
 	"encoding/binary"
 	"encoding/binary"
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
+	"io"
 	"log/slog"
 	"log/slog"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
@@ -18,19 +19,23 @@ import (
 )
 )
 
 
 type Params struct {
 type Params struct {
-	Architectures    []string `json:"architectures"`
-	VocabSize        int      `json:"vocab_size"`
-	HiddenSize       int      `json:"hidden_size"`       // n_embd
-	HiddenLayers     int      `json:"num_hidden_layers"` // n_layer
-	ContextSize      int      `json:"max_position_embeddings"`
-	IntermediateSize int      `json:"intermediate_size"`
-	AttentionHeads   int      `json:"num_attention_heads"` // n_head
-	KeyValHeads      int      `json:"num_key_value_heads"`
-	NormEPS          float64  `json:"rms_norm_eps"`
-	BoSTokenID       int      `json:"bos_token_id"`
-	EoSTokenID       int      `json:"eos_token_id"`
-	HeadDimension    int      `json:"head_dim"`
-	PaddingTokenID   int      `json:"pad_token_id"`
+	Architectures     []string `json:"architectures"`
+	VocabSize         int      `json:"vocab_size"`
+	HiddenSize        int      `json:"hidden_size"`       // n_embd
+	HiddenLayers      int      `json:"num_hidden_layers"` // n_layer
+	ContextSize       int      `json:"max_position_embeddings"`
+	IntermediateSize  int      `json:"intermediate_size"`
+	AttentionHeads    int      `json:"num_attention_heads"` // n_head
+	KeyValHeads       int      `json:"num_key_value_heads"`
+	NormEPS           float64  `json:"rms_norm_eps"`
+	BoSTokenID        int      `json:"bos_token_id"`
+	EoSTokenID        int      `json:"eos_token_id"`
+	HeadDimension     int      `json:"head_dim"`
+	PaddingTokenID    int      `json:"pad_token_id"`
+	RopeFrequencyBase float64  `json:"rope_theta"`
+
+	Experts     int `json:"num_local_experts"`
+	ExpertsUsed int `json:"num_experts_per_tok"`
 
 
 	ByteOrder
 	ByteOrder
 }
 }
@@ -43,7 +48,7 @@ type ByteOrder interface {
 type ModelArch interface {
 type ModelArch interface {
 	GetTensors() error
 	GetTensors() error
 	LoadVocab() error
 	LoadVocab() error
-	WriteGGUF() (string, error)
+	WriteGGUF(io.WriteSeeker) error
 }
 }
 
 
 type ModelFormat interface {
 type ModelFormat interface {

+ 2 - 13
convert/gemma.go

@@ -94,7 +94,7 @@ func (m *GemmaModel) LoadVocab() error {
 	return nil
 	return nil
 }
 }
 
 
-func (m *GemmaModel) WriteGGUF() (string, error) {
+func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error {
 	kv := llm.KV{
 	kv := llm.KV{
 		"general.architecture":                   "gemma",
 		"general.architecture":                   "gemma",
 		"general.name":                           m.Name,
 		"general.name":                           m.Name,
@@ -122,16 +122,5 @@ func (m *GemmaModel) WriteGGUF() (string, error) {
 		"tokenizer.ggml.add_eos_token":    false,
 		"tokenizer.ggml.add_eos_token":    false,
 	}
 	}
 
 
-	f, err := os.CreateTemp("", "ollama-gguf")
-	if err != nil {
-		return "", err
-	}
-	defer f.Close()
-
-	mod := llm.NewGGUFV3(m.Params.ByteOrder)
-	if err := mod.Encode(f, kv, m.Tensors); err != nil {
-		return "", err
-	}
-
-	return f.Name(), nil
+	return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
 }
 }

+ 2 - 16
convert/llama.go

@@ -5,7 +5,6 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"log/slog"
 	"log/slog"
-	"os"
 	"regexp"
 	"regexp"
 	"strings"
 	"strings"
 
 
@@ -132,7 +131,7 @@ func (m *LlamaModel) LoadVocab() error {
 	return nil
 	return nil
 }
 }
 
 
-func (m *LlamaModel) WriteGGUF() (string, error) {
+func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
 	kv := llm.KV{
 	kv := llm.KV{
 		"general.architecture":                   "llama",
 		"general.architecture":                   "llama",
 		"general.name":                           m.Name,
 		"general.name":                           m.Name,
@@ -159,18 +158,5 @@ func (m *LlamaModel) WriteGGUF() (string, error) {
 		"tokenizer.ggml.add_eos_token":    false,
 		"tokenizer.ggml.add_eos_token":    false,
 	}
 	}
 
 
-	f, err := os.CreateTemp("", "ollama-gguf")
-	if err != nil {
-		return "", err
-	}
-	defer f.Close()
-
-	mod := llm.NewGGUFV3(m.Params.ByteOrder)
-	if err := mod.Encode(f, kv, m.Tensors); err != nil {
-		return "", err
-	}
-
-	slog.Debug(fmt.Sprintf("gguf file = %s", f.Name()))
-
-	return f.Name(), nil
+	return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
 }
 }

+ 2 - 13
convert/mistral.go

@@ -132,7 +132,7 @@ func (m *MistralModel) LoadVocab() error {
 	return nil
 	return nil
 }
 }
 
 
-func (m *MistralModel) WriteGGUF() (string, error) {
+func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error {
 	kv := llm.KV{
 	kv := llm.KV{
 		"general.architecture":                   "llama",
 		"general.architecture":                   "llama",
 		"general.name":                           m.Name,
 		"general.name":                           m.Name,
@@ -158,16 +158,5 @@ func (m *MistralModel) WriteGGUF() (string, error) {
 		"tokenizer.ggml.unknown_token_id": uint32(0),
 		"tokenizer.ggml.unknown_token_id": uint32(0),
 	}
 	}
 
 
-	f, err := os.CreateTemp("", "ollama-gguf")
-	if err != nil {
-		return "", err
-	}
-	defer f.Close()
-
-	mod := llm.NewGGUFV3(m.Params.ByteOrder)
-	if err := mod.Encode(f, kv, m.Tensors); err != nil {
-		return "", err
-	}
-
-	return f.Name(), nil
+	return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
 }
 }

+ 85 - 0
convert/mixtral.go

@@ -0,0 +1,85 @@
+package convert
+
+import (
+	"io"
+	"regexp"
+
+	"github.com/ollama/ollama/llm"
+)
+
+type MixtralModel struct {
+	ModelData
+}
+
+func (m *MixtralModel) GetTensors() error {
+	t, err := m.Format.GetTensors(m.Path, m.Params)
+	if err != nil {
+		return err
+	}
+
+	m.Tensors = []llm.Tensor{}
+
+	pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
+	re, err := regexp.Compile(pattern)
+	if err != nil {
+		return err
+	}
+
+	for _, l := range t {
+		matches := re.FindAllStringSubmatch(l.Name, -1)
+		if len(matches) > 0 {
+			wt := l.WriterTo.(safetensorWriterTo)
+			wt.handler = mistralLayerHandler
+			l.WriterTo = wt
+		}
+		m.Tensors = append(m.Tensors, l)
+	}
+
+	return nil
+}
+
+func (m *MixtralModel) LoadVocab() error {
+	v, err := LoadSentencePieceTokens(m.Path, m.Params)
+	if err != nil {
+		return err
+	}
+	m.Vocab = v
+	return nil
+}
+
+func (m *MixtralModel) WriteGGUF(ws io.WriteSeeker) error {
+	kv := llm.KV{
+		"general.architecture":          "llama",
+		"general.name":                  m.Name,
+		"llama.block_count":             uint32(m.Params.HiddenLayers),
+		"llama.context_length":          uint32(m.Params.ContextSize),
+		"llama.embedding_length":        uint32(m.Params.HiddenSize),
+		"llama.feed_forward_length":     uint32(m.Params.IntermediateSize),
+		"llama.attention.head_count":    uint32(m.Params.AttentionHeads),
+		"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
+
+		"llama.rope.freq_base":                   float32(m.Params.RopeFrequencyBase),
+		"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
+
+		"llama.expert_count":      uint32(m.Params.Experts),
+		"llama.expert_used_count": uint32(m.Params.ExpertsUsed),
+
+		"llama.vocab_size":           uint32(len(m.Vocab.Tokens)),
+		"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
+
+		"general.file_type":    uint32(1),
+		"tokenizer.ggml.model": "llama",
+
+		"tokenizer.ggml.tokens":     m.Vocab.Tokens,
+		"tokenizer.ggml.scores":     m.Vocab.Scores,
+		"tokenizer.ggml.token_type": m.Vocab.Types,
+
+		"tokenizer.ggml.bos_token_id":     uint32(m.Params.BoSTokenID),
+		"tokenizer.ggml.eos_token_id":     uint32(m.Params.EoSTokenID),
+		"tokenizer.ggml.unknown_token_id": uint32(0),
+		"tokenizer.ggml.add_bos_token":    true,
+		"tokenizer.ggml.add_eos_token":    false,
+	}
+
+	return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
+}

+ 27 - 14
convert/safetensors.go

@@ -53,7 +53,7 @@ func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Ten
 		var err error
 		var err error
 		t, offset, err = m.readTensors(f, offset, params)
 		t, offset, err = m.readTensors(f, offset, params)
 		if err != nil {
 		if err != nil {
-			slog.Error("%v", err)
+			slog.Error(err.Error())
 			return nil, err
 			return nil, err
 		}
 		}
 		tensors = append(tensors, t...)
 		tensors = append(tensors, t...)
@@ -93,7 +93,6 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
 	}
 	}
 
 
 	slices.Sort(keys)
 	slices.Sort(keys)
-
 	slog.Info("converting layers")
 	slog.Info("converting layers")
 
 
 	var tensors []llm.Tensor
 	var tensors []llm.Tensor
@@ -105,7 +104,6 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
 			return nil, 0, err
 			return nil, 0, err
 		}
 		}
 
 
-		slog.Debug(fmt.Sprintf("metadata = %#v", data))
 		var size uint64
 		var size uint64
 		var kind uint32
 		var kind uint32
 		switch len(data.Shape) {
 		switch len(data.Shape) {
@@ -124,7 +122,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
 
 
 		ggufName, err := m.GetLayerName(k)
 		ggufName, err := m.GetLayerName(k)
 		if err != nil {
 		if err != nil {
-			slog.Error("%v", err)
+			slog.Error(err.Error())
 			return nil, 0, err
 			return nil, 0, err
 		}
 		}
 
 
@@ -150,11 +148,13 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
 			padding:  8 + jsonSize,
 			padding:  8 + jsonSize,
 		}
 		}
 
 
-		tensors = append(tensors, t)
 		offset += size
 		offset += size
+		tensors = append(tensors, t)
 	}
 	}
+
 	slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors)))
 	slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors)))
 	slog.Debug(fmt.Sprintf("offset = %d", offset))
 	slog.Debug(fmt.Sprintf("offset = %d", offset))
+
 	return tensors, offset, nil
 	return tensors, offset, nil
 }
 }
 
 
@@ -185,15 +185,19 @@ func (m *SafetensorFormat) GetLayerName(n string) (string, error) {
 	}
 	}
 
 
 	tMap := map[string]string{
 	tMap := map[string]string{
-		"model.layers.(\\d+).input_layernorm.weight":          "blk.$1.attn_norm.weight",
-		"model.layers.(\\d+).mlp.down_proj.weight":            "blk.$1.ffn_down.weight",
-		"model.layers.(\\d+).mlp.gate_proj.weight":            "blk.$1.ffn_gate.weight",
-		"model.layers.(\\d+).mlp.up_proj.weight":              "blk.$1.ffn_up.weight",
-		"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
-		"model.layers.(\\d+).self_attn.k_proj.weight":         "blk.$1.attn_k.weight",
-		"model.layers.(\\d+).self_attn.o_proj.weight":         "blk.$1.attn_output.weight",
-		"model.layers.(\\d+).self_attn.q_proj.weight":         "blk.$1.attn_q.weight",
-		"model.layers.(\\d+).self_attn.v_proj.weight":         "blk.$1.attn_v.weight",
+		"model.layers.(\\d+).input_layernorm.weight":                    "blk.$1.attn_norm.weight",
+		"model.layers.(\\d+).mlp.down_proj.weight":                      "blk.$1.ffn_down.weight",
+		"model.layers.(\\d+).mlp.gate_proj.weight":                      "blk.$1.ffn_gate.weight",
+		"model.layers.(\\d+).mlp.up_proj.weight":                        "blk.$1.ffn_up.weight",
+		"model.layers.(\\d+).post_attention_layernorm.weight":           "blk.$1.ffn_norm.weight",
+		"model.layers.(\\d+).self_attn.k_proj.weight":                   "blk.$1.attn_k.weight",
+		"model.layers.(\\d+).self_attn.o_proj.weight":                   "blk.$1.attn_output.weight",
+		"model.layers.(\\d+).self_attn.q_proj.weight":                   "blk.$1.attn_q.weight",
+		"model.layers.(\\d+).self_attn.v_proj.weight":                   "blk.$1.attn_v.weight",
+		"model.layers.(\\d+).block_sparse_moe.gate.weight":              "blk.$1.ffn_gate_inp.weight",
+		"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w1.weight": "blk.$1.ffn_gate.$2.weight",
+		"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w2.weight": "blk.$1.ffn_down.$2.weight",
+		"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w3.weight": "blk.$1.ffn_up.$2.weight",
 	}
 	}
 
 
 	v, ok := directMap[n]
 	v, ok := directMap[n]
@@ -286,6 +290,15 @@ func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (M
 					Format: m,
 					Format: m,
 				},
 				},
 			}, nil
 			}, nil
+		case "MixtralForCausalLM":
+			return &MixtralModel{
+				ModelData{
+					Name:   name,
+					Path:   dirPath,
+					Params: params,
+					Format: m,
+				},
+			}, nil
 		case "GemmaForCausalLM":
 		case "GemmaForCausalLM":
 			return &GemmaModel{
 			return &GemmaModel{
 				ModelData{
 				ModelData{

+ 1 - 1
convert/torch.go

@@ -74,7 +74,7 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
 
 
 			ggufName, err := tf.GetLayerName(k.(string))
 			ggufName, err := tf.GetLayerName(k.(string))
 			if err != nil {
 			if err != nil {
-				slog.Error("%v", err)
+				slog.Error(err.Error())
 				return nil, err
 				return nil, err
 			}
 			}
 			slog.Debug(fmt.Sprintf("finding name for '%s' -> '%s'", k.(string), ggufName))
 			slog.Debug(fmt.Sprintf("finding name for '%s' -> '%s'", k.(string), ggufName))

+ 31 - 31
docs/api.md

@@ -17,7 +17,7 @@
 
 
 ### Model names
 ### Model names
 
 
-Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama2:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
+Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
 
 
 ### Durations
 ### Durations
 
 
@@ -66,7 +66,7 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur
 
 
 ```shell
 ```shell
 curl http://localhost:11434/api/generate -d '{
 curl http://localhost:11434/api/generate -d '{
-  "model": "llama2",
+  "model": "llama3",
   "prompt": "Why is the sky blue?"
   "prompt": "Why is the sky blue?"
 }'
 }'
 ```
 ```
@@ -77,7 +77,7 @@ A stream of JSON objects is returned:
 
 
 ```json
 ```json
 {
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-08-04T08:52:19.385406455-07:00",
   "created_at": "2023-08-04T08:52:19.385406455-07:00",
   "response": "The",
   "response": "The",
   "done": false
   "done": false
@@ -90,16 +90,16 @@ The final response in the stream also includes additional data about the generat
 - `load_duration`: time spent in nanoseconds loading the model
 - `load_duration`: time spent in nanoseconds loading the model
 - `prompt_eval_count`: number of tokens in the prompt
 - `prompt_eval_count`: number of tokens in the prompt
 - `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
 - `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
-- `eval_count`: number of tokens the response
+- `eval_count`: number of tokens in the response
 - `eval_duration`: time in nanoseconds spent generating the response
 - `eval_duration`: time in nanoseconds spent generating the response
 - `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
 - `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
 - `response`: empty if the response was streamed, if not streamed, this will contain the full response
 - `response`: empty if the response was streamed, if not streamed, this will contain the full response
 
 
-To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`.
+To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration` * `10^9`.
 
 
 ```json
 ```json
 {
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "response": "",
   "response": "",
   "done": true,
   "done": true,
@@ -121,7 +121,7 @@ A response can be received in one reply when streaming is off.
 
 
 ```shell
 ```shell
 curl http://localhost:11434/api/generate -d '{
 curl http://localhost:11434/api/generate -d '{
-  "model": "llama2",
+  "model": "llama3",
   "prompt": "Why is the sky blue?",
   "prompt": "Why is the sky blue?",
   "stream": false
   "stream": false
 }'
 }'
@@ -133,7 +133,7 @@ If `stream` is set to `false`, the response will be a single JSON object:
 
 
 ```json
 ```json
 {
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "response": "The sky is blue because it is the color of the sky.",
   "response": "The sky is blue because it is the color of the sky.",
   "done": true,
   "done": true,
@@ -155,7 +155,7 @@ If `stream` is set to `false`, the response will be a single JSON object:
 
 
 ```shell
 ```shell
 curl http://localhost:11434/api/generate -d '{
 curl http://localhost:11434/api/generate -d '{
-  "model": "llama2",
+  "model": "llama3",
   "prompt": "What color is the sky at different times of the day? Respond using JSON",
   "prompt": "What color is the sky at different times of the day? Respond using JSON",
   "format": "json",
   "format": "json",
   "stream": false
   "stream": false
@@ -166,7 +166,7 @@ curl http://localhost:11434/api/generate -d '{
 
 
 ```json
 ```json
 {
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-11-09T21:07:55.186497Z",
   "created_at": "2023-11-09T21:07:55.186497Z",
   "response": "{\n\"morning\": {\n\"color\": \"blue\"\n},\n\"noon\": {\n\"color\": \"blue-gray\"\n},\n\"afternoon\": {\n\"color\": \"warm gray\"\n},\n\"evening\": {\n\"color\": \"orange\"\n}\n}\n",
   "response": "{\n\"morning\": {\n\"color\": \"blue\"\n},\n\"noon\": {\n\"color\": \"blue-gray\"\n},\n\"afternoon\": {\n\"color\": \"warm gray\"\n},\n\"evening\": {\n\"color\": \"orange\"\n}\n}\n",
   "done": true,
   "done": true,
@@ -289,7 +289,7 @@ If you want to set custom options for the model at runtime rather than in the Mo
 
 
 ```shell
 ```shell
 curl http://localhost:11434/api/generate -d '{
 curl http://localhost:11434/api/generate -d '{
-  "model": "llama2",
+  "model": "llama3",
   "prompt": "Why is the sky blue?",
   "prompt": "Why is the sky blue?",
   "stream": false,
   "stream": false,
   "options": {
   "options": {
@@ -332,7 +332,7 @@ curl http://localhost:11434/api/generate -d '{
 
 
 ```json
 ```json
 {
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "response": "The sky is blue because it is the color of the sky.",
   "response": "The sky is blue because it is the color of the sky.",
   "done": true,
   "done": true,
@@ -354,7 +354,7 @@ If an empty prompt is provided, the model will be loaded into memory.
 
 
 ```shell
 ```shell
 curl http://localhost:11434/api/generate -d '{
 curl http://localhost:11434/api/generate -d '{
-  "model": "llama2"
+  "model": "llama3"
 }'
 }'
 ```
 ```
 
 
@@ -364,7 +364,7 @@ A single JSON object is returned:
 
 
 ```json
 ```json
 {
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-12-18T19:52:07.071755Z",
   "created_at": "2023-12-18T19:52:07.071755Z",
   "response": "",
   "response": "",
   "done": true
   "done": true
@@ -407,7 +407,7 @@ Send a chat message with a streaming response.
 
 
 ```shell
 ```shell
 curl http://localhost:11434/api/chat -d '{
 curl http://localhost:11434/api/chat -d '{
-  "model": "llama2",
+  "model": "llama3",
   "messages": [
   "messages": [
     {
     {
       "role": "user",
       "role": "user",
@@ -423,7 +423,7 @@ A stream of JSON objects is returned:
 
 
 ```json
 ```json
 {
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-08-04T08:52:19.385406455-07:00",
   "created_at": "2023-08-04T08:52:19.385406455-07:00",
   "message": {
   "message": {
     "role": "assistant",
     "role": "assistant",
@@ -438,7 +438,7 @@ Final response:
 
 
 ```json
 ```json
 {
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "done": true,
   "done": true,
   "total_duration": 4883583458,
   "total_duration": 4883583458,
@@ -456,7 +456,7 @@ Final response:
 
 
 ```shell
 ```shell
 curl http://localhost:11434/api/chat -d '{
 curl http://localhost:11434/api/chat -d '{
-  "model": "llama2",
+  "model": "llama3",
   "messages": [
   "messages": [
     {
     {
       "role": "user",
       "role": "user",
@@ -471,7 +471,7 @@ curl http://localhost:11434/api/chat -d '{
 
 
 ```json
 ```json
 {
 {
-  "model": "registry.ollama.ai/library/llama2:latest",
+  "model": "registry.ollama.ai/library/llama3:latest",
   "created_at": "2023-12-12T14:13:43.416799Z",
   "created_at": "2023-12-12T14:13:43.416799Z",
   "message": {
   "message": {
     "role": "assistant",
     "role": "assistant",
@@ -495,7 +495,7 @@ Send a chat message with a conversation history. You can use this same approach
 
 
 ```shell
 ```shell
 curl http://localhost:11434/api/chat -d '{
 curl http://localhost:11434/api/chat -d '{
-  "model": "llama2",
+  "model": "llama3",
   "messages": [
   "messages": [
     {
     {
       "role": "user",
       "role": "user",
@@ -519,7 +519,7 @@ A stream of JSON objects is returned:
 
 
 ```json
 ```json
 {
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-08-04T08:52:19.385406455-07:00",
   "created_at": "2023-08-04T08:52:19.385406455-07:00",
   "message": {
   "message": {
     "role": "assistant",
     "role": "assistant",
@@ -533,7 +533,7 @@ Final response:
 
 
 ```json
 ```json
 {
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "done": true,
   "done": true,
   "total_duration": 8113331500,
   "total_duration": 8113331500,
@@ -591,7 +591,7 @@ curl http://localhost:11434/api/chat -d '{
 
 
 ```shell
 ```shell
 curl http://localhost:11434/api/chat -d '{
 curl http://localhost:11434/api/chat -d '{
-  "model": "llama2",
+  "model": "llama3",
   "messages": [
   "messages": [
     {
     {
       "role": "user",
       "role": "user",
@@ -609,7 +609,7 @@ curl http://localhost:11434/api/chat -d '{
 
 
 ```json
 ```json
 {
 {
-  "model": "registry.ollama.ai/library/llama2:latest",
+  "model": "registry.ollama.ai/library/llama3:latest",
   "created_at": "2023-12-12T14:13:43.416799Z",
   "created_at": "2023-12-12T14:13:43.416799Z",
   "message": {
   "message": {
     "role": "assistant",
     "role": "assistant",
@@ -651,7 +651,7 @@ Create a new model from a `Modelfile`.
 ```shell
 ```shell
 curl http://localhost:11434/api/create -d '{
 curl http://localhost:11434/api/create -d '{
   "name": "mario",
   "name": "mario",
-  "modelfile": "FROM llama2\nSYSTEM You are mario from Super Mario Bros."
+  "modelfile": "FROM llama3\nSYSTEM You are mario from Super Mario Bros."
 }'
 }'
 ```
 ```
 
 
@@ -758,7 +758,7 @@ A single JSON object will be returned.
       }
       }
     },
     },
     {
     {
-      "name": "llama2:latest",
+      "name": "llama3:latest",
       "modified_at": "2023-12-07T09:32:18.757212583-08:00",
       "modified_at": "2023-12-07T09:32:18.757212583-08:00",
       "size": 3825819519,
       "size": 3825819519,
       "digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
       "digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
@@ -792,7 +792,7 @@ Show information about a model including details, modelfile, template, parameter
 
 
 ```shell
 ```shell
 curl http://localhost:11434/api/show -d '{
 curl http://localhost:11434/api/show -d '{
-  "name": "llama2"
+  "name": "llama3"
 }'
 }'
 ```
 ```
 
 
@@ -827,8 +827,8 @@ Copy a model. Creates a model with another name from an existing model.
 
 
 ```shell
 ```shell
 curl http://localhost:11434/api/copy -d '{
 curl http://localhost:11434/api/copy -d '{
-  "source": "llama2",
-  "destination": "llama2-backup"
+  "source": "llama3",
+  "destination": "llama3-backup"
 }'
 }'
 ```
 ```
 
 
@@ -854,7 +854,7 @@ Delete a model and its data.
 
 
 ```shell
 ```shell
 curl -X DELETE http://localhost:11434/api/delete -d '{
 curl -X DELETE http://localhost:11434/api/delete -d '{
-  "name": "llama2:13b"
+  "name": "llama3:13b"
 }'
 }'
 ```
 ```
 
 
@@ -882,7 +882,7 @@ Download a model from the ollama library. Cancelled pulls are resumed from where
 
 
 ```shell
 ```shell
 curl http://localhost:11434/api/pull -d '{
 curl http://localhost:11434/api/pull -d '{
-  "name": "llama2"
+  "name": "llama3"
 }'
 }'
 ```
 ```
 
 

+ 2 - 2
docs/development.md

@@ -51,7 +51,7 @@ Typically the build scripts will auto-detect CUDA, however, if your Linux distro
 or installation approach uses unusual paths, you can specify the location by
 or installation approach uses unusual paths, you can specify the location by
 specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
 specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
 libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
 libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
-set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
+a set of target CUDA architectures by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
 
 
 Then generate dependencies:
 Then generate dependencies:
 
 
@@ -142,4 +142,4 @@ In addition to the common Windows development tools described above, install AMD
 - [AMD HIP](https://www.amd.com/en/developer/resources/rocm-hub/hip-sdk.html)
 - [AMD HIP](https://www.amd.com/en/developer/resources/rocm-hub/hip-sdk.html)
 - [Strawberry Perl](https://strawberryperl.com/)
 - [Strawberry Perl](https://strawberryperl.com/)
 
 
-Lastly, add `ninja.exe` included with MSVC to the system path (e.g. `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja`).
+Lastly, add `ninja.exe` included with MSVC to the system path (e.g. `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja`).

+ 16 - 6
docs/faq.md

@@ -32,7 +32,7 @@ When using the API, specify the `num_ctx` parameter:
 
 
 ```
 ```
 curl http://localhost:11434/api/generate -d '{
 curl http://localhost:11434/api/generate -d '{
-  "model": "llama2",
+  "model": "llama3",
   "prompt": "Why is the sky blue?",
   "prompt": "Why is the sky blue?",
   "options": {
   "options": {
     "num_ctx": 4096
     "num_ctx": 4096
@@ -88,9 +88,9 @@ On windows, Ollama inherits your user and system environment variables.
 
 
 3. Edit or create New variable(s) for your user account for `OLLAMA_HOST`, `OLLAMA_MODELS`, etc.
 3. Edit or create New variable(s) for your user account for `OLLAMA_HOST`, `OLLAMA_MODELS`, etc.
 
 
-4. Click OK/Apply to save 
+4. Click OK/Apply to save
 
 
-5. Run `ollama` from a new terminal window 
+5. Run `ollama` from a new terminal window
 
 
 
 
 ## How can I expose Ollama on my network?
 ## How can I expose Ollama on my network?
@@ -140,7 +140,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e
 
 
 - macOS: `~/.ollama/models`
 - macOS: `~/.ollama/models`
 - Linux: `/usr/share/ollama/.ollama/models`
 - Linux: `/usr/share/ollama/.ollama/models`
-- Windows: `C:\Users\<username>\.ollama\models`
+- Windows: `C:\Users\%username%\.ollama\models`
 
 
 ### How do I set them to a different location?
 ### How do I set them to a different location?
 
 
@@ -221,10 +221,20 @@ The `keep_alive` parameter can be set to:
 
 
 For example, to preload a model and leave it in memory use:
 For example, to preload a model and leave it in memory use:
 ```shell
 ```shell
-curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": -1}'
+curl http://localhost:11434/api/generate -d '{"model": "llama3", "keep_alive": -1}'
 ```
 ```
 
 
 To unload the model and free up memory use:
 To unload the model and free up memory use:
 ```shell
 ```shell
-curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": 0}'
+curl http://localhost:11434/api/generate -d '{"model": "llama3", "keep_alive": 0}'
 ```
 ```
+
+Alternatively, you can change the amount of time all models are loaded into memory by setting the `OLLAMA_KEEP_ALIVE` environment variable when starting the Ollama server. The `OLLAMA_KEEP_ALIVE` variable uses the same parameter types as the `keep_alive` parameter types mentioned above. Refer to section explaining [how to configure the Ollama server](#how-do-i-configure-ollama-server) to correctly set the environment variable.
+
+If you wish to override the `OLLAMA_KEEP_ALIVE` setting, use the `keep_alive` API parameter with the `/api/generate` or `/api/chat` API endpoints.
+
+## How do I manage the maximum number of requests the server can queue
+
+If too many requests are sent to the server, it will respond with a 503 error
+indicating the server is overloaded.  You can adjust how many requests may be
+queue by setting `OLLAMA_MAX_QUEUE`

+ 3 - 1
docs/import.md

@@ -125,7 +125,7 @@ Publishing models is in early alpha. If you'd like to publish your model to shar
 
 
 1. Create [an account](https://ollama.com/signup)
 1. Create [an account](https://ollama.com/signup)
 2. Copy your Ollama public key:
 2. Copy your Ollama public key:
-  - macOS: `cat ~/.ollama/id_ed25519.pub`
+  - macOS: `cat ~/.ollama/id_ed25519.pub | pbcopy`
   - Windows: `type %USERPROFILE%\.ollama\id_ed25519.pub`
   - Windows: `type %USERPROFILE%\.ollama\id_ed25519.pub`
   - Linux: `cat /usr/share/ollama/.ollama/id_ed25519.pub`
   - Linux: `cat /usr/share/ollama/.ollama/id_ed25519.pub`
 3. Add your public key to your [Ollama account](https://ollama.com/settings/keys)
 3. Add your public key to your [Ollama account](https://ollama.com/settings/keys)
@@ -136,6 +136,8 @@ Next, copy your model to your username's namespace:
 ollama cp example <your username>/example
 ollama cp example <your username>/example
 ```
 ```
 
 
+> Note: model names may only contain lowercase letters, digits, and the characters `.`, `-`, and `_`.
+
 Then push the model:
 Then push the model:
 
 
 ```
 ```

+ 1 - 1
docs/linux.md

@@ -105,7 +105,7 @@ sudo chmod +x /usr/bin/ollama
 To view logs of Ollama running as a startup service, run:
 To view logs of Ollama running as a startup service, run:
 
 
 ```bash
 ```bash
-journalctl -u ollama
+journalctl -e -u ollama
 ```
 ```
 
 
 ## Uninstall
 ## Uninstall

+ 17 - 25
docs/modelfile.md

@@ -10,7 +10,7 @@ A model file is the blueprint to create and share models with Ollama.
 - [Examples](#examples)
 - [Examples](#examples)
 - [Instructions](#instructions)
 - [Instructions](#instructions)
   - [FROM (Required)](#from-required)
   - [FROM (Required)](#from-required)
-    - [Build from llama2](#build-from-llama2)
+    - [Build from llama3](#build-from-llama3)
     - [Build from a bin file](#build-from-a-bin-file)
     - [Build from a bin file](#build-from-a-bin-file)
   - [PARAMETER](#parameter)
   - [PARAMETER](#parameter)
     - [Valid Parameters and Values](#valid-parameters-and-values)
     - [Valid Parameters and Values](#valid-parameters-and-values)
@@ -48,7 +48,7 @@ INSTRUCTION arguments
 An example of a `Modelfile` creating a mario blueprint:
 An example of a `Modelfile` creating a mario blueprint:
 
 
 ```modelfile
 ```modelfile
-FROM llama2
+FROM llama3
 # sets the temperature to 1 [higher is more creative, lower is more coherent]
 # sets the temperature to 1 [higher is more creative, lower is more coherent]
 PARAMETER temperature 1
 PARAMETER temperature 1
 # sets the context window size to 4096, this controls how many tokens the LLM can use as context to generate the next token
 # sets the context window size to 4096, this controls how many tokens the LLM can use as context to generate the next token
@@ -67,33 +67,25 @@ To use this:
 
 
 More examples are available in the [examples directory](../examples).
 More examples are available in the [examples directory](../examples).
 
 
-### `Modelfile`s in [ollama.com/library][1]
-
-There are two ways to view `Modelfile`s underlying the models in [ollama.com/library][1]:
-
-- Option 1: view a details page from a model's tags page:
-  1.  Go to a particular model's tags (e.g. https://ollama.com/library/llama2/tags)
-  2.  Click on a tag (e.g. https://ollama.com/library/llama2:13b)
-  3.  Scroll down to "Layers"
-      - Note: if the [`FROM` instruction](#from-required) is not present,
-        it means the model was created from a local file
-- Option 2: use `ollama show` to print the `Modelfile` for any local models like so:
+To view the Modelfile of a given model, use the `ollama show --modelfile` command.
 
 
   ```bash
   ```bash
-  > ollama show --modelfile llama2:13b
+  > ollama show --modelfile llama3
   # Modelfile generated by "ollama show"
   # Modelfile generated by "ollama show"
   # To build a new Modelfile based on this one, replace the FROM line with:
   # To build a new Modelfile based on this one, replace the FROM line with:
-  # FROM llama2:13b
+  # FROM llama3:latest
+  FROM /Users/pdevine/.ollama/models/blobs/sha256-00e1317cbf74d901080d7100f57580ba8dd8de57203072dc6f668324ba545f29
+  TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
+
+  {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
 
 
-  FROM /root/.ollama/models/blobs/sha256:123abc
-  TEMPLATE """[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>>
+  {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
 
 
-  {{ end }}{{ .Prompt }} [/INST] """
-  SYSTEM """"""
-  PARAMETER stop [INST]
-  PARAMETER stop [/INST]
-  PARAMETER stop <<SYS>>
-  PARAMETER stop <</SYS>>
+  {{ .Response }}<|eot_id|>"""
+  PARAMETER stop "<|start_header_id|>"
+  PARAMETER stop "<|end_header_id|>"
+  PARAMETER stop "<|eot_id|>"
+  PARAMETER stop "<|reserved_special_token"
   ```
   ```
 
 
 ## Instructions
 ## Instructions
@@ -106,10 +98,10 @@ The `FROM` instruction defines the base model to use when creating a model.
 FROM <model name>:<tag>
 FROM <model name>:<tag>
 ```
 ```
 
 
-#### Build from llama2
+#### Build from llama3
 
 
 ```modelfile
 ```modelfile
-FROM llama2
+FROM llama3
 ```
 ```
 
 
 A list of available base models:
 A list of available base models:

+ 5 - 5
docs/openai.md

@@ -25,7 +25,7 @@ chat_completion = client.chat.completions.create(
             'content': 'Say this is a test',
             'content': 'Say this is a test',
         }
         }
     ],
     ],
-    model='llama2',
+    model='llama3',
 )
 )
 ```
 ```
 
 
@@ -43,7 +43,7 @@ const openai = new OpenAI({
 
 
 const chatCompletion = await openai.chat.completions.create({
 const chatCompletion = await openai.chat.completions.create({
   messages: [{ role: 'user', content: 'Say this is a test' }],
   messages: [{ role: 'user', content: 'Say this is a test' }],
-  model: 'llama2',
+  model: 'llama3',
 })
 })
 ```
 ```
 
 
@@ -53,7 +53,7 @@ const chatCompletion = await openai.chat.completions.create({
 curl http://localhost:11434/v1/chat/completions \
 curl http://localhost:11434/v1/chat/completions \
     -H "Content-Type: application/json" \
     -H "Content-Type: application/json" \
     -d '{
     -d '{
-        "model": "llama2",
+        "model": "llama3",
         "messages": [
         "messages": [
             {
             {
                 "role": "system",
                 "role": "system",
@@ -113,7 +113,7 @@ curl http://localhost:11434/v1/chat/completions \
 Before using a model, pull it locally `ollama pull`:
 Before using a model, pull it locally `ollama pull`:
 
 
 ```shell
 ```shell
-ollama pull llama2
+ollama pull llama3
 ```
 ```
 
 
 ### Default model names
 ### Default model names
@@ -121,7 +121,7 @@ ollama pull llama2
 For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
 For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
 
 
 ```
 ```
-ollama cp llama2 gpt-3.5-turbo
+ollama cp llama3 gpt-3.5-turbo
 ```
 ```
 
 
 Afterwards, this new model name can be specified the `model` field:
 Afterwards, this new model name can be specified the `model` field:

+ 3 - 3
docs/tutorials/langchainjs.md

@@ -15,7 +15,7 @@ import { Ollama } from "langchain/llms/ollama";
 
 
 const ollama = new Ollama({
 const ollama = new Ollama({
   baseUrl: "http://localhost:11434",
   baseUrl: "http://localhost:11434",
-  model: "llama2",
+  model: "llama3",
 });
 });
 
 
 const answer = await ollama.invoke(`why is the sky blue?`);
 const answer = await ollama.invoke(`why is the sky blue?`);
@@ -23,10 +23,10 @@ const answer = await ollama.invoke(`why is the sky blue?`);
 console.log(answer);
 console.log(answer);
 ```
 ```
 
 
-That will get us the same thing as if we ran `ollama run llama2 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app.
+That will get us the same thing as if we ran `ollama run llama3 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app.
 
 
 ```bash
 ```bash
-npm install cheerio 
+npm install cheerio
 ```
 ```
 
 
 ```javascript
 ```javascript

+ 7 - 5
docs/tutorials/langchainpy.md

@@ -12,15 +12,17 @@ So let's figure out how we can use **LangChain** with Ollama to ask our question
 
 
 Let's start by asking a simple question that we can get an answer to from the **Llama2** model using **Ollama**. First, we need to install the **LangChain** package:
 Let's start by asking a simple question that we can get an answer to from the **Llama2** model using **Ollama**. First, we need to install the **LangChain** package:
 
 
-`pip install langchain`
+`pip install langchain_community`
 
 
 Then we can create a model and ask the question:
 Then we can create a model and ask the question:
 
 
 ```python
 ```python
-from langchain.llms import Ollama
-ollama = Ollama(base_url='http://localhost:11434',
-model="llama2")
-print(ollama("why is the sky blue"))
+from langchain_community.llms import Ollama
+ollama = Ollama(
+    base_url='http://localhost:11434',
+    model="llama3"
+)
+print(ollama.invoke("why is the sky blue"))
 ```
 ```
 
 
 Notice that we are defining the model and the base URL for Ollama.
 Notice that we are defining the model and the base URL for Ollama.

+ 6 - 29
docs/tutorials/nvidia-jetson.md

@@ -1,38 +1,15 @@
 # Running Ollama on NVIDIA Jetson Devices
 # Running Ollama on NVIDIA Jetson Devices
 
 
-With some minor configuration, Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/). The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack).
+Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/) and should run out of the box with the standard installation instructions. 
 
 
-NVIDIA Jetson devices are Linux-based embedded AI computers that are purpose-built for AI applications.
-
-Jetsons have an integrated GPU that is wired directly to the memory controller of the machine. For this reason, the `nvidia-smi` command is unrecognized, and Ollama proceeds to operate in "CPU only"
-mode. This can be verified by using a monitoring tool like jtop.
-
-In order to address this, we simply pass the path to the Jetson's pre-installed CUDA libraries into `ollama serve` (while in a tmux session). We then hardcode the num_gpu parameters into a cloned
-version of our target model.
-
-Prerequisites:
-
-- curl
-- tmux
-
-Here are the steps:
+The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack), but should also work on JetPack 6.0.
 
 
 - Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.com/install.sh | sh`
 - Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.com/install.sh | sh`
-- Stop the Ollama service: `sudo systemctl stop ollama`
-- Start Ollama serve in a tmux session called ollama_jetson and reference the CUDA libraries path: `tmux has-session -t ollama_jetson 2>/dev/null || tmux new-session -d -s ollama_jetson 
-'LD_LIBRARY_PATH=/usr/local/cuda/lib64 ollama serve'`
 - Pull the model you want to use (e.g. mistral): `ollama pull mistral`
 - Pull the model you want to use (e.g. mistral): `ollama pull mistral`
-- Create a new Modelfile specifically for enabling GPU support on the Jetson: `touch ModelfileMistralJetson`
-- In the ModelfileMistralJetson file, specify the FROM model and the num_gpu PARAMETER as shown below:
-
-```
-FROM mistral
-PARAMETER num_gpu 999
-```
+- Start an interactive session: `ollama run mistral`
 
 
-- Create a new model from your Modelfile: `ollama create mistral-jetson -f ./ModelfileMistralJetson`
-- Run the new model: `ollama run mistral-jetson`
+And that's it!
 
 
-If you run a monitoring tool like jtop you should now see that Ollama is using the Jetson's integrated GPU.
+# Running Ollama in Docker
 
 
-And that's it!
+When running GPU accelerated applications in Docker, it is highly recommended to use [dusty-nv jetson-containers repo](https://github.com/dusty-nv/jetson-containers).

+ 61 - 47
docs/windows.md

@@ -1,47 +1,61 @@
-# Ollama Windows Preview
-
-Welcome to the Ollama Windows preview.
-
-No more WSL required!
-
-Ollama now runs as a native Windows application, including NVIDIA and AMD Radeon GPU support.
-After installing Ollama Windows Preview, Ollama will run in the background and
-the `ollama` command line is available in `cmd`, `powershell` or your favorite
-terminal application. As usual the Ollama [api](./api.md) will be served on
-`http://localhost:11434`.
-
-As this is a preview release, you should expect a few bugs here and there.  If
-you run into a problem you can reach out on
-[Discord](https://discord.gg/ollama), or file an 
-[issue](https://github.com/ollama/ollama/issues).
-Logs will often be helpful in dianosing the problem (see
-[Troubleshooting](#troubleshooting) below)
-
-## System Requirements
-
-* Windows 10 or newer, Home or Pro
-* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card
-* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card
-
-## API Access
-
-Here's a quick example showing API access from `powershell`
-```powershell
-(Invoke-WebRequest -method POST -Body '{"model":"llama2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
-```
-
-## Troubleshooting
-
-While we're in preview, `OLLAMA_DEBUG` is always enabled, which adds
-a "view logs" menu item to the app, and increses logging for the GUI app and
-server.
-
-Ollama on Windows stores files in a few different locations.  You can view them in
-the explorer window by hitting `<cmd>+R` and type in:
-- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
-    - *app.log* contains logs from the GUI application
-    - *server.log* contains the server logs
-    - *upgrade.log* contains log output for upgrades
-- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
-- `explorer %HOMEPATH%\.ollama` contains models and configuration
-- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
+# Ollama Windows Preview
+
+Welcome to the Ollama Windows preview.
+
+No more WSL required!
+
+Ollama now runs as a native Windows application, including NVIDIA and AMD Radeon GPU support.
+After installing Ollama Windows Preview, Ollama will run in the background and
+the `ollama` command line is available in `cmd`, `powershell` or your favorite
+terminal application. As usual the Ollama [api](./api.md) will be served on
+`http://localhost:11434`.
+
+As this is a preview release, you should expect a few bugs here and there.  If
+you run into a problem you can reach out on
+[Discord](https://discord.gg/ollama), or file an
+[issue](https://github.com/ollama/ollama/issues).
+Logs will often be helpful in diagnosing the problem (see
+[Troubleshooting](#troubleshooting) below)
+
+## System Requirements
+
+* Windows 10 or newer, Home or Pro
+* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card
+* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card
+
+## API Access
+
+Here's a quick example showing API access from `powershell`
+```powershell
+(Invoke-WebRequest -method POST -Body '{"model":"llama3", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
+```
+
+## Troubleshooting
+
+While we're in preview, `OLLAMA_DEBUG` is always enabled, which adds
+a "view logs" menu item to the app, and increses logging for the GUI app and
+server.
+
+Ollama on Windows stores files in a few different locations.  You can view them in
+the explorer window by hitting `<cmd>+R` and type in:
+- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
+    - *app.log* contains logs from the GUI application
+    - *server.log* contains the server logs
+    - *upgrade.log* contains log output for upgrades
+- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
+- `explorer %HOMEPATH%\.ollama` contains models and configuration
+- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
+
+
+## Standalone CLI
+
+The easiest way to install Ollama on Windows is to use the `OllamaSetup.exe`
+installer. It installs in your account without requiring Administrator rights.
+We update Ollama regularly to support the latest models, and this installer will
+help you keep up to date.
+
+If you'd like to install or integrate Ollama as a service, a standalone
+`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
+and GPU library dependencies for Nvidia and AMD. This allows for embedding
+Ollama in existing applications, or running it as a system service via `ollama
+serve` with tools such as [NSSM](https://nssm.cc/).

+ 1 - 1
examples/bash-comparemodels/README.md

@@ -2,7 +2,7 @@
 
 
 When calling `ollama`, you can pass it a file to run all the prompts in the file, one after the other:
 When calling `ollama`, you can pass it a file to run all the prompts in the file, one after the other:
 
 
-`ollama run llama2 < sourcequestions.txt`
+`ollama run llama3 < sourcequestions.txt`
 
 
 This concept is used in the following example.
 This concept is used in the following example.
 
 

+ 1 - 0
examples/flyio/.gitignore

@@ -0,0 +1 @@
+fly.toml

+ 67 - 0
examples/flyio/README.md

@@ -0,0 +1,67 @@
+# Deploy Ollama to Fly.io
+
+> Note: this example exposes a public endpoint and does not configure authentication. Use with care.
+
+## Prerequisites
+
+- Ollama: https://ollama.com/download
+- Fly.io account. Sign up for a free account: https://fly.io/app/sign-up
+
+## Steps
+
+1. Login to Fly.io
+
+    ```bash
+    fly auth login
+    ```
+
+1. Create a new Fly app
+
+    ```bash
+    fly launch --name <name> --image ollama/ollama --internal-port 11434 --vm-size shared-cpu-8x --now
+    ```
+
+1. Pull and run `orca-mini:3b`
+
+    ```bash
+    OLLAMA_HOST=https://<name>.fly.dev ollama run orca-mini:3b
+    ```
+
+`shared-cpu-8x` is a free-tier eligible machine type. For better performance, switch to a `performance` or `dedicated` machine type or attach a GPU for hardware acceleration (see below).
+
+## (Optional) Persistent Volume
+
+By default Fly Machines use ephemeral storage which is problematic if you want to use the same model across restarts without pulling it again. Create and attach a persistent volume to store the downloaded models:
+
+1. Create the Fly Volume
+
+    ```bash
+    fly volume create ollama
+    ```
+
+1. Update `fly.toml` and add `[mounts]`
+
+    ```toml
+    [mounts]
+      source = "ollama"
+      destination = "/mnt/ollama/models"
+    ```
+
+1. Update `fly.toml` and add `[env]`
+
+    ```toml
+    [env]
+      OLLAMA_MODELS = "/mnt/ollama/models"
+    ```
+
+1. Deploy your app
+
+    ```bash
+    fly deploy
+    ```
+
+## (Optional) Hardware Acceleration
+
+Fly.io GPU is currently in waitlist. Sign up for the waitlist: https://fly.io/gpu
+
+Once you've been accepted, create the app with the additional flags `--vm-gpu-kind a100-pcie-40gb` or `--vm-gpu-kind a100-pcie-80gb`.

+ 1 - 1
examples/go-chat/main.go

@@ -35,7 +35,7 @@ func main() {
 
 
 	ctx := context.Background()
 	ctx := context.Background()
 	req := &api.ChatRequest{
 	req := &api.ChatRequest{
-		Model:    "llama2",
+		Model:    "llama3",
 		Messages: messages,
 		Messages: messages,
 	}
 	}
 
 

+ 1 - 1
examples/go-http-generate/main.go

@@ -19,7 +19,7 @@ func main() {
 	}
 	}
 
 
 	defer resp.Body.Close()
 	defer resp.Body.Close()
-	
+
 	responseData, err := io.ReadAll(resp.Body)
 	responseData, err := io.ReadAll(resp.Body)
 	if err != nil {
 	if err != nil {
 		log.Fatal(err)
 		log.Fatal(err)

+ 14 - 12
examples/kubernetes/README.md

@@ -7,12 +7,24 @@
 
 
 ## Steps
 ## Steps
 
 
-1. Create the Ollama namespace, daemon set, and service
+1. Create the Ollama namespace, deployment, and service
 
 
    ```bash
    ```bash
    kubectl apply -f cpu.yaml
    kubectl apply -f cpu.yaml
    ```
    ```
 
 
+## (Optional) Hardware Acceleration
+
+Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin) which is deployed in Kubernetes in form of daemonset. Follow the link for more details.
+
+Once configured, create a GPU enabled Ollama deployment.
+
+```bash
+kubectl apply -f gpu.yaml
+```
+
+## Test
+
 1. Port forward the Ollama service to connect and use it locally
 1. Port forward the Ollama service to connect and use it locally
 
 
    ```bash
    ```bash
@@ -23,14 +35,4 @@
 
 
    ```bash
    ```bash
    ollama run orca-mini:3b
    ollama run orca-mini:3b
-   ```
-
-## (Optional) Hardware Acceleration
-
-Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin). Follow the link for more details.
-
-Once configured, create a GPU enabled Ollama deployment.
-
-```bash
-kubectl apply -f gpu.yaml
-```
+   ```

+ 5 - 5
examples/langchain-python-rag-document/main.py

@@ -40,9 +40,9 @@ while True:
         continue
         continue
 
 
     # Prompt
     # Prompt
-    template = """Use the following pieces of context to answer the question at the end. 
-    If you don't know the answer, just say that you don't know, don't try to make up an answer. 
-    Use three sentences maximum and keep the answer as concise as possible. 
+    template = """Use the following pieces of context to answer the question at the end.
+    If you don't know the answer, just say that you don't know, don't try to make up an answer.
+    Use three sentences maximum and keep the answer as concise as possible.
     {context}
     {context}
     Question: {question}
     Question: {question}
     Helpful Answer:"""
     Helpful Answer:"""
@@ -51,11 +51,11 @@ while True:
         template=template,
         template=template,
     )
     )
 
 
-    llm = Ollama(model="llama2:13b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
+    llm = Ollama(model="llama3:8b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
     qa_chain = RetrievalQA.from_chain_type(
     qa_chain = RetrievalQA.from_chain_type(
         llm,
         llm,
         retriever=vectorstore.as_retriever(),
         retriever=vectorstore.as_retriever(),
         chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
         chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
     )
     )
 
 
-    result = qa_chain({"query": query})
+    result = qa_chain({"query": query})

+ 4 - 4
examples/langchain-python-rag-websummary/main.py

@@ -1,12 +1,12 @@
-from langchain.llms import Ollama
-from langchain.document_loaders import WebBaseLoader
+from langchain_community.llms import Ollama
+from langchain_community.document_loaders import WebBaseLoader
 from langchain.chains.summarize import load_summarize_chain
 from langchain.chains.summarize import load_summarize_chain
 
 
 loader = WebBaseLoader("https://ollama.com/blog/run-llama2-uncensored-locally")
 loader = WebBaseLoader("https://ollama.com/blog/run-llama2-uncensored-locally")
 docs = loader.load()
 docs = loader.load()
 
 
-llm = Ollama(model="llama2")
+llm = Ollama(model="llama3")
 chain = load_summarize_chain(llm, chain_type="stuff")
 chain = load_summarize_chain(llm, chain_type="stuff")
 
 
-result = chain.run(docs)
+result = chain.invoke(docs) 
 print(result)
 print(result)

+ 2 - 3
examples/langchain-python-simple/README.md

@@ -4,10 +4,10 @@ This example is a basic "hello world" of using LangChain with Ollama.
 
 
 ## Running the Example
 ## Running the Example
 
 
-1. Ensure you have the `llama2` model installed:
+1. Ensure you have the `llama3` model installed:
 
 
    ```bash
    ```bash
-   ollama pull llama2
+   ollama pull llama3
    ```
    ```
 
 
 2. Install the Python Requirements.
 2. Install the Python Requirements.
@@ -21,4 +21,3 @@ This example is a basic "hello world" of using LangChain with Ollama.
    ```bash
    ```bash
    python main.py
    python main.py
    ```
    ```
-  

+ 1 - 1
examples/langchain-python-simple/main.py

@@ -1,6 +1,6 @@
 from langchain.llms import Ollama
 from langchain.llms import Ollama
 
 
 input = input("What is your question?")
 input = input("What is your question?")
-llm = Ollama(model="llama2")
+llm = Ollama(model="llama3")
 res = llm.predict(input)
 res = llm.predict(input)
 print (res)
 print (res)

+ 1 - 1
examples/modelfile-mario/Modelfile

@@ -1,4 +1,4 @@
-FROM llama2
+FROM llama3
 PARAMETER temperature 1
 PARAMETER temperature 1
 SYSTEM """
 SYSTEM """
 You are Mario from super mario bros, acting as an assistant.
 You are Mario from super mario bros, acting as an assistant.

+ 3 - 3
examples/modelfile-mario/readme.md

@@ -2,12 +2,12 @@
 
 
 # Example character: Mario
 # Example character: Mario
 
 
-This example shows how to create a basic character using Llama2 as the base model.
+This example shows how to create a basic character using Llama3 as the base model.
 
 
 To run this example:
 To run this example:
 
 
 1. Download the Modelfile
 1. Download the Modelfile
-2. `ollama pull llama2` to get the base model used in the model file.
+2. `ollama pull llama3` to get the base model used in the model file.
 3. `ollama create NAME -f ./Modelfile`
 3. `ollama create NAME -f ./Modelfile`
 4. `ollama run NAME`
 4. `ollama run NAME`
 
 
@@ -18,7 +18,7 @@ Ask it some questions like "Who are you?" or "Is Peach in trouble again?"
 What the model file looks like:
 What the model file looks like:
 
 
 ```
 ```
-FROM llama2
+FROM llama3
 PARAMETER temperature 1
 PARAMETER temperature 1
 SYSTEM """
 SYSTEM """
 You are Mario from Super Mario Bros, acting as an assistant.
 You are Mario from Super Mario Bros, acting as an assistant.

+ 7 - 7
examples/python-json-datagenerator/predefinedschema.py

@@ -2,16 +2,16 @@ import requests
 import json
 import json
 import random
 import random
 
 
-model = "llama2"
+model = "llama3"
 template = {
 template = {
-  "firstName": "", 
-  "lastName": "", 
+  "firstName": "",
+  "lastName": "",
   "address": {
   "address": {
-    "street": "", 
-    "city": "", 
-    "state": "", 
+    "street": "",
+    "city": "",
+    "state": "",
     "zipCode": ""
     "zipCode": ""
-  }, 
+  },
   "phoneNumber": ""
   "phoneNumber": ""
 }
 }
 
 

+ 1 - 1
examples/python-json-datagenerator/randomaddresses.py

@@ -12,7 +12,7 @@ countries = [
     "France",
     "France",
 ]
 ]
 country = random.choice(countries)
 country = random.choice(countries)
-model = "llama2"
+model = "llama3"
 
 
 prompt = f"generate one realistically believable sample data set of a persons first name, last name, address in {country}, and phone number. Do not use common names. Respond using JSON. Key names should have no backslashes, values should use plain ascii with no special characters."
 prompt = f"generate one realistically believable sample data set of a persons first name, last name, address in {country}, and phone number. Do not use common names. Respond using JSON. Key names should have no backslashes, values should use plain ascii with no special characters."
 
 

+ 2 - 2
examples/python-json-datagenerator/readme.md

@@ -6,10 +6,10 @@ There are two python scripts in this example. `randomaddresses.py` generates ran
 
 
 ## Running the Example
 ## Running the Example
 
 
-1. Ensure you have the `llama2` model installed:
+1. Ensure you have the `llama3` model installed:
 
 
    ```bash
    ```bash
-   ollama pull llama2
+   ollama pull llama3
    ```
    ```
 
 
 2. Install the Python Requirements.
 2. Install the Python Requirements.

+ 1 - 1
examples/python-simplechat/client.py

@@ -2,7 +2,7 @@ import json
 import requests
 import requests
 
 
 # NOTE: ollama must be running for this to work, start the ollama app or run `ollama serve`
 # NOTE: ollama must be running for this to work, start the ollama app or run `ollama serve`
-model = "llama2"  # TODO: update this for whatever model you wish to use
+model = "llama3"  # TODO: update this for whatever model you wish to use
 
 
 
 
 def chat(messages):
 def chat(messages):

+ 2 - 2
examples/python-simplechat/readme.md

@@ -4,10 +4,10 @@ The **chat** endpoint is one of two ways to generate text from an LLM with Ollam
 
 
 ## Running the Example
 ## Running the Example
 
 
-1. Ensure you have the `llama2` model installed:
+1. Ensure you have the `llama3` model installed:
 
 
    ```bash
    ```bash
-   ollama pull llama2
+   ollama pull llama3
    ```
    ```
 
 
 2. Install the Python Requirements.
 2. Install the Python Requirements.

+ 2 - 2
examples/typescript-mentors/README.md

@@ -4,10 +4,10 @@ This example demonstrates how one would create a set of 'mentors' you can have a
 
 
 ## Usage
 ## Usage
 
 
-1. Add llama2 to have the mentors ask your questions:
+1. Add llama3 to have the mentors ask your questions:
 
 
    ```bash
    ```bash
-   ollama pull llama2
+   ollama pull llama3
    ```
    ```
 
 
 2. Install prerequisites:
 2. Install prerequisites:

+ 2 - 2
examples/typescript-mentors/character-generator.ts

@@ -15,7 +15,7 @@ async function characterGenerator() {
   ollama.setModel("stablebeluga2:70b-q4_K_M");
   ollama.setModel("stablebeluga2:70b-q4_K_M");
   const bio = await ollama.generate(`create a bio of ${character} in a single long paragraph. Instead of saying '${character} is...' or '${character} was...' use language like 'You are...' or 'You were...'. Then create a paragraph describing the speaking mannerisms and style of ${character}. Don't include anything about how ${character} looked or what they sounded like, just focus on the words they said. Instead of saying '${character} would say...' use language like 'You should say...'. If you use quotes, always use single quotes instead of double quotes. If there are any specific words or phrases you used a lot, show how you used them. `);
   const bio = await ollama.generate(`create a bio of ${character} in a single long paragraph. Instead of saying '${character} is...' or '${character} was...' use language like 'You are...' or 'You were...'. Then create a paragraph describing the speaking mannerisms and style of ${character}. Don't include anything about how ${character} looked or what they sounded like, just focus on the words they said. Instead of saying '${character} would say...' use language like 'You should say...'. If you use quotes, always use single quotes instead of double quotes. If there are any specific words or phrases you used a lot, show how you used them. `);
 
 
-  const thecontents = `FROM llama2\nSYSTEM """\n${bio.response.replace(/(\r\n|\n|\r)/gm, " ").replace('would', 'should')} All answers to questions should be related back to what you are most known for.\n"""`;
+  const thecontents = `FROM llama3\nSYSTEM """\n${bio.response.replace(/(\r\n|\n|\r)/gm, " ").replace('would', 'should')} All answers to questions should be related back to what you are most known for.\n"""`;
 
 
   fs.writeFile(path.join(directory, 'Modelfile'), thecontents, (err: any) => {
   fs.writeFile(path.join(directory, 'Modelfile'), thecontents, (err: any) => {
     if (err) throw err;
     if (err) throw err;
@@ -23,4 +23,4 @@ async function characterGenerator() {
   });
   });
 }
 }
 
 
-characterGenerator();
+characterGenerator();

+ 2 - 2
examples/typescript-simplechat/client.ts

@@ -1,6 +1,6 @@
 import * as readline from "readline";
 import * as readline from "readline";
 
 
-const model = "llama2";
+const model = "llama3";
 type Message = {
 type Message = {
   role: "assistant" | "user" | "system";
   role: "assistant" | "user" | "system";
   content: string;
   content: string;
@@ -74,4 +74,4 @@ async function main() {
 
 
 }
 }
 
 
-main();
+main();

+ 3 - 0
format/bytes.go

@@ -15,6 +15,7 @@ const (
 
 
 	KibiByte = Byte * 1024
 	KibiByte = Byte * 1024
 	MebiByte = KibiByte * 1024
 	MebiByte = KibiByte * 1024
+	GibiByte = MebiByte * 1024
 )
 )
 
 
 func HumanBytes(b int64) string {
 func HumanBytes(b int64) string {
@@ -52,6 +53,8 @@ func HumanBytes(b int64) string {
 
 
 func HumanBytes2(b uint64) string {
 func HumanBytes2(b uint64) string {
 	switch {
 	switch {
+	case b >= GibiByte:
+		return fmt.Sprintf("%.1f GiB", float64(b)/GibiByte)
 	case b >= MebiByte:
 	case b >= MebiByte:
 		return fmt.Sprintf("%.1f MiB", float64(b)/MebiByte)
 		return fmt.Sprintf("%.1f MiB", float64(b)/MebiByte)
 	case b >= KibiByte:
 	case b >= KibiByte:

+ 14 - 6
format/format.go

@@ -13,12 +13,20 @@ const (
 
 
 func HumanNumber(b uint64) string {
 func HumanNumber(b uint64) string {
 	switch {
 	switch {
-	case b > Billion:
-		return fmt.Sprintf("%.0fB", math.Round(float64(b)/Billion))
-	case b > Million:
-		return fmt.Sprintf("%.0fM", math.Round(float64(b)/Million))
-	case b > Thousand:
-		return fmt.Sprintf("%.0fK", math.Round(float64(b)/Thousand))
+	case b >= Billion:
+		number := float64(b) / Billion
+		if number == math.Floor(number) {
+			return fmt.Sprintf("%.0fB", number) // no decimals if whole number
+		}
+		return fmt.Sprintf("%.1fB", number) // one decimal if not a whole number
+	case b >= Million:
+		number := float64(b) / Million
+		if number == math.Floor(number) {
+			return fmt.Sprintf("%.0fM", number) // no decimals if whole number
+		}
+		return fmt.Sprintf("%.2fM", number) // two decimals if not a whole number
+	case b >= Thousand:
+		return fmt.Sprintf("%.0fK", float64(b)/Thousand)
 	default:
 	default:
 		return fmt.Sprintf("%d", b)
 		return fmt.Sprintf("%d", b)
 	}
 	}

+ 34 - 0
format/format_test.go

@@ -0,0 +1,34 @@
+package format
+
+import (
+	"testing"
+)
+
+func TestHumanNumber(t *testing.T) {
+
+	type testCase struct {
+		input    uint64
+		expected string
+	}
+
+	testCases := []testCase{
+		{0, "0"},
+		{1000000, "1M"},
+		{125000000, "125M"},
+		{500500000, "500.50M"},
+		{500550000, "500.55M"},
+		{1000000000, "1B"},
+		{2800000000, "2.8B"},
+		{2850000000, "2.9B"},
+		{1000000000000, "1000B"},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.expected, func(t *testing.T) {
+			result := HumanNumber(tc.input)
+			if result != tc.expected {
+				t.Errorf("Expected %s, got %s", tc.expected, result)
+			}
+		})
+	}
+}

+ 58 - 14
gpu/amd_common.go

@@ -7,7 +7,7 @@ import (
 	"log/slog"
 	"log/slog"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
-	"strconv"
+	"runtime"
 	"strings"
 	"strings"
 )
 )
 
 
@@ -35,22 +35,66 @@ func GetSupportedGFX(libDir string) ([]string, error) {
 	return ret, nil
 	return ret, nil
 }
 }
 
 
-func amdSetVisibleDevices(ids []int, skip map[int]interface{}) {
-	// Set the visible devices if not already set
-	// TODO - does sort order matter?
-	devices := []string{}
-	for i := range ids {
-		if _, skipped := skip[i]; skipped {
+func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
+	ids := []string{}
+	for _, info := range gpuInfo {
+		if info.Library != "rocm" {
+			// TODO shouldn't happen if things are wired correctly...
+			slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
 			continue
 			continue
 		}
 		}
-		devices = append(devices, strconv.Itoa(i))
+		ids = append(ids, info.ID)
 	}
 	}
+	return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
+}
 
 
-	val := strings.Join(devices, ",")
-	err := os.Setenv("HIP_VISIBLE_DEVICES", val)
-	if err != nil {
-		slog.Warn(fmt.Sprintf("failed to set env: %s", err))
-	} else {
-		slog.Info("Setting HIP_VISIBLE_DEVICES=" + val)
+func commonAMDValidateLibDir() (string, error) {
+	// We try to favor system paths first, so that we can wire up the subprocess to use
+	// the system version.  Only use our bundled version if the system version doesn't work
+	// This gives users a more recovery options if versions have subtle problems at runtime
+
+	// Prefer explicit HIP env var
+	hipPath := os.Getenv("HIP_PATH")
+	if hipPath != "" {
+		hipLibDir := filepath.Join(hipPath, "bin")
+		if rocmLibUsable(hipLibDir) {
+			slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
+			return hipLibDir, nil
+		}
+	}
+
+	// Scan the LD_LIBRARY_PATH or PATH
+	pathEnv := "LD_LIBRARY_PATH"
+	if runtime.GOOS == "windows" {
+		pathEnv = "PATH"
+	}
+
+	paths := os.Getenv(pathEnv)
+	for _, path := range filepath.SplitList(paths) {
+		d, err := filepath.Abs(path)
+		if err != nil {
+			continue
+		}
+		if rocmLibUsable(d) {
+			return d, nil
+		}
+	}
+
+	// Well known location(s)
+	for _, path := range RocmStandardLocations {
+		if rocmLibUsable(path) {
+			return path, nil
+		}
+	}
+
+	// Installer payload location if we're running the installed binary
+	exe, err := os.Executable()
+	if err == nil {
+		rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
+		if rocmLibUsable(rocmTargetDir) {
+			slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
+			return rocmTargetDir, nil
+		}
 	}
 	}
+	return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
 }
 }

+ 2 - 2
gpu/amd_hip_windows.go

@@ -69,7 +69,7 @@ func NewHipLib() (*HipLib, error) {
 func (hl *HipLib) Release() {
 func (hl *HipLib) Release() {
 	err := windows.FreeLibrary(hl.dll)
 	err := windows.FreeLibrary(hl.dll)
 	if err != nil {
 	if err != nil {
-		slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err))
+		slog.Warn("failed to unload amdhip64.dll", "error", err)
 	}
 	}
 	hl.dll = 0
 	hl.dll = 0
 }
 }
@@ -98,7 +98,7 @@ func (hl *HipLib) HipGetDeviceCount() int {
 		return 0
 		return 0
 	}
 	}
 	if status != hipSuccess {
 	if status != hipSuccess {
-		slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err))
+		slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
 	}
 	}
 	return count
 	return count
 }
 }

+ 176 - 287
gpu/amd_linux.go

@@ -11,6 +11,8 @@ import (
 	"slices"
 	"slices"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
+
+	"github.com/ollama/ollama/format"
 )
 )
 
 
 // Discovery logic for AMD/ROCm GPUs
 // Discovery logic for AMD/ROCm GPUs
@@ -23,26 +25,20 @@ const (
 	// Prefix with the node dir
 	// Prefix with the node dir
 	GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
 	GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
 	GPUUsedMemoryFileGlob  = "mem_banks/*/used_memory"
 	GPUUsedMemoryFileGlob  = "mem_banks/*/used_memory"
-	RocmStandardLocation   = "/opt/rocm/lib"
-
-	// TODO find a better way to detect iGPU instead of minimum memory
-	IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU
 )
 )
 
 
 var (
 var (
 	// Used to validate if the given ROCm lib is usable
 	// Used to validate if the given ROCm lib is usable
-	ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
+	ROCmLibGlobs          = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
+	RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"}
 )
 )
 
 
 // Gather GPU information from the amdgpu driver if any supported GPUs are detected
 // Gather GPU information from the amdgpu driver if any supported GPUs are detected
-// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
-// and the user hasn't already set this variable
-func AMDGetGPUInfo(resp *GpuInfo) {
-	// TODO - DRY this out with windows
+func AMDGetGPUInfo() []GpuInfo {
+	resp := []GpuInfo{}
 	if !AMDDetected() {
 	if !AMDDetected() {
-		return
+		return resp
 	}
 	}
-	skip := map[int]interface{}{}
 
 
 	// Opportunistic logging of driver version to aid in troubleshooting
 	// Opportunistic logging of driver version to aid in troubleshooting
 	ver, err := AMDDriverVersion()
 	ver, err := AMDDriverVersion()
@@ -50,160 +46,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
 		slog.Info("AMD Driver: " + ver)
 		slog.Info("AMD Driver: " + ver)
 	} else {
 	} else {
 		// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
 		// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
-		slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
+		slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
 	}
 	}
 
 
-	// If the user has specified exactly which GPUs to use, look up their memory
-	visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
-	if visibleDevices != "" {
-		ids := []int{}
-		for _, idStr := range strings.Split(visibleDevices, ",") {
-			id, err := strconv.Atoi(idStr)
-			if err != nil {
-				slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
-			} else {
-				ids = append(ids, id)
-			}
-		}
-		amdProcMemLookup(resp, nil, ids)
-		return
+	// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
+	var visibleDevices []string
+	hipVD := os.Getenv("HIP_VISIBLE_DEVICES")   // zero based index only
+	rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID
+	gpuDO := os.Getenv("GPU_DEVICE_ORDINAL")    // zero based index
+	switch {
+	// TODO is this priorty order right?
+	case hipVD != "":
+		visibleDevices = strings.Split(hipVD, ",")
+	case rocrVD != "":
+		visibleDevices = strings.Split(rocrVD, ",")
+		// TODO - since we don't yet support UUIDs, consider detecting and reporting here
+		// all our test systems show GPU-XX indicating UUID is not supported
+	case gpuDO != "":
+		visibleDevices = strings.Split(gpuDO, ",")
 	}
 	}
 
 
-	// Gather GFX version information from all detected cards
-	gfx := AMDGFXVersions()
-	verStrings := []string{}
-	for i, v := range gfx {
-		verStrings = append(verStrings, v.ToGFXString())
-		if v.Major == 0 {
-			// Silently skip CPUs
-			skip[i] = struct{}{}
+	gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
+	var supported []string
+	libDir := ""
+
+	// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
+	// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
+	matches, _ := filepath.Glob(GPUPropertiesFileGlob)
+	cpuCount := 0
+	for _, match := range matches {
+		slog.Debug("evaluating amdgpu node " + match)
+		fp, err := os.Open(match)
+		if err != nil {
+			slog.Debug("failed to open sysfs node", "file", match, "error", err)
 			continue
 			continue
 		}
 		}
-		if v.Major < 9 {
-			// TODO consider this a build-time setting if we can support 8xx family GPUs
-			slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
-			skip[i] = struct{}{}
+		defer fp.Close()
+		nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
+		if err != nil {
+			slog.Debug("failed to parse node ID", "error", err)
+			continue
 		}
 		}
-	}
-	slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
 
 
-	// Abort if all GPUs are skipped
-	if len(skip) >= len(gfx) {
-		slog.Info("all detected amdgpus are skipped, falling back to CPU")
-		return
-	}
+		scanner := bufio.NewScanner(fp)
+		isCPU := false
+		var major, minor, patch uint64
+		for scanner.Scan() {
+			line := strings.TrimSpace(scanner.Text())
+			// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
+			if strings.HasPrefix(line, "gfx_target_version") {
+				ver := strings.Fields(line)
 
 
-	// If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
-	libDir, err := AMDValidateLibDir()
-	if err != nil {
-		slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
-		return
-	}
+				// Detect CPUs
+				if len(ver) == 2 && ver[1] == "0" {
+					slog.Debug("detected CPU " + match)
+					isCPU = true
+					break
+				}
 
 
-	updateLibPath(libDir)
+				if len(ver) != 2 || len(ver[1]) < 5 {
+					slog.Warn("malformed "+match, "gfx_target_version", line)
+					// If this winds up being a CPU, our offsets may be wrong
+					continue
+				}
+				l := len(ver[1])
+				var err1, err2, err3 error
+				patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
+				minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
+				major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
+				if err1 != nil || err2 != nil || err3 != nil {
+					slog.Debug("malformed int " + line)
+					continue
+				}
+			}
 
 
-	gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
-	if gfxOverride == "" {
-		supported, err := GetSupportedGFX(libDir)
-		if err != nil {
-			slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
-			return
+			// TODO - any other properties we want to extract and record?
+			// vendor_id + device_id -> pci lookup for "Name"
+			// Other metrics that may help us understand relative performance between multiple GPUs
 		}
 		}
-		slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
 
 
-		for i, v := range gfx {
-			if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
-				slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
-				// TODO - consider discrete markdown just for ROCM troubleshooting?
-				slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
-				skip[i] = struct{}{}
-			} else {
-				slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
-			}
+		if isCPU {
+			cpuCount++
+			continue
 		}
 		}
-	} else {
-		slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
-	}
 
 
-	if len(skip) >= len(gfx) {
-		slog.Info("all detected amdgpus are skipped, falling back to CPU")
-		return
-	}
+		// CPUs are always first in the list
+		gpuID := nodeID - cpuCount
 
 
-	ids := make([]int, len(gfx))
-	i := 0
-	for k := range gfx {
-		ids[i] = k
-		i++
-	}
-	amdProcMemLookup(resp, skip, ids)
-	if resp.memInfo.DeviceCount == 0 {
-		return
-	}
-	if len(skip) > 0 {
-		amdSetVisibleDevices(ids, skip)
-	}
-}
-
-func updateLibPath(libDir string) {
-	ldPaths := []string{}
-	if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
-		ldPaths = strings.Split(val, ":")
-	}
-	for _, d := range ldPaths {
-		if d == libDir {
-			return
-		}
-	}
-	val := strings.Join(append(ldPaths, libDir), ":")
-	slog.Debug("updated lib path", "LD_LIBRARY_PATH", val)
-	os.Setenv("LD_LIBRARY_PATH", val)
-}
-
-// Walk the sysfs nodes for the available GPUs and gather information from them
-// skipping over any devices in the skip map
-func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
-	resp.memInfo.DeviceCount = 0
-	resp.memInfo.TotalMemory = 0
-	resp.memInfo.FreeMemory = 0
-	slog.Debug("discovering VRAM for amdgpu devices")
-	if len(ids) == 0 {
-		entries, err := os.ReadDir(AMDNodesSysfsDir)
-		if err != nil {
-			slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
-			return
-		}
-		for _, node := range entries {
-			if !node.IsDir() {
-				continue
-			}
-			id, err := strconv.Atoi(node.Name())
-			if err != nil {
-				slog.Warn("malformed amdgpu sysfs node id " + node.Name())
-				continue
-			}
-			ids = append(ids, id)
+		// Shouldn't happen, but just in case...
+		if gpuID < 0 {
+			slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
+			return []GpuInfo{}
 		}
 		}
-	}
-	slog.Debug(fmt.Sprintf("amdgpu devices %v", ids))
 
 
-	for _, id := range ids {
-		if _, skipped := skip[id]; skipped {
+		if int(major) < RocmComputeMin {
+			slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%d%x", major, minor, patch), "gpu", gpuID)
 			continue
 			continue
 		}
 		}
+
+		// Look up the memory for the current node
 		totalMemory := uint64(0)
 		totalMemory := uint64(0)
 		usedMemory := uint64(0)
 		usedMemory := uint64(0)
-		// Adjust for sysfs vs HIP ids
-		propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob)
+		propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
 		propFiles, err := filepath.Glob(propGlob)
 		propFiles, err := filepath.Glob(propGlob)
 		if err != nil {
 		if err != nil {
-			slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
+			slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
 		}
 		}
 		// 1 or more memory banks - sum the values of all of them
 		// 1 or more memory banks - sum the values of all of them
 		for _, propFile := range propFiles {
 		for _, propFile := range propFiles {
 			fp, err := os.Open(propFile)
 			fp, err := os.Open(propFile)
 			if err != nil {
 			if err != nil {
-				slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
+				slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
 				continue
 				continue
 			}
 			}
 			defer fp.Close()
 			defer fp.Close()
@@ -226,49 +179,113 @@ func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
 			}
 			}
 		}
 		}
 		if totalMemory == 0 {
 		if totalMemory == 0 {
-			slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id))
-			skip[id] = struct{}{}
+			slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
 			continue
 			continue
 		}
 		}
-		if totalMemory < IGPUMemLimit {
-			slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024))
-			skip[id] = struct{}{}
-			continue
-		}
-		usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
+		usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
 		usedFiles, err := filepath.Glob(usedGlob)
 		usedFiles, err := filepath.Glob(usedGlob)
 		if err != nil {
 		if err != nil {
-			slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
+			slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
 			continue
 			continue
 		}
 		}
 		for _, usedFile := range usedFiles {
 		for _, usedFile := range usedFiles {
 			fp, err := os.Open(usedFile)
 			fp, err := os.Open(usedFile)
 			if err != nil {
 			if err != nil {
-				slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
+				slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
 				continue
 				continue
 			}
 			}
 			defer fp.Close()
 			defer fp.Close()
 			data, err := io.ReadAll(fp)
 			data, err := io.ReadAll(fp)
 			if err != nil {
 			if err != nil {
-				slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
+				slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
 				continue
 				continue
 			}
 			}
 			used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
 			used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
 			if err != nil {
 			if err != nil {
-				slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
+				slog.Warn("malformed used memory", "data", string(data), "error", err)
 				continue
 				continue
 			}
 			}
 			usedMemory += used
 			usedMemory += used
 		}
 		}
-		slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024))
-		slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory  %dM", id, (totalMemory-usedMemory)/1024/1024))
-		resp.memInfo.DeviceCount++
-		resp.memInfo.TotalMemory += totalMemory
-		resp.memInfo.FreeMemory += (totalMemory - usedMemory)
+
+		// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
+		if totalMemory < IGPUMemLimit {
+			slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
+			continue
+		}
+
+		slog.Info("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
+		slog.Info("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
+		gpuInfo := GpuInfo{
+			Library: "rocm",
+			memInfo: memInfo{
+				TotalMemory: totalMemory,
+				FreeMemory:  (totalMemory - usedMemory),
+			},
+			ID: fmt.Sprintf("%d", gpuID),
+			// Name: not exposed in sysfs directly, would require pci device id lookup
+			Major:         int(major),
+			Minor:         int(minor),
+			Patch:         int(patch),
+			MinimumMemory: rocmMinimumMemory,
+		}
+
+		// If the user wants to filter to a subset of devices, filter out if we aren't a match
+		if len(visibleDevices) > 0 {
+			include := false
+			for _, visible := range visibleDevices {
+				if visible == gpuInfo.ID {
+					include = true
+					break
+				}
+			}
+			if !include {
+				slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices)
+				continue
+			}
+		}
+
+		// Final validation is gfx compatibility - load the library if we haven't already loaded it
+		// even if the user overrides, we still need to validate the library
+		if libDir == "" {
+			libDir, err = AMDValidateLibDir()
+			if err != nil {
+				slog.Warn("unable to verify rocm library, will use cpu", "error", err)
+				return []GpuInfo{}
+			}
+		}
+		gpuInfo.DependencyPath = libDir
+
+		if gfxOverride == "" {
+			// Only load supported list once
+			if len(supported) == 0 {
+				supported, err = GetSupportedGFX(libDir)
+				if err != nil {
+					slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
+					return []GpuInfo{}
+				}
+				slog.Debug("rocm supported GPUs", "types", supported)
+			}
+			gfx := fmt.Sprintf("gfx%d%d%x", gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch)
+			if !slices.Contains[[]string, string](supported, gfx) {
+				slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
+				// TODO - consider discrete markdown just for ROCM troubleshooting?
+				slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
+				continue
+			} else {
+				slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
+			}
+		} else {
+			slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
+		}
+
+		// The GPU has passed all the verification steps and is supported
+		resp = append(resp, gpuInfo)
 	}
 	}
-	if resp.memInfo.DeviceCount > 0 {
-		resp.Library = "rocm"
+	if len(resp) == 0 {
+		slog.Info("no compatible amdgpu devices detected")
 	}
 	}
+	return resp
 }
 }
 
 
 // Quick check for AMD driver so we can skip amdgpu discovery if not present
 // Quick check for AMD driver so we can skip amdgpu discovery if not present
@@ -280,87 +297,24 @@ func AMDDetected() bool {
 		slog.Debug("amdgpu driver not detected " + sysfsDir)
 		slog.Debug("amdgpu driver not detected " + sysfsDir)
 		return false
 		return false
 	} else if err != nil {
 	} else if err != nil {
-		slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
+		slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
 		return false
 		return false
 	}
 	}
 	return true
 	return true
 }
 }
 
 
-func setupLink(source, target string) error {
-	if err := os.RemoveAll(target); err != nil {
-		return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
-	}
-	if err := os.Symlink(source, target); err != nil {
-		return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
-	}
-	slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
-	return nil
-}
-
-// Ensure the AMD rocm lib dir is wired up
 // Prefer to use host installed ROCm, as long as it meets our minimum requirements
 // Prefer to use host installed ROCm, as long as it meets our minimum requirements
 // failing that, tell the user how to download it on their own
 // failing that, tell the user how to download it on their own
 func AMDValidateLibDir() (string, error) {
 func AMDValidateLibDir() (string, error) {
-	// We rely on the rpath compiled into our library to find rocm
-	// so we establish a symlink to wherever we find it on the system
-	// to <payloads>/rocm
-	payloadsDir, err := PayloadsDir()
-	if err != nil {
-		return "", err
-	}
-
-	// If we already have a rocm dependency wired, nothing more to do
-	rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm"))
-	if rocmLibUsable(rocmTargetDir) {
-		return rocmTargetDir, nil
-	}
-
-	// next to the running binary
-	exe, err := os.Executable()
+	libDir, err := commonAMDValidateLibDir()
 	if err == nil {
 	if err == nil {
-		peerDir := filepath.Dir(exe)
-		if rocmLibUsable(peerDir) {
-			slog.Debug("detected ROCM next to ollama executable " + peerDir)
-			return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
-		}
-		peerDir = filepath.Join(filepath.Dir(exe), "rocm")
-		if rocmLibUsable(peerDir) {
-			slog.Debug("detected ROCM next to ollama executable " + peerDir)
-			return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
-		}
+		return libDir, nil
 	}
 	}
 
 
 	// Well known ollama installer path
 	// Well known ollama installer path
 	installedRocmDir := "/usr/share/ollama/lib/rocm"
 	installedRocmDir := "/usr/share/ollama/lib/rocm"
 	if rocmLibUsable(installedRocmDir) {
 	if rocmLibUsable(installedRocmDir) {
-		return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir)
-	}
-
-	// Prefer explicit HIP env var
-	hipPath := os.Getenv("HIP_PATH")
-	if hipPath != "" {
-		hipLibDir := filepath.Join(hipPath, "lib")
-		if rocmLibUsable(hipLibDir) {
-			slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
-			return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
-		}
-	}
-
-	// Scan the library path for potential matches
-	ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
-	for _, ldPath := range ldPaths {
-		d, err := filepath.Abs(ldPath)
-		if err != nil {
-			continue
-		}
-		if rocmLibUsable(d) {
-			return rocmTargetDir, setupLink(d, rocmTargetDir)
-		}
-	}
-
-	// Well known location(s)
-	if rocmLibUsable("/opt/rocm/lib") {
-		return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
+		return installedRocmDir, nil
 	}
 	}
 
 
 	// If we still haven't found a usable rocm, the user will have to install it on their own
 	// If we still haven't found a usable rocm, the user will have to install it on their own
@@ -384,68 +338,3 @@ func AMDDriverVersion() (string, error) {
 	}
 	}
 	return strings.TrimSpace(string(verString)), nil
 	return strings.TrimSpace(string(verString)), nil
 }
 }
-
-func AMDGFXVersions() map[int]Version {
-	// The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one
-	// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
-	res := map[int]Version{}
-	matches, _ := filepath.Glob(GPUPropertiesFileGlob)
-	for _, match := range matches {
-		fp, err := os.Open(match)
-		if err != nil {
-			slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
-			continue
-		}
-		defer fp.Close()
-		i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
-		if err != nil {
-			slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
-			continue
-		}
-
-		if i == 0 {
-			// Skipping the CPU
-			continue
-		}
-		// Align with HIP IDs (zero is first GPU, not CPU)
-		i -= 1
-
-		scanner := bufio.NewScanner(fp)
-		for scanner.Scan() {
-			line := strings.TrimSpace(scanner.Text())
-			if strings.HasPrefix(line, "gfx_target_version") {
-				ver := strings.Fields(line)
-				if len(ver) != 2 || len(ver[1]) < 5 {
-					if ver[1] != "0" {
-						slog.Debug("malformed " + line)
-					}
-					res[i] = Version{
-						Major: 0,
-						Minor: 0,
-						Patch: 0,
-					}
-					continue
-				}
-				l := len(ver[1])
-				patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
-				minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
-				major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
-				if err1 != nil || err2 != nil || err3 != nil {
-					slog.Debug("malformed int " + line)
-					continue
-				}
-
-				res[i] = Version{
-					Major: uint(major),
-					Minor: uint(minor),
-					Patch: uint(patch),
-				}
-			}
-		}
-	}
-	return res
-}
-
-func (v Version) ToGFXString() string {
-	return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
-}

+ 85 - 76
gpu/amd_windows.go

@@ -7,11 +7,13 @@ import (
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"slices"
 	"slices"
+	"strconv"
 	"strings"
 	"strings"
+
+	"github.com/ollama/ollama/format"
 )
 )
 
 
 const (
 const (
-	RocmStandardLocation = "C:\\Program Files\\AMD\\ROCm\\5.7\\bin" // TODO glob?
 
 
 	// TODO  We're lookinng for this exact name to detect iGPUs since hipGetDeviceProperties never reports integrated==true
 	// TODO  We're lookinng for this exact name to detect iGPUs since hipGetDeviceProperties never reports integrated==true
 	iGPUName = "AMD Radeon(TM) Graphics"
 	iGPUName = "AMD Radeon(TM) Graphics"
@@ -19,39 +21,36 @@ const (
 
 
 var (
 var (
 	// Used to validate if the given ROCm lib is usable
 	// Used to validate if the given ROCm lib is usable
-	ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
+	ROCmLibGlobs          = []string{"hipblas.dll", "rocblas"}                 // TODO - probably include more coverage of files here...
+	RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
 )
 )
 
 
-func AMDGetGPUInfo(resp *GpuInfo) {
+func AMDGetGPUInfo() []GpuInfo {
+	resp := []GpuInfo{}
 	hl, err := NewHipLib()
 	hl, err := NewHipLib()
 	if err != nil {
 	if err != nil {
 		slog.Debug(err.Error())
 		slog.Debug(err.Error())
-		return
+		return nil
 	}
 	}
 	defer hl.Release()
 	defer hl.Release()
-	skip := map[int]interface{}{}
-	ids := []int{}
-	resp.memInfo.DeviceCount = 0
-	resp.memInfo.TotalMemory = 0
-	resp.memInfo.FreeMemory = 0
 
 
 	ver, err := hl.AMDDriverVersion()
 	ver, err := hl.AMDDriverVersion()
 	if err == nil {
 	if err == nil {
 		slog.Info("AMD Driver: " + ver)
 		slog.Info("AMD Driver: " + ver)
 	} else {
 	} else {
 		// For now this is benign, but we may eventually need to fail compatibility checks
 		// For now this is benign, but we may eventually need to fail compatibility checks
-		slog.Debug(fmt.Sprintf("error looking up amd driver version: %s", err))
+		slog.Debug("error looking up amd driver version", "error", err)
 	}
 	}
 
 
-	// Note: the HIP library automatically handles HIP_VISIBLE_DEVICES
+	// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
 	count := hl.HipGetDeviceCount()
 	count := hl.HipGetDeviceCount()
 	if count == 0 {
 	if count == 0 {
-		return
+		return nil
 	}
 	}
 	libDir, err := AMDValidateLibDir()
 	libDir, err := AMDValidateLibDir()
 	if err != nil {
 	if err != nil {
-		slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
-		return
+		slog.Warn("unable to verify rocm library, will use cpu", "error", err)
+		return nil
 	}
 	}
 
 
 	var supported []string
 	var supported []string
@@ -59,95 +58,120 @@ func AMDGetGPUInfo(resp *GpuInfo) {
 	if gfxOverride == "" {
 	if gfxOverride == "" {
 		supported, err = GetSupportedGFX(libDir)
 		supported, err = GetSupportedGFX(libDir)
 		if err != nil {
 		if err != nil {
-			slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
-			return
+			slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
+			return nil
 		}
 		}
 	} else {
 	} else {
 		slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
 		slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
 	}
 	}
 
 
-	slog.Info(fmt.Sprintf("detected %d hip devices", count))
+	slog.Info("detected hip devices", "count", count)
+	// TODO how to determine the underlying device ID when visible devices is causing this to subset?
 	for i := 0; i < count; i++ {
 	for i := 0; i < count; i++ {
-		ids = append(ids, i)
 		err = hl.HipSetDevice(i)
 		err = hl.HipSetDevice(i)
 		if err != nil {
 		if err != nil {
-			slog.Warn(fmt.Sprintf("[%d] %s", i, err))
-			skip[i] = struct{}{}
+			slog.Warn("set device", "id", i, "error", err)
 			continue
 			continue
 		}
 		}
 
 
 		props, err := hl.HipGetDeviceProperties(i)
 		props, err := hl.HipGetDeviceProperties(i)
 		if err != nil {
 		if err != nil {
-			slog.Warn(fmt.Sprintf("[%d] %s", i, err))
-			skip[i] = struct{}{}
+			slog.Warn("get properties", "id", i, "error", err)
 			continue
 			continue
 		}
 		}
 		n := bytes.IndexByte(props.Name[:], 0)
 		n := bytes.IndexByte(props.Name[:], 0)
 		name := string(props.Name[:n])
 		name := string(props.Name[:n])
-		slog.Info(fmt.Sprintf("[%d] Name: %s", i, name))
+		// TODO is UUID actually populated on windows?
+		// Can luid be used on windows for setting visible devices (and is it actually set?)
 		n = bytes.IndexByte(props.GcnArchName[:], 0)
 		n = bytes.IndexByte(props.GcnArchName[:], 0)
 		gfx := string(props.GcnArchName[:n])
 		gfx := string(props.GcnArchName[:n])
-		slog.Info(fmt.Sprintf("[%d] GcnArchName: %s", i, gfx))
+		slog.Info("hip device", "id", i, "name", name, "gfx", gfx)
+		var major, minor, patch string
+		switch len(gfx) {
+		case 6:
+			major, minor, patch = gfx[3:4], gfx[4:5], gfx[5:]
+		case 7:
+			major, minor, patch = gfx[3:5], gfx[5:6], gfx[6:]
+		}
 		//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY!  Always 0
 		//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY!  Always 0
 		// TODO  Why isn't props.iGPU accurate!?
 		// TODO  Why isn't props.iGPU accurate!?
 		if strings.EqualFold(name, iGPUName) {
 		if strings.EqualFold(name, iGPUName) {
-			slog.Info(fmt.Sprintf("iGPU detected [%d] skipping", i))
-			skip[i] = struct{}{}
+			slog.Info("iGPU detected skipping", "id", i)
 			continue
 			continue
 		}
 		}
 		if gfxOverride == "" {
 		if gfxOverride == "" {
 			if !slices.Contains[[]string, string](supported, gfx) {
 			if !slices.Contains[[]string, string](supported, gfx) {
-				slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, gfx, libDir, supported))
+				slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
 				// TODO - consider discrete markdown just for ROCM troubleshooting?
 				// TODO - consider discrete markdown just for ROCM troubleshooting?
 				slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
 				slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
-				skip[i] = struct{}{}
 				continue
 				continue
 			} else {
 			} else {
-				slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, gfx))
+				slog.Info("amdgpu is supported", "gpu", i, "gpu_type", gfx)
 			}
 			}
 		}
 		}
 
 
-		totalMemory, freeMemory, err := hl.HipMemGetInfo()
+		freeMemory, totalMemory, err := hl.HipMemGetInfo()
 		if err != nil {
 		if err != nil {
-			slog.Warn(fmt.Sprintf("[%d] %s", i, err))
+			slog.Warn("get mem info", "id", i, "error", err)
 			continue
 			continue
 		}
 		}
 
 
-		// TODO according to docs, freeMem may lie on windows!
-		slog.Info(fmt.Sprintf("[%d] Total Mem: %d", i, totalMemory))
-		slog.Info(fmt.Sprintf("[%d] Free Mem:  %d", i, freeMemory))
-		resp.memInfo.DeviceCount++
-		resp.memInfo.TotalMemory += totalMemory
-		resp.memInfo.FreeMemory += freeMemory
-	}
-	if resp.memInfo.DeviceCount > 0 {
-		resp.Library = "rocm"
-	}
-	// Abort if all GPUs are skipped
-	if len(skip) >= count {
-		slog.Info("all detected amdgpus are skipped, falling back to CPU")
-		return
-	}
-	if len(skip) > 0 {
-		amdSetVisibleDevices(ids, skip)
+		// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
+		if totalMemory < IGPUMemLimit {
+			slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
+			continue
+		}
+
+		// TODO revisit this once ROCm v6 is available on windows.
+		// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
+		slog.Info("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
+		slog.Info("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
+		gpuInfo := GpuInfo{
+			Library: "rocm",
+			memInfo: memInfo{
+				TotalMemory: totalMemory,
+				FreeMemory:  freeMemory,
+			},
+			ID:             fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
+			DependencyPath: libDir,
+			MinimumMemory:  rocmMinimumMemory,
+		}
+		if major != "" {
+			gpuInfo.Major, err = strconv.Atoi(major)
+			if err != nil {
+				slog.Info("failed to parse version", "version", gfx, "error", err)
+			}
+		}
+		if minor != "" {
+			gpuInfo.Minor, err = strconv.Atoi(minor)
+			if err != nil {
+				slog.Info("failed to parse version", "version", gfx, "error", err)
+			}
+		}
+		if patch != "" {
+			// Patch rev is hex; e.g. gfx90a
+			p, err := strconv.ParseInt(patch, 16, 0)
+			if err != nil {
+				slog.Info("failed to parse version", "version", gfx, "error", err)
+			} else {
+				gpuInfo.Patch = int(p)
+			}
+		}
+		if gpuInfo.Major < RocmComputeMin {
+			slog.Warn(fmt.Sprintf("amdgpu [%s] too old gfx%d%d%x", gpuInfo.ID, gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch))
+			continue
+		}
+
+		resp = append(resp, gpuInfo)
 	}
 	}
-	UpdatePath(libDir)
+
+	return resp
 }
 }
 
 
 func AMDValidateLibDir() (string, error) {
 func AMDValidateLibDir() (string, error) {
-	// On windows non-admins typically can't create links
-	// so instead of trying to rely on rpath and a link in
-	// $LibDir/rocm, we instead rely on setting PATH to point
-	// to the location of the ROCm library
-
-	// Installer payload location if we're running the installed binary
-	exe, err := os.Executable()
+	libDir, err := commonAMDValidateLibDir()
 	if err == nil {
 	if err == nil {
-		rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
-		if rocmLibUsable(rocmTargetDir) {
-			slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
-			return rocmTargetDir, nil
-		}
+		return libDir, nil
 	}
 	}
 
 
 	// Installer payload (if we're running from some other location)
 	// Installer payload (if we're running from some other location)
@@ -159,21 +183,6 @@ func AMDValidateLibDir() (string, error) {
 		return rocmTargetDir, nil
 		return rocmTargetDir, nil
 	}
 	}
 
 
-	// Prefer explicit HIP env var
-	hipPath := os.Getenv("HIP_PATH")
-	if hipPath != "" {
-		hipLibDir := filepath.Join(hipPath, "bin")
-		if rocmLibUsable(hipLibDir) {
-			slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
-			return hipLibDir, nil
-		}
-	}
-
-	// Well known location(s)
-	if rocmLibUsable(RocmStandardLocation) {
-		return RocmStandardLocation, nil
-	}
-
 	// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
 	// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
 	slog.Warn("amdgpu detected, but no compatible rocm library found.  Please install ROCm")
 	slog.Warn("amdgpu detected, but no compatible rocm library found.  Please install ROCm")
 	return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
 	return "", fmt.Errorf("no suitable rocm found, falling back to CPU")

+ 15 - 4
gpu/assets.go

@@ -12,6 +12,8 @@ import (
 	"sync"
 	"sync"
 	"syscall"
 	"syscall"
 	"time"
 	"time"
+
+	"github.com/ollama/ollama/server/envconfig"
 )
 )
 
 
 var (
 var (
@@ -24,8 +26,16 @@ func PayloadsDir() (string, error) {
 	defer lock.Unlock()
 	defer lock.Unlock()
 	var err error
 	var err error
 	if payloadsDir == "" {
 	if payloadsDir == "" {
+		runnersDir := envconfig.RunnersDir
+
+		if runnersDir != "" {
+			payloadsDir = runnersDir
+			return payloadsDir, nil
+		}
+
+		// The remainder only applies on non-windows where we still carry payloads in the main executable
 		cleanupTmpDirs()
 		cleanupTmpDirs()
-		tmpDir := os.Getenv("OLLAMA_TMPDIR")
+		tmpDir := envconfig.TmpDir
 		if tmpDir == "" {
 		if tmpDir == "" {
 			tmpDir, err = os.MkdirTemp("", "ollama")
 			tmpDir, err = os.MkdirTemp("", "ollama")
 			if err != nil {
 			if err != nil {
@@ -80,7 +90,7 @@ func cleanupTmpDirs() {
 		}
 		}
 		err = os.RemoveAll(d)
 		err = os.RemoveAll(d)
 		if err != nil {
 		if err != nil {
-			slog.Debug(fmt.Sprintf("unable to cleanup stale tmpdir %s: %s", d, err))
+			slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err)
 		}
 		}
 	}
 	}
 }
 }
@@ -88,7 +98,8 @@ func cleanupTmpDirs() {
 func Cleanup() {
 func Cleanup() {
 	lock.Lock()
 	lock.Lock()
 	defer lock.Unlock()
 	defer lock.Unlock()
-	if payloadsDir != "" {
+	runnersDir := envconfig.RunnersDir
+	if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" {
 		// We want to fully clean up the tmpdir parent of the payloads dir
 		// We want to fully clean up the tmpdir parent of the payloads dir
 		tmpDir := filepath.Clean(filepath.Join(payloadsDir, ".."))
 		tmpDir := filepath.Clean(filepath.Join(payloadsDir, ".."))
 		slog.Debug("cleaning up", "dir", tmpDir)
 		slog.Debug("cleaning up", "dir", tmpDir)
@@ -120,7 +131,7 @@ func UpdatePath(dir string) {
 			}
 			}
 		}
 		}
 		newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
 		newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
-		slog.Info(fmt.Sprintf("Updating PATH to %s", newPath))
+		slog.Info("updating", "PATH", newPath)
 		os.Setenv("PATH", newPath)
 		os.Setenv("PATH", newPath)
 	}
 	}
 	// linux and darwin rely on rpath
 	// linux and darwin rely on rpath

+ 22 - 0
gpu/cuda_common.go

@@ -0,0 +1,22 @@
+//go:build linux || windows
+
+package gpu
+
+import (
+	"log/slog"
+	"strings"
+)
+
+func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
+	ids := []string{}
+	for _, info := range gpuInfo {
+		if info.Library != "cuda" {
+			// TODO shouldn't happen if things are wired correctly...
+			slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
+			continue
+		}
+		ids = append(ids, info.ID)
+	}
+	return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
+
+}

+ 154 - 142
gpu/gpu.go

@@ -16,22 +16,23 @@ import (
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"runtime"
 	"runtime"
-	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"unsafe"
 	"unsafe"
 
 
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/format"
+	"github.com/ollama/ollama/server/envconfig"
 )
 )
 
 
 type handles struct {
 type handles struct {
-	nvml   *C.nvml_handle_t
-	cudart *C.cudart_handle_t
+	deviceCount int
+	cudart      *C.cudart_handle_t
+	nvcuda      *C.nvcuda_handle_t
 }
 }
 
 
 const (
 const (
-	cudaMinimumMemory = 457 * format.MebiByte
-	rocmMinimumMemory = 457 * format.MebiByte
+	cudaMinimumMemory = 256 * format.MebiByte
+	rocmMinimumMemory = 256 * format.MebiByte
 )
 )
 
 
 var gpuMutex sync.Mutex
 var gpuMutex sync.Mutex
@@ -39,26 +40,10 @@ var gpuMutex sync.Mutex
 // With our current CUDA compile flags, older than 5.0 will not work properly
 // With our current CUDA compile flags, older than 5.0 will not work properly
 var CudaComputeMin = [2]C.int{5, 0}
 var CudaComputeMin = [2]C.int{5, 0}
 
 
-// Possible locations for the nvidia-ml library
-var NvmlLinuxGlobs = []string{
-	"/usr/local/cuda/lib64/libnvidia-ml.so*",
-	"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
-	"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
-	"/usr/lib/wsl/lib/libnvidia-ml.so*",
-	"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
-	"/opt/cuda/lib64/libnvidia-ml.so*",
-	"/usr/lib*/libnvidia-ml.so*",
-	"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
-	"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
-	"/usr/local/lib*/libnvidia-ml.so*",
-
-	// TODO: are these stubs ever valid?
-	"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
-}
+var RocmComputeMin = 9
 
 
-var NvmlWindowsGlobs = []string{
-	"c:\\Windows\\System32\\nvml.dll",
-}
+// TODO find a better way to detect iGPU instead of minimum memory
+const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
 
 
 var CudartLinuxGlobs = []string{
 var CudartLinuxGlobs = []string{
 	"/usr/local/cuda/lib64/libcudart.so*",
 	"/usr/local/cuda/lib64/libcudart.so*",
@@ -79,6 +64,22 @@ var CudartWindowsGlobs = []string{
 	"c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
 	"c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
 }
 }
 
 
+var NvcudaLinuxGlobs = []string{
+	"/usr/local/cuda*/targets/*/lib/libcuda.so*",
+	"/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*",
+	"/usr/lib/*-linux-gnu/libcuda.so*",
+	"/usr/lib/wsl/lib/libcuda.so*",
+	"/usr/lib/wsl/drivers/*/libcuda.so*",
+	"/opt/cuda/lib*/libcuda.so*",
+	"/usr/local/cuda/lib*/libcuda.so*",
+	"/usr/lib*/libcuda.so*",
+	"/usr/local/lib*/libcuda.so*",
+}
+
+var NvcudaWindowsGlobs = []string{
+	"c:\\windows\\system*\\nvcuda.dll",
+}
+
 // Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
 // Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
 // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
 // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
 var CudaTegra string = os.Getenv("JETSON_JETPACK")
 var CudaTegra string = os.Getenv("JETSON_JETPACK")
@@ -88,61 +89,62 @@ func initGPUHandles() *handles {
 
 
 	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
 	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
 
 
-	gpuHandles := &handles{nil, nil}
-	var nvmlMgmtName string
-	var nvmlMgmtPatterns []string
+	gpuHandles := &handles{}
 	var cudartMgmtName string
 	var cudartMgmtName string
 	var cudartMgmtPatterns []string
 	var cudartMgmtPatterns []string
+	var nvcudaMgmtName string
+	var nvcudaMgmtPatterns []string
 
 
 	tmpDir, _ := PayloadsDir()
 	tmpDir, _ := PayloadsDir()
 	switch runtime.GOOS {
 	switch runtime.GOOS {
 	case "windows":
 	case "windows":
-		nvmlMgmtName = "nvml.dll"
-		nvmlMgmtPatterns = make([]string, len(NvmlWindowsGlobs))
-		copy(nvmlMgmtPatterns, NvmlWindowsGlobs)
 		cudartMgmtName = "cudart64_*.dll"
 		cudartMgmtName = "cudart64_*.dll"
 		localAppData := os.Getenv("LOCALAPPDATA")
 		localAppData := os.Getenv("LOCALAPPDATA")
 		cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
 		cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
 		cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
 		cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
+		// Aligned with driver, we can't carry as payloads
+		nvcudaMgmtName = "nvcuda.dll"
+		nvcudaMgmtPatterns = NvcudaWindowsGlobs
 	case "linux":
 	case "linux":
-		nvmlMgmtName = "libnvidia-ml.so"
-		nvmlMgmtPatterns = make([]string, len(NvmlLinuxGlobs))
-		copy(nvmlMgmtPatterns, NvmlLinuxGlobs)
 		cudartMgmtName = "libcudart.so*"
 		cudartMgmtName = "libcudart.so*"
 		if tmpDir != "" {
 		if tmpDir != "" {
 			// TODO - add "payloads" for subprocess
 			// TODO - add "payloads" for subprocess
 			cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", cudartMgmtName)}
 			cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", cudartMgmtName)}
 		}
 		}
 		cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
 		cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
+		// Aligned with driver, we can't carry as payloads
+		nvcudaMgmtName = "libcuda.so*"
+		nvcudaMgmtPatterns = NvcudaLinuxGlobs
 	default:
 	default:
 		return gpuHandles
 		return gpuHandles
 	}
 	}
 
 
-	slog.Info("Detecting GPU type")
-	cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
-	if len(cudartLibPaths) > 0 {
-		cudart := LoadCUDARTMgmt(cudartLibPaths)
-		if cudart != nil {
-			slog.Info("Nvidia GPU detected via cudart")
-			gpuHandles.cudart = cudart
+	slog.Info("Detecting GPUs")
+	nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns)
+	if len(nvcudaLibPaths) > 0 {
+		deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
+		if nvcuda != nil {
+			slog.Info("detected GPUs", "count", deviceCount, "library", libPath)
+			gpuHandles.nvcuda = nvcuda
+			gpuHandles.deviceCount = deviceCount
 			return gpuHandles
 			return gpuHandles
 		}
 		}
 	}
 	}
 
 
-	// TODO once we build confidence, remove this and the gpu_info_nvml.[ch] files
-	nvmlLibPaths := FindGPULibs(nvmlMgmtName, nvmlMgmtPatterns)
-	if len(nvmlLibPaths) > 0 {
-		nvml := LoadNVMLMgmt(nvmlLibPaths)
-		if nvml != nil {
-			slog.Info("Nvidia GPU detected via nvidia-ml")
-			gpuHandles.nvml = nvml
+	cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
+	if len(cudartLibPaths) > 0 {
+		deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
+		if cudart != nil {
+			slog.Info("detected GPUs", "library", libPath, "count", deviceCount)
+			gpuHandles.cudart = cudart
+			gpuHandles.deviceCount = deviceCount
 			return gpuHandles
 			return gpuHandles
 		}
 		}
 	}
 	}
 	return gpuHandles
 	return gpuHandles
 }
 }
 
 
-func GetGPUInfo() GpuInfo {
+func GetGPUInfo() GpuInfoList {
 	// TODO - consider exploring lspci (and equivalent on windows) to check for
 	// TODO - consider exploring lspci (and equivalent on windows) to check for
 	// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
 	// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
 	gpuMutex.Lock()
 	gpuMutex.Lock()
@@ -150,12 +152,12 @@ func GetGPUInfo() GpuInfo {
 
 
 	gpuHandles := initGPUHandles()
 	gpuHandles := initGPUHandles()
 	defer func() {
 	defer func() {
-		if gpuHandles.nvml != nil {
-			C.nvml_release(*gpuHandles.nvml)
-		}
 		if gpuHandles.cudart != nil {
 		if gpuHandles.cudart != nil {
 			C.cudart_release(*gpuHandles.cudart)
 			C.cudart_release(*gpuHandles.cudart)
 		}
 		}
+		if gpuHandles.nvcuda != nil {
+			C.nvcuda_release(*gpuHandles.nvcuda)
+		}
 	}()
 	}()
 
 
 	// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
 	// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
@@ -164,73 +166,75 @@ func GetGPUInfo() GpuInfo {
 		slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
 		slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
 	}
 	}
 
 
+	// On windows we bundle the nvidia library one level above the runner dir
+	depPath := ""
+	if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
+		depPath = filepath.Dir(envconfig.RunnersDir)
+	}
+
 	var memInfo C.mem_info_t
 	var memInfo C.mem_info_t
-	resp := GpuInfo{}
-	if gpuHandles.nvml != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
-		C.nvml_check_vram(*gpuHandles.nvml, &memInfo)
-		if memInfo.err != nil {
-			slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU memory: %s", C.GoString(memInfo.err)))
-			C.free(unsafe.Pointer(memInfo.err))
-		} else if memInfo.count > 0 {
-			// Verify minimum compute capability
-			var cc C.nvml_compute_capability_t
-			C.nvml_compute_capability(*gpuHandles.nvml, &cc)
-			if cc.err != nil {
-				slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU compute capability: %s", C.GoString(cc.err)))
-				C.free(unsafe.Pointer(cc.err))
-			} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
-				slog.Info(fmt.Sprintf("[nvidia-ml] NVML CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
-				resp.Library = "cuda"
-				resp.MinimumMemory = cudaMinimumMemory
-			} else {
-				slog.Info(fmt.Sprintf("[nvidia-ml] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
-			}
+	resp := []GpuInfo{}
+
+	// NVIDIA first
+	for i := 0; i < gpuHandles.deviceCount; i++ {
+		// TODO once we support CPU compilation variants of GPU libraries refine this...
+		if cpuVariant == "" && runtime.GOARCH == "amd64" {
+			continue
+		}
+		gpuInfo := GpuInfo{
+			Library: "cuda",
+		}
+		if gpuHandles.cudart != nil {
+			C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
+		} else {
+			C.nvcuda_check_vram(*gpuHandles.nvcuda, C.int(i), &memInfo)
 		}
 		}
-	} else if gpuHandles.cudart != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
-		C.cudart_check_vram(*gpuHandles.cudart, &memInfo)
 		if memInfo.err != nil {
 		if memInfo.err != nil {
-			slog.Info(fmt.Sprintf("[cudart] error looking up CUDART GPU memory: %s", C.GoString(memInfo.err)))
+			slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
 			C.free(unsafe.Pointer(memInfo.err))
 			C.free(unsafe.Pointer(memInfo.err))
-		} else if memInfo.count > 0 {
-			// Verify minimum compute capability
-			var cc C.cudart_compute_capability_t
-			C.cudart_compute_capability(*gpuHandles.cudart, &cc)
-			if cc.err != nil {
-				slog.Info(fmt.Sprintf("[cudart] error looking up CUDA compute capability: %s", C.GoString(cc.err)))
-				C.free(unsafe.Pointer(cc.err))
-			} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
-				slog.Info(fmt.Sprintf("[cudart] CUDART CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
-				resp.Library = "cuda"
-				resp.MinimumMemory = cudaMinimumMemory
-			} else {
-				slog.Info(fmt.Sprintf("[cudart] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
-			}
+			continue
 		}
 		}
-	} else {
-		AMDGetGPUInfo(&resp)
-		if resp.Library != "" {
-			resp.MinimumMemory = rocmMinimumMemory
-			return resp
+		if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
+			slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
+			continue
 		}
 		}
+		gpuInfo.TotalMemory = uint64(memInfo.total)
+		gpuInfo.FreeMemory = uint64(memInfo.free)
+		gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
+		gpuInfo.Major = int(memInfo.major)
+		gpuInfo.Minor = int(memInfo.minor)
+		gpuInfo.MinimumMemory = cudaMinimumMemory
+		gpuInfo.DependencyPath = depPath
+
+		// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
+		resp = append(resp, gpuInfo)
 	}
 	}
-	if resp.Library == "" {
+
+	// Then AMD
+	resp = append(resp, AMDGetGPUInfo()...)
+
+	if len(resp) == 0 {
 		C.cpu_check_ram(&memInfo)
 		C.cpu_check_ram(&memInfo)
-		resp.Library = "cpu"
-		resp.Variant = cpuVariant
-	}
-	if memInfo.err != nil {
-		slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
-		C.free(unsafe.Pointer(memInfo.err))
-		return resp
+		if memInfo.err != nil {
+			slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
+			C.free(unsafe.Pointer(memInfo.err))
+			return resp
+		}
+		gpuInfo := GpuInfo{
+			Library: "cpu",
+			Variant: cpuVariant,
+		}
+		gpuInfo.TotalMemory = uint64(memInfo.total)
+		gpuInfo.FreeMemory = uint64(memInfo.free)
+		gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
+
+		resp = append(resp, gpuInfo)
 	}
 	}
 
 
-	resp.DeviceCount = uint32(memInfo.count)
-	resp.FreeMemory = uint64(memInfo.free)
-	resp.TotalMemory = uint64(memInfo.total)
 	return resp
 	return resp
 }
 }
 
 
-func getCPUMem() (memInfo, error) {
+func GetCPUMem() (memInfo, error) {
 	var ret memInfo
 	var ret memInfo
 	var info C.mem_info_t
 	var info C.mem_info_t
 	C.cpu_check_ram(&info)
 	C.cpu_check_ram(&info)
@@ -243,29 +247,12 @@ func getCPUMem() (memInfo, error) {
 	return ret, nil
 	return ret, nil
 }
 }
 
 
-func CheckVRAM() (uint64, error) {
-	userLimit := os.Getenv("OLLAMA_MAX_VRAM")
-	if userLimit != "" {
-		avail, err := strconv.ParseInt(userLimit, 10, 64)
-		if err != nil {
-			return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
-		}
-		slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
-		return uint64(avail), nil
-	}
-	gpuInfo := GetGPUInfo()
-	if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
-		return gpuInfo.FreeMemory, nil
-	}
-
-	return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
-}
-
-func FindGPULibs(baseLibName string, patterns []string) []string {
+func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
 	// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
 	// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
 	var ldPaths []string
 	var ldPaths []string
+	var patterns []string
 	gpuLibPaths := []string{}
 	gpuLibPaths := []string{}
-	slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName))
+	slog.Debug("Searching for GPU library", "name", baseLibName)
 
 
 	switch runtime.GOOS {
 	switch runtime.GOOS {
 	case "windows":
 	case "windows":
@@ -283,8 +270,14 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
 		}
 		}
 		patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
 		patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
 	}
 	}
-	slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns))
+	patterns = append(patterns, defaultPatterns...)
+	slog.Debug("gpu library search", "globs", patterns)
 	for _, pattern := range patterns {
 	for _, pattern := range patterns {
+
+		// Nvidia PhysX known to return bogus results
+		if strings.Contains(pattern, "PhysX") {
+			slog.Debug("skipping PhysX cuda library path", "path", pattern)
+		}
 		// Ignore glob discovery errors
 		// Ignore glob discovery errors
 		matches, _ := filepath.Glob(pattern)
 		matches, _ := filepath.Glob(pattern)
 		for _, match := range matches {
 		for _, match := range matches {
@@ -311,47 +304,66 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
 			}
 			}
 		}
 		}
 	}
 	}
-	slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths))
+	slog.Debug("discovered GPU libraries", "paths", gpuLibPaths)
 	return gpuLibPaths
 	return gpuLibPaths
 }
 }
 
 
-func LoadNVMLMgmt(nvmlLibPaths []string) *C.nvml_handle_t {
-	var resp C.nvml_init_resp_t
+func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) {
+	var resp C.cudart_init_resp_t
 	resp.ch.verbose = getVerboseState()
 	resp.ch.verbose = getVerboseState()
-	for _, libPath := range nvmlLibPaths {
+	for _, libPath := range cudartLibPaths {
 		lib := C.CString(libPath)
 		lib := C.CString(libPath)
 		defer C.free(unsafe.Pointer(lib))
 		defer C.free(unsafe.Pointer(lib))
-		C.nvml_init(lib, &resp)
+		C.cudart_init(lib, &resp)
 		if resp.err != nil {
 		if resp.err != nil {
-			slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)))
+			slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err))
 			C.free(unsafe.Pointer(resp.err))
 			C.free(unsafe.Pointer(resp.err))
 		} else {
 		} else {
-			return &resp.ch
+			return int(resp.num_devices), &resp.ch, libPath
 		}
 		}
 	}
 	}
-	return nil
+	return 0, nil, ""
 }
 }
 
 
-func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
-	var resp C.cudart_init_resp_t
+func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
+	var resp C.nvcuda_init_resp_t
 	resp.ch.verbose = getVerboseState()
 	resp.ch.verbose = getVerboseState()
-	for _, libPath := range cudartLibPaths {
+	for _, libPath := range nvcudaLibPaths {
 		lib := C.CString(libPath)
 		lib := C.CString(libPath)
 		defer C.free(unsafe.Pointer(lib))
 		defer C.free(unsafe.Pointer(lib))
-		C.cudart_init(lib, &resp)
+		C.nvcuda_init(lib, &resp)
 		if resp.err != nil {
 		if resp.err != nil {
-			slog.Info(fmt.Sprintf("Unable to load cudart CUDA management library %s: %s", libPath, C.GoString(resp.err)))
+			slog.Debug("Unable to load nvcuda", "library", libPath, "error", C.GoString(resp.err))
 			C.free(unsafe.Pointer(resp.err))
 			C.free(unsafe.Pointer(resp.err))
 		} else {
 		} else {
-			return &resp.ch
+			return int(resp.num_devices), &resp.ch, libPath
 		}
 		}
 	}
 	}
-	return nil
+	return 0, nil, ""
 }
 }
 
 
 func getVerboseState() C.uint16_t {
 func getVerboseState() C.uint16_t {
-	if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
+	if envconfig.Debug {
 		return C.uint16_t(1)
 		return C.uint16_t(1)
 	}
 	}
 	return C.uint16_t(0)
 	return C.uint16_t(0)
 }
 }
+
+// Given the list of GPUs this instantiation is targeted for,
+// figure out the visible devices environment variable
+//
+// If different libraries are detected, the first one is what we use
+func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
+	if len(l) == 0 {
+		return "", ""
+	}
+	switch l[0].Library {
+	case "cuda":
+		return cudaGetVisibleDevicesEnv(l)
+	case "rocm":
+		return rocmGetVisibleDevicesEnv(l)
+	default:
+		slog.Debug("no filter required for library " + l[0].Library)
+		return "", ""
+	}
+}

+ 28 - 33
gpu/gpu_darwin.go

@@ -9,52 +9,47 @@ package gpu
 */
 */
 import "C"
 import "C"
 import (
 import (
-	"fmt"
-	"log/slog"
-	"os"
 	"runtime"
 	"runtime"
-	"strconv"
-)
-
-// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
-func CheckVRAM() (uint64, error) {
-	userLimit := os.Getenv("OLLAMA_MAX_VRAM")
-	if userLimit != "" {
-		avail, err := strconv.ParseInt(userLimit, 10, 64)
-		if err != nil {
-			return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
-		}
-		slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
-		return uint64(avail), nil
-	}
 
 
-	if runtime.GOARCH == "amd64" {
-		// gpu not supported, this may not be metal
-		return 0, nil
-	}
+	"github.com/ollama/ollama/format"
+)
 
 
-	return uint64(C.getRecommendedMaxVRAM()), nil
-}
+const (
+	metalMinimumMemory = 384 * format.MebiByte
+)
 
 
-func GetGPUInfo() GpuInfo {
-	mem, _ := getCPUMem()
+func GetGPUInfo() GpuInfoList {
+	mem, _ := GetCPUMem()
 	if runtime.GOARCH == "amd64" {
 	if runtime.GOARCH == "amd64" {
-		return GpuInfo{
-			Library: "cpu",
-			Variant: GetCPUVariant(),
-			memInfo: mem,
+		return []GpuInfo{
+			{
+				Library: "cpu",
+				Variant: GetCPUVariant(),
+				memInfo: mem,
+			},
 		}
 		}
 	}
 	}
-	return GpuInfo{
+	info := GpuInfo{
 		Library: "metal",
 		Library: "metal",
-		memInfo: mem,
+		ID:      "0",
 	}
 	}
+	info.TotalMemory = uint64(C.getRecommendedMaxVRAM())
+
+	// TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work)
+	info.FreeMemory = info.TotalMemory
+
+	info.MinimumMemory = metalMinimumMemory
+	return []GpuInfo{info}
 }
 }
 
 
-func getCPUMem() (memInfo, error) {
+func GetCPUMem() (memInfo, error) {
 	return memInfo{
 	return memInfo{
 		TotalMemory: uint64(C.getPhysicalMemory()),
 		TotalMemory: uint64(C.getPhysicalMemory()),
 		FreeMemory:  0,
 		FreeMemory:  0,
-		DeviceCount: 1,
 	}, nil
 	}, nil
 }
 }
+
+func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
+	// No-op on darwin
+	return "", ""
+}

+ 9 - 4
gpu/gpu_info.h

@@ -38,12 +38,17 @@
 extern "C" {
 extern "C" {
 #endif
 #endif
 
 
+#define GPU_ID_LEN 64
+
 typedef struct mem_info {
 typedef struct mem_info {
+  char *err;  // If non-nill, caller responsible for freeing
+  char gpu_id[GPU_ID_LEN];
   uint64_t total;
   uint64_t total;
   uint64_t free;
   uint64_t free;
-  unsigned int count;
-  int igpu_index; // If >= 0, we detected an integrated GPU to ignore
-  char *err;  // If non-nill, caller responsible for freeing
+
+  // Compute Capability
+  int major; 
+  int minor;
 } mem_info_t;
 } mem_info_t;
 
 
 void cpu_check_ram(mem_info_t *resp);
 void cpu_check_ram(mem_info_t *resp);
@@ -52,8 +57,8 @@ void cpu_check_ram(mem_info_t *resp);
 }
 }
 #endif
 #endif
 
 
-#include "gpu_info_nvml.h"
 #include "gpu_info_cudart.h"
 #include "gpu_info_cudart.h"
+#include "gpu_info_nvcuda.h"
 
 
 #endif  // __GPU_INFO_H__
 #endif  // __GPU_INFO_H__
 #endif  // __APPLE__
 #endif  // __APPLE__

+ 6 - 2
gpu/gpu_info_cpu.c

@@ -8,9 +8,11 @@ void cpu_check_ram(mem_info_t *resp) {
   MEMORYSTATUSEX info;
   MEMORYSTATUSEX info;
   info.dwLength = sizeof(info);
   info.dwLength = sizeof(info);
   if (GlobalMemoryStatusEx(&info) != 0) {
   if (GlobalMemoryStatusEx(&info) != 0) {
-    resp->count = 1;
     resp->total = info.ullTotalPhys;
     resp->total = info.ullTotalPhys;
     resp->free = info.ullAvailPhys;
     resp->free = info.ullAvailPhys;
+    resp->major = 0;
+    resp->minor = 0;
+    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
   } else {
   } else {
     resp->err = LOAD_ERR();
     resp->err = LOAD_ERR();
   }
   }
@@ -27,9 +29,11 @@ void cpu_check_ram(mem_info_t *resp) {
   if (sysinfo(&info) != 0) {
   if (sysinfo(&info) != 0) {
     resp->err = strdup(strerror(errno));
     resp->err = strdup(strerror(errno));
   } else {
   } else {
-    resp->count = 1;
     resp->total = info.totalram * info.mem_unit;
     resp->total = info.totalram * info.mem_unit;
     resp->free = info.freeram * info.mem_unit;
     resp->free = info.freeram * info.mem_unit;
+    resp->major = 0;
+    resp->minor = 0;
+    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
   }
   }
   return;
   return;
 }
 }

+ 63 - 82
gpu/gpu_info_cudart.c

@@ -6,6 +6,7 @@
 void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
 void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
   cudartReturn_t ret;
   cudartReturn_t ret;
   resp->err = NULL;
   resp->err = NULL;
+  resp->num_devices = 0;
   const int buflen = 256;
   const int buflen = 256;
   char buf[buflen + 1];
   char buf[buflen + 1];
   int i;
   int i;
@@ -21,6 +22,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
       {"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
       {"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
       {"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
       {"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
       {"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
       {"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
+      {"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties},
       {NULL, NULL},
       {NULL, NULL},
   };
   };
 
 
@@ -36,13 +38,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
     return;
     return;
   }
   }
 
 
-  // TODO once we've squashed the remaining corner cases remove this log
-  LOG(resp->ch.verbose, "wiring cudart library functions in %s\n", cudart_lib_path);
-  
   for (i = 0; l[i].s != NULL; i++) {
   for (i = 0; l[i].s != NULL; i++) {
-    // TODO once we've squashed the remaining corner cases remove this log
-    LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
-
     *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
     *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
     if (!l[i].p) {
     if (!l[i].p) {
       char *msg = LOAD_ERR();
       char *msg = LOAD_ERR();
@@ -63,7 +59,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
     UNLOAD_LIBRARY(resp->ch.handle);
     UNLOAD_LIBRARY(resp->ch.handle);
     resp->ch.handle = NULL;
     resp->ch.handle = NULL;
     if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
     if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
-      resp->err = strdup("your nvidia driver is too old or missing, please upgrade to run ollama");
+      resp->err = strdup("your nvidia driver is too old or missing.  If you have a CUDA GPU please upgrade to run ollama");
       return;
       return;
     }
     }
     snprintf(buf, buflen, "cudart init failure: %d", ret);
     snprintf(buf, buflen, "cudart init failure: %d", ret);
@@ -85,110 +81,95 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
     driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
     driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
     LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
     LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
   }
   }
+
+  ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices);
+  if (ret != CUDART_SUCCESS) {
+    LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret);
+    UNLOAD_LIBRARY(resp->ch.handle);
+    resp->ch.handle = NULL;
+    snprintf(buf, buflen, "unable to get device count: %d", ret);
+    resp->err = strdup(buf);
+    return;
+  }
 }
 }
 
 
 
 
-void cudart_check_vram(cudart_handle_t h, mem_info_t *resp) {
+void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
   resp->err = NULL;
   resp->err = NULL;
   cudartMemory_t memInfo = {0,0,0};
   cudartMemory_t memInfo = {0,0,0};
   cudartReturn_t ret;
   cudartReturn_t ret;
   const int buflen = 256;
   const int buflen = 256;
   char buf[buflen + 1];
   char buf[buflen + 1];
-  int i;
 
 
   if (h.handle == NULL) {
   if (h.handle == NULL) {
     resp->err = strdup("cudart handle isn't initialized");
     resp->err = strdup("cudart handle isn't initialized");
     return;
     return;
   }
   }
 
 
-  // cudaGetDeviceCount takes int type, resp-> count is uint
-  int deviceCount;
-  ret = (*h.cudaGetDeviceCount)(&deviceCount);
+  ret = (*h.cudaSetDevice)(i);
   if (ret != CUDART_SUCCESS) {
   if (ret != CUDART_SUCCESS) {
-    snprintf(buf, buflen, "unable to get device count: %d", ret);
+    snprintf(buf, buflen, "cudart device failed to initialize");
     resp->err = strdup(buf);
     resp->err = strdup(buf);
     return;
     return;
-  } else {
-    resp->count = (unsigned int)deviceCount;
   }
   }
 
 
-  resp->total = 0;
-  resp->free = 0;
-  for (i = 0; i < resp-> count; i++) {  
-    ret = (*h.cudaSetDevice)(i);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "cudart device failed to initialize");
-      resp->err = strdup(buf);
-      return;
+  cudaDeviceProp_t props;
+  ret = (*h.cudaGetDeviceProperties)(&props, i);
+  if (ret != CUDART_SUCCESS) {
+    LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret);
+    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
+    resp->major = 0;
+    resp->minor = 0;
+  } else {
+    int allNull = 1;
+    for (int j = 0; j < 16; j++) {
+      if (props.uuid.bytes[j] != 0) {
+        allNull = 0;
+        break;
+      }
     }
     }
-    ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
-      resp->err = strdup(buf);
-      return;
+    if (allNull != 0) {
+      snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
+    } else {
+      // GPU-d110a105-ac29-1d54-7b49-9c90440f215b
+      snprintf(&resp->gpu_id[0], GPU_ID_LEN,
+          "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
+          props.uuid.bytes[0],
+          props.uuid.bytes[1],
+          props.uuid.bytes[2],
+          props.uuid.bytes[3],
+          props.uuid.bytes[4],
+          props.uuid.bytes[5],
+          props.uuid.bytes[6],
+          props.uuid.bytes[7],
+          props.uuid.bytes[8],
+          props.uuid.bytes[9],
+          props.uuid.bytes[10],
+          props.uuid.bytes[11],
+          props.uuid.bytes[12],
+          props.uuid.bytes[13],
+          props.uuid.bytes[14],
+          props.uuid.bytes[15]
+        );
     }
     }
+    resp->major = props.major;
+    resp->minor = props.minor;
 
 
-    LOG(h.verbose, "[%d] CUDA totalMem %lu\n", i, memInfo.total);
-    LOG(h.verbose, "[%d] CUDA freeMem %lu\n", i, memInfo.free);
-
-    resp->total += memInfo.total;
-    resp->free += memInfo.free;
-  }
-}
-
-void cudart_compute_capability(cudart_handle_t h, cudart_compute_capability_t *resp) {
-  resp->err = NULL;
-  resp->major = 0;
-  resp->minor = 0;
-  int major = 0;
-  int minor = 0;
-  cudartReturn_t ret;
-  const int buflen = 256;
-  char buf[buflen + 1];
-  int i;
-
-  if (h.handle == NULL) {
-    resp->err = strdup("cudart handle not initialized");
-    return;
+    // TODO add other useful properties from props
   }
   }
-
-  int devices;
-  ret = (*h.cudaGetDeviceCount)(&devices);
+  ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
   if (ret != CUDART_SUCCESS) {
   if (ret != CUDART_SUCCESS) {
-    snprintf(buf, buflen, "unable to get cudart device count: %d", ret);
+    snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
     resp->err = strdup(buf);
     resp->err = strdup(buf);
     return;
     return;
   }
   }
 
 
-  for (i = 0; i < devices; i++) {
-    ret = (*h.cudaSetDevice)(i);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "cudart device failed to initialize");
-      resp->err = strdup(buf);
-      return;
-    }
+  resp->total = memInfo.total;
+  resp->free = memInfo.free;
 
 
-    ret = (*h.cudaDeviceGetAttribute)(&major, cudartDevAttrComputeCapabilityMajor, i);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-    ret = (*h.cudaDeviceGetAttribute)(&minor, cudartDevAttrComputeCapabilityMinor, i);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-      
-    // Report the lowest major.minor we detect as that limits our compatibility
-    if (resp->major == 0 || resp->major > major ) {
-      resp->major = major;
-      resp->minor = minor;
-    } else if ( resp->major == major && resp->minor > minor ) {
-      resp->minor = minor;
-    }
-  }
+  LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
+  LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
+  LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
 }
 }
 
 
 void cudart_release(cudart_handle_t h) {
 void cudart_release(cudart_handle_t h) {

+ 97 - 11
gpu/gpu_info_cudart.h

@@ -6,14 +6,20 @@
 // Just enough typedef's to dlopen/dlsym for memory information
 // Just enough typedef's to dlopen/dlsym for memory information
 typedef enum cudartReturn_enum {
 typedef enum cudartReturn_enum {
   CUDART_SUCCESS = 0,
   CUDART_SUCCESS = 0,
-  CUDART_UNSUPPORTED = 1,
-  CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
+  CUDART_ERROR_INVALID_VALUE = 1,
+  CUDART_ERROR_MEMORY_ALLOCATION = 2,
+  CUDART_ERROR_INSUFFICIENT_DRIVER = 35,
   // Other values omitted for now...
   // Other values omitted for now...
 } cudartReturn_t;
 } cudartReturn_t;
 
 
 typedef enum cudartDeviceAttr_enum {
 typedef enum cudartDeviceAttr_enum {
   cudartDevAttrComputeCapabilityMajor = 75,
   cudartDevAttrComputeCapabilityMajor = 75,
   cudartDevAttrComputeCapabilityMinor = 76,
   cudartDevAttrComputeCapabilityMinor = 76,
+
+  // TODO - not yet wired up but may be useful for Jetson or other
+  // integrated GPU scenarios with shared memory
+  cudaDevAttrIntegrated = 18
+
 } cudartDeviceAttr_t;
 } cudartDeviceAttr_t;
 
 
 typedef void *cudartDevice_t;  // Opaque is sufficient
 typedef void *cudartDevice_t;  // Opaque is sufficient
@@ -28,6 +34,92 @@ typedef struct cudartDriverVersion {
   int minor;
   int minor;
 } cudartDriverVersion_t;
 } cudartDriverVersion_t;
 
 
+typedef struct cudaUUID {
+    unsigned char bytes[16];
+} cudaUUID_t;
+typedef struct cudaDeviceProp {
+    char         name[256];                  /**< ASCII string identifying device */
+    cudaUUID_t   uuid;                       /**< 16-byte unique identifier */
+    char         luid[8];                    /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */
+    unsigned int luidDeviceNodeMask;         /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */
+    size_t       totalGlobalMem;             /**< Global memory available on device in bytes */
+    size_t       sharedMemPerBlock;          /**< Shared memory available per block in bytes */
+    int          regsPerBlock;               /**< 32-bit registers available per block */
+    int          warpSize;                   /**< Warp size in threads */
+    size_t       memPitch;                   /**< Maximum pitch in bytes allowed by memory copies */
+    int          maxThreadsPerBlock;         /**< Maximum number of threads per block */
+    int          maxThreadsDim[3];           /**< Maximum size of each dimension of a block */
+    int          maxGridSize[3];             /**< Maximum size of each dimension of a grid */
+    int          clockRate;                  /**< Clock frequency in kilohertz */
+    size_t       totalConstMem;              /**< Constant memory available on device in bytes */
+    int          major;                      /**< Major compute capability */
+    int          minor;                      /**< Minor compute capability */
+    size_t       textureAlignment;           /**< Alignment requirement for textures */
+    size_t       texturePitchAlignment;      /**< Pitch alignment requirement for texture references bound to pitched memory */
+    int          deviceOverlap;              /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */
+    int          multiProcessorCount;        /**< Number of multiprocessors on device */
+    int          kernelExecTimeoutEnabled;   /**< Specified whether there is a run time limit on kernels */
+    int          integrated;                 /**< Device is integrated as opposed to discrete */
+    int          canMapHostMemory;           /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */
+    int          computeMode;                /**< Compute mode (See ::cudaComputeMode) */
+    int          maxTexture1D;               /**< Maximum 1D texture size */
+    int          maxTexture1DMipmap;         /**< Maximum 1D mipmapped texture size */
+    int          maxTexture1DLinear;         /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */
+    int          maxTexture2D[2];            /**< Maximum 2D texture dimensions */
+    int          maxTexture2DMipmap[2];      /**< Maximum 2D mipmapped texture dimensions */
+    int          maxTexture2DLinear[3];      /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */
+    int          maxTexture2DGather[2];      /**< Maximum 2D texture dimensions if texture gather operations have to be performed */
+    int          maxTexture3D[3];            /**< Maximum 3D texture dimensions */
+    int          maxTexture3DAlt[3];         /**< Maximum alternate 3D texture dimensions */
+    int          maxTextureCubemap;          /**< Maximum Cubemap texture dimensions */
+    int          maxTexture1DLayered[2];     /**< Maximum 1D layered texture dimensions */
+    int          maxTexture2DLayered[3];     /**< Maximum 2D layered texture dimensions */
+    int          maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */
+    int          maxSurface1D;               /**< Maximum 1D surface size */
+    int          maxSurface2D[2];            /**< Maximum 2D surface dimensions */
+    int          maxSurface3D[3];            /**< Maximum 3D surface dimensions */
+    int          maxSurface1DLayered[2];     /**< Maximum 1D layered surface dimensions */
+    int          maxSurface2DLayered[3];     /**< Maximum 2D layered surface dimensions */
+    int          maxSurfaceCubemap;          /**< Maximum Cubemap surface dimensions */
+    int          maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */
+    size_t       surfaceAlignment;           /**< Alignment requirements for surfaces */
+    int          concurrentKernels;          /**< Device can possibly execute multiple kernels concurrently */
+    int          ECCEnabled;                 /**< Device has ECC support enabled */
+    int          pciBusID;                   /**< PCI bus ID of the device */
+    int          pciDeviceID;                /**< PCI device ID of the device */
+    int          pciDomainID;                /**< PCI domain ID of the device */
+    int          tccDriver;                  /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */
+    int          asyncEngineCount;           /**< Number of asynchronous engines */
+    int          unifiedAddressing;          /**< Device shares a unified address space with the host */
+    int          memoryClockRate;            /**< Peak memory clock frequency in kilohertz */
+    int          memoryBusWidth;             /**< Global memory bus width in bits */
+    int          l2CacheSize;                /**< Size of L2 cache in bytes */
+    int          persistingL2CacheMaxSize;   /**< Device's maximum l2 persisting lines capacity setting in bytes */
+    int          maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */
+    int          streamPrioritiesSupported;  /**< Device supports stream priorities */
+    int          globalL1CacheSupported;     /**< Device supports caching globals in L1 */
+    int          localL1CacheSupported;      /**< Device supports caching locals in L1 */
+    size_t       sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */
+    int          regsPerMultiprocessor;      /**< 32-bit registers available per multiprocessor */
+    int          managedMemory;              /**< Device supports allocating managed memory on this system */
+    int          isMultiGpuBoard;            /**< Device is on a multi-GPU board */
+    int          multiGpuBoardGroupID;       /**< Unique identifier for a group of devices on the same multi-GPU board */
+    int          hostNativeAtomicSupported;  /**< Link between the device and the host supports native atomic operations */
+    int          singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */
+    int          pageableMemoryAccess;       /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */
+    int          concurrentManagedAccess;    /**< Device can coherently access managed memory concurrently with the CPU */
+    int          computePreemptionSupported; /**< Device supports Compute Preemption */
+    int          canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */
+    int          cooperativeLaunch;          /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */
+    int          cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */
+    size_t       sharedMemPerBlockOptin;     /**< Per device maximum shared memory per block usable by special opt in */
+    int          pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */
+    int          directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */
+    int          maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */
+    int          accessPolicyMaxWindowSize;  /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */
+    size_t       reservedSharedMemPerBlock;  /**< Shared memory reserved by CUDA driver per block in bytes */
+  } cudaDeviceProp_t;
+
 typedef struct cudart_handle {
 typedef struct cudart_handle {
   void *handle;
   void *handle;
   uint16_t verbose;
   uint16_t verbose;
@@ -38,23 +130,17 @@ typedef struct cudart_handle {
   cudartReturn_t (*cudaGetDeviceCount)(int *);
   cudartReturn_t (*cudaGetDeviceCount)(int *);
   cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
   cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
   cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
   cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
+  cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device);
 } cudart_handle_t;
 } cudart_handle_t;
 
 
 typedef struct cudart_init_resp {
 typedef struct cudart_init_resp {
   char *err;  // If err is non-null handle is invalid
   char *err;  // If err is non-null handle is invalid
   cudart_handle_t ch;
   cudart_handle_t ch;
+  int num_devices;
 } cudart_init_resp_t;
 } cudart_init_resp_t;
 
 
-typedef struct cudart_compute_capability {
-  char *err;
-  int major;
-  int minor;
-} cudart_compute_capability_t;
-
-
 void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
 void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
-void cudart_check_vram(cudart_handle_t ch, mem_info_t *resp);
-void cudart_compute_capability(cudart_handle_t th, cudart_compute_capability_t *cc);
+void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp);
 void cudart_release(cudart_handle_t ch);
 void cudart_release(cudart_handle_t ch);
 
 
 #endif  // __GPU_INFO_CUDART_H__
 #endif  // __GPU_INFO_CUDART_H__

+ 203 - 0
gpu/gpu_info_nvcuda.c

@@ -0,0 +1,203 @@
+#ifndef __APPLE__  // TODO - maybe consider nvidia support on intel macs?
+
+#include <string.h>
+#include "gpu_info_nvcuda.h"
+
+void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
+  CUresult ret;
+  resp->err = NULL;
+  resp->num_devices = 0;
+  const int buflen = 256;
+  char buf[buflen + 1];
+  int i;
+
+  struct lookup {
+    char *s;
+    void **p;
+  } l[] = {
+   
+      {"cuInit", (void *)&resp->ch.cuInit},
+      {"cuDriverGetVersion", (void *)&resp->ch.cuDriverGetVersion},
+      {"cuDeviceGetCount", (void *)&resp->ch.cuDeviceGetCount},
+      {"cuDeviceGet", (void *)&resp->ch.cuDeviceGet},
+      {"cuDeviceGetAttribute", (void *)&resp->ch.cuDeviceGetAttribute},
+      {"cuDeviceGetUuid", (void *)&resp->ch.cuDeviceGetUuid},
+      {"cuCtxCreate_v3", (void *)&resp->ch.cuCtxCreate_v3},
+      {"cuMemGetInfo_v2", (void *)&resp->ch.cuMemGetInfo_v2},
+      {"cuCtxDestroy", (void *)&resp->ch.cuCtxDestroy},
+      {NULL, NULL},
+  };
+
+  resp->ch.handle = LOAD_LIBRARY(nvcuda_lib_path, RTLD_LAZY);
+  if (!resp->ch.handle) {
+    char *msg = LOAD_ERR();
+    LOG(resp->ch.verbose, "library %s load err: %s\n", nvcuda_lib_path, msg);
+    snprintf(buf, buflen,
+            "Unable to load %s library to query for Nvidia GPUs: %s",
+            nvcuda_lib_path, msg);
+    free(msg);
+    resp->err = strdup(buf);
+    return;
+  }
+
+  for (i = 0; l[i].s != NULL; i++) {
+    *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
+    if (!*l[i].p) {
+      char *msg = LOAD_ERR();
+      LOG(resp->ch.verbose, "dlerr: %s\n", msg);
+      UNLOAD_LIBRARY(resp->ch.handle);
+      resp->ch.handle = NULL;
+      snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
+              msg);
+      free(msg);
+      resp->err = strdup(buf);
+      return;
+    }
+  }
+
+  ret = (*resp->ch.cuInit)(0);
+  if (ret != CUDA_SUCCESS) {
+    LOG(resp->ch.verbose, "cuInit err: %d\n", ret);
+    UNLOAD_LIBRARY(resp->ch.handle);
+    resp->ch.handle = NULL;
+    if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
+      resp->err = strdup("your nvidia driver is too old or missing.  If you have a CUDA GPU please upgrade to run ollama");
+      return;
+    }
+    snprintf(buf, buflen, "nvcuda init failure: %d", ret);
+    resp->err = strdup(buf);
+    return;
+  }
+
+  int version = 0;
+  nvcudaDriverVersion_t driverVersion;
+  driverVersion.major = 0;
+  driverVersion.minor = 0;
+
+  // Report driver version if we're in verbose mode, ignore errors
+  ret = (*resp->ch.cuDriverGetVersion)(&version);
+  if (ret != CUDA_SUCCESS) {
+    LOG(resp->ch.verbose, "cuDriverGetVersion failed: %d\n", ret);
+  } else {
+    driverVersion.major = version / 1000;
+    driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
+    LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
+  }
+
+  ret = (*resp->ch.cuDeviceGetCount)(&resp->num_devices);
+  if (ret != CUDA_SUCCESS) {
+    LOG(resp->ch.verbose, "cuDeviceGetCount err: %d\n", ret);
+    UNLOAD_LIBRARY(resp->ch.handle);
+    resp->ch.handle = NULL;
+    snprintf(buf, buflen, "unable to get device count: %d", ret);
+    resp->err = strdup(buf);
+    return;
+  }
+}
+
+const int buflen = 256;
+void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) {
+  resp->err = NULL;
+  nvcudaMemory_t memInfo = {0,0};
+  CUresult ret;
+  CUdevice device = -1;
+  CUcontext ctx = NULL;
+  char buf[buflen + 1];
+  CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
+
+  if (h.handle == NULL) {
+    resp->err = strdup("nvcuda handle isn't initialized");
+    return;
+  }
+
+  ret = (*h.cuDeviceGet)(&device, i);
+  if (ret != CUDA_SUCCESS) {
+    snprintf(buf, buflen, "nvcuda device failed to initialize");
+    resp->err = strdup(buf);
+    return;
+  }
+
+  resp->major = 0;
+  resp->minor = 0;
+  int major = 0;
+  int minor = 0;
+  ret = (*h.cuDeviceGetAttribute)(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device);
+  if (ret != CUDA_SUCCESS) {
+    LOG(h.verbose, "[%d] device major lookup failure: %d\n", i, ret);
+  } else {
+    ret = (*h.cuDeviceGetAttribute)(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device);
+    if (ret != CUDA_SUCCESS) {
+      LOG(h.verbose, "[%d] device minor lookup failure: %d\n", i, ret);
+    } else {
+      resp->minor = minor;  
+      resp->major = major;  
+    }
+  }
+
+  ret = (*h.cuDeviceGetUuid)(&uuid, device);
+  if (ret != CUDA_SUCCESS) {
+    LOG(h.verbose, "[%d] device uuid lookup failure: %d\n", i, ret);
+    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
+  } else {
+    // GPU-d110a105-ac29-1d54-7b49-9c90440f215b
+    snprintf(&resp->gpu_id[0], GPU_ID_LEN,
+        "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
+        uuid.bytes[0],
+        uuid.bytes[1],
+        uuid.bytes[2],
+        uuid.bytes[3],
+        uuid.bytes[4],
+        uuid.bytes[5],
+        uuid.bytes[6],
+        uuid.bytes[7],
+        uuid.bytes[8],
+        uuid.bytes[9],
+        uuid.bytes[10],
+        uuid.bytes[11],
+        uuid.bytes[12],
+        uuid.bytes[13],
+        uuid.bytes[14],
+        uuid.bytes[15]
+      );
+  }
+
+  // To get memory we have to set (and release) a context
+  ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
+  if (ret != CUDA_SUCCESS) {
+    snprintf(buf, buflen, "nvcuda failed to get primary device context %d", ret);
+    resp->err = strdup(buf);
+    return;
+  }
+
+  ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total);
+  if (ret != CUDA_SUCCESS) {
+    snprintf(buf, buflen, "nvcuda device memory info lookup failure %d", ret);
+    resp->err = strdup(buf);
+    // Best effort on failure...
+    (*h.cuCtxDestroy)(ctx);
+    return;
+  }
+
+  resp->total = memInfo.total;
+  resp->free = memInfo.free;
+
+  LOG(h.verbose, "[%s] CUDA totalMem %lu mb\n", resp->gpu_id, resp->total / 1024 / 1024);
+  LOG(h.verbose, "[%s] CUDA freeMem %lu mb\n", resp->gpu_id, resp->free / 1024 / 1024);
+  LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
+
+  
+
+  ret = (*h.cuCtxDestroy)(ctx);
+  if (ret != CUDA_SUCCESS) {
+    LOG(1, "nvcuda failed to release primary device context %d", ret);
+  }
+}
+
+void nvcuda_release(nvcuda_handle_t h) {
+  LOG(h.verbose, "releasing nvcuda library\n");
+  UNLOAD_LIBRARY(h.handle);
+  // TODO and other context release logic?
+  h.handle = NULL;
+}
+
+#endif  // __APPLE__

+ 71 - 0
gpu/gpu_info_nvcuda.h

@@ -0,0 +1,71 @@
+#ifndef __APPLE__
+#ifndef __GPU_INFO_NVCUDA_H__
+#define __GPU_INFO_NVCUDA_H__
+#include "gpu_info.h"
+
+// Just enough typedef's to dlopen/dlsym for memory information
+typedef enum cudaError_enum {
+  CUDA_SUCCESS = 0,
+  CUDA_ERROR_INVALID_VALUE = 1,
+  CUDA_ERROR_MEMORY_ALLOCATION = 2,
+  CUDA_ERROR_NOT_INITIALIZED = 3,
+  CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
+  // Other values omitted for now...
+} CUresult;
+
+typedef enum CUdevice_attribute_enum {
+  CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75,
+  CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76,
+
+  // TODO - not yet wired up but may be useful for Jetson or other
+  // integrated GPU scenarios with shared memory
+  CU_DEVICE_ATTRIBUTE_INTEGRATED = 18
+
+} CUdevice_attribute;
+
+typedef void *nvcudaDevice_t;  // Opaque is sufficient
+typedef struct nvcudaMemory_st {
+  uint64_t total;
+  uint64_t free;
+} nvcudaMemory_t;
+
+typedef struct nvcudaDriverVersion {
+  int major;
+  int minor;
+} nvcudaDriverVersion_t;
+
+typedef struct CUuuid_st {
+    unsigned char bytes[16];
+} CUuuid;
+
+typedef int CUdevice;
+typedef void* CUcontext;
+
+typedef struct nvcuda_handle {
+  void *handle;
+  uint16_t verbose;
+  CUresult (*cuInit)(unsigned int Flags);
+  CUresult (*cuDriverGetVersion)(int *driverVersion);
+  CUresult (*cuDeviceGetCount)(int *);
+  CUresult (*cuDeviceGet)(CUdevice* device, int ordinal);
+  CUresult (*cuDeviceGetAttribute)(int* pi, CUdevice_attribute attrib, CUdevice dev);
+  CUresult (*cuDeviceGetUuid)(CUuuid* uuid, CUdevice dev); // signature compatible with cuDeviceGetUuid_v2
+
+  // Context specific aspects
+  CUresult (*cuCtxCreate_v3)(CUcontext* pctx, void *params, int len, unsigned int flags, CUdevice dev);
+  CUresult (*cuMemGetInfo_v2)(uint64_t* free, uint64_t* total);
+  CUresult (*cuCtxDestroy)(CUcontext ctx);
+} nvcuda_handle_t;
+
+typedef struct nvcuda_init_resp {
+  char *err;  // If err is non-null handle is invalid
+  nvcuda_handle_t ch;
+  int num_devices;
+} nvcuda_init_resp_t;
+
+void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);
+void nvcuda_check_vram(nvcuda_handle_t ch, int device_id, mem_info_t *resp);
+void nvcuda_release(nvcuda_handle_t ch);
+
+#endif  // __GPU_INFO_NVCUDA_H__
+#endif  // __APPLE__

+ 0 - 221
gpu/gpu_info_nvml.c

@@ -1,221 +0,0 @@
-#ifndef __APPLE__  // TODO - maybe consider nvidia support on intel macs?
-
-#include <string.h>
-
-#include "gpu_info_nvml.h"
-
-void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
-  nvmlReturn_t ret;
-  resp->err = NULL;
-  const int buflen = 256;
-  char buf[buflen + 1];
-  int i;
-
-  struct lookup {
-    char *s;
-    void **p;
-  } l[] = {
-      {"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
-      {"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
-      {"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
-      {"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
-      {"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2},
-      {"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability},
-      {"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion},
-      {"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName},
-      {"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial},
-      {"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion},
-      {"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber},
-      {"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand},
-      {NULL, NULL},
-  };
-
-  resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY);
-  if (!resp->ch.handle) {
-    char *msg = LOAD_ERR();
-    LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg);
-    snprintf(buf, buflen,
-             "Unable to load %s library to query for Nvidia GPUs: %s",
-             nvml_lib_path, msg);
-    free(msg);
-    resp->err = strdup(buf);
-    return;
-  }
-
-  // TODO once we've squashed the remaining corner cases remove this log
-  LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path);
-  
-  for (i = 0; l[i].s != NULL; i++) {
-    // TODO once we've squashed the remaining corner cases remove this log
-    LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
-
-    *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
-    if (!l[i].p) {
-      resp->ch.handle = NULL;
-      char *msg = LOAD_ERR();
-      LOG(resp->ch.verbose, "dlerr: %s\n", msg);
-      UNLOAD_LIBRARY(resp->ch.handle);
-      snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
-               msg);
-      free(msg);
-      resp->err = strdup(buf);
-      return;
-    }
-  }
-
-  ret = (*resp->ch.nvmlInit_v2)();
-  if (ret != NVML_SUCCESS) {
-    LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
-    UNLOAD_LIBRARY(resp->ch.handle);
-    resp->ch.handle = NULL;
-    snprintf(buf, buflen, "nvml vram init failure: %d", ret);
-    resp->err = strdup(buf);
-    return;
-  }
-
-  // Report driver version if we're in verbose mode, ignore errors
-  ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen);
-  if (ret != NVML_SUCCESS) {
-    LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret);
-  } else {
-    LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf);
-  }
-}
-
-void nvml_check_vram(nvml_handle_t h, mem_info_t *resp) {
-  resp->err = NULL;
-  nvmlDevice_t device;
-  nvmlMemory_t memInfo = {0};
-  nvmlReturn_t ret;
-  const int buflen = 256;
-  char buf[buflen + 1];
-  int i;
-
-  if (h.handle == NULL) {
-    resp->err = strdup("nvml handle isn't initialized");
-    return;
-  }
-
-  ret = (*h.nvmlDeviceGetCount_v2)(&resp->count);
-  if (ret != NVML_SUCCESS) {
-    snprintf(buf, buflen, "unable to get device count: %d", ret);
-    resp->err = strdup(buf);
-    return;
-  }
-
-  resp->total = 0;
-  resp->free = 0;
-  for (i = 0; i < resp->count; i++) {
-    ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
-    if (ret != NVML_SUCCESS) {
-      snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-
-    ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
-    if (ret != NVML_SUCCESS) {
-      snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-    if (h.verbose) {
-      nvmlBrandType_t brand = 0;
-      // When in verbose mode, report more information about
-      // the card we discover, but don't fail on error
-      ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
-      }
-      ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
-      }
-      ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
-      }
-      ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
-      }
-      ret = (*h.nvmlDeviceGetBrand)(device, &brand);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);
-      }
-    }
-
-    LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total);
-    LOG(h.verbose, "[%d] CUDA freeMem %ld\n", i, memInfo.free);
-
-    resp->total += memInfo.total;
-    resp->free += memInfo.free;
-  }
-}
-
-void nvml_compute_capability(nvml_handle_t h, nvml_compute_capability_t *resp) {
-  resp->err = NULL;
-  resp->major = 0;
-  resp->minor = 0;
-  nvmlDevice_t device;
-  int major = 0;
-  int minor = 0;
-  nvmlReturn_t ret;
-  const int buflen = 256;
-  char buf[buflen + 1];
-  int i;
-
-  if (h.handle == NULL) {
-    resp->err = strdup("nvml handle not initialized");
-    return;
-  }
-
-  unsigned int devices;
-  ret = (*h.nvmlDeviceGetCount_v2)(&devices);
-  if (ret != NVML_SUCCESS) {
-    snprintf(buf, buflen, "unable to get device count: %d", ret);
-    resp->err = strdup(buf);
-    return;
-  }
-
-  for (i = 0; i < devices; i++) {
-    ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
-    if (ret != NVML_SUCCESS) {
-      snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-
-    ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor);
-    if (ret != NVML_SUCCESS) {
-      snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-    // Report the lowest major.minor we detect as that limits our compatibility
-    if (resp->major == 0 || resp->major > major ) {
-      resp->major = major;
-      resp->minor = minor;
-    } else if ( resp->major == major && resp->minor > minor ) {
-      resp->minor = minor;
-    }
-  }
-}
-
-void nvml_release(nvml_handle_t h) {
-  LOG(h.verbose, "releasing nvml library\n");
-  UNLOAD_LIBRARY(h.handle);
-  h.handle = NULL;
-}
-
-#endif  // __APPLE__

+ 0 - 57
gpu/gpu_info_nvml.h

@@ -1,57 +0,0 @@
-#ifndef __APPLE__
-#ifndef __GPU_INFO_NVML_H__
-#define __GPU_INFO_NVML_H__
-#include "gpu_info.h"
-
-// Just enough typedef's to dlopen/dlsym for memory information
-typedef enum nvmlReturn_enum {
-  NVML_SUCCESS = 0,
-  // Other values omitted for now...
-} nvmlReturn_t;
-typedef void *nvmlDevice_t;  // Opaque is sufficient
-typedef struct nvmlMemory_st {
-  unsigned long long total;
-  unsigned long long free;
-  unsigned long long used;
-} nvmlMemory_t;
-
-typedef enum nvmlBrandType_enum
-{
-    NVML_BRAND_UNKNOWN          = 0,
-} nvmlBrandType_t;
-
-typedef struct nvml_handle {
-  void *handle;
-  uint16_t verbose;
-  nvmlReturn_t (*nvmlInit_v2)(void);
-  nvmlReturn_t (*nvmlShutdown)(void);
-  nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
-  nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
-  nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *);
-  nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor);
-  nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type);
-} nvml_handle_t;
-
-typedef struct nvml_init_resp {
-  char *err;  // If err is non-null handle is invalid
-  nvml_handle_t ch;
-} nvml_init_resp_t;
-
-typedef struct nvml_compute_capability {
-  char *err;
-  int major;
-  int minor;
-} nvml_compute_capability_t;
-
-void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
-void nvml_check_vram(nvml_handle_t ch, mem_info_t *resp);
-void nvml_compute_capability(nvml_handle_t ch, nvml_compute_capability_t *cc);
-void nvml_release(nvml_handle_t ch);
-
-#endif  // __GPU_INFO_NVML_H__
-#endif  // __APPLE__

+ 6 - 13
gpu/gpu_test.go

@@ -9,23 +9,16 @@ import (
 
 
 func TestBasicGetGPUInfo(t *testing.T) {
 func TestBasicGetGPUInfo(t *testing.T) {
 	info := GetGPUInfo()
 	info := GetGPUInfo()
-	assert.Contains(t, "cuda rocm cpu metal", info.Library)
-
-	switch runtime.GOOS {
-	case "darwin":
-		// TODO - remove this once MacOS returns some size for CPU
-		return
-	case "linux", "windows":
-		assert.Greater(t, info.TotalMemory, uint64(0))
-		assert.Greater(t, info.FreeMemory, uint64(0))
-		assert.Greater(t, info.DeviceCount, uint32(0))
-	default:
-		return
+	assert.Greater(t, len(info), 0)
+	assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
+	if info[0].Library != "cpu" {
+		assert.Greater(t, info[0].TotalMemory, uint64(0))
+		assert.Greater(t, info[0].FreeMemory, uint64(0))
 	}
 	}
 }
 }
 
 
 func TestCPUMemInfo(t *testing.T) {
 func TestCPUMemInfo(t *testing.T) {
-	info, err := getCPUMem()
+	info, err := GetCPUMem()
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	switch runtime.GOOS {
 	switch runtime.GOOS {
 	case "darwin":
 	case "darwin":

+ 43 - 6
gpu/types.go

@@ -3,7 +3,6 @@ package gpu
 type memInfo struct {
 type memInfo struct {
 	TotalMemory uint64 `json:"total_memory,omitempty"`
 	TotalMemory uint64 `json:"total_memory,omitempty"`
 	FreeMemory  uint64 `json:"free_memory,omitempty"`
 	FreeMemory  uint64 `json:"free_memory,omitempty"`
-	DeviceCount uint32 `json:"device_count,omitempty"`
 }
 }
 
 
 // Beginning of an `ollama info` command
 // Beginning of an `ollama info` command
@@ -17,11 +16,49 @@ type GpuInfo struct {
 	// MinimumMemory represents the minimum memory required to use the GPU
 	// MinimumMemory represents the minimum memory required to use the GPU
 	MinimumMemory uint64 `json:"-"`
 	MinimumMemory uint64 `json:"-"`
 
 
-	// TODO add other useful attributes about the card here for discovery information
+	// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
+	DependencyPath string `json:"lib_path,omitempty"`
+
+	// GPU information
+	ID    string `json:"gpu_id"`          // string to use for selection of this specific GPU
+	Name  string `json:"name"`            // user friendly name if available
+	Major int    `json:"major,omitempty"` // Major compatibility version (CC or gfx)
+	Minor int    `json:"minor,omitempty"` // Minor compatibility version (CC or gfx)
+	Patch int    `json:"patch,omitempty"` // Patch compatibility only matters on AMD
+
+	// TODO other performance capability info to help in scheduling decisions
 }
 }
 
 
-type Version struct {
-	Major uint
-	Minor uint
-	Patch uint
+type GpuInfoList []GpuInfo
+
+// Split up the set of gpu info's by Library and variant
+func (l GpuInfoList) ByLibrary() []GpuInfoList {
+	resp := []GpuInfoList{}
+	libs := []string{}
+	for _, info := range l {
+		found := false
+		requested := info.Library
+		if info.Variant != "" {
+			requested += "_" + info.Variant
+		}
+		for i, lib := range libs {
+			if lib == requested {
+				resp[i] = append(resp[i], info)
+				found = true
+				break
+			}
+		}
+		if !found {
+			libs = append(libs, info.Library)
+			resp = append(resp, []GpuInfo{info})
+		}
+	}
+	return resp
 }
 }
+
+// Sort by Free Space
+type ByFreeMemory []GpuInfo
+
+func (a ByFreeMemory) Len() int           { return len(a) }
+func (a ByFreeMemory) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
+func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }

+ 44 - 2
integration/basic_test.go

@@ -4,11 +4,14 @@ package integration
 
 
 import (
 import (
 	"context"
 	"context"
-	"net/http"
+	"log/slog"
+	"os"
+	"runtime"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
+	"github.com/stretchr/testify/require"
 )
 )
 
 
 func TestOrcaMiniBlueSky(t *testing.T) {
 func TestOrcaMiniBlueSky(t *testing.T) {
@@ -24,5 +27,44 @@ func TestOrcaMiniBlueSky(t *testing.T) {
 			"seed":        123,
 			"seed":        123,
 		},
 		},
 	}
 	}
-	GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh", "scattering"})
+	GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
+}
+
+func TestUnicodeModelDir(t *testing.T) {
+	// This is only useful for Windows with utf-16 characters, so skip this test for other platforms
+	if runtime.GOOS != "windows" {
+		t.Skip("Unicode test only applicable to windows")
+	}
+	// Only works for local testing
+	if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
+		t.Skip("TestUnicodeModelDir only works for local testing, skipping")
+	}
+
+	modelDir, err := os.MkdirTemp("", "ollama_埃")
+	require.NoError(t, err)
+	defer os.RemoveAll(modelDir)
+	slog.Info("unicode", "OLLAMA_MODELS", modelDir)
+
+	oldModelsDir := os.Getenv("OLLAMA_MODELS")
+	if oldModelsDir == "" {
+		defer os.Unsetenv("OLLAMA_MODELS")
+	} else {
+		defer os.Setenv("OLLAMA_MODELS", oldModelsDir)
+	}
+	err = os.Setenv("OLLAMA_MODELS", modelDir)
+	require.NoError(t, err)
+
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
+	defer cancel()
+
+	req := api.GenerateRequest{
+		Model:  "orca-mini",
+		Prompt: "why is the sky blue?",
+		Stream: &stream,
+		Options: map[string]interface{}{
+			"temperature": 0,
+			"seed":        123,
+		},
+	}
+	GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
 }
 }

+ 225 - 0
integration/concurrency_test.go

@@ -0,0 +1,225 @@
+//go:build integration
+
+package integration
+
+import (
+	"context"
+	"log/slog"
+	"os"
+	"strconv"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/ollama/ollama/api"
+	"github.com/stretchr/testify/require"
+)
+
+func TestMultiModelConcurrency(t *testing.T) {
+	var (
+		req = [2]api.GenerateRequest{
+			{
+				Model:  "orca-mini",
+				Prompt: "why is the ocean blue?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "tinydolphin",
+				Prompt: "what is the origin of the us thanksgiving holiday?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			},
+		}
+		resp = [2][]string{
+			[]string{"sunlight"},
+			[]string{"england", "english", "massachusetts", "pilgrims"},
+		}
+	)
+	var wg sync.WaitGroup
+	wg.Add(len(req))
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
+	defer cancel()
+	for i := 0; i < len(req); i++ {
+		go func(i int) {
+			defer wg.Done()
+			GenerateTestHelper(ctx, t, req[i], resp[i])
+		}(i)
+	}
+	wg.Wait()
+}
+
+func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
+	defer cancel()
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+
+	req, resp := GenerateRequests()
+	// Get the server running (if applicable) warm the model up with a single initial request
+	DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
+
+	var wg sync.WaitGroup
+	wg.Add(len(req))
+	for i := 0; i < len(req); i++ {
+		go func(i int) {
+			defer wg.Done()
+			for j := 0; j < 5; j++ {
+				slog.Info("Starting", "req", i, "iter", j)
+				// On slower GPUs it can take a while to process the 4 concurrent requests
+				// so we allow a much longer initial timeout
+				DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
+			}
+		}(i)
+	}
+	wg.Wait()
+}
+
+// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
+func TestMultiModelStress(t *testing.T) {
+	vram := os.Getenv("OLLAMA_MAX_VRAM")
+	if vram == "" {
+		t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
+	}
+	max, err := strconv.ParseUint(vram, 10, 64)
+	require.NoError(t, err)
+	const MB = uint64(1024 * 1024)
+	type model struct {
+		name string
+		size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
+	}
+
+	smallModels := []model{
+		{
+			name: "orca-mini",
+			size: 2992 * MB,
+		},
+		{
+			name: "phi",
+			size: 2616 * MB,
+		},
+		{
+			name: "gemma:2b",
+			size: 2364 * MB,
+		},
+		{
+			name: "stable-code:3b",
+			size: 2608 * MB,
+		},
+		{
+			name: "starcoder2:3b",
+			size: 2166 * MB,
+		},
+	}
+	mediumModels := []model{
+		{
+			name: "llama2",
+			size: 5118 * MB,
+		},
+		{
+			name: "mistral",
+			size: 4620 * MB,
+		},
+		{
+			name: "orca-mini:7b",
+			size: 5118 * MB,
+		},
+		{
+			name: "dolphin-mistral",
+			size: 4620 * MB,
+		},
+		{
+			name: "gemma:7b",
+			size: 5000 * MB,
+		},
+		// TODO - uncomment this once #3565 is merged and this is rebased on it
+		// {
+		// 	name: "codellama:7b",
+		// 	size: 5118 * MB,
+		// },
+	}
+
+	// These seem to be too slow to be useful...
+	// largeModels := []model{
+	// 	{
+	// 		name: "llama2:13b",
+	// 		size: 7400 * MB,
+	// 	},
+	// 	{
+	// 		name: "codellama:13b",
+	// 		size: 7400 * MB,
+	// 	},
+	// 	{
+	// 		name: "orca-mini:13b",
+	// 		size: 7400 * MB,
+	// 	},
+	// 	{
+	// 		name: "gemma:7b",
+	// 		size: 5000 * MB,
+	// 	},
+	// 	{
+	// 		name: "starcoder2:15b",
+	// 		size: 9100 * MB,
+	// 	},
+	// }
+
+	var chosenModels []model
+	switch {
+	case max < 10000*MB:
+		slog.Info("selecting small models")
+		chosenModels = smallModels
+	// case max < 30000*MB:
+	default:
+		slog.Info("selecting medium models")
+		chosenModels = mediumModels
+		// default:
+		// 	slog.Info("selecting large models")
+		// 	chosenModels = largModels
+	}
+
+	req, resp := GenerateRequests()
+
+	for i := range req {
+		if i > len(chosenModels) {
+			break
+		}
+		req[i].Model = chosenModels[i].name
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
+	defer cancel()
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+
+	// Make sure all the models are pulled before we get started
+	for _, r := range req {
+		require.NoError(t, PullIfMissing(ctx, client, r.Model))
+	}
+
+	var wg sync.WaitGroup
+	consumed := uint64(256 * MB) // Assume some baseline usage
+	for i := 0; i < len(req); i++ {
+		// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
+		if i > 1 && consumed > max {
+			slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
+			break
+		}
+		consumed += chosenModels[i].size
+		slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
+
+		wg.Add(1)
+		go func(i int) {
+			defer wg.Done()
+			for j := 0; j < 3; j++ {
+				slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
+				DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
+			}
+		}(i)
+	}
+	wg.Wait()
+}

+ 1 - 2
integration/context_test.go

@@ -4,7 +4,6 @@ package integration
 
 
 import (
 import (
 	"context"
 	"context"
-	"net/http"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -25,5 +24,5 @@ func TestContextExhaustion(t *testing.T) {
 			"num_ctx":     128,
 			"num_ctx":     128,
 		},
 		},
 	}
 	}
-	GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"})
+	GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
 }
 }

+ 3 - 3
integration/llm_image_test.go

@@ -5,7 +5,6 @@ package integration
 import (
 import (
 	"context"
 	"context"
 	"encoding/base64"
 	"encoding/base64"
-	"net/http"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) {
 		},
 		},
 	}
 	}
 
 
-	resp := "the ollamas"
+	// Note: sometimes it returns "the ollamas" sometimes "the ollams"
+	resp := "the ollam"
 	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
 	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
 	defer cancel()
 	defer cancel()
-	GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
+	GenerateTestHelper(ctx, t, req, []string{resp})
 }
 }
 
 
 const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
 const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb

+ 1 - 23
integration/llm_test.go

@@ -4,8 +4,6 @@ package integration
 
 
 import (
 import (
 	"context"
 	"context"
-	"net/http"
-	"sync"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -45,25 +43,5 @@ var (
 func TestIntegrationSimpleOrcaMini(t *testing.T) {
 func TestIntegrationSimpleOrcaMini(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
 	defer cancel()
 	defer cancel()
-	GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0])
+	GenerateTestHelper(ctx, t, req[0], resp[0])
 }
 }
-
-// TODO
-// The server always loads a new runner and closes the old one, which forces serial execution
-// At present this test case fails with concurrency problems.  Eventually we should try to
-// get true concurrency working with n_parallel support in the backend
-func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
-	var wg sync.WaitGroup
-	wg.Add(len(req))
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
-	defer cancel()
-	for i := 0; i < len(req); i++ {
-		go func(i int) {
-			defer wg.Done()
-			GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
-		}(i)
-	}
-	wg.Wait()
-}
-
-// TODO - create a parallel test with 2 different models once we support concurrency

+ 117 - 0
integration/max_queue_test.go

@@ -0,0 +1,117 @@
+//go:build integration
+
+package integration
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"log/slog"
+	"os"
+	"strconv"
+	"strings"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/ollama/ollama/api"
+	"github.com/stretchr/testify/require"
+)
+
+func TestMaxQueue(t *testing.T) {
+	// Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU
+	// Also note that by default Darwin can't sustain > ~128 connections without adjusting limits
+	threadCount := 32
+	mq := os.Getenv("OLLAMA_MAX_QUEUE")
+	if mq != "" {
+		var err error
+		threadCount, err = strconv.Atoi(mq)
+		require.NoError(t, err)
+	} else {
+		os.Setenv("OLLAMA_MAX_QUEUE", fmt.Sprintf("%d", threadCount))
+	}
+
+	req := api.GenerateRequest{
+		Model:  "orca-mini",
+		Prompt: "write a long historical fiction story about christopher columbus.  use at least 10 facts from his actual journey",
+		Options: map[string]interface{}{
+			"seed":        42,
+			"temperature": 0.0,
+		},
+	}
+	resp := []string{"explore", "discover", "ocean"}
+
+	// CPU mode takes much longer at the limit with a large queue setting
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+	defer cancel()
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+
+	require.NoError(t, PullIfMissing(ctx, client, req.Model))
+
+	// Context for the worker threads so we can shut them down
+	// embedCtx, embedCancel := context.WithCancel(ctx)
+	embedCtx := ctx
+
+	var genwg sync.WaitGroup
+	go func() {
+		genwg.Add(1)
+		defer genwg.Done()
+		slog.Info("Starting generate request")
+		DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
+		slog.Info("generate completed")
+	}()
+
+	// Give the generate a chance to get started before we start hammering on embed requests
+	time.Sleep(5 * time.Millisecond)
+
+	threadCount += 10 // Add a few extra to ensure we push the queue past its limit
+	busyCount := 0
+	resetByPeerCount := 0
+	canceledCount := 0
+	succesCount := 0
+	counterMu := sync.Mutex{}
+	var embedwg sync.WaitGroup
+	for i := 0; i < threadCount; i++ {
+		go func(i int) {
+			embedwg.Add(1)
+			defer embedwg.Done()
+			slog.Info("embed started", "id", i)
+			embedReq := api.EmbeddingRequest{
+				Model:   req.Model,
+				Prompt:  req.Prompt,
+				Options: req.Options,
+			}
+			// Fresh client for every request
+			client, _ = GetTestEndpoint()
+
+			resp, genErr := client.Embeddings(embedCtx, &embedReq)
+			counterMu.Lock()
+			defer counterMu.Unlock()
+			switch {
+			case genErr == nil:
+				succesCount++
+				require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
+			case errors.Is(genErr, context.Canceled):
+				canceledCount++
+			case strings.Contains(genErr.Error(), "busy"):
+				busyCount++
+			case strings.Contains(genErr.Error(), "connection reset by peer"):
+				resetByPeerCount++
+			default:
+				require.NoError(t, genErr, "%d request failed", i)
+			}
+
+			slog.Info("embed finished", "id", i)
+		}(i)
+	}
+	genwg.Wait()
+	slog.Info("generate done, waiting for embeds")
+	embedwg.Wait()
+
+	require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
+	require.True(t, busyCount > 0, "no requests hit busy error but some should have")
+	require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")
+
+	slog.Info("embeds completed", "success", succesCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
+}

+ 172 - 99
integration/utils_test.go

@@ -5,13 +5,14 @@ package integration
 import (
 import (
 	"bytes"
 	"bytes"
 	"context"
 	"context"
-	"encoding/json"
+	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"log/slog"
 	"log/slog"
 	"math/rand"
 	"math/rand"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
+	"net/url"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"runtime"
 	"runtime"
@@ -23,9 +24,13 @@ import (
 
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/app/lifecycle"
 	"github.com/ollama/ollama/app/lifecycle"
-	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 )
 
 
+func Init() {
+	lifecycle.InitLogging()
+}
+
 func FindPort() string {
 func FindPort() string {
 	port := 0
 	port := 0
 	if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
 	if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
@@ -41,7 +46,7 @@ func FindPort() string {
 	return strconv.Itoa(port)
 	return strconv.Itoa(port)
 }
 }
 
 
-func GetTestEndpoint() (string, string) {
+func GetTestEndpoint() (*api.Client, string) {
 	defaultPort := "11434"
 	defaultPort := "11434"
 	ollamaHost := os.Getenv("OLLAMA_HOST")
 	ollamaHost := os.Getenv("OLLAMA_HOST")
 
 
@@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) {
 		port = FindPort()
 		port = FindPort()
 	}
 	}
 
 
-	url := fmt.Sprintf("%s:%s", host, port)
-	slog.Info("server connection", "url", url)
-	return scheme, url
+	slog.Info("server connection", "host", host, "port", port)
+
+	return api.NewClient(
+		&url.URL{
+			Scheme: scheme,
+			Host:   net.JoinHostPort(host, port),
+		},
+		http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
 }
 }
 
 
-// TODO make fanicier, grab logs, etc.
 var serverMutex sync.Mutex
 var serverMutex sync.Mutex
 var serverReady bool
 var serverReady bool
 
 
-func StartServer(ctx context.Context, ollamaHost string) error {
+func startServer(ctx context.Context, ollamaHost string) error {
 	// Make sure the server has been built
 	// Make sure the server has been built
 	CLIName, err := filepath.Abs("../ollama")
 	CLIName, err := filepath.Abs("../ollama")
 	if err != nil {
 	if err != nil {
@@ -98,7 +107,7 @@ func StartServer(ctx context.Context, ollamaHost string) error {
 
 
 	if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
 	if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
 		slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
 		slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
-		os.Setenv("OLLAMA_HOST", ollamaHost)
+		t.Setenv("OLLAMA_HOST", ollamaHost)
 	}
 	}
 
 
 	slog.Info("starting server", "url", ollamaHost)
 	slog.Info("starting server", "url", ollamaHost)
@@ -125,67 +134,76 @@ func StartServer(ctx context.Context, ollamaHost string) error {
 	return nil
 	return nil
 }
 }
 
 
-func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
+func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
 	slog.Info("checking status of model", "model", modelName)
 	slog.Info("checking status of model", "model", modelName)
 	showReq := &api.ShowRequest{Name: modelName}
 	showReq := &api.ShowRequest{Name: modelName}
-	requestJSON, err := json.Marshal(showReq)
-	if err != nil {
-		return err
-	}
-
-	req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
-	if err != nil {
-		return err
-	}
 
 
-	// Make the request with the HTTP client
-	response, err := client.Do(req.WithContext(ctx))
-	if err != nil {
+	showCtx, cancel := context.WithDeadlineCause(
+		ctx,
+		time.Now().Add(5*time.Second),
+		fmt.Errorf("show for existing model %s took too long", modelName),
+	)
+	defer cancel()
+	_, err := client.Show(showCtx, showReq)
+	var statusError api.StatusError
+	switch {
+	case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
+		break
+	case err != nil:
 		return err
 		return err
-	}
-	defer response.Body.Close()
-	if response.StatusCode == 200 {
+	default:
 		slog.Info("model already present", "model", modelName)
 		slog.Info("model already present", "model", modelName)
 		return nil
 		return nil
 	}
 	}
-	slog.Info("model missing", "status", response.StatusCode)
+	slog.Info("model missing", "model", modelName)
+
+	stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
+	stallTimer := time.NewTimer(stallDuration)
+	fn := func(resp api.ProgressResponse) error {
+		// fmt.Print(".")
+		if !stallTimer.Reset(stallDuration) {
+			return fmt.Errorf("stall was detected, aborting status reporting")
+		}
+		return nil
+	}
 
 
+	stream := true
 	pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
 	pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
-	requestJSON, err = json.Marshal(pullReq)
-	if err != nil {
-		return err
-	}
 
 
-	req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
-	if err != nil {
-		return err
-	}
-	slog.Info("pulling", "model", modelName)
+	var pullError error
 
 
-	response, err = client.Do(req.WithContext(ctx))
-	if err != nil {
-		return err
-	}
-	defer response.Body.Close()
-	if response.StatusCode != 200 {
-		return fmt.Errorf("failed to pull model") // TODO more details perhaps
+	done := make(chan int)
+	go func() {
+		pullError = client.Pull(ctx, pullReq, fn)
+		done <- 0
+	}()
+
+	select {
+	case <-stallTimer.C:
+		return fmt.Errorf("download stalled")
+	case <-done:
+		return pullError
 	}
 	}
-	slog.Info("model pulled", "model", modelName)
-	return nil
 }
 }
 
 
 var serverProcMutex sync.Mutex
 var serverProcMutex sync.Mutex
 
 
-func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
-
-	// TODO maybe stuff in an init routine?
-	lifecycle.InitLogging()
-
-	requestJSON, err := json.Marshal(genReq)
-	if err != nil {
-		t.Fatalf("Error serializing request: %v", err)
+// Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
+// Starts the server if needed
+func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
+	client, testEndpoint := GetTestEndpoint()
+	if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
+		serverProcMutex.Lock()
+		fp, err := os.CreateTemp("", "ollama-server-*.log")
+		if err != nil {
+			t.Fatalf("failed to generate log file: %s", err)
+		}
+		lifecycle.ServerLogFile = fp.Name()
+		fp.Close()
+		require.NoError(t, startServer(ctx, testEndpoint))
 	}
 	}
-	defer func() {
+
+	return client, testEndpoint, func() {
 		if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
 		if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
 			defer serverProcMutex.Unlock()
 			defer serverProcMutex.Unlock()
 			if t.Failed() {
 			if t.Failed() {
@@ -203,63 +221,118 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
 				os.Stderr.Write(data)
 				os.Stderr.Write(data)
 				slog.Warn("END OF SERVER")
 				slog.Warn("END OF SERVER")
 			}
 			}
-			err = os.Remove(lifecycle.ServerLogFile)
+			err := os.Remove(lifecycle.ServerLogFile)
 			if err != nil && !os.IsNotExist(err) {
 			if err != nil && !os.IsNotExist(err) {
 				slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
 				slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
 			}
 			}
 		}
 		}
-	}()
-	scheme, testEndpoint := GetTestEndpoint()
-
-	if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
-		serverProcMutex.Lock()
-		fp, err := os.CreateTemp("", "ollama-server-*.log")
-		if err != nil {
-			t.Fatalf("failed to generate log file: %s", err)
-		}
-		lifecycle.ServerLogFile = fp.Name()
-		fp.Close()
-		assert.NoError(t, StartServer(ctx, testEndpoint))
 	}
 	}
+}
 
 
-	err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
-	if err != nil {
-		t.Fatalf("Error pulling model: %v", err)
-	}
+func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+	require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
+	DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
+}
 
 
-	// Make the request and get the response
-	req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
-	if err != nil {
-		t.Fatalf("Error creating request: %v", err)
+func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
+	stallTimer := time.NewTimer(initialTimeout)
+	var buf bytes.Buffer
+	fn := func(response api.GenerateResponse) error {
+		// fmt.Print(".")
+		buf.Write([]byte(response.Response))
+		if !stallTimer.Reset(streamTimeout) {
+			return fmt.Errorf("stall was detected while streaming response, aborting")
+		}
+		return nil
 	}
 	}
 
 
-	// Set the content type for the request
-	req.Header.Set("Content-Type", "application/json")
+	stream := true
+	genReq.Stream = &stream
+	done := make(chan int)
+	var genErr error
+	go func() {
+		genErr = client.Generate(ctx, &genReq, fn)
+		done <- 0
+	}()
 
 
-	// Make the request with the HTTP client
-	response, err := client.Do(req.WithContext(ctx))
-	if err != nil {
-		t.Fatalf("Error making request: %v", err)
-	}
-	defer response.Body.Close()
-	body, err := io.ReadAll(response.Body)
-	assert.NoError(t, err)
-	assert.Equal(t, response.StatusCode, 200, string(body))
-
-	// Verify the response is valid JSON
-	var payload api.GenerateResponse
-	err = json.Unmarshal(body, &payload)
-	if err != nil {
-		assert.NoError(t, err, body)
+	select {
+	case <-stallTimer.C:
+		if buf.Len() == 0 {
+			t.Errorf("generate never started.  Timed out after :%s", initialTimeout.String())
+		} else {
+			t.Errorf("generate stalled.  Response so far:%s", buf.String())
+		}
+	case <-done:
+		require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
+		// Verify the response contains the expected data
+		response := buf.String()
+		atLeastOne := false
+		for _, resp := range anyResp {
+			if strings.Contains(strings.ToLower(response), resp) {
+				atLeastOne = true
+				break
+			}
+		}
+		require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
+		slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
+	case <-ctx.Done():
+		t.Error("outer test context done while waiting for generate")
 	}
 	}
+}
 
 
-	// Verify the response contains the expected data
-	atLeastOne := false
-	for _, resp := range anyResp {
-		if strings.Contains(strings.ToLower(payload.Response), resp) {
-			atLeastOne = true
-			break
+// Generate a set of requests
+// By default each request uses orca-mini as the model
+func GenerateRequests() ([]api.GenerateRequest, [][]string) {
+	return []api.GenerateRequest{
+			{
+				Model:  "orca-mini",
+				Prompt: "why is the ocean blue?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "orca-mini",
+				Prompt: "why is the color of dirt brown?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "orca-mini",
+				Prompt: "what is the origin of the us thanksgiving holiday?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "orca-mini",
+				Prompt: "what is the origin of independence day?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "orca-mini",
+				Prompt: "what is the composition of air?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			},
+		},
+		[][]string{
+			[]string{"sunlight"},
+			[]string{"soil", "organic", "earth", "black", "tan"},
+			[]string{"england", "english", "massachusetts", "pilgrims"},
+			[]string{"fourth", "july", "declaration", "independence"},
+			[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
 		}
 		}
-	}
-	assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
 }
 }

+ 9 - 8
llm/ext_server/server.cpp

@@ -1032,7 +1032,7 @@ struct llama_server_context
             slot.has_next_token = false;
             slot.has_next_token = false;
         }
         }
 
 
-        if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model))
+        if (!slot.cache_tokens.empty() && llama_token_is_eog(model, result.tok))
         {
         {
             slot.stopped_eos = true;
             slot.stopped_eos = true;
             slot.has_next_token = false;
             slot.has_next_token = false;
@@ -1144,12 +1144,15 @@ struct llama_server_context
 
 
         res.result_json = json
         res.result_json = json
         {
         {
-            {"content",    tkn.text_to_send},
             {"stop",       false},
             {"stop",       false},
             {"slot_id",    slot.id},
             {"slot_id",    slot.id},
             {"multimodal", multimodal}
             {"multimodal", multimodal}
         };
         };
 
 
+        if (!llama_token_is_eog(model, tkn.tok)) {
+            res.result_json["content"] = tkn.text_to_send;
+        }
+
         if (slot.sparams.n_probs > 0)
         if (slot.sparams.n_probs > 0)
         {
         {
             std::vector<completion_token_output> probs_output = {};
             std::vector<completion_token_output> probs_output = {};
@@ -1183,8 +1186,6 @@ struct llama_server_context
             {"model",               params.model_alias},
             {"model",               params.model_alias},
             {"tokens_predicted",    slot.n_decoded},
             {"tokens_predicted",    slot.n_decoded},
             {"tokens_evaluated",    slot.n_prompt_tokens},
             {"tokens_evaluated",    slot.n_prompt_tokens},
-            {"generation_settings", get_formated_generation(slot)},
-            {"prompt",              slot.prompt},
             {"truncated",           slot.truncated},
             {"truncated",           slot.truncated},
             {"stopped_eos",         slot.stopped_eos},
             {"stopped_eos",         slot.stopped_eos},
             {"stopped_word",        slot.stopped_word},
             {"stopped_word",        slot.stopped_word},
@@ -2644,18 +2645,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
             if (strncmp(sep, "int:", 4) == 0) {
             if (strncmp(sep, "int:", 4) == 0) {
                 sep += 4;
                 sep += 4;
                 kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
                 kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
-                kvo.int_value = std::atol(sep);
+                kvo.val_i64 = std::atol(sep);
             } else if (strncmp(sep, "float:", 6) == 0) {
             } else if (strncmp(sep, "float:", 6) == 0) {
                 sep += 6;
                 sep += 6;
                 kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
                 kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
-                kvo.float_value = std::atof(sep);
+                kvo.val_f64 = std::atof(sep);
             } else if (strncmp(sep, "bool:", 5) == 0) {
             } else if (strncmp(sep, "bool:", 5) == 0) {
                 sep += 5;
                 sep += 5;
                 kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
                 kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
                 if (std::strcmp(sep, "true") == 0) {
                 if (std::strcmp(sep, "true") == 0) {
-                    kvo.bool_value = true;
+                    kvo.val_bool = true;
                 } else if (std::strcmp(sep, "false") == 0) {
                 } else if (std::strcmp(sep, "false") == 0) {
-                    kvo.bool_value = false;
+                    kvo.val_bool = false;
                 } else {
                 } else {
                     fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
                     fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
                     invalid_param = true;
                     invalid_param = true;

+ 140 - 0
llm/filetype.go

@@ -0,0 +1,140 @@
+package llm
+
+import "fmt"
+
+type fileType uint32
+
+const (
+	fileTypeF32 fileType = iota
+	fileTypeF16
+	fileTypeQ4_0
+	fileTypeQ4_1
+	fileTypeQ4_1_F16
+	fileTypeQ4_2 // unused
+	fileTypeQ4_3 // unused
+	fileTypeQ8_0
+	fileTypeQ5_0
+	fileTypeQ5_1
+	fileTypeQ2_K
+	fileTypeQ3_K_S
+	fileTypeQ3_K_M
+	fileTypeQ3_K_L
+	fileTypeQ4_K_S
+	fileTypeQ4_K_M
+	fileTypeQ5_K_S
+	fileTypeQ5_K_M
+	fileTypeQ6_K
+	fileTypeIQ2_XXS
+	fileTypeIQ2_XS
+	fileTypeQ2_K_S
+	fileTypeQ3_K_XS
+	fileTypeIQ3_XXS
+
+	fileTypeUnknown
+)
+
+func ParseFileType(s string) (fileType, error) {
+	switch s {
+	case "F32":
+		return fileTypeF32, nil
+	case "F16":
+		return fileTypeF16, nil
+	case "Q4_0":
+		return fileTypeQ4_0, nil
+	case "Q4_1":
+		return fileTypeQ4_1, nil
+	case "Q4_1_F16":
+		return fileTypeQ4_1_F16, nil
+	case "Q8_0":
+		return fileTypeQ8_0, nil
+	case "Q5_0":
+		return fileTypeQ5_0, nil
+	case "Q5_1":
+		return fileTypeQ5_1, nil
+	case "Q2_K":
+		return fileTypeQ2_K, nil
+	case "Q3_K_S":
+		return fileTypeQ3_K_S, nil
+	case "Q3_K_M":
+		return fileTypeQ3_K_M, nil
+	case "Q3_K_L":
+		return fileTypeQ3_K_L, nil
+	case "Q4_K_S":
+		return fileTypeQ4_K_S, nil
+	case "Q4_K_M":
+		return fileTypeQ4_K_M, nil
+	case "Q5_K_S":
+		return fileTypeQ5_K_S, nil
+	case "Q5_K_M":
+		return fileTypeQ5_K_M, nil
+	case "Q6_K":
+		return fileTypeQ6_K, nil
+	case "IQ2_XXS":
+		return fileTypeIQ2_XXS, nil
+	case "IQ2_XS":
+		return fileTypeIQ2_XS, nil
+	case "Q2_K_S":
+		return fileTypeQ2_K_S, nil
+	case "Q3_K_XS":
+		return fileTypeQ3_K_XS, nil
+	case "IQ3_XXS":
+		return fileTypeIQ3_XXS, nil
+	default:
+		return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s)
+	}
+}
+
+func (t fileType) String() string {
+	switch t {
+	case fileTypeF32:
+		return "F32"
+	case fileTypeF16:
+		return "F16"
+	case fileTypeQ4_0:
+		return "Q4_0"
+	case fileTypeQ4_1:
+		return "Q4_1"
+	case fileTypeQ4_1_F16:
+		return "Q4_1_F16"
+	case fileTypeQ8_0:
+		return "Q8_0"
+	case fileTypeQ5_0:
+		return "Q5_0"
+	case fileTypeQ5_1:
+		return "Q5_1"
+	case fileTypeQ2_K:
+		return "Q2_K"
+	case fileTypeQ3_K_S:
+		return "Q3_K_S"
+	case fileTypeQ3_K_M:
+		return "Q3_K_M"
+	case fileTypeQ3_K_L:
+		return "Q3_K_L"
+	case fileTypeQ4_K_S:
+		return "Q4_K_S"
+	case fileTypeQ4_K_M:
+		return "Q4_K_M"
+	case fileTypeQ5_K_S:
+		return "Q5_K_S"
+	case fileTypeQ5_K_M:
+		return "Q5_K_M"
+	case fileTypeQ6_K:
+		return "Q6_K"
+	case fileTypeIQ2_XXS:
+		return "IQ2_XXS"
+	case fileTypeIQ2_XS:
+		return "IQ2_XS"
+	case fileTypeQ2_K_S:
+		return "Q2_K_S"
+	case fileTypeQ3_K_XS:
+		return "Q3_K_XS"
+	case fileTypeIQ3_XXS:
+		return "IQ3_XXS"
+	default:
+		return "unknown"
+	}
+}
+
+func (t fileType) Value() uint32 {
+	return uint32(t)
+}

+ 1 - 1
llm/generate/gen_common.sh

@@ -21,7 +21,7 @@ init_vars() {
         # TODO - add additional optimization flags...
         # TODO - add additional optimization flags...
         CMAKE_DEFS="-DCMAKE_BUILD_TYPE=Release -DLLAMA_SERVER_VERBOSE=off ${CMAKE_DEFS}"
         CMAKE_DEFS="-DCMAKE_BUILD_TYPE=Release -DLLAMA_SERVER_VERBOSE=off ${CMAKE_DEFS}"
     fi
     fi
-    case $(uname -s) in 
+    case $(uname -s) in
     "Darwin")
     "Darwin")
         LIB_EXT="dylib"
         LIB_EXT="dylib"
         WHOLE_ARCHIVE="-Wl,-force_load"
         WHOLE_ARCHIVE="-Wl,-force_load"

+ 30 - 16
llm/generate/gen_linux.sh

@@ -57,21 +57,21 @@ init_vars
 git_module_setup
 git_module_setup
 apply_patches
 apply_patches
 
 
+init_vars
+if [ -z "${OLLAMA_SKIP_STATIC_GENERATE}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then
+    # Builds by default, allows skipping, forces build if OLLAMA_CPU_TARGET="static"
+    # Enables optimized Dockerfile builds using a blanket skip and targeted overrides
+    # Static build for linking into the Go binary
+    init_vars
+    CMAKE_TARGETS="--target llama --target ggml"
+    CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
+    BUILD_DIR="../build/linux/${ARCH}_static"
+    echo "Building static library"
+    build
+fi
 
 
 init_vars
 init_vars
 if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
 if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
-
-    if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then
-        # Static build for linking into the Go binary
-        init_vars
-        CMAKE_TARGETS="--target llama --target ggml"
-        CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
-        BUILD_DIR="../build/linux/${ARCH}_static"
-        echo "Building static library"
-        build
-    fi
-
-
     # Users building from source can tune the exact flags we pass to cmake for configuring
     # Users building from source can tune the exact flags we pass to cmake for configuring
     # llama.cpp, and we'll build only 1 CPU variant in that case as the default.
     # llama.cpp, and we'll build only 1 CPU variant in that case as the default.
     if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
     if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
@@ -165,14 +165,22 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
     fi
     fi
     if [ "${ARCH}" == "arm64" ]; then
     if [ "${ARCH}" == "arm64" ]; then
         echo "ARM CPU detected - disabling unsupported AVX instructions"
         echo "ARM CPU detected - disabling unsupported AVX instructions"
-        
+
         # ARM-based CPUs such as M1 and Tegra do not support AVX extensions.
         # ARM-based CPUs such as M1 and Tegra do not support AVX extensions.
         #
         #
-        # CUDA compute < 6.0 lacks proper FP16 support on ARM. 
-        # Disabling has minimal performance effect while maintaining compatibility. 
+        # CUDA compute < 6.0 lacks proper FP16 support on ARM.
+        # Disabling has minimal performance effect while maintaining compatibility.
         ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off"
         ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off"
     fi
     fi
-    CMAKE_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS}"
+    # Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
+    if [ -n "${OLLAMA_CUSTOM_CUDA_DEFS}" ]; then
+        echo "OLLAMA_CUSTOM_CUDA_DEFS=\"${OLLAMA_CUSTOM_CUDA_DEFS}\""
+        CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
+        echo "Building custom CUDA GPU"
+    else
+        CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
+    fi
+    CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
     BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
     BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
     EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
     EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
     build
     build
@@ -217,6 +225,12 @@ if [ -d "${ROCM_PATH}" ]; then
     fi
     fi
     init_vars
     init_vars
     CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
     CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
+    # Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
+    if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
+        echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""
+        CMAKE_DEFS="${CMAKE_DEFS} ${OLLAMA_CUSTOM_ROCM_DEFS}"
+        echo "Building custom ROCM GPU"
+    fi
     BUILD_DIR="../build/linux/${ARCH}/rocm${ROCM_VARIANT}"
     BUILD_DIR="../build/linux/${ARCH}/rocm${ROCM_VARIANT}"
     EXTRA_LIBS="-L${ROCM_PATH}/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ -Wl,-rpath,\$ORIGIN/../../rocm/ -lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu"
     EXTRA_LIBS="-L${ROCM_PATH}/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ -Wl,-rpath,\$ORIGIN/../../rocm/ -lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu"
     build
     build

+ 190 - 112
llm/generate/gen_windows.ps1

@@ -26,15 +26,25 @@ function amdGPUs {
     $GPU_LIST -join ';'
     $GPU_LIST -join ';'
 }
 }
 
 
+
 function init_vars {
 function init_vars {
-    $script:SRC_DIR = $(resolve-path "..\..\")
-    $script:llamacppDir = "../llama.cpp"
+    if (!$script:SRC_DIR) {
+        $script:SRC_DIR = $(resolve-path "..\..\")
+    }
+    if (!$script:llamacppDir) {
+        $script:llamacppDir = "../llama.cpp"
+    }
+    if (!$script:cmakeTargets) {
+        $script:cmakeTargets = @("ollama_llama_server")
+    }
     $script:cmakeDefs = @(
     $script:cmakeDefs = @(
         "-DBUILD_SHARED_LIBS=on",
         "-DBUILD_SHARED_LIBS=on",
         "-DLLAMA_NATIVE=off"
         "-DLLAMA_NATIVE=off"
         )
         )
-    $script:cmakeTargets = @("ollama_llama_server")
-    $script:ARCH = "amd64" # arm not yet supported.
+    $script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
+    $script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower()
+    $script:DIST_BASE = "${script:SRC_DIR}\dist\windows-${script:ARCH}\ollama_runners"
+    md "$script:DIST_BASE" -ea 0 > $null
     if ($env:CGO_CFLAGS -contains "-g") {
     if ($env:CGO_CFLAGS -contains "-g") {
         $script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on", "-DCMAKE_BUILD_TYPE=RelWithDebInfo")
         $script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on", "-DCMAKE_BUILD_TYPE=RelWithDebInfo")
         $script:config = "RelWithDebInfo"
         $script:config = "RelWithDebInfo"
@@ -55,7 +65,6 @@ function init_vars {
     } else {
     } else {
         $script:CUDA_LIB_DIR=$env:CUDA_LIB_DIR
         $script:CUDA_LIB_DIR=$env:CUDA_LIB_DIR
     }
     }
-    $script:GZIP=(get-command -ea 'silentlycontinue' gzip).path
     $script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
     $script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
     if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) {
     if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) {
         $script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
         $script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
@@ -134,21 +143,18 @@ function sign {
     }
     }
 }
 }
 
 
-function compress {
-    if ($script:GZIP -eq $null) {
-        write-host "gzip not installed, not compressing files"
-        return
-    }
-    write-host "Compressing binaries..."
+function install {
+    write-host "Installing binaries to dist dir ${script:distDir}"
+    mkdir ${script:distDir} -ErrorAction SilentlyContinue
     $binaries = dir "${script:buildDir}/bin/*.exe"
     $binaries = dir "${script:buildDir}/bin/*.exe"
     foreach ($file in $binaries) {
     foreach ($file in $binaries) {
-        & "$script:GZIP" --best -f $file
+        copy-item -Path $file -Destination ${script:distDir} -Force
     }
     }
 
 
-    write-host "Compressing dlls..."
+    write-host "Installing dlls to dist dir ${script:distDir}"
     $dlls = dir "${script:buildDir}/bin/*.dll"
     $dlls = dir "${script:buildDir}/bin/*.dll"
     foreach ($file in $dlls) {
     foreach ($file in $dlls) {
-        & "$script:GZIP" --best -f $file
+        copy-item -Path $file -Destination ${script:distDir} -Force
     }
     }
 }
 }
 
 
@@ -169,123 +175,195 @@ function cleanup {
     }
     }
 }
 }
 
 
-init_vars
-git_module_setup
-apply_patches
 
 
 # -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
 # -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
 # -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
 # -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
 # -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
 # -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
 
 
-$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
-
-if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
-
-# GCC build for direct linking into the Go binary
-init_vars
-# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
-# as we need this to be compiled by gcc for golang to be able to link with itx
-write-host "Checking for MinGW..."
-# error action ensures we exit on failure
-get-command gcc
-get-command mingw32-make
-$script:cmakeTargets = @("llama", "ggml")
-$script:cmakeDefs = @(
-    "-G", "MinGW Makefiles"
-    "-DCMAKE_C_COMPILER=gcc.exe",
-    "-DCMAKE_CXX_COMPILER=g++.exe",
-    "-DBUILD_SHARED_LIBS=off",
-    "-DLLAMA_NATIVE=off",
-    "-DLLAMA_AVX=off",
-    "-DLLAMA_AVX2=off",
-    "-DLLAMA_AVX512=off",
-    "-DLLAMA_F16C=off",
-    "-DLLAMA_FMA=off")
-$script:buildDir="../build/windows/${script:ARCH}_static"
-write-host "Building static library"
-build
 
 
-# remaining llama.cpp builds use MSVC 
-    init_vars
-    $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
-    $script:buildDir="../build/windows/${script:ARCH}/cpu"
-    write-host "Building LCD CPU"
-    build
-    sign
-    compress
+function build_static() {
+    if ((-not "${env:OLLAMA_SKIP_STATIC_GENERATE}") -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "static"))) {
+        # GCC build for direct linking into the Go binary
+        init_vars
+        # cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
+        # as we need this to be compiled by gcc for golang to be able to link with itx
+        write-host "Checking for MinGW..."
+        # error action ensures we exit on failure
+        get-command gcc
+        get-command mingw32-make
+        $oldTargets = $script:cmakeTargets
+        $script:cmakeTargets = @("llama", "ggml")
+        $script:cmakeDefs = @(
+            "-G", "MinGW Makefiles"
+            "-DCMAKE_C_COMPILER=gcc.exe",
+            "-DCMAKE_CXX_COMPILER=g++.exe",
+            "-DBUILD_SHARED_LIBS=off",
+            "-DLLAMA_NATIVE=off",
+            "-DLLAMA_AVX=off",
+            "-DLLAMA_AVX2=off",
+            "-DLLAMA_AVX512=off",
+            "-DLLAMA_F16C=off",
+            "-DLLAMA_FMA=off")
+        $script:buildDir="../build/windows/${script:ARCH}_static"
+        write-host "Building static library"
+        build
+        $script:cmakeTargets = $oldTargets
+    } else {
+        write-host "Skipping CPU generation step as requested"
+    }
+}
 
 
-    init_vars
-    $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
-    $script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
-    write-host "Building AVX CPU"
-    build
-    sign
-    compress
+function build_cpu($gen_arch) {
+    if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
+        # remaining llama.cpp builds use MSVC 
+        init_vars
+        $script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
+        $script:buildDir="../build/windows/${script:ARCH}/cpu"
+        $script:distDir="$script:DIST_BASE\cpu"
+        write-host "Building LCD CPU"
+        build
+        sign
+        install
+    } else {
+        write-host "Skipping CPU generation step as requested"
+    }
+}
 
 
-    init_vars
-    $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
-    $script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
-    write-host "Building AVX2 CPU"
-    build
-    sign
-    compress
-} else {
-    write-host "Skipping CPU generation step as requested"
+function build_cpu_avx() {
+    if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx"))) {
+        init_vars
+        $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
+        $script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
+        $script:distDir="$script:DIST_BASE\cpu_avx"
+        write-host "Building AVX CPU"
+        build
+        sign
+        install
+    } else {
+        write-host "Skipping CPU AVX generation step as requested"
+    }
 }
 }
 
 
-if ($null -ne $script:CUDA_LIB_DIR) {
-    # Then build cuda as a dynamically loaded library
-    $nvcc = "$script:CUDA_LIB_DIR\nvcc.exe"
-    $script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename
-    if ($null -ne $script:CUDA_VERSION) {
-        $script:CUDA_VARIANT="_"+$script:CUDA_VERSION
+function build_cpu_avx2() {
+    if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx2"))) {
+        init_vars
+        $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
+        $script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
+        $script:distDir="$script:DIST_BASE\cpu_avx2"
+        write-host "Building AVX2 CPU"
+        build
+        sign
+        install
+    } else {
+        write-host "Skipping CPU AVX2 generation step as requested"
     }
     }
-    init_vars
-    $script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
-    $script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
-    build
-    sign
-    compress
 }
 }
 
 
-if ($null -ne $env:HIP_PATH) {
-    $script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
-    if ($null -ne $script:ROCM_VERSION) {
-        $script:ROCM_VARIANT="_v"+$script:ROCM_VERSION
+function build_cuda() {
+    if ((-not "${env:OLLAMA_SKIP_CUDA_GENERATE}") -and ("${script:CUDA_LIB_DIR}")) {
+        # Then build cuda as a dynamically loaded library
+        $nvcc = "$script:CUDA_LIB_DIR\nvcc.exe"
+        $script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename
+        if ($null -ne $script:CUDA_VERSION) {
+            $script:CUDA_VARIANT="_"+$script:CUDA_VERSION
+        }
+        init_vars
+        $script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
+        $script:distDir="$script:DIST_BASE\cuda$script:CUDA_VARIANT"
+        $script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
+        if ($null -ne $env:OLLAMA_CUSTOM_CUDA_DEFS) {
+            write-host "OLLAMA_CUSTOM_CUDA_DEFS=`"${env:OLLAMA_CUSTOM_CUDA_DEFS}`""
+            $script:cmakeDefs +=@("${env:OLLAMA_CUSTOM_CUDA_DEFS}")
+            write-host "building custom CUDA GPU"
+        }
+        build
+        sign
+        install
+
+        write-host "copying CUDA dependencies to ${script:SRC_DIR}\dist\windows-${script:ARCH}\"
+        cp "${script:CUDA_LIB_DIR}\cudart64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
+        cp "${script:CUDA_LIB_DIR}\cublas64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
+        cp "${script:CUDA_LIB_DIR}\cublasLt64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
+    } else {
+        write-host "Skipping CUDA generation step"
     }
     }
+}
 
 
-    init_vars
-    $script:buildDir="../build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT"
-    $script:cmakeDefs += @(
-        "-G", "Ninja", 
-        "-DCMAKE_C_COMPILER=clang.exe",
-        "-DCMAKE_CXX_COMPILER=clang++.exe",
-        "-DLLAMA_HIPBLAS=on",
-        "-DHIP_PLATFORM=amd",
-        "-DLLAMA_AVX=on",
-        "-DLLAMA_AVX2=off",
-        "-DCMAKE_POSITION_INDEPENDENT_CODE=on",
-        "-DAMDGPU_TARGETS=$(amdGPUs)",
-        "-DGPU_TARGETS=$(amdGPUs)"
-        )
+function build_rocm() {
+    if ((-not "${env:OLLAMA_SKIP_ROCM_GENERATE}") -and ("${env:HIP_PATH}")) {
+        $script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
+        if ($null -ne $script:ROCM_VERSION) {
+            $script:ROCM_VARIANT="_v"+$script:ROCM_VERSION
+        }
+
+        init_vars
+        $script:buildDir="../build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT"
+        $script:distDir="$script:DIST_BASE\rocm$script:ROCM_VARIANT"
+        $script:cmakeDefs += @(
+            "-G", "Ninja", 
+            "-DCMAKE_C_COMPILER=clang.exe",
+            "-DCMAKE_CXX_COMPILER=clang++.exe",
+            "-DLLAMA_HIPBLAS=on",
+            "-DHIP_PLATFORM=amd",
+            "-DLLAMA_AVX=on",
+            "-DLLAMA_AVX2=off",
+            "-DCMAKE_POSITION_INDEPENDENT_CODE=on",
+            "-DAMDGPU_TARGETS=$(amdGPUs)",
+            "-DGPU_TARGETS=$(amdGPUs)"
+            )
 
 
-    # Make sure the ROCm binary dir is first in the path
-    $env:PATH="$env:HIP_PATH\bin;$env:PATH"
+        # Make sure the ROCm binary dir is first in the path
+        $env:PATH="$env:HIP_PATH\bin;$env:PATH"
 
 
-    # We have to clobber the LIB var from the developer shell for clang to work properly
-    $env:LIB=""
+        # We have to clobber the LIB var from the developer shell for clang to work properly
+        $env:LIB=""
+        if ($null -ne $env:OLLAMA_CUSTOM_ROCM_DEFS) {
+            write-host "OLLAMA_CUSTOM_ROCM_DEFS=`"${env:OLLAMA_CUSTOM_ROCM_DEFS}`""
+            $script:cmakeDefs += @("${env:OLLAMA_CUSTOM_ROCM_DEFS}")
+            write-host "building custom ROCM GPU"
+        }
+        write-host "Building ROCm"
+        build
+        # Ninja doesn't prefix with config name
+        ${script:config}=""
+        if ($null -ne $script:DUMPBIN) {
+            & "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | select-string ".dll"
+        }
+        sign
+        install
 
 
-    write-host "Building ROCm"
-    build
-    # Ninja doesn't prefix with config name
-    ${script:config}=""
-    if ($null -ne $script:DUMPBIN) {
-        & "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | select-string ".dll"
+        # Assumes v5.7, may need adjustments for v6
+        rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
+        md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
+        cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
+        cp "${env:HIP_PATH}\bin\rocblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
+        # amdhip64.dll dependency comes from the driver and must be installed on the host to use AMD GPUs
+        cp "${env:HIP_PATH}\bin\rocblas\library\*" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\"
+    } else {
+        write-host "Skipping ROCm generation step"
     }
     }
-    sign
-    compress
 }
 }
 
 
+init_vars
+if ($($args.count) -eq 0) {
+    git_module_setup
+    apply_patches
+    build_static
+    if ($script:ARCH -eq "arm64") {
+        build_cpu("ARM64")
+    } else { # amd64
+        build_cpu("x64")
+        build_cpu_avx
+        build_cpu_avx2
+        build_cuda
+        build_rocm
+    }
 
 
-cleanup
-write-host "`ngo generate completed.  LLM runners: $(get-childitem -path ${script:SRC_DIR}\llm\build\windows\${script:ARCH})"
+    cleanup
+    write-host "`ngo generate completed.  LLM runners: $(get-childitem -path $script:DIST_BASE)"
+} else {
+    for ( $i = 0; $i -lt $args.count; $i++ ) {
+        write-host "performing $($args[$i])"
+        & $($args[$i])
+    } 
+}

+ 28 - 79
llm/ggml.go

@@ -13,82 +13,6 @@ type GGML struct {
 	model
 	model
 }
 }
 
 
-const (
-	fileTypeF32 uint32 = iota
-	fileTypeF16
-	fileTypeQ4_0
-	fileTypeQ4_1
-	fileTypeQ4_1_F16
-	fileTypeQ8_0 uint32 = iota + 2
-	fileTypeQ5_0
-	fileTypeQ5_1
-	fileTypeQ2_K
-	fileTypeQ3_K_S
-	fileTypeQ3_K_M
-	fileTypeQ3_K_L
-	fileTypeQ4_K_S
-	fileTypeQ4_K_M
-	fileTypeQ5_K_S
-	fileTypeQ5_K_M
-	fileTypeQ6_K
-	fileTypeIQ2_XXS
-	fileTypeIQ2_XS
-	fileTypeQ2_K_S
-	fileTypeQ3_K_XS
-	fileTypeIQ3_XXS
-)
-
-func fileType(fileType uint32) string {
-	switch fileType {
-	case fileTypeF32:
-		return "F32"
-	case fileTypeF16:
-		return "F16"
-	case fileTypeQ4_0:
-		return "Q4_0"
-	case fileTypeQ4_1:
-		return "Q4_1"
-	case fileTypeQ4_1_F16:
-		return "Q4_1_F16"
-	case fileTypeQ8_0:
-		return "Q8_0"
-	case fileTypeQ5_0:
-		return "Q5_0"
-	case fileTypeQ5_1:
-		return "Q5_1"
-	case fileTypeQ2_K:
-		return "Q2_K"
-	case fileTypeQ3_K_S:
-		return "Q3_K_S"
-	case fileTypeQ3_K_M:
-		return "Q3_K_M"
-	case fileTypeQ3_K_L:
-		return "Q3_K_L"
-	case fileTypeQ4_K_S:
-		return "Q4_K_S"
-	case fileTypeQ4_K_M:
-		return "Q4_K_M"
-	case fileTypeQ5_K_S:
-		return "Q5_K_S"
-	case fileTypeQ5_K_M:
-		return "Q5_K_M"
-	case fileTypeQ6_K:
-		return "Q6_K"
-	case fileTypeIQ2_XXS:
-		return "IQ2_XXS"
-	case fileTypeIQ2_XS:
-		return "IQ2_XS"
-	case fileTypeQ2_K_S:
-		return "Q2_K_S"
-	case fileTypeQ3_K_XS:
-		return "Q3_K_XS"
-	case fileTypeIQ3_XXS:
-		return "IQ3_XXS"
-	default:
-		return "unknown"
-	}
-}
-
 type model interface {
 type model interface {
 	KV() KV
 	KV() KV
 	Tensors() Tensors
 	Tensors() Tensors
@@ -121,12 +45,12 @@ func (kv KV) ParameterCount() uint64 {
 	return kv.u64("general.parameter_count")
 	return kv.u64("general.parameter_count")
 }
 }
 
 
-func (kv KV) FileType() string {
+func (kv KV) FileType() fileType {
 	if u64 := kv.u64("general.file_type"); u64 > 0 {
 	if u64 := kv.u64("general.file_type"); u64 > 0 {
 		return fileType(uint32(u64))
 		return fileType(uint32(u64))
 	}
 	}
 
 
-	return "unknown"
+	return fileTypeUnknown
 }
 }
 
 
 func (kv KV) BlockCount() uint64 {
 func (kv KV) BlockCount() uint64 {
@@ -286,6 +210,23 @@ const (
 
 
 var ErrUnsupportedFormat = errors.New("unsupported model format")
 var ErrUnsupportedFormat = errors.New("unsupported model format")
 
 
+func DetectGGMLType(b []byte) string {
+	switch binary.LittleEndian.Uint32(b[:4]) {
+	case FILE_MAGIC_GGML:
+		return "ggml"
+	case FILE_MAGIC_GGMF:
+		return "ggmf"
+	case FILE_MAGIC_GGJT:
+		return "ggjt"
+	case FILE_MAGIC_GGLA:
+		return "ggla"
+	case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
+		return "gguf"
+	default:
+		return ""
+	}
+}
+
 func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
 func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
 	var magic uint32
 	var magic uint32
 	if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
 	if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
@@ -343,7 +284,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
 		)
 		)
 
 
-		if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok {
+		if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
+			// mixtral 8x22b
+			ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
+			partialOffload = max(
+				3*ffnGateExpsWeight.size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
+				4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch),
+			)
+		} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
+			// mixtral 8x7b
 			ffnGateWeight1 := ffnGateWeight.Shape[1]
 			ffnGateWeight1 := ffnGateWeight.Shape[1]
 			fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
 			fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
 			partialOffload = max(
 			partialOffload = max(

+ 10 - 6
llm/gguf.go

@@ -190,8 +190,6 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
 		llm.kv[k] = v
 		llm.kv[k] = v
 	}
 	}
 
 
-	slog.Debug(fmt.Sprintf("general.architecture = %s", llm.kv["general.architecture"]))
-
 	// decode tensors
 	// decode tensors
 	for i := 0; uint64(i) < llm.numTensor(); i++ {
 	for i := 0; uint64(i) < llm.numTensor(); i++ {
 		name, err := readGGUFString(llm, rs)
 		name, err := readGGUFString(llm, rs)
@@ -465,11 +463,13 @@ var ggufKVOrder = map[string][]string{
 		"llama.embedding_length",
 		"llama.embedding_length",
 		"llama.block_count",
 		"llama.block_count",
 		"llama.feed_forward_length",
 		"llama.feed_forward_length",
-		"llama.rope.dimension_count",
 		"llama.attention.head_count",
 		"llama.attention.head_count",
 		"llama.attention.head_count_kv",
 		"llama.attention.head_count_kv",
 		"llama.attention.layer_norm_rms_epsilon",
 		"llama.attention.layer_norm_rms_epsilon",
 		"llama.rope.freq_base",
 		"llama.rope.freq_base",
+		"llama.rope.dimension_count",
+		"llama.expert_count",
+		"llama.expert_used_count",
 		"gemma.context_length",
 		"gemma.context_length",
 		"gemma.embedding_length",
 		"gemma.embedding_length",
 		"gemma.block_count",
 		"gemma.block_count",
@@ -577,6 +577,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
 					return err
 					return err
 				}
 				}
 			}
 			}
+		default:
+			return fmt.Errorf("improper type for '%s'", k)
 		}
 		}
 		if err != nil {
 		if err != nil {
 			return err
 			return err
@@ -598,9 +600,11 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
 			return err
 			return err
 		}
 		}
 
 
-		dims := 1
-		if tensor.Shape[1] > 0 {
-			dims = 2
+		dims := 0
+		for cnt := 0; cnt < len(tensor.Shape); cnt++ {
+			if tensor.Shape[cnt] > 0 {
+				dims++
+			}
 		}
 		}
 
 
 		if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil {
 		if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil {

+ 1 - 1
llm/llama.cpp

@@ -1 +1 @@
-Subproject commit 7593639ce335e8d7f89aa9a54d616951f273af60
+Subproject commit 952d03dbead16e4dbdd1d3458486340673cc2465

+ 5 - 52
llm/llm.go

@@ -4,6 +4,7 @@ package llm
 // #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
 // #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
 // #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
 // #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
 // #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
 // #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
+// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/libllama.a -static -lstdc++
 // #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
 // #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
 // #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
 // #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
 // #include <stdlib.h>
 // #include <stdlib.h>
@@ -19,7 +20,7 @@ func SystemInfo() string {
 	return C.GoString(C.llama_print_system_info())
 	return C.GoString(C.llama_print_system_info())
 }
 }
 
 
-func Quantize(infile, outfile, filetype string) error {
+func Quantize(infile, outfile string, ftype fileType) error {
 	cinfile := C.CString(infile)
 	cinfile := C.CString(infile)
 	defer C.free(unsafe.Pointer(cinfile))
 	defer C.free(unsafe.Pointer(cinfile))
 
 
@@ -28,58 +29,10 @@ func Quantize(infile, outfile, filetype string) error {
 
 
 	params := C.llama_model_quantize_default_params()
 	params := C.llama_model_quantize_default_params()
 	params.nthread = -1
 	params.nthread = -1
+	params.ftype = ftype.Value()
 
 
-	switch filetype {
-	case "F32":
-		params.ftype = fileTypeF32
-	case "F16":
-		params.ftype = fileTypeF16
-	case "Q4_0":
-		params.ftype = fileTypeQ4_0
-	case "Q4_1":
-		params.ftype = fileTypeQ4_1
-	case "Q4_1_F16":
-		params.ftype = fileTypeQ4_1_F16
-	case "Q8_0":
-		params.ftype = fileTypeQ8_0
-	case "Q5_0":
-		params.ftype = fileTypeQ5_0
-	case "Q5_1":
-		params.ftype = fileTypeQ5_1
-	case "Q2_K":
-		params.ftype = fileTypeQ2_K
-	case "Q3_K_S":
-		params.ftype = fileTypeQ3_K_S
-	case "Q3_K_M":
-		params.ftype = fileTypeQ3_K_M
-	case "Q3_K_L":
-		params.ftype = fileTypeQ3_K_L
-	case "Q4_K_S":
-		params.ftype = fileTypeQ4_K_S
-	case "Q4_K_M":
-		params.ftype = fileTypeQ4_K_M
-	case "Q5_K_S":
-		params.ftype = fileTypeQ5_K_S
-	case "Q5_K_M":
-		params.ftype = fileTypeQ5_K_M
-	case "Q6_K":
-		params.ftype = fileTypeQ6_K
-	case "IQ2_XXS":
-		params.ftype = fileTypeIQ2_XXS
-	case "IQ2_XS":
-		params.ftype = fileTypeIQ2_XS
-	case "Q2_K_S":
-		params.ftype = fileTypeQ2_K_S
-	case "Q3_K_XS":
-		params.ftype = fileTypeQ3_K_XS
-	case "IQ3_XXS":
-		params.ftype = fileTypeIQ3_XXS
-	default:
-		return fmt.Errorf("unknown filetype: %s", filetype)
-	}
-
-	if retval := C.llama_model_quantize(cinfile, coutfile, &params); retval != 0 {
-		return fmt.Errorf("llama_model_quantize: %d", retval)
+	if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
+		return fmt.Errorf("llama_model_quantize: %d", rc)
 	}
 	}
 
 
 	return nil
 	return nil

+ 1 - 1
llm/llm_windows.go

@@ -2,5 +2,5 @@ package llm
 
 
 import "embed"
 import "embed"
 
 
-//go:embed build/windows/*/*/bin/*
+// unused on windows
 var libEmbed embed.FS
 var libEmbed embed.FS

+ 185 - 0
llm/memory.go

@@ -0,0 +1,185 @@
+package llm
+
+import (
+	"fmt"
+	"log/slog"
+
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/format"
+	"github.com/ollama/ollama/gpu"
+	"github.com/ollama/ollama/server/envconfig"
+)
+
+// This algorithm looks for a complete fit to determine if we need to unload other models
+func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
+	var estimatedVRAM uint64
+	if opts.NumCtx > int(ggml.KV().ContextLength()) {
+		slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
+		opts.NumCtx = int(ggml.KV().ContextLength())
+	}
+
+	if opts.NumCtx < 4 {
+		opts.NumCtx = 4
+	}
+
+	// Split up the GPUs by type and try them
+	for _, gpus := range allGpus.ByLibrary() {
+		var layerCount int
+		layerCount, estimatedVRAM, _ = EstimateGPULayers(gpus, ggml, projectors, opts)
+		if opts.NumGPU < 0 {
+			if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
+				return true, estimatedVRAM
+			}
+		} else {
+			if layerCount > 0 && layerCount >= opts.NumGPU {
+				return true, estimatedVRAM
+			}
+		}
+	}
+	return false, estimatedVRAM
+}
+
+// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
+// The GPUs provided must all be the same Library
+func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64, uint64) {
+	var memoryAvailable uint64
+	for _, info := range gpus {
+		memoryAvailable += info.FreeMemory
+	}
+	if envconfig.MaxVRAM > 0 {
+		memoryAvailable = envconfig.MaxVRAM
+	}
+
+	slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))
+
+	// TODO - this is probably wrong, first GPU vs secondaries will have different overheads
+	memoryMinimum := gpus[0].MinimumMemory
+
+	for _, projector := range projectors {
+		memoryMinimum += projectorMemoryRequirements(projector)
+
+		// multimodal models require at least 2048 context
+		opts.NumCtx = max(opts.NumCtx, 2048)
+	}
+
+	// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
+	var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
+
+	graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
+	if graphPartialOffload == 0 {
+		graphPartialOffload = ggml.KV().GQA() * kv / 6
+	}
+
+	if graphFullOffload == 0 {
+		graphFullOffload = graphPartialOffload
+	}
+
+	graphFullOffload *= uint64(len(gpus))
+	graphPartialOffload *= uint64(len(gpus))
+
+	// on metal there's no partial offload overhead
+	if gpus[0].Library == "metal" {
+		graphPartialOffload = graphFullOffload
+	}
+
+	layers := ggml.Tensors().Layers()
+
+	// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
+	memoryRequiredTotal := memoryMinimum + graphFullOffload + layers["blk.0"].size()
+
+	// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
+	memoryRequiredPartial := memoryMinimum + graphPartialOffload + layers["blk.0"].size()
+
+	var memoryLayerOutput uint64
+	if layer, ok := layers["output_norm"]; ok {
+		memoryLayerOutput += layer.size()
+	}
+
+	if layer, ok := layers["output"]; ok {
+		memoryLayerOutput += layer.size()
+	} else if layer, ok := layers["token_embd"]; ok {
+		memoryLayerOutput += layer.size()
+	}
+
+	if gpus[0].Library == "metal" && opts.UseMMap {
+		// memory is preallocated for output tensors
+		memoryRequiredTotal += memoryLayerOutput
+		memoryRequiredPartial += memoryLayerOutput
+	}
+
+	var layerCount int
+	for i := 0; i < int(ggml.KV().BlockCount()); i++ {
+		memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
+
+		// KV is proportional to the number of layers
+		memoryLayer += kv / ggml.KV().BlockCount()
+
+		memoryRequiredTotal += memoryLayer
+		if memoryAvailable > memoryRequiredPartial+memoryLayer {
+			memoryRequiredPartial += memoryLayer
+			layerCount++
+		}
+	}
+
+	if gpus[0].Library != "metal" || !opts.UseMMap {
+		// memory was not preallocated for output tensors
+		memoryRequiredTotal += memoryLayerOutput
+	}
+
+	if memoryAvailable > memoryRequiredTotal {
+		layerCount = int(ggml.KV().BlockCount()) + 1
+		memoryRequiredPartial = memoryRequiredTotal
+	}
+
+	memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
+
+	slog.Info(
+		"offload to gpu",
+		slog.Group(
+			"layers",
+			// actual number of layers offloaded
+			"real", opts.NumGPU,
+			// estimated number of layers that can be offloaded
+			"estimate", layerCount,
+		),
+		slog.Group(
+			"memory",
+			// memory available for offloading
+			"available", format.HumanBytes2(memoryAvailable),
+			slog.Group(
+				"required",
+				// memory required for full offloading
+				"full", format.HumanBytes2(memoryRequiredTotal),
+				// memory required to offload layers.estimate layers
+				"partial", format.HumanBytes2(memoryRequiredPartial),
+				// memory of KV cache
+				"kv", format.HumanBytes2(kv),
+			),
+			slog.Group(
+				"weights",
+				// memory of the weights
+				"total", format.HumanBytes2(memoryWeights),
+				// memory of repeating layers
+				"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
+				// memory of non-repeating layers
+				"nonrepeating", format.HumanBytes2(memoryLayerOutput),
+			),
+			slog.Group(
+				"graph",
+				// memory of graph when fully offloaded
+				"full", format.HumanBytes2(graphFullOffload),
+				// memory of graph when not fully offloaded
+				"partial", format.HumanBytes2(graphPartialOffload),
+			),
+		),
+	)
+	if gpus[0].Library == "cpu" {
+		return 0, 0, memoryRequiredTotal
+	}
+	if memoryRequiredPartial > memoryAvailable {
+		slog.Debug("insufficient VRAM to load any model layers")
+		return 0, 0, memoryRequiredTotal
+	}
+
+	return layerCount, memoryRequiredPartial, memoryRequiredTotal
+}

+ 12 - 0
llm/patches/02-clip-log.diff

@@ -0,0 +1,12 @@
+diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
+index e431c7f7..f077e688 100644
+--- a/examples/llava/clip.cpp
++++ b/examples/llava/clip.cpp
+@@ -3,6 +3,7 @@
+ // I'll gradually clean and extend it
+ // Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
+ #include "clip.h"
++#include "common.h"
+ #include "log.h"
+ #include "ggml.h"
+ #include "ggml-alloc.h"

+ 45 - 0
llm/patches/04-metal.diff

@@ -0,0 +1,45 @@
+diff --git a/ggml-metal.m b/ggml-metal.m
+index 0207b787..b5e9884b 100644
+--- a/ggml-metal.m
++++ b/ggml-metal.m
+@@ -1396,27 +1396,23 @@ static enum ggml_status ggml_metal_graph_compute(
+                         // to the matrix-vector kernel
+                         int ne11_mm_min = 1;
+ 
+-#if 0
+                         // the numbers below are measured on M2 Ultra for 7B and 13B models
+                         // these numbers do not translate to other devices or model sizes
+                         // TODO: need to find a better approach
+-                        if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
+-                            switch (src0t) {
+-                                case GGML_TYPE_F16:  ne11_mm_min = 2;  break;
+-                                case GGML_TYPE_Q8_0: ne11_mm_min = 7;  break;
+-                                case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
+-                                case GGML_TYPE_Q3_K: ne11_mm_min = 7;  break;
+-                                case GGML_TYPE_Q4_0:
+-                                case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
+-                                case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
+-                                case GGML_TYPE_Q5_0:                          // not tested yet
+-                                case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
+-                                case GGML_TYPE_Q5_K: ne11_mm_min = 7;  break;
+-                                case GGML_TYPE_Q6_K: ne11_mm_min = 7;  break;
+-                                default:             ne11_mm_min = 1;  break;
+-                            }
++                        switch (src0t) {
++                            case GGML_TYPE_F16:  ne11_mm_min = 2;  break;
++                            case GGML_TYPE_Q8_0: ne11_mm_min = 7;  break;
++                            case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
++                            case GGML_TYPE_Q3_K: ne11_mm_min = 7;  break;
++                            case GGML_TYPE_Q4_0:
++                            case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
++                            case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
++                            case GGML_TYPE_Q5_0:                          // not tested yet
++                            case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
++                            case GGML_TYPE_Q5_K: ne11_mm_min = 7;  break;
++                            case GGML_TYPE_Q6_K: ne11_mm_min = 7;  break;
++                            default:             ne11_mm_min = 1;  break;
+                         }
+-#endif
+ 
+                         // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+                         // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel

+ 24 - 0
llm/patches/05-clip-fix.diff

@@ -0,0 +1,24 @@
+diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
+index e3c9bcd4..b43f892d 100644
+--- a/examples/llava/clip.cpp
++++ b/examples/llava/clip.cpp
+@@ -573,14 +573,16 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
+     struct ggml_tensor * embeddings = inp;
+     if (ctx->has_class_embedding) {
+         embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
++    }
++    ggml_set_name(embeddings, "embeddings");
++    ggml_set_input(embeddings);
++
++    if (ctx->has_class_embedding) {
+         embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
+                 embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
+         embeddings = ggml_acc(ctx0, embeddings, inp,
+                 embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
+     }
+-    ggml_set_name(embeddings, "embeddings");
+-    ggml_set_input(embeddings);
+-
+ 
+     struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
+     ggml_set_name(positions, "positions");

+ 27 - 7
llm/payload.go

@@ -9,6 +9,7 @@ import (
 	"log/slog"
 	"log/slog"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
+	"runtime"
 	"strings"
 	"strings"
 
 
 	"golang.org/x/exp/slices"
 	"golang.org/x/exp/slices"
@@ -17,7 +18,7 @@ import (
 	"github.com/ollama/ollama/gpu"
 	"github.com/ollama/ollama/gpu"
 )
 )
 
 
-var errPayloadMissing = fmt.Errorf("expected payloads not included in this build of ollama")
+var errPayloadMissing = errors.New("expected payloads not included in this build of ollama")
 
 
 func Init() error {
 func Init() error {
 	payloadsDir, err := gpu.PayloadsDir()
 	payloadsDir, err := gpu.PayloadsDir()
@@ -25,13 +26,15 @@ func Init() error {
 		return err
 		return err
 	}
 	}
 
 
-	slog.Info("extracting embedded files", "dir", payloadsDir)
-	binGlob := "build/*/*/*/bin/*"
+	if runtime.GOOS != "windows" {
+		slog.Info("extracting embedded files", "dir", payloadsDir)
+		binGlob := "build/*/*/*/bin/*"
 
 
-	// extract server libraries
-	err = extractFiles(payloadsDir, binGlob)
-	if err != nil {
-		return fmt.Errorf("extract binaries: %v", err)
+		// extract server libraries
+		err = extractFiles(payloadsDir, binGlob)
+		if err != nil {
+			return fmt.Errorf("extract binaries: %v", err)
+		}
 	}
 	}
 
 
 	var variants []string
 	var variants []string
@@ -138,6 +141,23 @@ func serversForGpu(info gpu.GpuInfo) []string {
 	return servers
 	return servers
 }
 }
 
 
+// Return the optimal server for this CPU architecture
+func serverForCpu() string {
+	if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
+		return "metal"
+	}
+	variant := gpu.GetCPUVariant()
+	availableServers := availableServers()
+	if variant != "" {
+		for cmp := range availableServers {
+			if cmp == "cpu_"+variant {
+				return cmp
+			}
+		}
+	}
+	return "cpu"
+}
+
 // extract extracts the embedded files to the target directory
 // extract extracts the embedded files to the target directory
 func extractFiles(targetDir string, glob string) error {
 func extractFiles(targetDir string, glob string) error {
 	files, err := fs.Glob(libEmbed, glob)
 	files, err := fs.Glob(libEmbed, glob)

+ 352 - 268
llm/server.go

@@ -21,21 +21,47 @@ import (
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
+	"golang.org/x/sync/semaphore"
+
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/gpu"
 	"github.com/ollama/ollama/gpu"
+	"github.com/ollama/ollama/server/envconfig"
 )
 )
 
 
-// LlamaServer is an instance of the llama.cpp server
-type LlamaServer struct {
+type LlamaServer interface {
+	Ping(ctx context.Context) error
+	WaitUntilRunning(ctx context.Context) error
+	Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
+	Embedding(ctx context.Context, prompt string) ([]float64, error)
+	Tokenize(ctx context.Context, content string) ([]int, error)
+	Detokenize(ctx context.Context, tokens []int) (string, error)
+	Close() error
+	EstimatedVRAM() uint64
+}
+
+// llmServer is an instance of the llama.cpp server
+type llmServer struct {
 	port    int
 	port    int
 	cmd     *exec.Cmd
 	cmd     *exec.Cmd
 	done    chan error // Channel to signal when the process exits
 	done    chan error // Channel to signal when the process exits
 	status  *StatusWriter
 	status  *StatusWriter
 	options api.Options
 	options api.Options
+
+	// TODO - this should be broken down by GPU
+	estimatedVRAM  uint64 // Estimated usage of VRAM by the loaded model
+	estimatedTotal uint64 // Total size of model
+	totalLayers    uint64
+	gpuCount       int
+
+	sem *semaphore.Weighted
 }
 }
 
 
-func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
+func LoadModel(model string) (*GGML, error) {
+	if _, err := os.Stat(model); err != nil {
+		return nil, err
+	}
+
 	f, err := os.Open(model)
 	f, err := os.Open(model)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -43,144 +69,69 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 	defer f.Close()
 	defer f.Close()
 
 
 	ggml, _, err := DecodeGGML(f)
 	ggml, _, err := DecodeGGML(f)
-	if err != nil {
-		return nil, err
-	}
+	return ggml, err
+}
 
 
+// NewLlamaServer will run a server for the given GPUs
+// The gpu list must be a single family.
+func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
+	var err error
 	if opts.NumCtx > int(ggml.KV().ContextLength()) {
 	if opts.NumCtx > int(ggml.KV().ContextLength()) {
-		slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
-		opts.NumCtx = int(ggml.KV().ContextLength())
+		slog.Warn("requested context length is greater than the model's training context window size", "requested", opts.NumCtx, "training size", ggml.KV().ContextLength())
 	}
 	}
 
 
 	if opts.NumCtx < 4 {
 	if opts.NumCtx < 4 {
 		opts.NumCtx = 4
 		opts.NumCtx = 4
 	}
 	}
 
 
-	memoryAvailable, _ := gpu.CheckVRAM()
-	info := gpu.GetGPUInfo()
+	cpuRunner := ""
+	var estimatedVRAM uint64
+	var estimatedTotal uint64
+	var systemMemory uint64
+	gpuCount := len(gpus)
+	if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
 
 
-	memoryMinimum := info.MinimumMemory
-	for _, projector := range projectors {
-		memoryMinimum += projectorMemoryRequirements(projector)
-
-		// multimodal models require at least 2048 context
-		opts.NumCtx = max(opts.NumCtx, 2048)
-	}
+		// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
 
 
-	// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
-	var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
-
-	graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
-	if graphPartialOffload == 0 {
-		graphPartialOffload = ggml.KV().GQA() * kv / 6
-	}
-
-	if graphFullOffload == 0 {
-		graphFullOffload = graphPartialOffload
-	}
-
-	graphFullOffload *= uint64(info.DeviceCount)
-	graphPartialOffload *= uint64(info.DeviceCount)
-
-	// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
-	memoryRequiredTotal := memoryMinimum + graphFullOffload
-
-	// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
-	memoryRequiredPartial := memoryMinimum + graphPartialOffload
-
-	if info.Library != "metal" {
-		if memoryRequiredPartial > memoryAvailable {
-			info.Library = "cpu"
-		}
-	}
-
-	var layerCount int
-	layers := ggml.Tensors().Layers()
-	for i := 0; i < int(ggml.KV().BlockCount()); i++ {
-		memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
-
-		// KV is proportional to the number of layers
-		memoryLayer += kv / ggml.KV().BlockCount()
-
-		memoryRequiredTotal += memoryLayer
-		if memoryAvailable > memoryRequiredPartial+memoryLayer {
-			memoryRequiredPartial += memoryLayer
-			layerCount++
+		cpuRunner = serverForCpu()
+		gpuCount = 0
+	} else {
+		if gpus[0].Library == "metal" {
+			memInfo, err := gpu.GetCPUMem()
+			if err != nil {
+				slog.Error("failed to lookup system memory", "error", err)
+			} else {
+				systemMemory = memInfo.TotalMemory
+				slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
+			}
 		}
 		}
-	}
-
-	var memoryLayerOutput uint64
-	for k, v := range layers {
-		if !strings.HasPrefix(k, "blk.") {
-			memoryLayerOutput += v.size()
+		var layers int
+		layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
+
+		if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
+			// disable partial offloading when model is greater than total system memory as this
+			// can lead to locking up the system
+			opts.NumGPU = 0
+		} else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
+			opts.NumGPU = layers
 		}
 		}
 	}
 	}
 
 
-	memoryRequiredTotal += memoryLayerOutput
-
-	if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory {
-		// disable partial offloading when model is greater than total system memory
-		opts.NumGPU = 0
-	} else if memoryAvailable > memoryRequiredTotal {
-		layerCount = int(ggml.KV().BlockCount()) + 1
-		memoryRequiredPartial = memoryRequiredTotal
-	}
-
-	if opts.NumGPU < 0 {
-		opts.NumGPU = layerCount
-	}
-
-	memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
-
-	slog.Info(
-		"offload to gpu",
-		slog.Group(
-			"layers",
-			// actual number of layers offloaded
-			"real", opts.NumGPU,
-			// estimated number of layers that can be offloaded
-			"estimate", layerCount,
-		),
-		slog.Group(
-			"memory",
-			// memory available for offloading
-			"available", format.HumanBytes2(memoryAvailable),
-			slog.Group(
-				"required",
-				// memory required for full offloading
-				"full", format.HumanBytes2(memoryRequiredTotal),
-				// memory required to offload layers.estimate layers
-				"partial", format.HumanBytes2(memoryRequiredPartial),
-				// memory of KV cache
-				"kv", format.HumanBytes2(kv),
-			),
-			slog.Group(
-				"weights",
-				// memory of the weights
-				"total", format.HumanBytes2(memoryWeights),
-				// memory of repeating layers
-				"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
-				// memory of non-repeating layers
-				"nonrepeating", format.HumanBytes2(memoryLayerOutput),
-			),
-			slog.Group(
-				"graph",
-				// memory of graph when fully offloaded
-				"full", format.HumanBytes2(graphFullOffload),
-				// memory of graph when not fully offloaded
-				"partial", format.HumanBytes2(graphPartialOffload),
-			),
-		),
-	)
+	// Loop through potential servers
+	finalErr := fmt.Errorf("no suitable llama servers found")
 
 
 	if len(adapters) > 1 {
 	if len(adapters) > 1 {
 		return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
 		return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
 	}
 	}
 
 
 	availableServers := availableServers()
 	availableServers := availableServers()
-	servers := serversForGpu(info)
-
-	demandLib := os.Getenv("OLLAMA_LLM_LIBRARY")
+	var servers []string
+	if cpuRunner != "" {
+		servers = []string{cpuRunner}
+	} else {
+		servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
+	}
+	demandLib := envconfig.LLMLibrary
 	if demandLib != "" {
 	if demandLib != "" {
 		serverPath := availableServers[demandLib]
 		serverPath := availableServers[demandLib]
 		if serverPath == "" {
 		if serverPath == "" {
@@ -188,11 +139,15 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 		} else {
 		} else {
 			slog.Info("user override", "OLLAMA_LLM_LIBRARY", demandLib, "path", serverPath)
 			slog.Info("user override", "OLLAMA_LLM_LIBRARY", demandLib, "path", serverPath)
 			servers = []string{demandLib}
 			servers = []string{demandLib}
+			if strings.HasPrefix(demandLib, "cpu") {
+				// Omit the GPU flag to silence the warning
+				opts.NumGPU = -1
+			}
 		}
 		}
 	}
 	}
 
 
 	if len(servers) == 0 {
 	if len(servers) == 0 {
-		return nil, fmt.Errorf("no servers found for %v", info)
+		return nil, fmt.Errorf("no servers found for %v", gpus)
 	}
 	}
 
 
 	params := []string{
 	params := []string{
@@ -201,7 +156,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 		"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
 		"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
 		"--embedding",
 		"--embedding",
 	}
 	}
-	if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
+	if envconfig.Debug {
 		params = append(params, "--log-format", "json")
 		params = append(params, "--log-format", "json")
 	} else {
 	} else {
 		params = append(params, "--log-disable")
 		params = append(params, "--log-disable")
@@ -211,7 +166,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 		params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU))
 		params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU))
 	}
 	}
 
 
-	if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
+	if envconfig.Debug {
 		params = append(params, "--verbose")
 		params = append(params, "--verbose")
 	}
 	}
 
 
@@ -249,10 +204,30 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 		params = append(params, "--numa")
 		params = append(params, "--numa")
 	}
 	}
 
 
-	// Loop through potential servers
-	var finalErr error
+	numParallel := envconfig.NumParallel
+
+	// TODO (jmorganca): multimodal models don't support parallel yet
+	// see https://github.com/ollama/ollama/issues/4165
+	if len(projectors) > 0 {
+		numParallel = 1
+		slog.Warn("multimodal models don't support parallel requests yet")
+	}
+
+	params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
+
 	for i := 0; i < len(servers); i++ {
 	for i := 0; i < len(servers); i++ {
 		dir := availableServers[servers[i]]
 		dir := availableServers[servers[i]]
+		if dir == "" {
+			// Shouldn't happen
+			finalErr = fmt.Errorf("[%d] server %s not listed in available servers %v", i, servers[i], availableServers)
+			slog.Error("sever list inconsistent", "error", finalErr)
+			continue
+		}
+
+		if strings.HasPrefix(servers[i], "cpu") {
+			// TODO if we tried a gpu runner first, and it failed, record the error and bubble that back up
+			gpuCount = 0
+		}
 
 
 		// Find an availableServers  port, retry on each iterration in case the failure was a port conflict race
 		// Find an availableServers  port, retry on each iterration in case the failure was a port conflict race
 		port := 0
 		port := 0
@@ -273,12 +248,21 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 		if runtime.GOOS == "windows" {
 		if runtime.GOOS == "windows" {
 			pathEnv = "PATH"
 			pathEnv = "PATH"
 		}
 		}
-		// append the server directory to LD_LIBRARY_PATH/PATH
+		// prepend the server directory to LD_LIBRARY_PATH/PATH
 		libraryPaths := []string{dir}
 		libraryPaths := []string{dir}
+
 		if libraryPath, ok := os.LookupEnv(pathEnv); ok {
 		if libraryPath, ok := os.LookupEnv(pathEnv); ok {
 			// Append our runner directory to the path
 			// Append our runner directory to the path
 			// This will favor system libraries over our bundled library dependencies
 			// This will favor system libraries over our bundled library dependencies
-			libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...)
+			libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...)
+		}
+
+		// Note: we always put the dependency path first
+		// since this was the exact version we verified for AMD GPUs
+		// and we favor what the user had in their path
+		if gpus[0].DependencyPath != "" {
+			// TODO refine for multi-gpu support
+			libraryPaths = append([]string{gpus[0].DependencyPath}, libraryPaths...)
 		}
 		}
 
 
 		server := filepath.Join(dir, "ollama_llama_server")
 		server := filepath.Join(dir, "ollama_llama_server")
@@ -286,21 +270,66 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 			server = server + ".exe"
 			server = server + ".exe"
 		}
 		}
 
 
-		s := &LlamaServer{
-			port:    port,
-			cmd:     exec.Command(server, finalParams...),
-			status:  NewStatusWriter(os.Stderr),
-			options: opts,
+		// Detect tmp cleaners wiping out the file
+		_, err := os.Stat(server)
+		if errors.Is(err, os.ErrNotExist) {
+			slog.Warn("llama server disappeared, reinitializing payloads", "path", server, "error", err)
+			err = Init()
+			if err != nil {
+				slog.Warn("failed to reinitialize payloads", "error", err)
+				return nil, err
+			}
 		}
 		}
-		libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
-		slog.Debug(libEnv)
-		s.cmd.Env = append(os.Environ(), libEnv)
+
+		s := &llmServer{
+			port:           port,
+			cmd:            exec.Command(server, finalParams...),
+			status:         NewStatusWriter(os.Stderr),
+			options:        opts,
+			estimatedVRAM:  estimatedVRAM,
+			estimatedTotal: estimatedTotal,
+			sem:            semaphore.NewWeighted(int64(numParallel)),
+			totalLayers:    ggml.KV().BlockCount() + 1,
+			gpuCount:       gpuCount,
+		}
+
+		s.cmd.Env = os.Environ()
 		s.cmd.Stdout = os.Stdout
 		s.cmd.Stdout = os.Stdout
 		s.cmd.Stderr = s.status
 		s.cmd.Stderr = s.status
 
 
+		visibleDevicesEnv, visibleDevicesEnvVal := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
+		pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
+
+		// Update or add the path and visible devices variable with our adjusted version
+		pathNeeded := true
+		devicesNeeded := visibleDevicesEnv != ""
+		for i := range s.cmd.Env {
+			cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
+			if strings.EqualFold(cmp[0], pathEnv) {
+				s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
+				pathNeeded = false
+			} else if devicesNeeded && strings.EqualFold(cmp[0], visibleDevicesEnv) {
+				s.cmd.Env[i] = visibleDevicesEnv + "=" + visibleDevicesEnvVal
+				devicesNeeded = false
+			}
+		}
+		if pathNeeded {
+			s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
+		}
+		if devicesNeeded {
+			s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
+		}
+
 		slog.Info("starting llama server", "cmd", s.cmd.String())
 		slog.Info("starting llama server", "cmd", s.cmd.String())
+		// Log at debug as the environment is inherited and might contain sensitive information
+		slog.Debug("subprocess", "environment", s.cmd.Env)
 
 
 		if err = s.cmd.Start(); err != nil {
 		if err = s.cmd.Start(); err != nil {
+			// Detect permission denied and augment them essage about noexec
+			if errors.Is(err, os.ErrPermission) {
+				finalErr = fmt.Errorf("unable to start server %w.  %s may have noexec set.  Set OLLAMA_TMPDIR for server to a writable executable directory", err, dir)
+				continue
+			}
 			msg := ""
 			msg := ""
 			if s.status != nil && s.status.LastErrMsg != "" {
 			if s.status != nil && s.status.LastErrMsg != "" {
 				msg = s.status.LastErrMsg
 				msg = s.status.LastErrMsg
@@ -310,12 +339,6 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 			continue
 			continue
 		}
 		}
 
 
-		// reap subprocess when it exits
-		go func() {
-			// Exit status managed via getServerStatus
-			_ = s.cmd.Wait()
-		}()
-
 		return s, nil
 		return s, nil
 	}
 	}
 
 
@@ -347,12 +370,27 @@ type ServerStatus int
 
 
 const ( // iota is reset to 0
 const ( // iota is reset to 0
 	ServerStatusReady ServerStatus = iota
 	ServerStatusReady ServerStatus = iota
-	ServerStatusNoSlotsAvaialble
+	ServerStatusNoSlotsAvailable
 	ServerStatusLoadingModel
 	ServerStatusLoadingModel
 	ServerStatusNotResponding
 	ServerStatusNotResponding
 	ServerStatusError
 	ServerStatusError
 )
 )
 
 
+func (s ServerStatus) ToString() string {
+	switch s {
+	case ServerStatusReady:
+		return "llm server ready"
+	case ServerStatusNoSlotsAvailable:
+		return "llm busy - no slots available"
+	case ServerStatusLoadingModel:
+		return "llm server loading model"
+	case ServerStatusNotResponding:
+		return "llm server not responding"
+	default:
+		return "llm server error"
+	}
+}
+
 type ServerStatusResp struct {
 type ServerStatusResp struct {
 	Status          string `json:"status"`
 	Status          string `json:"status"`
 	SlotsIdle       int    `json:"slots_idle"`
 	SlotsIdle       int    `json:"slots_idle"`
@@ -360,13 +398,17 @@ type ServerStatusResp struct {
 	Error           string `json:"error"`
 	Error           string `json:"error"`
 }
 }
 
 
-func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
+func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
 	// Fail fast if its exited
 	// Fail fast if its exited
 	if s.cmd.ProcessState != nil {
 	if s.cmd.ProcessState != nil {
 		msg := ""
 		msg := ""
 		if s.status != nil && s.status.LastErrMsg != "" {
 		if s.status != nil && s.status.LastErrMsg != "" {
 			msg = s.status.LastErrMsg
 			msg = s.status.LastErrMsg
 		}
 		}
+		if s.cmd.ProcessState.ExitCode() == -1 {
+			// Most likely a signal killed it, log some more details to try to help troubleshoot
+			slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
+		}
 		return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
 		return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
 	}
 	}
 
 
@@ -399,7 +441,7 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
 	case "ok":
 	case "ok":
 		return ServerStatusReady, nil
 		return ServerStatusReady, nil
 	case "no slot available":
 	case "no slot available":
-		return ServerStatusNoSlotsAvaialble, nil
+		return ServerStatusNoSlotsAvailable, nil
 	case "loading model":
 	case "loading model":
 		return ServerStatusLoadingModel, nil
 		return ServerStatusLoadingModel, nil
 	default:
 	default:
@@ -407,7 +449,30 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
 	}
 	}
 }
 }
 
 
-func (s *LlamaServer) Ping(ctx context.Context) error {
+// getServerStatusRetry will retry if ServerStatusNoSlotsAvailable is received
+func (s *llmServer) getServerStatusRetry(ctx context.Context) (ServerStatus, error) {
+	var retries int
+	for {
+		status, err := s.getServerStatus(ctx)
+		if err != nil {
+			return status, err
+		}
+
+		if status == ServerStatusNoSlotsAvailable {
+			if retries >= 10 {
+				return status, fmt.Errorf("no slots available after %d retries", retries)
+			}
+
+			time.Sleep(5 * time.Millisecond)
+			retries++
+			continue
+		}
+
+		return status, nil
+	}
+}
+
+func (s *llmServer) Ping(ctx context.Context) error {
 	_, err := s.getServerStatus(ctx)
 	_, err := s.getServerStatus(ctx)
 	if err != nil {
 	if err != nil {
 		slog.Debug("server unhealthy", "error", err)
 		slog.Debug("server unhealthy", "error", err)
@@ -416,13 +481,25 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
 	return nil
 	return nil
 }
 }
 
 
-func (s *LlamaServer) WaitUntilRunning() error {
+func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
 	start := time.Now()
 	start := time.Now()
 	expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
 	expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
 
 
 	slog.Info("waiting for llama runner to start responding")
 	slog.Info("waiting for llama runner to start responding")
 
 
 	for {
 	for {
+		select {
+		case <-ctx.Done():
+			slog.Info("context expired before server started")
+			return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
+		case err := <-s.done:
+			msg := ""
+			if s.status != nil && s.status.LastErrMsg != "" {
+				msg = s.status.LastErrMsg
+			}
+			return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
+		default:
+		}
 		ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
 		ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
 		defer cancel()
 		defer cancel()
 		status, err := s.getServerStatus(ctx)
 		status, err := s.getServerStatus(ctx)
@@ -487,7 +564,6 @@ ws ::= ([ \t\n] ws)?
 `
 `
 
 
 const maxBufferSize = 512 * format.KiloByte
 const maxBufferSize = 512 * format.KiloByte
-const maxRetries = 3
 
 
 type ImageData struct {
 type ImageData struct {
 	Data []byte `json:"data"`
 	Data []byte `json:"data"`
@@ -524,7 +600,19 @@ type CompletionResponse struct {
 	EvalDuration       time.Duration
 	EvalDuration       time.Duration
 }
 }
 
 
-func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
+func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
+	if err := s.sem.Acquire(ctx, 1); err != nil {
+		slog.Error("Failed to acquire semaphore", "error", err)
+		return err
+	}
+	defer s.sem.Release(1)
+
+	// only allow maximum 10 "context shifts" to avoid infinite generation
+	if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
+		req.Options.NumPredict = 10 * s.options.NumCtx
+		slog.Debug("setting token limit to 10x num_ctx", "num_ctx", s.options.NumCtx, "num_predict", req.Options.NumPredict)
+	}
+
 	request := map[string]any{
 	request := map[string]any{
 		"prompt":            req.Prompt,
 		"prompt":            req.Prompt,
 		"stream":            true,
 		"stream":            true,
@@ -551,11 +639,11 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
 	}
 	}
 
 
 	// Make sure the server is ready
 	// Make sure the server is ready
-	status, err := s.getServerStatus(ctx)
+	status, err := s.getServerStatusRetry(ctx)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	} else if status != ServerStatusReady {
 	} else if status != ServerStatusReady {
-		return fmt.Errorf("unexpected server status: %d", status)
+		return fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 	}
 
 
 	if req.Format == "json" {
 	if req.Format == "json" {
@@ -565,133 +653,113 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
 		}
 		}
 	}
 	}
 
 
-	retryDelay := 100 * time.Microsecond
-	for retries := 0; retries < maxRetries; retries++ {
-		if retries > 0 {
-			time.Sleep(retryDelay) // wait before retrying
-			retryDelay *= 2        // exponential backoff
-		}
+	// Handling JSON marshaling with special characters unescaped.
+	buffer := &bytes.Buffer{}
+	enc := json.NewEncoder(buffer)
+	enc.SetEscapeHTML(false)
 
 
-		// Handling JSON marshaling with special characters unescaped.
-		buffer := &bytes.Buffer{}
-		enc := json.NewEncoder(buffer)
-		enc.SetEscapeHTML(false)
+	if err := enc.Encode(request); err != nil {
+		return fmt.Errorf("failed to marshal data: %v", err)
+	}
 
 
-		if err := enc.Encode(request); err != nil {
-			return fmt.Errorf("failed to marshal data: %v", err)
-		}
+	endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
+	serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
+	if err != nil {
+		return fmt.Errorf("error creating POST request: %v", err)
+	}
+	serverReq.Header.Set("Content-Type", "application/json")
 
 
-		endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
-		req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
-		if err != nil {
-			return fmt.Errorf("error creating POST request: %v", err)
-		}
-		req.Header.Set("Content-Type", "application/json")
+	res, err := http.DefaultClient.Do(serverReq)
+	if err != nil {
+		return fmt.Errorf("POST predict: %v", err)
+	}
+	defer res.Body.Close()
 
 
-		resp, err := http.DefaultClient.Do(req)
+	if res.StatusCode >= 400 {
+		bodyBytes, err := io.ReadAll(res.Body)
 		if err != nil {
 		if err != nil {
-			return fmt.Errorf("POST predict: %v", err)
+			return fmt.Errorf("failed reading llm error response: %w", err)
 		}
 		}
-		defer resp.Body.Close()
+		log.Printf("llm predict error: %s", bodyBytes)
+		return fmt.Errorf("%s", bodyBytes)
+	}
 
 
-		if resp.StatusCode >= 400 {
-			bodyBytes, err := io.ReadAll(resp.Body)
-			if err != nil {
-				return fmt.Errorf("failed reading llm error response: %w", err)
+	scanner := bufio.NewScanner(res.Body)
+	buf := make([]byte, 0, maxBufferSize)
+	scanner.Buffer(buf, maxBufferSize)
+
+	// keep track of the last token generated, this is used to abort if the model starts looping
+	var lastToken string
+	var tokenRepeat int
+
+	for scanner.Scan() {
+		select {
+		case <-ctx.Done():
+			// This handles the request cancellation
+			return ctx.Err()
+		default:
+			line := scanner.Bytes()
+			if len(line) == 0 {
+				continue
 			}
 			}
-			log.Printf("llm predict error: %s", bodyBytes)
-			return fmt.Errorf("%s", bodyBytes)
-		}
 
 
-		scanner := bufio.NewScanner(resp.Body)
-		buf := make([]byte, 0, maxBufferSize)
-		scanner.Buffer(buf, maxBufferSize)
+			evt, ok := bytes.CutPrefix(line, []byte("data: "))
+			if !ok {
+				return fmt.Errorf("error parsing llm response stream: %s", line)
+			}
 
 
-		retryNeeded := false
-		// keep track of the last token generated, this is used to abort if the model starts looping
-		var lastToken string
-		var tokenRepeat int
+			var c completion
+			if err := json.Unmarshal(evt, &c); err != nil {
+				return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
+			}
 
 
-		for scanner.Scan() {
-			select {
-			case <-ctx.Done():
-				// This handles the request cancellation
-				return ctx.Err()
+			switch {
+			case strings.TrimSpace(c.Content) == lastToken:
+				tokenRepeat++
 			default:
 			default:
-				line := scanner.Bytes()
-				if len(line) == 0 {
-					continue
-				}
-
-				// try again on slot unavailable
-				if bytes.Contains(line, []byte("slot unavailable")) {
-					retryNeeded = true
-					break
-				}
-
-				evt, ok := bytes.CutPrefix(line, []byte("data: "))
-				if !ok {
-					return fmt.Errorf("error parsing llm response stream: %s", line)
-				}
-
-				var c completion
-				if err := json.Unmarshal(evt, &c); err != nil {
-					return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
-				}
-
-				switch {
-				case strings.TrimSpace(c.Content) == lastToken:
-					tokenRepeat++
-				default:
-					lastToken = strings.TrimSpace(c.Content)
-					tokenRepeat = 0
-				}
-
-				// 30 picked as an arbitrary max token repeat limit, modify as needed
-				if tokenRepeat > 30 {
-					slog.Debug("prediction aborted, token repeat limit reached")
-					return ctx.Err()
-				}
-
-				if c.Content != "" {
-					fn(CompletionResponse{
-						Content: c.Content,
-					})
-				}
-
-				if c.Stop {
-					fn(CompletionResponse{
-						Done:               true,
-						PromptEvalCount:    c.Timings.PromptN,
-						PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
-						EvalCount:          c.Timings.PredictedN,
-						EvalDuration:       parseDurationMs(c.Timings.PredictedMS),
-					})
-					return nil
-				}
+				lastToken = strings.TrimSpace(c.Content)
+				tokenRepeat = 0
+			}
+
+			// 30 picked as an arbitrary max token repeat limit, modify as needed
+			if tokenRepeat > 30 {
+				slog.Debug("prediction aborted, token repeat limit reached")
+				return ctx.Err()
 			}
 			}
-		}
 
 
-		if err := scanner.Err(); err != nil {
-			if strings.Contains(err.Error(), "unexpected EOF") {
-				s.Close()
-				msg := ""
-				if s.status != nil && s.status.LastErrMsg != "" {
-					msg = s.status.LastErrMsg
-				}
+			if c.Content != "" {
+				fn(CompletionResponse{
+					Content: c.Content,
+				})
+			}
 
 
-				return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
+			if c.Stop {
+				fn(CompletionResponse{
+					Done:               true,
+					PromptEvalCount:    c.Timings.PromptN,
+					PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
+					EvalCount:          c.Timings.PredictedN,
+					EvalDuration:       parseDurationMs(c.Timings.PredictedMS),
+				})
+				return nil
 			}
 			}
-			return fmt.Errorf("error reading llm response: %v", err)
 		}
 		}
+	}
 
 
-		if !retryNeeded {
-			return nil // success
+	if err := scanner.Err(); err != nil {
+		if strings.Contains(err.Error(), "unexpected EOF") {
+			s.Close()
+			msg := ""
+			if s.status != nil && s.status.LastErrMsg != "" {
+				msg = s.status.LastErrMsg
+			}
+			return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
 		}
 		}
+
+		return fmt.Errorf("error reading llm response: %v", err)
 	}
 	}
 
 
-	// should never reach here ideally
-	return fmt.Errorf("max retries exceeded")
+	return nil
 }
 }
 
 
 type EmbeddingRequest struct {
 type EmbeddingRequest struct {
@@ -702,13 +770,19 @@ type EmbeddingResponse struct {
 	Embedding []float64 `json:"embedding"`
 	Embedding []float64 `json:"embedding"`
 }
 }
 
 
-func (s *LlamaServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
+func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
+	if err := s.sem.Acquire(ctx, 1); err != nil {
+		slog.Error("Failed to acquire semaphore", "error", err)
+		return nil, err
+	}
+	defer s.sem.Release(1)
+
 	// Make sure the server is ready
 	// Make sure the server is ready
-	status, err := s.getServerStatus(ctx)
+	status, err := s.getServerStatusRetry(ctx)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	} else if status != ServerStatusReady {
 	} else if status != ServerStatusReady {
-		return nil, fmt.Errorf("unexpected server status: %d", status)
+		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 	}
 
 
 	data, err := json.Marshal(TokenizeRequest{Content: prompt})
 	data, err := json.Marshal(TokenizeRequest{Content: prompt})
@@ -754,13 +828,13 @@ type TokenizeResponse struct {
 	Tokens []int `json:"tokens"`
 	Tokens []int `json:"tokens"`
 }
 }
 
 
-func (s *LlamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
+func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
 	// Make sure the server is ready
 	// Make sure the server is ready
 	status, err := s.getServerStatus(ctx)
 	status, err := s.getServerStatus(ctx)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
-	} else if status != ServerStatusReady {
-		return nil, fmt.Errorf("unexpected server status: %d", status)
+	} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
+		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 	}
 
 
 	data, err := json.Marshal(TokenizeRequest{Content: content})
 	data, err := json.Marshal(TokenizeRequest{Content: content})
@@ -806,13 +880,13 @@ type DetokenizeResponse struct {
 	Content string `json:"content"`
 	Content string `json:"content"`
 }
 }
 
 
-func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
+func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
 	// Make sure the server is ready
 	// Make sure the server is ready
 	status, err := s.getServerStatus(ctx)
 	status, err := s.getServerStatus(ctx)
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
-	} else if status != ServerStatusReady {
-		return "", fmt.Errorf("unexpected server status: %d", status)
+	} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
+		return "", fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 	}
 
 
 	data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
 	data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
@@ -850,15 +924,25 @@ func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, err
 	return decoded.Content, nil
 	return decoded.Content, nil
 }
 }
 
 
-func (s *LlamaServer) Close() error {
+func (s *llmServer) Close() error {
 	if s.cmd != nil {
 	if s.cmd != nil {
 		slog.Debug("stopping llama server")
 		slog.Debug("stopping llama server")
-		return s.cmd.Process.Kill()
+		if err := s.cmd.Process.Kill(); err != nil {
+			return err
+		}
+
+		_ = s.cmd.Wait()
+
+		slog.Debug("llama server stopped")
 	}
 	}
 
 
 	return nil
 	return nil
 }
 }
 
 
+func (s *llmServer) EstimatedVRAM() uint64 {
+	return s.estimatedVRAM
+}
+
 func parseDurationMs(ms float64) time.Duration {
 func parseDurationMs(ms float64) time.Duration {
 	dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
 	dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
 	if err != nil {
 	if err != nil {

+ 1 - 1
macapp/src/app.tsx

@@ -19,7 +19,7 @@ export default function () {
   const [step, setStep] = useState<Step>(Step.WELCOME)
   const [step, setStep] = useState<Step>(Step.WELCOME)
   const [commandCopied, setCommandCopied] = useState<boolean>(false)
   const [commandCopied, setCommandCopied] = useState<boolean>(false)
 
 
-  const command = 'ollama run llama2'
+  const command = 'ollama run llama3'
 
 
   return (
   return (
     <div className='drag'>
     <div className='drag'>

Some files were not shown because too many files changed in this diff