Explorar o código

Merge remote-tracking branch 'upstream/main' into pr3702

Daniel Hiltgen hai 1 ano
pai
achega
920a4b0794
Modificáronse 100 ficheiros con 3810 adicións e 2222 borrados
  1. 10 10
      .github/workflows/release.yaml
  2. 20 14
      .github/workflows/test.yaml
  3. 2 1
      .gitignore
  4. 7 7
      Dockerfile
  5. 41 28
      README.md
  6. 76 10
      api/client.go
  7. 43 1
      api/client_test.go
  8. 82 15
      api/types.go
  9. 57 0
      api/types_test.go
  10. 3 1
      app/lifecycle/logging.go
  11. 37 27
      app/lifecycle/server.go
  12. 1 4
      app/lifecycle/updater_windows.go
  13. 5 8
      app/ollama.iss
  14. 71 71
      app/tray/wintray/menus.go
  15. 33 3
      auth/auth.go
  16. 193 103
      cmd/cmd.go
  17. 6 1
      cmd/interactive.go
  18. 19 14
      convert/convert.go
  19. 2 13
      convert/gemma.go
  20. 2 16
      convert/llama.go
  21. 2 13
      convert/mistral.go
  22. 85 0
      convert/mixtral.go
  23. 27 14
      convert/safetensors.go
  24. 1 1
      convert/torch.go
  25. 31 31
      docs/api.md
  26. 2 2
      docs/development.md
  27. 16 6
      docs/faq.md
  28. 3 1
      docs/import.md
  29. 1 1
      docs/linux.md
  30. 17 25
      docs/modelfile.md
  31. 5 5
      docs/openai.md
  32. 3 3
      docs/tutorials/langchainjs.md
  33. 7 5
      docs/tutorials/langchainpy.md
  34. 6 29
      docs/tutorials/nvidia-jetson.md
  35. 61 47
      docs/windows.md
  36. 1 1
      examples/bash-comparemodels/README.md
  37. 1 0
      examples/flyio/.gitignore
  38. 67 0
      examples/flyio/README.md
  39. 1 1
      examples/go-chat/main.go
  40. 1 1
      examples/go-http-generate/main.go
  41. 14 12
      examples/kubernetes/README.md
  42. 5 5
      examples/langchain-python-rag-document/main.py
  43. 4 4
      examples/langchain-python-rag-websummary/main.py
  44. 2 3
      examples/langchain-python-simple/README.md
  45. 1 1
      examples/langchain-python-simple/main.py
  46. 1 1
      examples/modelfile-mario/Modelfile
  47. 3 3
      examples/modelfile-mario/readme.md
  48. 7 7
      examples/python-json-datagenerator/predefinedschema.py
  49. 1 1
      examples/python-json-datagenerator/randomaddresses.py
  50. 2 2
      examples/python-json-datagenerator/readme.md
  51. 1 1
      examples/python-simplechat/client.py
  52. 2 2
      examples/python-simplechat/readme.md
  53. 2 2
      examples/typescript-mentors/README.md
  54. 2 2
      examples/typescript-mentors/character-generator.ts
  55. 2 2
      examples/typescript-simplechat/client.ts
  56. 3 0
      format/bytes.go
  57. 14 6
      format/format.go
  58. 34 0
      format/format_test.go
  59. 58 14
      gpu/amd_common.go
  60. 2 2
      gpu/amd_hip_windows.go
  61. 176 287
      gpu/amd_linux.go
  62. 85 76
      gpu/amd_windows.go
  63. 15 4
      gpu/assets.go
  64. 22 0
      gpu/cuda_common.go
  65. 154 142
      gpu/gpu.go
  66. 28 33
      gpu/gpu_darwin.go
  67. 9 4
      gpu/gpu_info.h
  68. 6 2
      gpu/gpu_info_cpu.c
  69. 63 82
      gpu/gpu_info_cudart.c
  70. 97 11
      gpu/gpu_info_cudart.h
  71. 203 0
      gpu/gpu_info_nvcuda.c
  72. 71 0
      gpu/gpu_info_nvcuda.h
  73. 0 221
      gpu/gpu_info_nvml.c
  74. 0 57
      gpu/gpu_info_nvml.h
  75. 6 13
      gpu/gpu_test.go
  76. 43 6
      gpu/types.go
  77. 44 2
      integration/basic_test.go
  78. 225 0
      integration/concurrency_test.go
  79. 1 2
      integration/context_test.go
  80. 3 3
      integration/llm_image_test.go
  81. 1 23
      integration/llm_test.go
  82. 117 0
      integration/max_queue_test.go
  83. 172 99
      integration/utils_test.go
  84. 9 8
      llm/ext_server/server.cpp
  85. 140 0
      llm/filetype.go
  86. 1 1
      llm/generate/gen_common.sh
  87. 30 16
      llm/generate/gen_linux.sh
  88. 190 112
      llm/generate/gen_windows.ps1
  89. 28 79
      llm/ggml.go
  90. 10 6
      llm/gguf.go
  91. 1 1
      llm/llama.cpp
  92. 5 52
      llm/llm.go
  93. 1 1
      llm/llm_windows.go
  94. 185 0
      llm/memory.go
  95. 12 0
      llm/patches/02-clip-log.diff
  96. 45 0
      llm/patches/04-metal.diff
  97. 24 0
      llm/patches/05-clip-fix.diff
  98. 27 7
      llm/payload.go
  99. 352 268
      llm/server.go
  100. 1 1
      macapp/src/app.tsx

+ 10 - 10
.github/workflows/release.yaml

@@ -103,6 +103,7 @@ jobs:
           path: |
             llm/build/**/bin/*
             llm/build/**/*.a
+            dist/windows-amd64/**
 
   # ROCm generation step
   generate-windows-rocm:
@@ -173,7 +174,9 @@ jobs:
       - uses: actions/upload-artifact@v4
         with:
           name: generate-windows-rocm
-          path: llm/build/**/bin/*
+          path: |
+            llm/build/**/bin/*
+            dist/windows-amd64/**
       - uses: actions/upload-artifact@v4
         with:
           name: windows-rocm-deps
@@ -253,7 +256,9 @@ jobs:
       - uses: actions/upload-artifact@v4
         with:
           name: generate-windows-cuda
-          path: llm/build/**/bin/*
+          path: |
+            llm/build/**/bin/*
+            dist/windows-amd64/**
       - uses: actions/upload-artifact@v4
         with:
           name: windows-cuda-deps
@@ -306,23 +311,18 @@ jobs:
       - uses: actions/download-artifact@v4
         with:
           name: generate-windows-cpu
-          path: llm/build
       - uses: actions/download-artifact@v4
         with:
           name: generate-windows-cuda
-          path: llm/build
       - uses: actions/download-artifact@v4
         with:
           name: windows-cuda-deps
-          path: dist/deps
       - uses: actions/download-artifact@v4
         with:
           name: windows-rocm-deps
-          path: dist/deps
       - uses: actions/download-artifact@v4
         with:
           name: generate-windows-rocm
-          path: llm/build
       - run: dir llm/build
       - run: |
           $gopath=(get-command go).source | split-path -parent
@@ -331,13 +331,13 @@ jobs:
           $env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
           $env:PATH="$gopath;$env:PATH"
           $env:OLLAMA_SKIP_GENERATE="1"
-          $env:NVIDIA_DIR=$(resolve-path ".\dist\deps")
-          $env:HIP_PATH=$(resolve-path ".\dist\deps")
           & .\scripts\build_windows.ps1
       - uses: actions/upload-artifact@v4
         with:
           name: dist-windows
-          path: dist/*.exe
+          path: |
+            dist/OllamaSetup.exe
+            dist/ollama-windows-*.zip
 
   # Linux x86 assets built using the container based build
   build-linux-amd64:

+ 20 - 14
.github/workflows/test.yaml

@@ -1,5 +1,15 @@
 name: test
 
+concurrency:
+  # For PRs, later CI runs preempt previous ones. e.g. a force push on a PR
+  # cancels running CI jobs and starts all new ones.
+  #
+  # For non-PR pushes, concurrency.group needs to be unique for every distinct
+  # CI run we want to have happen. Use run_id, which in practice means all
+  # non-PR CI runs will be allowed to run without preempting each other.
+  group: ${{ github.workflow }}-$${{ github.pull_request.number || github.run_id }}
+  cancel-in-progress: true
+
 on:
   pull_request:
     paths:
@@ -21,7 +31,9 @@ jobs:
       - id: changes
         run: |
           changed() {
-            git diff-tree -r --no-commit-id --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} \
+            git diff-tree -r --no-commit-id --name-only \
+              $(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \
+              ${{ github.event.pull_request.head.sha }} \
               | xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))"
           }
 
@@ -103,7 +115,9 @@ jobs:
       - uses: actions/upload-artifact@v4
         with:
           name: cuda-${{ matrix.cuda-version }}-libraries
-          path: llm/build/**/bin/*
+          path: |
+            llm/build/**/bin/*
+            dist/windows-amd64/**
   generate-rocm:
     needs: [changes]
     if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }}
@@ -134,7 +148,9 @@ jobs:
       - uses: actions/upload-artifact@v4
         with:
           name: rocm-${{ matrix.rocm-version }}-libraries
-          path: llm/build/**/bin/*
+          path: |
+            llm/build/**/bin/*
+            dist/windows-amd64/**
 
   # ROCm generation step
   generate-windows-rocm:
@@ -253,14 +269,9 @@ jobs:
           mkdir -p llm/build/darwin/$ARCH/stub/bin
           touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
         if: ${{ startsWith(matrix.os, 'macos-') }}
-      - run: |
-          mkdir -p llm/build/windows/$ARCH/stub/bin
-          touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
-        if: ${{ startsWith(matrix.os, 'windows-') }}
-        shell: bash
       - uses: golangci/golangci-lint-action@v4
         with:
-          args: --timeout 8m0s
+          args: --timeout 8m0s -v
   test:
     strategy:
       matrix:
@@ -284,7 +295,6 @@ jobs:
         with:
           go-version-file: go.mod
           cache: true
-      - run: go get
       - run: |
           case ${{ matrix.arch }} in
             amd64) echo ARCH=x86_64 ;;
@@ -299,10 +309,6 @@ jobs:
           mkdir -p llm/build/darwin/$ARCH/stub/bin
           touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
         if: ${{ startsWith(matrix.os, 'macos-') }}
-      - run: |
-          mkdir -p llm/build/windows/$ARCH/stub/bin
-          touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
-        if: ${{ startsWith(matrix.os, 'windows-') }}
         shell: bash
       - run: go generate ./...
       - run: go build

+ 2 - 1
.gitignore

@@ -11,4 +11,5 @@ ggml-metal.metal
 .idea
 test_data
 *.crt
-llm/build
+llm/build
+__debug_bin*

+ 7 - 7
Dockerfile

@@ -18,7 +18,7 @@ ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
 COPY --from=llm-code / /go/src/github.com/ollama/ollama/
 WORKDIR /go/src/github.com/ollama/ollama/llm/generate
 ARG CGO_CFLAGS
-RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
+RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
 
 FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
 ARG CMAKE_VERSION
@@ -28,7 +28,7 @@ ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
 COPY --from=llm-code / /go/src/github.com/ollama/ollama/
 WORKDIR /go/src/github.com/ollama/ollama/llm/generate
 ARG CGO_CFLAGS
-RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
+RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
 
 FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS rocm-build-amd64
 ARG CMAKE_VERSION
@@ -40,7 +40,7 @@ COPY --from=llm-code / /go/src/github.com/ollama/ollama/
 WORKDIR /go/src/github.com/ollama/ollama/llm/generate
 ARG CGO_CFLAGS
 ARG AMDGPU_TARGETS
-RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
+RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
 RUN mkdir /tmp/scratch && \
     for dep in $(zcat /go/src/github.com/ollama/ollama/llm/build/linux/x86_64/rocm*/bin/deps.txt.gz) ; do \
         cp ${dep} /tmp/scratch/ || exit 1 ; \
@@ -64,11 +64,11 @@ WORKDIR /go/src/github.com/ollama/ollama/llm/generate
 FROM --platform=linux/amd64 cpu-builder-amd64 AS static-build-amd64
 RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
-RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
+RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64
-RUN OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
+RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
-RUN OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
+RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
 
 FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
 ARG CMAKE_VERSION
@@ -84,7 +84,7 @@ WORKDIR /go/src/github.com/ollama/ollama/llm/generate
 FROM --platform=linux/arm64 cpu-builder-arm64 AS static-build-arm64
 RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
 FROM --platform=linux/arm64 cpu-builder-arm64 AS cpu-build-arm64
-RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
+RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
 
 
 # Intermediate stage used for ./scripts/build_linux.sh

+ 41 - 28
README.md

@@ -1,5 +1,5 @@
 <div align="center">
-  <img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
+ <img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
 </div>
 
 # Ollama
@@ -35,10 +35,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
 
 ## Quickstart
 
-To run and chat with [Llama 2](https://ollama.com/library/llama2):
+To run and chat with [Llama 3](https://ollama.com/library/llama3):
 
 ```
-ollama run llama2
+ollama run llama3
 ```
 
 ## Model library
@@ -49,17 +49,14 @@ Here are some example models that can be downloaded:
 
 | Model              | Parameters | Size  | Download                       |
 | ------------------ | ---------- | ----- | ------------------------------ |
-| Llama 2            | 7B         | 3.8GB | `ollama run llama2`            |
+| Llama 3            | 8B         | 4.7GB | `ollama run llama3`            |
+| Llama 3            | 70B        | 40GB  | `ollama run llama3:70b`        |
+| Phi-3              | 3.8B       | 2.3GB | `ollama run phi3`              |
 | Mistral            | 7B         | 4.1GB | `ollama run mistral`           |
-| Dolphin Phi        | 2.7B       | 1.6GB | `ollama run dolphin-phi`       |
-| Phi-2              | 2.7B       | 1.7GB | `ollama run phi`               |
 | Neural Chat        | 7B         | 4.1GB | `ollama run neural-chat`       |
 | Starling           | 7B         | 4.1GB | `ollama run starling-lm`       |
 | Code Llama         | 7B         | 3.8GB | `ollama run codellama`         |
 | Llama 2 Uncensored | 7B         | 3.8GB | `ollama run llama2-uncensored` |
-| Llama 2 13B        | 13B        | 7.3GB | `ollama run llama2:13b`        |
-| Llama 2 70B        | 70B        | 39GB  | `ollama run llama2:70b`        |
-| Orca Mini          | 3B         | 1.9GB | `ollama run orca-mini`         |
 | LLaVA              | 7B         | 4.5GB | `ollama run llava`             |
 | Gemma              | 2B         | 1.4GB | `ollama run gemma:2b`          |
 | Gemma              | 7B         | 4.8GB | `ollama run gemma:7b`          |
@@ -97,16 +94,16 @@ See the [guide](docs/import.md) on importing models for more information.
 
 ### Customize a prompt
 
-Models from the Ollama library can be customized with a prompt. For example, to customize the `llama2` model:
+Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3` model:
 
 ```
-ollama pull llama2
+ollama pull llama3
 ```
 
 Create a `Modelfile`:
 
 ```
-FROM llama2
+FROM llama3
 
 # set the temperature to 1 [higher is more creative, lower is more coherent]
 PARAMETER temperature 1
@@ -141,7 +138,7 @@ ollama create mymodel -f ./Modelfile
 ### Pull a model
 
 ```
-ollama pull llama2
+ollama pull llama3
 ```
 
 > This command can also be used to update a local model. Only the diff will be pulled.
@@ -149,13 +146,13 @@ ollama pull llama2
 ### Remove a model
 
 ```
-ollama rm llama2
+ollama rm llama3
 ```
 
 ### Copy a model
 
 ```
-ollama cp llama2 my-llama2
+ollama cp llama3 my-model
 ```
 
 ### Multiline input
@@ -176,10 +173,10 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol
 The image features a yellow smiley face, which is likely the central focus of the picture.
 ```
 
-### Pass in prompt as arguments
+### Pass the prompt as an argument
 
 ```
-$ ollama run llama2 "Summarize this file: $(cat README.md)"
+$ ollama run llama3 "Summarize this file: $(cat README.md)"
  Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
 ```
 
@@ -226,7 +223,7 @@ Next, start the server:
 Finally, in a separate shell, run a model:
 
 ```
-./ollama run llama2
+./ollama run llama3
 ```
 
 ## REST API
@@ -237,7 +234,7 @@ Ollama has a REST API for running and managing models.
 
 ```
 curl http://localhost:11434/api/generate -d '{
-  "model": "llama2",
+  "model": "llama3",
   "prompt":"Why is the sky blue?"
 }'
 ```
@@ -246,7 +243,7 @@ curl http://localhost:11434/api/generate -d '{
 
 ```
 curl http://localhost:11434/api/chat -d '{
-  "model": "mistral",
+  "model": "llama3",
   "messages": [
     { "role": "user", "content": "why is the sky blue?" }
   ]
@@ -259,16 +256,18 @@ See the [API documentation](./docs/api.md) for all endpoints.
 
 ### Web & Desktop
 
+- [Open WebUI](https://github.com/open-webui/open-webui)
+- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
+- [Hollama](https://github.com/fmaclen/hollama)
 - [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
 - [LibreChat](https://github.com/danny-avila/LibreChat)
 - [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
-- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
 - [HTML UI](https://github.com/rtcfirefly/ollama-ui)
 - [Saddle](https://github.com/jikkuatwork/saddle)
 - [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
+- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
 - [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
 - [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
-- [Open WebUI](https://github.com/open-webui/open-webui)
 - [Ollamac](https://github.com/kevinhermawan/Ollamac)
 - [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
 - [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
@@ -286,13 +285,20 @@ See the [API documentation](./docs/api.md) for all endpoints.
 - [OllamaGUI](https://github.com/enoch1118/ollamaGUI)
 - [OpenAOE](https://github.com/InternLM/OpenAOE)
 - [Odin Runes](https://github.com/leonid20000/OdinRunes)
-- [LLM-X: Progressive Web App](https://github.com/mrdjohnson/llm-x)
+- [LLM-X](https://github.com/mrdjohnson/llm-x) (Progressive Web App)
 - [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
 - [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
 - [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
-- [ChatOllama: Open Source Chatbot based on Ollama with Knowledge Bases](https://github.com/sugarforever/chat-ollama)
-- [CRAG Ollama Chat: Simple Web Search with Corrective RAG](https://github.com/Nagi-ovo/CRAG-Ollama-Chat)
-- [RAGFlow: Open-source Retrieval-Augmented Generation engine based on deep document understanding](https://github.com/infiniflow/ragflow)
+- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Chat with Code Repository)
+- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
+- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
+- [RAGFlow](https://github.com/infiniflow/ragflow) (Open-source Retrieval-Augmented Generation engine based on deep document understanding)
+- [StreamDeploy](https://github.com/StreamDeploy-DevRel/streamdeploy-llm-app-scaffold) (LLM Application Scaffold)
+- [chat](https://github.com/swuecho/chat) (chat web app for teams)
+- [Lobe Chat](https://github.com/lobehub/lobe-chat) with [Integrating Doc](https://lobehub.com/docs/self-hosting/examples/ollama)
+- [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG)
+- [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation)
+- [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends)
 
 ### Terminal
 
@@ -308,11 +314,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
 - [Oatmeal](https://github.com/dustinblackman/oatmeal)
 - [cmdh](https://github.com/pgibler/cmdh)
 - [ooo](https://github.com/npahlfer/ooo)
+- [shell-pilot](https://github.com/reid41/shell-pilot)
 - [tenere](https://github.com/pythops/tenere)
 - [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/).
 - [typechat-cli](https://github.com/anaisbetts/typechat-cli)
 - [ShellOracle](https://github.com/djcopley/ShellOracle)
 - [tlm](https://github.com/yusufcanb/tlm)
+- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
 
 ### Database
 
@@ -344,9 +352,11 @@ See the [API documentation](./docs/api.md) for all endpoints.
 - [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md)
 - [Elixir LangChain](https://github.com/brainlid/langchain)
 - [Ollama for R - rollama](https://github.com/JBGruber/rollama)
+- [Ollama for R - ollama-r](https://github.com/hauselin/ollama-r)
 - [Ollama-ex for Elixir](https://github.com/lebrunel/ollama-ex)
 - [Ollama Connector for SAP ABAP](https://github.com/b-tocs/abap_btocs_ollama)
 - [Testcontainers](https://testcontainers.com/modules/ollama/)
+- [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama)
 
 ### Mobile
 
@@ -366,17 +376,20 @@ See the [API documentation](./docs/api.md) for all endpoints.
 - [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram)
 - [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation)
 - [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
-- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
 - [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot)
 - [Cliobot](https://github.com/herval/cliobot) (Telegram bot with Ollama support)
 - [Copilot for Obsidian plugin](https://github.com/logancyang/obsidian-copilot)
 - [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
 - [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
+- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
+- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot)
 - [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
 - [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)
 - [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
 - [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend)
 - [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
+- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
 
 ### Supported backends 
-- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov. 
+- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
+

+ 76 - 10
api/client.go

@@ -1,9 +1,16 @@
 // Package api implements the client-side API for code wishing to interact
 // with the ollama service. The methods of the [Client] type correspond to
-// the ollama REST API as described in https://github.com/ollama/ollama/blob/main/docs/api.md
-//
+// the ollama REST API as described in [the API documentation].
 // The ollama command-line client itself uses this package to interact with
 // the backend service.
+//
+// # Examples
+//
+// Several examples of using this package are available [in the GitHub
+// repository].
+//
+// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
+// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
 package api
 
 import (
@@ -18,6 +25,7 @@ import (
 	"net/url"
 	"os"
 	"runtime"
+	"strconv"
 	"strings"
 
 	"github.com/ollama/ollama/format"
@@ -57,12 +65,36 @@ func checkError(resp *http.Response, body []byte) error {
 // If the variable is not specified, a default ollama host and port will be
 // used.
 func ClientFromEnvironment() (*Client, error) {
+	ollamaHost, err := GetOllamaHost()
+	if err != nil {
+		return nil, err
+	}
+
+	return &Client{
+		base: &url.URL{
+			Scheme: ollamaHost.Scheme,
+			Host:   net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
+		},
+		http: http.DefaultClient,
+	}, nil
+}
+
+type OllamaHost struct {
+	Scheme string
+	Host   string
+	Port   string
+}
+
+func GetOllamaHost() (OllamaHost, error) {
 	defaultPort := "11434"
 
-	scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
+	hostVar := os.Getenv("OLLAMA_HOST")
+	hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
+
+	scheme, hostport, ok := strings.Cut(hostVar, "://")
 	switch {
 	case !ok:
-		scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
+		scheme, hostport = "http", hostVar
 	case scheme == "http":
 		defaultPort = "80"
 	case scheme == "https":
@@ -82,15 +114,24 @@ func ClientFromEnvironment() (*Client, error) {
 		}
 	}
 
-	return &Client{
-		base: &url.URL{
-			Scheme: scheme,
-			Host:   net.JoinHostPort(host, port),
-		},
-		http: http.DefaultClient,
+	if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
+		return OllamaHost{}, ErrInvalidHostPort
+	}
+
+	return OllamaHost{
+		Scheme: scheme,
+		Host:   host,
+		Port:   port,
 	}, nil
 }
 
+func NewClient(base *url.URL, http *http.Client) *Client {
+	return &Client{
+		base: base,
+		http: http,
+	}
+}
+
 func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
 	var reqBody io.Reader
 	var data []byte
@@ -265,8 +306,14 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc
 	})
 }
 
+// PushProgressFunc is a function that [Client.Push] invokes when progress is
+// made.
+// It's similar to other progress function types like [PullProgressFunc].
 type PushProgressFunc func(ProgressResponse) error
 
+// Push uploads a model to the model library; requires registering for ollama.ai
+// and adding a public key first. fn is called each time progress is made on
+// the request and can be used to display a progress bar, etc.
 func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
 	return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
 		var resp ProgressResponse
@@ -278,8 +325,15 @@ func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc
 	})
 }
 
+// CreateProgressFunc is a function that [Client.Create] invokes when progress
+// is made.
+// It's similar to other progress function types like [PullProgressFunc].
 type CreateProgressFunc func(ProgressResponse) error
 
+// Create creates a model from a [Modelfile]. fn is a progress function that
+// behaves similarly to other methods (see [Client.Pull]).
+//
+// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
 func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
 	return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
 		var resp ProgressResponse
@@ -291,6 +345,7 @@ func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgre
 	})
 }
 
+// List lists models that are available locally.
 func (c *Client) List(ctx context.Context) (*ListResponse, error) {
 	var lr ListResponse
 	if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
@@ -299,6 +354,8 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
 	return &lr, nil
 }
 
+// Copy copies a model - creating a model with another name from an existing
+// model.
 func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
 	if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
 		return err
@@ -306,6 +363,7 @@ func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
 	return nil
 }
 
+// Delete deletes a model and its data.
 func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
 	if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
 		return err
@@ -313,6 +371,7 @@ func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
 	return nil
 }
 
+// Show obtains model information, including details, modelfile, license etc.
 func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) {
 	var resp ShowResponse
 	if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil {
@@ -321,12 +380,16 @@ func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, err
 	return &resp, nil
 }
 
+// Hearbeat checks if the server has started and is responsive; if yes, it
+// returns nil, otherwise an error.
 func (c *Client) Heartbeat(ctx context.Context) error {
 	if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
 		return err
 	}
 	return nil
 }
+
+// Embeddings generates embeddings from a model.
 func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
 	var resp EmbeddingResponse
 	if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
@@ -335,10 +398,13 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
 	return &resp, nil
 }
 
+// CreateBlob creates a blob from a file on the server. digest is the
+// expected SHA256 digest of the file, and r represents the file.
 func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
 	return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
 }
 
+// Version returns the Ollama server version as a string.
 func (c *Client) Version(ctx context.Context) (string, error) {
 	var version struct {
 		Version string `json:"version"`

+ 43 - 1
api/client_test.go

@@ -1,6 +1,12 @@
 package api
 
-import "testing"
+import (
+	"fmt"
+	"net"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
 
 func TestClientFromEnvironment(t *testing.T) {
 	type testCase struct {
@@ -40,4 +46,40 @@ func TestClientFromEnvironment(t *testing.T) {
 			}
 		})
 	}
+
+	hostTestCases := map[string]*testCase{
+		"empty":               {value: "", expect: "127.0.0.1:11434"},
+		"only address":        {value: "1.2.3.4", expect: "1.2.3.4:11434"},
+		"only port":           {value: ":1234", expect: ":1234"},
+		"address and port":    {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
+		"hostname":            {value: "example.com", expect: "example.com:11434"},
+		"hostname and port":   {value: "example.com:1234", expect: "example.com:1234"},
+		"zero port":           {value: ":0", expect: ":0"},
+		"too large port":      {value: ":66000", err: ErrInvalidHostPort},
+		"too small port":      {value: ":-1", err: ErrInvalidHostPort},
+		"ipv6 localhost":      {value: "[::1]", expect: "[::1]:11434"},
+		"ipv6 world open":     {value: "[::]", expect: "[::]:11434"},
+		"ipv6 no brackets":    {value: "::1", expect: "[::1]:11434"},
+		"ipv6 + port":         {value: "[::1]:1337", expect: "[::1]:1337"},
+		"extra space":         {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
+		"extra quotes":        {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
+		"extra space+quotes":  {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
+		"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
+	}
+
+	for k, v := range hostTestCases {
+		t.Run(k, func(t *testing.T) {
+			t.Setenv("OLLAMA_HOST", v.value)
+
+			oh, err := GetOllamaHost()
+			if err != v.err {
+				t.Fatalf("expected %s, got %s", v.err, err)
+			}
+
+			if err == nil {
+				host := net.JoinHostPort(oh.Host, oh.Port)
+				assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
+			}
+		})
+	}
 }

+ 82 - 15
api/types.go

@@ -2,6 +2,7 @@ package api
 
 import (
 	"encoding/json"
+	"errors"
 	"fmt"
 	"math"
 	"os"
@@ -11,6 +12,7 @@ import (
 	"time"
 )
 
+// StatusError is an error with and HTTP status code.
 type StatusError struct {
 	StatusCode   int
 	Status       string
@@ -31,6 +33,7 @@ func (e StatusError) Error() string {
 	}
 }
 
+// ImageData represents the raw binary data of an image file.
 type ImageData []byte
 
 // GenerateRequest describes a request sent by [Client.Generate]. While you
@@ -76,22 +79,39 @@ type GenerateRequest struct {
 	Options map[string]interface{} `json:"options"`
 }
 
+// ChatRequest describes a request sent by [Client.Chat].
 type ChatRequest struct {
-	Model     string    `json:"model"`
-	Messages  []Message `json:"messages"`
-	Stream    *bool     `json:"stream,omitempty"`
-	Format    string    `json:"format"`
+	// Model is the model name, as in [GenerateRequest].
+	Model string `json:"model"`
+
+	// Messages is the messages of the chat - can be used to keep a chat memory.
+	Messages []Message `json:"messages"`
+
+	// Stream enable streaming of returned response; true by default.
+	Stream *bool `json:"stream,omitempty"`
+
+	// Format is the format to return the response in (e.g. "json").
+	Format string `json:"format"`
+
+	// KeepAlive controls how long the model will stay loaded into memory
+	// followin the request.
 	KeepAlive *Duration `json:"keep_alive,omitempty"`
 
+	// Options lists model-specific options.
 	Options map[string]interface{} `json:"options"`
 }
 
+// Message is a single message in a chat sequence. The message contains the
+// role ("system", "user", or "assistant"), the content and an optional list
+// of images.
 type Message struct {
-	Role    string      `json:"role"` // one of ["system", "user", "assistant"]
+	Role    string      `json:"role"`
 	Content string      `json:"content"`
 	Images  []ImageData `json:"images,omitempty"`
 }
 
+// ChatResponse is the response returned by [Client.Chat]. Its fields are
+// similar to [GenerateResponse].
 type ChatResponse struct {
 	Model     string    `json:"model"`
 	CreatedAt time.Time `json:"created_at"`
@@ -111,7 +131,8 @@ type Metrics struct {
 	EvalDuration       time.Duration `json:"eval_duration,omitempty"`
 }
 
-// Options specified in GenerateRequest, if you add a new option here add it to the API docs also
+// Options specified in [GenerateRequest], if you add a new option here add it
+// to the API docs also.
 type Options struct {
 	Runner
 
@@ -157,18 +178,28 @@ type Runner struct {
 	RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
 }
 
+// EmbeddingRequest is the request passed to [Client.Embeddings].
 type EmbeddingRequest struct {
-	Model     string    `json:"model"`
-	Prompt    string    `json:"prompt"`
+	// Model is the model name.
+	Model string `json:"model"`
+
+	// Prompt is the textual prompt to embed.
+	Prompt string `json:"prompt"`
+
+	// KeepAlive controls how long the model will stay loaded in memory following
+	// this request.
 	KeepAlive *Duration `json:"keep_alive,omitempty"`
 
+	// Options lists model-specific options.
 	Options map[string]interface{} `json:"options"`
 }
 
+// EmbeddingResponse is the response from [Client.Embeddings].
 type EmbeddingResponse struct {
 	Embedding []float64 `json:"embedding"`
 }
 
+// CreateRequest is the request passed to [Client.Create].
 type CreateRequest struct {
 	Model        string `json:"model"`
 	Path         string `json:"path"`
@@ -180,6 +211,7 @@ type CreateRequest struct {
 	Name string `json:"name"`
 }
 
+// DeleteRequest is the request passed to [Client.Delete].
 type DeleteRequest struct {
 	Model string `json:"model"`
 
@@ -187,6 +219,7 @@ type DeleteRequest struct {
 	Name string `json:"name"`
 }
 
+// ShowRequest is the request passed to [Client.Show].
 type ShowRequest struct {
 	Model    string `json:"model"`
 	System   string `json:"system"`
@@ -198,6 +231,7 @@ type ShowRequest struct {
 	Name string `json:"name"`
 }
 
+// ShowResponse is the response returned from [Client.Show].
 type ShowResponse struct {
 	License    string       `json:"license,omitempty"`
 	Modelfile  string       `json:"modelfile,omitempty"`
@@ -208,11 +242,13 @@ type ShowResponse struct {
 	Messages   []Message    `json:"messages,omitempty"`
 }
 
+// CopyRequest is the request passed to [Client.Copy].
 type CopyRequest struct {
 	Source      string `json:"source"`
 	Destination string `json:"destination"`
 }
 
+// PullRequest is the request passed to [Client.Pull].
 type PullRequest struct {
 	Model    string `json:"model"`
 	Insecure bool   `json:"insecure,omitempty"`
@@ -224,6 +260,8 @@ type PullRequest struct {
 	Name string `json:"name"`
 }
 
+// ProgressResponse is the response passed to progress functions like
+// [PullProgressFunc] and [PushProgressFunc].
 type ProgressResponse struct {
 	Status    string `json:"status"`
 	Digest    string `json:"digest,omitempty"`
@@ -231,6 +269,7 @@ type ProgressResponse struct {
 	Completed int64  `json:"completed,omitempty"`
 }
 
+// PushRequest is the request passed to [Client.Push].
 type PushRequest struct {
 	Model    string `json:"model"`
 	Insecure bool   `json:"insecure,omitempty"`
@@ -242,10 +281,12 @@ type PushRequest struct {
 	Name string `json:"name"`
 }
 
+// ListResponse is the response from [Client.List].
 type ListResponse struct {
 	Models []ModelResponse `json:"models"`
 }
 
+// ModelResponse is a single model description in [ListResponse].
 type ModelResponse struct {
 	Name       string       `json:"name"`
 	Model      string       `json:"model"`
@@ -259,17 +300,28 @@ type TokenResponse struct {
 	Token string `json:"token"`
 }
 
+// GenerateResponse is the response passed into [GenerateResponseFunc].
 type GenerateResponse struct {
-	Model     string    `json:"model"`
+	// Model is the model name that generated the response.
+	Model string `json:"model"`
+
+	//CreatedAt is the timestamp of the response.
 	CreatedAt time.Time `json:"created_at"`
-	Response  string    `json:"response"`
 
-	Done    bool  `json:"done"`
+	// Response is the textual response itself.
+	Response string `json:"response"`
+
+	// Done specifies if the response is complete.
+	Done bool `json:"done"`
+
+	// Context is an encoding of the conversation used in this response; this
+	// can be sent in the next request to keep a conversational memory.
 	Context []int `json:"context,omitempty"`
 
 	Metrics
 }
 
+// ModelDetails provides details about a model.
 type ModelDetails struct {
 	ParentModel       string   `json:"parent_model"`
 	Format            string   `json:"format"`
@@ -307,7 +359,9 @@ func (m *Metrics) Summary() {
 	}
 }
 
-var ErrInvalidOpts = fmt.Errorf("invalid options")
+// ErrInvalidOpts is returned when invalid options are passed to the client.
+var ErrInvalidOpts = errors.New("invalid options")
+var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
 
 func (opts *Options) FromMap(m map[string]interface{}) error {
 	valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
@@ -392,11 +446,15 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
 	return nil
 }
 
+// DefaultOptions is the default set of options for [GenerateRequest]; these
+// values are used unless the user specifies other values explicitly.
 func DefaultOptions() Options {
 	return Options{
 		// options set on request to runner
-		NumPredict:       -1,
-		NumKeep:          0,
+		NumPredict: -1,
+
+		// set a minimal num_keep to avoid issues on context shifts
+		NumKeep:          4,
 		Temperature:      0.8,
 		TopK:             40,
 		TopP:             0.9,
@@ -432,6 +490,13 @@ type Duration struct {
 	time.Duration
 }
 
+func (d Duration) MarshalJSON() ([]byte, error) {
+	if d.Duration < 0 {
+		return []byte("-1"), nil
+	}
+	return []byte("\"" + d.Duration.String() + "\""), nil
+}
+
 func (d *Duration) UnmarshalJSON(b []byte) (err error) {
 	var v any
 	if err := json.Unmarshal(b, &v); err != nil {
@@ -445,7 +510,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
 		if t < 0 {
 			d.Duration = time.Duration(math.MaxInt64)
 		} else {
-			d.Duration = time.Duration(t * float64(time.Second))
+			d.Duration = time.Duration(int(t) * int(time.Second))
 		}
 	case string:
 		d.Duration, err = time.ParseDuration(t)
@@ -455,6 +520,8 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
 		if d.Duration < 0 {
 			d.Duration = time.Duration(math.MaxInt64)
 		}
+	default:
+		return fmt.Errorf("Unsupported type: '%s'", reflect.TypeOf(v))
 	}
 
 	return nil

+ 57 - 0
api/types_test.go

@@ -21,6 +21,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
 			req:  `{ "keep_alive": 42 }`,
 			exp:  &Duration{42 * time.Second},
 		},
+		{
+			name: "Positive Float",
+			req:  `{ "keep_alive": 42.5 }`,
+			exp:  &Duration{42 * time.Second},
+		},
 		{
 			name: "Positive Integer String",
 			req:  `{ "keep_alive": "42m" }`,
@@ -31,6 +36,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
 			req:  `{ "keep_alive": -1 }`,
 			exp:  &Duration{math.MaxInt64},
 		},
+		{
+			name: "Negative Float",
+			req:  `{ "keep_alive": -3.14 }`,
+			exp:  &Duration{math.MaxInt64},
+		},
 		{
 			name: "Negative Integer String",
 			req:  `{ "keep_alive": "-1m" }`,
@@ -48,3 +58,50 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
 		})
 	}
 }
+
+func TestDurationMarshalUnmarshal(t *testing.T) {
+	tests := []struct {
+		name     string
+		input    time.Duration
+		expected time.Duration
+	}{
+		{
+			"negative duration",
+			time.Duration(-1),
+			time.Duration(math.MaxInt64),
+		},
+		{
+			"positive duration",
+			time.Duration(42 * time.Second),
+			time.Duration(42 * time.Second),
+		},
+		{
+			"another positive duration",
+			time.Duration(42 * time.Minute),
+			time.Duration(42 * time.Minute),
+		},
+		{
+			"zero duration",
+			time.Duration(0),
+			time.Duration(0),
+		},
+		{
+			"max duration",
+			time.Duration(math.MaxInt64),
+			time.Duration(math.MaxInt64),
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			b, err := json.Marshal(Duration{test.input})
+			require.NoError(t, err)
+
+			var d Duration
+			err = json.Unmarshal(b, &d)
+			require.NoError(t, err)
+
+			assert.Equal(t, test.expected, d.Duration, "input %v, marshalled %v, got %v", test.input, string(b), d.Duration)
+		})
+	}
+}

+ 3 - 1
app/lifecycle/logging.go

@@ -5,12 +5,14 @@ import (
 	"log/slog"
 	"os"
 	"path/filepath"
+
+	"github.com/ollama/ollama/server/envconfig"
 )
 
 func InitLogging() {
 	level := slog.LevelInfo
 
-	if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
+	if envconfig.Debug {
 		level = slog.LevelDebug
 	}
 

+ 37 - 27
app/lifecycle/server.go

@@ -43,37 +43,36 @@ func getCLIFullPath(command string) string {
 	return command
 }
 
-func SpawnServer(ctx context.Context, command string) (chan int, error) {
-	done := make(chan int)
-
-	logDir := filepath.Dir(ServerLogFile)
-	_, err := os.Stat(logDir)
-	if errors.Is(err, os.ErrNotExist) {
-		if err := os.MkdirAll(logDir, 0o755); err != nil {
-			return done, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
-		}
-	}
-
+func start(ctx context.Context, command string) (*exec.Cmd, error) {
 	cmd := getCmd(ctx, getCLIFullPath(command))
-	// send stdout and stderr to a file
 	stdout, err := cmd.StdoutPipe()
 	if err != nil {
-		return done, fmt.Errorf("failed to spawn server stdout pipe %s", err)
+		return nil, fmt.Errorf("failed to spawn server stdout pipe: %w", err)
 	}
 	stderr, err := cmd.StderrPipe()
 	if err != nil {
-		return done, fmt.Errorf("failed to spawn server stderr pipe %s", err)
-	}
-	stdin, err := cmd.StdinPipe()
-	if err != nil {
-		return done, fmt.Errorf("failed to spawn server stdin pipe %s", err)
+		return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
 	}
 
 	// TODO - rotation
 	logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
 	if err != nil {
-		return done, fmt.Errorf("failed to create server log %w", err)
+		return nil, fmt.Errorf("failed to create server log: %w", err)
 	}
+
+	logDir := filepath.Dir(ServerLogFile)
+	_, err = os.Stat(logDir)
+	if err != nil {
+		if !errors.Is(err, os.ErrNotExist) {
+			return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
+
+		}
+
+		if err := os.MkdirAll(logDir, 0o755); err != nil {
+			return nil, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
+		}
+	}
+
 	go func() {
 		defer logFile.Close()
 		io.Copy(logFile, stdout) //nolint:errcheck
@@ -117,19 +116,33 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
 
 	// run the command and wait for it to finish
 	if err := cmd.Start(); err != nil {
-		return done, fmt.Errorf("failed to start server %w", err)
+		return nil, fmt.Errorf("failed to start server %w", err)
 	}
 	if cmd.Process != nil {
 		slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
 	}
 	slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))
 
+	return cmd, nil
+}
+
+func SpawnServer(ctx context.Context, command string) (chan int, error) {
+	done := make(chan int)
+
 	go func() {
 		// Keep the server running unless we're shuttind down the app
 		crashCount := 0
 		for {
+			slog.Info("starting server...")
+			cmd, err := start(ctx, command)
+			if err != nil {
+				crashCount++
+				slog.Error(fmt.Sprintf("failed to start server %s", err))
+				time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
+				continue
+			}
+
 			cmd.Wait() //nolint:errcheck
-			stdin.Close()
 			var code int
 			if cmd.ProcessState != nil {
 				code = cmd.ProcessState.ExitCode()
@@ -143,15 +156,12 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
 			default:
 				crashCount++
 				slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
-				time.Sleep(500 * time.Millisecond)
-				if err := cmd.Start(); err != nil {
-					slog.Error(fmt.Sprintf("failed to restart server %s", err))
-					// Keep trying, but back off if we keep failing
-					time.Sleep(time.Duration(crashCount) * time.Second)
-				}
+				time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
+				break
 			}
 		}
 	}()
+
 	return done, nil
 }
 

+ 1 - 4
app/lifecycle/updater_windows.go

@@ -31,16 +31,13 @@ func DoUpgrade(cancel context.CancelFunc, done chan int) error {
 		"/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd
 		"/FORCECLOSEAPPLICATIONS",               // Force close the tray app - might be needed
 	}
-	// When we're not in debug mode, make the upgrade as quiet as possible (no GUI, no prompts)
-	// TODO - temporarily disable since we're pinning in debug mode for the preview
-	// if debug := os.Getenv("OLLAMA_DEBUG"); debug == "" {
+	// make the upgrade as quiet as possible (no GUI, no prompts)
 	installArgs = append(installArgs,
 		"/SP", // Skip the "This will install... Do you wish to continue" prompt
 		"/SUPPRESSMSGBOXES",
 		"/SILENT",
 		"/VERYSILENT",
 	)
-	// }
 
 	// Safeguard in case we have requests in flight that need to drain...
 	slog.Info("Waiting for server to shutdown")

+ 5 - 8
app/ollama.iss

@@ -88,15 +88,12 @@ DialogFontSize=12
 [Files]
 Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
 Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
-Source: "..\dist\windeps\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
+Source: "..\dist\windows-{#ARCH}\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
+Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
 Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
 Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
-; Assumes v5.7, may need adjustments for v6
-#if GetEnv("HIP_PATH") != ""
-  Source: "{#GetEnv('HIP_PATH')}\bin\hipblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion
-  Source: "{#GetEnv('HIP_PATH')}\bin\rocblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion
-  ; amdhip64.dll dependency comes from the driver and must be installed already
-  Source: "{#GetEnv('HIP_PATH')}\bin\rocblas\library\*"; DestDir: "{app}\rocm\rocblas\library\"; Flags: ignoreversion
+#if DirExists("..\dist\windows-amd64\rocm")
+  Source: "..\dist\windows-amd64\rocm\*"; DestDir: "{app}\rocm\"; Flags: ignoreversion recursesubdirs
 #endif
 
 
@@ -132,7 +129,7 @@ SetupAppRunningError=Another Ollama installer is running.%n%nPlease cancel or fi
 
 
 ;FinishedHeadingLabel=Run your first model
-;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n    ollama run llama2
+;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n    ollama run llama3
 ;ClickFinish=%n
 
 [Registry]

+ 71 - 71
app/tray/wintray/menus.go

@@ -1,71 +1,71 @@
-//go:build windows
-
-package wintray
-
-import (
-	"fmt"
-	"log/slog"
-	"unsafe"
-
-	"golang.org/x/sys/windows"
-)
-
-const (
-	updatAvailableMenuID = 1
-	updateMenuID         = updatAvailableMenuID + 1
-	separatorMenuID      = updateMenuID + 1
-	diagLogsMenuID       = separatorMenuID + 1
-	diagSeparatorMenuID  = diagLogsMenuID + 1
-	quitMenuID           = diagSeparatorMenuID + 1
-)
-
-func (t *winTray) initMenus() error {
-	if err := t.addOrUpdateMenuItem(diagLogsMenuID, 0, diagLogsMenuTitle, false); err != nil {
-		return fmt.Errorf("unable to create menu entries %w\n", err)
-	}
-	if err := t.addSeparatorMenuItem(diagSeparatorMenuID, 0); err != nil {
-		return fmt.Errorf("unable to create menu entries %w", err)
-	}
-	if err := t.addOrUpdateMenuItem(quitMenuID, 0, quitMenuTitle, false); err != nil {
-		return fmt.Errorf("unable to create menu entries %w\n", err)
-	}
-	return nil
-}
-
-func (t *winTray) UpdateAvailable(ver string) error {
-	if !t.updateNotified {
-		slog.Debug("updating menu and sending notification for new update")
-		if err := t.addOrUpdateMenuItem(updatAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil {
-			return fmt.Errorf("unable to create menu entries %w", err)
-		}
-		if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil {
-			return fmt.Errorf("unable to create menu entries %w", err)
-		}
-		if err := t.addSeparatorMenuItem(separatorMenuID, 0); err != nil {
-			return fmt.Errorf("unable to create menu entries %w", err)
-		}
-		iconFilePath, err := iconBytesToFilePath(wt.updateIcon)
-		if err != nil {
-			return fmt.Errorf("unable to write icon data to temp file: %w", err)
-		}
-		if err := wt.setIcon(iconFilePath); err != nil {
-			return fmt.Errorf("unable to set icon: %w", err)
-		}
-		t.updateNotified = true
-
-		t.pendingUpdate = true
-		// Now pop up the notification
-		t.muNID.Lock()
-		defer t.muNID.Unlock()
-		copy(t.nid.InfoTitle[:], windows.StringToUTF16(updateTitle))
-		copy(t.nid.Info[:], windows.StringToUTF16(fmt.Sprintf(updateMessage, ver)))
-		t.nid.Flags |= NIF_INFO
-		t.nid.Timeout = 10
-		t.nid.Size = uint32(unsafe.Sizeof(*wt.nid))
-		err = t.nid.modify()
-		if err != nil {
-			return err
-		}
-	}
-	return nil
-}
+//go:build windows
+
+package wintray
+
+import (
+	"fmt"
+	"log/slog"
+	"unsafe"
+
+	"golang.org/x/sys/windows"
+)
+
+const (
+	updatAvailableMenuID = 1
+	updateMenuID         = updatAvailableMenuID + 1
+	separatorMenuID      = updateMenuID + 1
+	diagLogsMenuID       = separatorMenuID + 1
+	diagSeparatorMenuID  = diagLogsMenuID + 1
+	quitMenuID           = diagSeparatorMenuID + 1
+)
+
+func (t *winTray) initMenus() error {
+	if err := t.addOrUpdateMenuItem(diagLogsMenuID, 0, diagLogsMenuTitle, false); err != nil {
+		return fmt.Errorf("unable to create menu entries %w\n", err)
+	}
+	if err := t.addSeparatorMenuItem(diagSeparatorMenuID, 0); err != nil {
+		return fmt.Errorf("unable to create menu entries %w", err)
+	}
+	if err := t.addOrUpdateMenuItem(quitMenuID, 0, quitMenuTitle, false); err != nil {
+		return fmt.Errorf("unable to create menu entries %w\n", err)
+	}
+	return nil
+}
+
+func (t *winTray) UpdateAvailable(ver string) error {
+	if !t.updateNotified {
+		slog.Debug("updating menu and sending notification for new update")
+		if err := t.addOrUpdateMenuItem(updatAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil {
+			return fmt.Errorf("unable to create menu entries %w", err)
+		}
+		if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil {
+			return fmt.Errorf("unable to create menu entries %w", err)
+		}
+		if err := t.addSeparatorMenuItem(separatorMenuID, 0); err != nil {
+			return fmt.Errorf("unable to create menu entries %w", err)
+		}
+		iconFilePath, err := iconBytesToFilePath(wt.updateIcon)
+		if err != nil {
+			return fmt.Errorf("unable to write icon data to temp file: %w", err)
+		}
+		if err := wt.setIcon(iconFilePath); err != nil {
+			return fmt.Errorf("unable to set icon: %w", err)
+		}
+		t.updateNotified = true
+
+		t.pendingUpdate = true
+		// Now pop up the notification
+		t.muNID.Lock()
+		defer t.muNID.Unlock()
+		copy(t.nid.InfoTitle[:], windows.StringToUTF16(updateTitle))
+		copy(t.nid.Info[:], windows.StringToUTF16(fmt.Sprintf(updateMessage, ver)))
+		t.nid.Flags |= NIF_INFO
+		t.nid.Timeout = 10
+		t.nid.Size = uint32(unsafe.Sizeof(*wt.nid))
+		err = t.nid.modify()
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}

+ 33 - 3
auth/auth.go

@@ -10,12 +10,44 @@ import (
 	"log/slog"
 	"os"
 	"path/filepath"
+	"strings"
 
 	"golang.org/x/crypto/ssh"
 )
 
 const defaultPrivateKey = "id_ed25519"
 
+func keyPath() (string, error) {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return "", err
+	}
+
+	return filepath.Join(home, ".ollama", defaultPrivateKey), nil
+}
+
+func GetPublicKey() (string, error) {
+	keyPath, err := keyPath()
+	if err != nil {
+		return "", err
+	}
+
+	privateKeyFile, err := os.ReadFile(keyPath)
+	if err != nil {
+		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
+		return "", err
+	}
+
+	privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
+	if err != nil {
+		return "", err
+	}
+
+	publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
+
+	return strings.TrimSpace(string(publicKey)), nil
+}
+
 func NewNonce(r io.Reader, length int) (string, error) {
 	nonce := make([]byte, length)
 	if _, err := io.ReadFull(r, nonce); err != nil {
@@ -26,13 +58,11 @@ func NewNonce(r io.Reader, length int) (string, error) {
 }
 
 func Sign(ctx context.Context, bts []byte) (string, error) {
-	home, err := os.UserHomeDir()
+	keyPath, err := keyPath()
 	if err != nil {
 		return "", err
 	}
 
-	keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
-
 	privateKeyFile, err := os.ReadFile(keyPath)
 	if err != nil {
 		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))

+ 193 - 103
cmd/cmd.go

@@ -17,6 +17,7 @@ import (
 	"os"
 	"os/signal"
 	"path/filepath"
+	"regexp"
 	"runtime"
 	"strings"
 	"syscall"
@@ -31,10 +32,12 @@ import (
 	"golang.org/x/term"
 
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/auth"
 	"github.com/ollama/ollama/format"
-	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/progress"
 	"github.com/ollama/ollama/server"
+	"github.com/ollama/ollama/types/errtypes"
+	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
 )
 
@@ -53,14 +56,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	p := progress.NewProgress(os.Stderr)
 	defer p.Stop()
 
-	bars := make(map[string]*progress.Bar)
-
-	modelfile, err := os.ReadFile(filename)
+	f, err := os.Open(filename)
 	if err != nil {
 		return err
 	}
+	defer f.Close()
 
-	commands, err := parser.Parse(bytes.NewReader(modelfile))
+	modelfile, err := model.ParseFile(f)
 	if err != nil {
 		return err
 	}
@@ -74,10 +76,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	spinner := progress.NewSpinner(status)
 	p.Add(status, spinner)
 
-	for _, c := range commands {
-		switch c.Name {
+	for i := range modelfile.Commands {
+		switch modelfile.Commands[i].Name {
 		case "model", "adapter":
-			path := c.Args
+			path := modelfile.Commands[i].Args
 			if path == "~" {
 				path = home
 			} else if strings.HasPrefix(path, "~/") {
@@ -89,101 +91,22 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 			}
 
 			fi, err := os.Stat(path)
-			if errors.Is(err, os.ErrNotExist) && c.Name == "model" {
+			if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
 				continue
 			} else if err != nil {
 				return err
 			}
 
-			// TODO make this work w/ adapters
 			if fi.IsDir() {
-				tf, err := os.CreateTemp("", "ollama-tf")
-				if err != nil {
-					return err
-				}
-				defer os.RemoveAll(tf.Name())
-
-				zf := zip.NewWriter(tf)
-
-				files := []string{}
-
-				tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
+				// this is likely a safetensors or pytorch directory
+				// TODO make this work w/ adapters
+				tempfile, err := tempZipFiles(path)
 				if err != nil {
 					return err
-				} else if len(tfiles) == 0 {
-					tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
-					if err != nil {
-						return err
-					}
-				}
-
-				files = append(files, tfiles...)
-
-				if len(files) == 0 {
-					return fmt.Errorf("no models were found in '%s'", path)
 				}
+				defer os.RemoveAll(tempfile)
 
-				// add the safetensor/torch config file + tokenizer
-				files = append(files, filepath.Join(path, "config.json"))
-				files = append(files, filepath.Join(path, "params.json"))
-				files = append(files, filepath.Join(path, "added_tokens.json"))
-				files = append(files, filepath.Join(path, "tokenizer.model"))
-
-				for _, fn := range files {
-					f, err := os.Open(fn)
-
-					// just skip whatever files aren't there
-					if os.IsNotExist(err) {
-						if strings.HasSuffix(fn, "tokenizer.model") {
-							// try the parent dir before giving up
-							parentDir := filepath.Dir(path)
-							newFn := filepath.Join(parentDir, "tokenizer.model")
-							f, err = os.Open(newFn)
-							if os.IsNotExist(err) {
-								continue
-							} else if err != nil {
-								return err
-							}
-						} else {
-							continue
-						}
-					} else if err != nil {
-						return err
-					}
-
-					fi, err := f.Stat()
-					if err != nil {
-						return err
-					}
-
-					h, err := zip.FileInfoHeader(fi)
-					if err != nil {
-						return err
-					}
-
-					h.Name = filepath.Base(fn)
-					h.Method = zip.Store
-
-					w, err := zf.CreateHeader(h)
-					if err != nil {
-						return err
-					}
-
-					_, err = io.Copy(w, f)
-					if err != nil {
-						return err
-					}
-
-				}
-
-				if err := zf.Close(); err != nil {
-					return err
-				}
-
-				if err := tf.Close(); err != nil {
-					return err
-				}
-				path = tf.Name()
+				path = tempfile
 			}
 
 			digest, err := createBlob(cmd, client, path)
@@ -191,10 +114,11 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 				return err
 			}
 
-			modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte("@"+digest))
+			modelfile.Commands[i].Args = "@" + digest
 		}
 	}
 
+	bars := make(map[string]*progress.Bar)
 	fn := func(resp api.ProgressResponse) error {
 		if resp.Digest != "" {
 			spinner.Stop()
@@ -220,7 +144,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 
 	quantization, _ := cmd.Flags().GetString("quantization")
 
-	request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization}
+	request := api.CreateRequest{Name: args[0], Modelfile: modelfile.String(), Quantization: quantization}
 	if err := client.Create(cmd.Context(), &request, fn); err != nil {
 		return err
 	}
@@ -228,6 +152,114 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	return nil
 }
 
+func tempZipFiles(path string) (string, error) {
+	tempfile, err := os.CreateTemp("", "ollama-tf")
+	if err != nil {
+		return "", err
+	}
+	defer tempfile.Close()
+
+	zipfile := zip.NewWriter(tempfile)
+	defer zipfile.Close()
+
+	detectContentType := func(path string) (string, error) {
+		f, err := os.Open(path)
+		if err != nil {
+			return "", err
+		}
+		defer f.Close()
+
+		var b bytes.Buffer
+		b.Grow(512)
+
+		if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
+			return "", err
+		}
+
+		contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
+		return contentType, nil
+	}
+
+	glob := func(pattern, contentType string) ([]string, error) {
+		matches, err := filepath.Glob(pattern)
+		if err != nil {
+			return nil, err
+		}
+
+		for _, safetensor := range matches {
+			if ct, err := detectContentType(safetensor); err != nil {
+				return nil, err
+			} else if ct != contentType {
+				return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
+			}
+		}
+
+		return matches, nil
+	}
+
+	var files []string
+	if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
+		// safetensors files might be unresolved git lfs references; skip if they are
+		// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
+		files = append(files, st...)
+	} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
+		// pytorch files might also be unresolved git lfs references; skip if they are
+		// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
+		files = append(files, pt...)
+	} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/octet-stream"); len(pt) > 0 {
+		// pytorch files might also be unresolved git lfs references; skip if they are
+		// covers consolidated.x.pth, consolidated.pth
+		files = append(files, pt...)
+	} else {
+		return "", errors.New("no safetensors or torch files found")
+	}
+
+	// add configuration files, json files are detected as text/plain
+	js, err := glob(filepath.Join(path, "*.json"), "text/plain")
+	if err != nil {
+		return "", err
+	}
+	files = append(files, js...)
+
+	if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
+		// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
+		// tokenizer.model might be a unresolved git lfs reference; error if it is
+		files = append(files, tks...)
+	} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
+		// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
+		files = append(files, tks...)
+	}
+
+	for _, file := range files {
+		f, err := os.Open(file)
+		if err != nil {
+			return "", err
+		}
+		defer f.Close()
+
+		fi, err := f.Stat()
+		if err != nil {
+			return "", err
+		}
+
+		zfi, err := zip.FileInfoHeader(fi)
+		if err != nil {
+			return "", err
+		}
+
+		zf, err := zipfile.CreateHeader(zfi)
+		if err != nil {
+			return "", err
+		}
+
+		if _, err := io.Copy(zf, f); err != nil {
+			return "", err
+		}
+	}
+
+	return tempfile.Name(), nil
+}
+
 func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
 	bin, err := os.Open(path)
 	if err != nil {
@@ -322,6 +354,47 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 	return generateInteractive(cmd, opts)
 }
 
+func errFromUnknownKey(unknownKeyErr error) error {
+	// find SSH public key in the error message
+	sshKeyPattern := `ssh-\w+ [^\s"]+`
+	re := regexp.MustCompile(sshKeyPattern)
+	matches := re.FindStringSubmatch(unknownKeyErr.Error())
+
+	if len(matches) > 0 {
+		serverPubKey := matches[0]
+
+		localPubKey, err := auth.GetPublicKey()
+		if err != nil {
+			return unknownKeyErr
+		}
+
+		if runtime.GOOS == "linux" && serverPubKey != localPubKey {
+			// try the ollama service public key
+			svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
+			if err != nil {
+				return unknownKeyErr
+			}
+			localPubKey = strings.TrimSpace(string(svcPubKey))
+		}
+
+		// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
+		if serverPubKey != localPubKey {
+			return unknownKeyErr
+		}
+
+		var msg strings.Builder
+		msg.WriteString(unknownKeyErr.Error())
+		msg.WriteString("\n\nYour ollama key is:\n")
+		msg.WriteString(localPubKey)
+		msg.WriteString("\nAdd your key at:\n")
+		msg.WriteString("https://ollama.com/settings/keys")
+
+		return errors.New(msg.String())
+	}
+
+	return unknownKeyErr
+}
+
 func PushHandler(cmd *cobra.Command, args []string) error {
 	client, err := api.ClientFromEnvironment()
 	if err != nil {
@@ -369,6 +442,20 @@ func PushHandler(cmd *cobra.Command, args []string) error {
 
 	request := api.PushRequest{Name: args[0], Insecure: insecure}
 	if err := client.Push(cmd.Context(), &request, fn); err != nil {
+		if spinner != nil {
+			spinner.Stop()
+		}
+		if strings.Contains(err.Error(), "access denied") {
+			return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
+		}
+		host := model.ParseName(args[0]).Host
+		isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
+		if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
+			// the user has not added their ollama key to ollama.com
+			// re-throw an error with a more user-friendly message
+			return errFromUnknownKey(err)
+		}
+
 		return err
 	}
 
@@ -796,24 +883,27 @@ func generate(cmd *cobra.Command, opts runOptions) error {
 }
 
 func RunServer(cmd *cobra.Command, _ []string) error {
-	host, port, err := net.SplitHostPort(strings.Trim(os.Getenv("OLLAMA_HOST"), "\"'"))
+	// retrieve the OLLAMA_HOST environment variable
+	ollamaHost, err := api.GetOllamaHost()
 	if err != nil {
-		host, port = "127.0.0.1", "11434"
-		if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
-			host = ip.String()
-		}
+		return err
 	}
 
 	if err := initializeKeypair(); err != nil {
 		return err
 	}
 
-	ln, err := net.Listen("tcp", net.JoinHostPort(host, port))
+	ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
 	if err != nil {
 		return err
 	}
 
-	return server.Serve(ln)
+	err = server.Serve(ln)
+	if errors.Is(err, http.ErrServerClosed) {
+		return nil
+	}
+
+	return err
 }
 
 func initializeKeypair() error {
@@ -1034,7 +1124,7 @@ Environment Variables:
 		RunE:    ListHandler,
 	}
 	copyCmd := &cobra.Command{
-		Use:     "cp SOURCE TARGET",
+		Use:     "cp SOURCE DESTINATION",
 		Short:   "Copy a model",
 		Args:    cobra.ExactArgs(2),
 		PreRunE: checkServerHeartbeat,

+ 6 - 1
cmd/interactive.go

@@ -94,6 +94,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 		fmt.Fprintln(os.Stderr, "  /show           Show model information")
 		fmt.Fprintln(os.Stderr, "  /load <model>   Load a session or model")
 		fmt.Fprintln(os.Stderr, "  /save <model>   Save your current session")
+		fmt.Fprintln(os.Stderr, "  /clear          Clear session context")
 		fmt.Fprintln(os.Stderr, "  /bye            Exit")
 		fmt.Fprintln(os.Stderr, "  /?, /help       Help for a command")
 		fmt.Fprintln(os.Stderr, "  /? shortcuts    Help for keyboard shortcuts")
@@ -161,7 +162,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 		fmt.Fprintln(os.Stderr, "  /set parameter repeat_penalty <float> How strongly to penalize repetitions")
 		fmt.Fprintln(os.Stderr, "  /set parameter repeat_last_n <int>    Set how far back to look for repetitions")
 		fmt.Fprintln(os.Stderr, "  /set parameter num_gpu <int>          The number of layers to send to the GPU")
-		fmt.Fprintln(os.Stderr, "  /set parameter stop \"<string>\", ...   Set the stop parameters")
+		fmt.Fprintln(os.Stderr, "  /set parameter stop <string> <string> ...   Set the stop parameters")
 		fmt.Fprintln(os.Stderr, "")
 	}
 
@@ -280,6 +281,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 			}
 			fmt.Printf("Created new model '%s'\n", args[1])
 			continue
+		case strings.HasPrefix(line, "/clear"):
+			opts.Messages = []api.Message{}
+			fmt.Println("Cleared session context")
+			continue
 		case strings.HasPrefix(line, "/set"):
 			args := strings.Fields(line)
 			if len(args) > 1 {

+ 19 - 14
convert/convert.go

@@ -5,6 +5,7 @@ import (
 	"encoding/binary"
 	"encoding/json"
 	"fmt"
+	"io"
 	"log/slog"
 	"os"
 	"path/filepath"
@@ -18,19 +19,23 @@ import (
 )
 
 type Params struct {
-	Architectures    []string `json:"architectures"`
-	VocabSize        int      `json:"vocab_size"`
-	HiddenSize       int      `json:"hidden_size"`       // n_embd
-	HiddenLayers     int      `json:"num_hidden_layers"` // n_layer
-	ContextSize      int      `json:"max_position_embeddings"`
-	IntermediateSize int      `json:"intermediate_size"`
-	AttentionHeads   int      `json:"num_attention_heads"` // n_head
-	KeyValHeads      int      `json:"num_key_value_heads"`
-	NormEPS          float64  `json:"rms_norm_eps"`
-	BoSTokenID       int      `json:"bos_token_id"`
-	EoSTokenID       int      `json:"eos_token_id"`
-	HeadDimension    int      `json:"head_dim"`
-	PaddingTokenID   int      `json:"pad_token_id"`
+	Architectures     []string `json:"architectures"`
+	VocabSize         int      `json:"vocab_size"`
+	HiddenSize        int      `json:"hidden_size"`       // n_embd
+	HiddenLayers      int      `json:"num_hidden_layers"` // n_layer
+	ContextSize       int      `json:"max_position_embeddings"`
+	IntermediateSize  int      `json:"intermediate_size"`
+	AttentionHeads    int      `json:"num_attention_heads"` // n_head
+	KeyValHeads       int      `json:"num_key_value_heads"`
+	NormEPS           float64  `json:"rms_norm_eps"`
+	BoSTokenID        int      `json:"bos_token_id"`
+	EoSTokenID        int      `json:"eos_token_id"`
+	HeadDimension     int      `json:"head_dim"`
+	PaddingTokenID    int      `json:"pad_token_id"`
+	RopeFrequencyBase float64  `json:"rope_theta"`
+
+	Experts     int `json:"num_local_experts"`
+	ExpertsUsed int `json:"num_experts_per_tok"`
 
 	ByteOrder
 }
@@ -43,7 +48,7 @@ type ByteOrder interface {
 type ModelArch interface {
 	GetTensors() error
 	LoadVocab() error
-	WriteGGUF() (string, error)
+	WriteGGUF(io.WriteSeeker) error
 }
 
 type ModelFormat interface {

+ 2 - 13
convert/gemma.go

@@ -94,7 +94,7 @@ func (m *GemmaModel) LoadVocab() error {
 	return nil
 }
 
-func (m *GemmaModel) WriteGGUF() (string, error) {
+func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error {
 	kv := llm.KV{
 		"general.architecture":                   "gemma",
 		"general.name":                           m.Name,
@@ -122,16 +122,5 @@ func (m *GemmaModel) WriteGGUF() (string, error) {
 		"tokenizer.ggml.add_eos_token":    false,
 	}
 
-	f, err := os.CreateTemp("", "ollama-gguf")
-	if err != nil {
-		return "", err
-	}
-	defer f.Close()
-
-	mod := llm.NewGGUFV3(m.Params.ByteOrder)
-	if err := mod.Encode(f, kv, m.Tensors); err != nil {
-		return "", err
-	}
-
-	return f.Name(), nil
+	return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
 }

+ 2 - 16
convert/llama.go

@@ -5,7 +5,6 @@ import (
 	"fmt"
 	"io"
 	"log/slog"
-	"os"
 	"regexp"
 	"strings"
 
@@ -132,7 +131,7 @@ func (m *LlamaModel) LoadVocab() error {
 	return nil
 }
 
-func (m *LlamaModel) WriteGGUF() (string, error) {
+func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
 	kv := llm.KV{
 		"general.architecture":                   "llama",
 		"general.name":                           m.Name,
@@ -159,18 +158,5 @@ func (m *LlamaModel) WriteGGUF() (string, error) {
 		"tokenizer.ggml.add_eos_token":    false,
 	}
 
-	f, err := os.CreateTemp("", "ollama-gguf")
-	if err != nil {
-		return "", err
-	}
-	defer f.Close()
-
-	mod := llm.NewGGUFV3(m.Params.ByteOrder)
-	if err := mod.Encode(f, kv, m.Tensors); err != nil {
-		return "", err
-	}
-
-	slog.Debug(fmt.Sprintf("gguf file = %s", f.Name()))
-
-	return f.Name(), nil
+	return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
 }

+ 2 - 13
convert/mistral.go

@@ -132,7 +132,7 @@ func (m *MistralModel) LoadVocab() error {
 	return nil
 }
 
-func (m *MistralModel) WriteGGUF() (string, error) {
+func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error {
 	kv := llm.KV{
 		"general.architecture":                   "llama",
 		"general.name":                           m.Name,
@@ -158,16 +158,5 @@ func (m *MistralModel) WriteGGUF() (string, error) {
 		"tokenizer.ggml.unknown_token_id": uint32(0),
 	}
 
-	f, err := os.CreateTemp("", "ollama-gguf")
-	if err != nil {
-		return "", err
-	}
-	defer f.Close()
-
-	mod := llm.NewGGUFV3(m.Params.ByteOrder)
-	if err := mod.Encode(f, kv, m.Tensors); err != nil {
-		return "", err
-	}
-
-	return f.Name(), nil
+	return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
 }

+ 85 - 0
convert/mixtral.go

@@ -0,0 +1,85 @@
+package convert
+
+import (
+	"io"
+	"regexp"
+
+	"github.com/ollama/ollama/llm"
+)
+
+type MixtralModel struct {
+	ModelData
+}
+
+func (m *MixtralModel) GetTensors() error {
+	t, err := m.Format.GetTensors(m.Path, m.Params)
+	if err != nil {
+		return err
+	}
+
+	m.Tensors = []llm.Tensor{}
+
+	pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
+	re, err := regexp.Compile(pattern)
+	if err != nil {
+		return err
+	}
+
+	for _, l := range t {
+		matches := re.FindAllStringSubmatch(l.Name, -1)
+		if len(matches) > 0 {
+			wt := l.WriterTo.(safetensorWriterTo)
+			wt.handler = mistralLayerHandler
+			l.WriterTo = wt
+		}
+		m.Tensors = append(m.Tensors, l)
+	}
+
+	return nil
+}
+
+func (m *MixtralModel) LoadVocab() error {
+	v, err := LoadSentencePieceTokens(m.Path, m.Params)
+	if err != nil {
+		return err
+	}
+	m.Vocab = v
+	return nil
+}
+
+func (m *MixtralModel) WriteGGUF(ws io.WriteSeeker) error {
+	kv := llm.KV{
+		"general.architecture":          "llama",
+		"general.name":                  m.Name,
+		"llama.block_count":             uint32(m.Params.HiddenLayers),
+		"llama.context_length":          uint32(m.Params.ContextSize),
+		"llama.embedding_length":        uint32(m.Params.HiddenSize),
+		"llama.feed_forward_length":     uint32(m.Params.IntermediateSize),
+		"llama.attention.head_count":    uint32(m.Params.AttentionHeads),
+		"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
+
+		"llama.rope.freq_base":                   float32(m.Params.RopeFrequencyBase),
+		"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
+
+		"llama.expert_count":      uint32(m.Params.Experts),
+		"llama.expert_used_count": uint32(m.Params.ExpertsUsed),
+
+		"llama.vocab_size":           uint32(len(m.Vocab.Tokens)),
+		"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
+
+		"general.file_type":    uint32(1),
+		"tokenizer.ggml.model": "llama",
+
+		"tokenizer.ggml.tokens":     m.Vocab.Tokens,
+		"tokenizer.ggml.scores":     m.Vocab.Scores,
+		"tokenizer.ggml.token_type": m.Vocab.Types,
+
+		"tokenizer.ggml.bos_token_id":     uint32(m.Params.BoSTokenID),
+		"tokenizer.ggml.eos_token_id":     uint32(m.Params.EoSTokenID),
+		"tokenizer.ggml.unknown_token_id": uint32(0),
+		"tokenizer.ggml.add_bos_token":    true,
+		"tokenizer.ggml.add_eos_token":    false,
+	}
+
+	return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
+}

+ 27 - 14
convert/safetensors.go

@@ -53,7 +53,7 @@ func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Ten
 		var err error
 		t, offset, err = m.readTensors(f, offset, params)
 		if err != nil {
-			slog.Error("%v", err)
+			slog.Error(err.Error())
 			return nil, err
 		}
 		tensors = append(tensors, t...)
@@ -93,7 +93,6 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
 	}
 
 	slices.Sort(keys)
-
 	slog.Info("converting layers")
 
 	var tensors []llm.Tensor
@@ -105,7 +104,6 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
 			return nil, 0, err
 		}
 
-		slog.Debug(fmt.Sprintf("metadata = %#v", data))
 		var size uint64
 		var kind uint32
 		switch len(data.Shape) {
@@ -124,7 +122,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
 
 		ggufName, err := m.GetLayerName(k)
 		if err != nil {
-			slog.Error("%v", err)
+			slog.Error(err.Error())
 			return nil, 0, err
 		}
 
@@ -150,11 +148,13 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
 			padding:  8 + jsonSize,
 		}
 
-		tensors = append(tensors, t)
 		offset += size
+		tensors = append(tensors, t)
 	}
+
 	slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors)))
 	slog.Debug(fmt.Sprintf("offset = %d", offset))
+
 	return tensors, offset, nil
 }
 
@@ -185,15 +185,19 @@ func (m *SafetensorFormat) GetLayerName(n string) (string, error) {
 	}
 
 	tMap := map[string]string{
-		"model.layers.(\\d+).input_layernorm.weight":          "blk.$1.attn_norm.weight",
-		"model.layers.(\\d+).mlp.down_proj.weight":            "blk.$1.ffn_down.weight",
-		"model.layers.(\\d+).mlp.gate_proj.weight":            "blk.$1.ffn_gate.weight",
-		"model.layers.(\\d+).mlp.up_proj.weight":              "blk.$1.ffn_up.weight",
-		"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
-		"model.layers.(\\d+).self_attn.k_proj.weight":         "blk.$1.attn_k.weight",
-		"model.layers.(\\d+).self_attn.o_proj.weight":         "blk.$1.attn_output.weight",
-		"model.layers.(\\d+).self_attn.q_proj.weight":         "blk.$1.attn_q.weight",
-		"model.layers.(\\d+).self_attn.v_proj.weight":         "blk.$1.attn_v.weight",
+		"model.layers.(\\d+).input_layernorm.weight":                    "blk.$1.attn_norm.weight",
+		"model.layers.(\\d+).mlp.down_proj.weight":                      "blk.$1.ffn_down.weight",
+		"model.layers.(\\d+).mlp.gate_proj.weight":                      "blk.$1.ffn_gate.weight",
+		"model.layers.(\\d+).mlp.up_proj.weight":                        "blk.$1.ffn_up.weight",
+		"model.layers.(\\d+).post_attention_layernorm.weight":           "blk.$1.ffn_norm.weight",
+		"model.layers.(\\d+).self_attn.k_proj.weight":                   "blk.$1.attn_k.weight",
+		"model.layers.(\\d+).self_attn.o_proj.weight":                   "blk.$1.attn_output.weight",
+		"model.layers.(\\d+).self_attn.q_proj.weight":                   "blk.$1.attn_q.weight",
+		"model.layers.(\\d+).self_attn.v_proj.weight":                   "blk.$1.attn_v.weight",
+		"model.layers.(\\d+).block_sparse_moe.gate.weight":              "blk.$1.ffn_gate_inp.weight",
+		"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w1.weight": "blk.$1.ffn_gate.$2.weight",
+		"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w2.weight": "blk.$1.ffn_down.$2.weight",
+		"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w3.weight": "blk.$1.ffn_up.$2.weight",
 	}
 
 	v, ok := directMap[n]
@@ -286,6 +290,15 @@ func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (M
 					Format: m,
 				},
 			}, nil
+		case "MixtralForCausalLM":
+			return &MixtralModel{
+				ModelData{
+					Name:   name,
+					Path:   dirPath,
+					Params: params,
+					Format: m,
+				},
+			}, nil
 		case "GemmaForCausalLM":
 			return &GemmaModel{
 				ModelData{

+ 1 - 1
convert/torch.go

@@ -74,7 +74,7 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
 
 			ggufName, err := tf.GetLayerName(k.(string))
 			if err != nil {
-				slog.Error("%v", err)
+				slog.Error(err.Error())
 				return nil, err
 			}
 			slog.Debug(fmt.Sprintf("finding name for '%s' -> '%s'", k.(string), ggufName))

+ 31 - 31
docs/api.md

@@ -17,7 +17,7 @@
 
 ### Model names
 
-Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama2:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
+Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
 
 ### Durations
 
@@ -66,7 +66,7 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur
 
 ```shell
 curl http://localhost:11434/api/generate -d '{
-  "model": "llama2",
+  "model": "llama3",
   "prompt": "Why is the sky blue?"
 }'
 ```
@@ -77,7 +77,7 @@ A stream of JSON objects is returned:
 
 ```json
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-08-04T08:52:19.385406455-07:00",
   "response": "The",
   "done": false
@@ -90,16 +90,16 @@ The final response in the stream also includes additional data about the generat
 - `load_duration`: time spent in nanoseconds loading the model
 - `prompt_eval_count`: number of tokens in the prompt
 - `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
-- `eval_count`: number of tokens the response
+- `eval_count`: number of tokens in the response
 - `eval_duration`: time in nanoseconds spent generating the response
 - `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
 - `response`: empty if the response was streamed, if not streamed, this will contain the full response
 
-To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`.
+To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration` * `10^9`.
 
 ```json
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "response": "",
   "done": true,
@@ -121,7 +121,7 @@ A response can be received in one reply when streaming is off.
 
 ```shell
 curl http://localhost:11434/api/generate -d '{
-  "model": "llama2",
+  "model": "llama3",
   "prompt": "Why is the sky blue?",
   "stream": false
 }'
@@ -133,7 +133,7 @@ If `stream` is set to `false`, the response will be a single JSON object:
 
 ```json
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "response": "The sky is blue because it is the color of the sky.",
   "done": true,
@@ -155,7 +155,7 @@ If `stream` is set to `false`, the response will be a single JSON object:
 
 ```shell
 curl http://localhost:11434/api/generate -d '{
-  "model": "llama2",
+  "model": "llama3",
   "prompt": "What color is the sky at different times of the day? Respond using JSON",
   "format": "json",
   "stream": false
@@ -166,7 +166,7 @@ curl http://localhost:11434/api/generate -d '{
 
 ```json
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-11-09T21:07:55.186497Z",
   "response": "{\n\"morning\": {\n\"color\": \"blue\"\n},\n\"noon\": {\n\"color\": \"blue-gray\"\n},\n\"afternoon\": {\n\"color\": \"warm gray\"\n},\n\"evening\": {\n\"color\": \"orange\"\n}\n}\n",
   "done": true,
@@ -289,7 +289,7 @@ If you want to set custom options for the model at runtime rather than in the Mo
 
 ```shell
 curl http://localhost:11434/api/generate -d '{
-  "model": "llama2",
+  "model": "llama3",
   "prompt": "Why is the sky blue?",
   "stream": false,
   "options": {
@@ -332,7 +332,7 @@ curl http://localhost:11434/api/generate -d '{
 
 ```json
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "response": "The sky is blue because it is the color of the sky.",
   "done": true,
@@ -354,7 +354,7 @@ If an empty prompt is provided, the model will be loaded into memory.
 
 ```shell
 curl http://localhost:11434/api/generate -d '{
-  "model": "llama2"
+  "model": "llama3"
 }'
 ```
 
@@ -364,7 +364,7 @@ A single JSON object is returned:
 
 ```json
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-12-18T19:52:07.071755Z",
   "response": "",
   "done": true
@@ -407,7 +407,7 @@ Send a chat message with a streaming response.
 
 ```shell
 curl http://localhost:11434/api/chat -d '{
-  "model": "llama2",
+  "model": "llama3",
   "messages": [
     {
       "role": "user",
@@ -423,7 +423,7 @@ A stream of JSON objects is returned:
 
 ```json
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-08-04T08:52:19.385406455-07:00",
   "message": {
     "role": "assistant",
@@ -438,7 +438,7 @@ Final response:
 
 ```json
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "done": true,
   "total_duration": 4883583458,
@@ -456,7 +456,7 @@ Final response:
 
 ```shell
 curl http://localhost:11434/api/chat -d '{
-  "model": "llama2",
+  "model": "llama3",
   "messages": [
     {
       "role": "user",
@@ -471,7 +471,7 @@ curl http://localhost:11434/api/chat -d '{
 
 ```json
 {
-  "model": "registry.ollama.ai/library/llama2:latest",
+  "model": "registry.ollama.ai/library/llama3:latest",
   "created_at": "2023-12-12T14:13:43.416799Z",
   "message": {
     "role": "assistant",
@@ -495,7 +495,7 @@ Send a chat message with a conversation history. You can use this same approach
 
 ```shell
 curl http://localhost:11434/api/chat -d '{
-  "model": "llama2",
+  "model": "llama3",
   "messages": [
     {
       "role": "user",
@@ -519,7 +519,7 @@ A stream of JSON objects is returned:
 
 ```json
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-08-04T08:52:19.385406455-07:00",
   "message": {
     "role": "assistant",
@@ -533,7 +533,7 @@ Final response:
 
 ```json
 {
-  "model": "llama2",
+  "model": "llama3",
   "created_at": "2023-08-04T19:22:45.499127Z",
   "done": true,
   "total_duration": 8113331500,
@@ -591,7 +591,7 @@ curl http://localhost:11434/api/chat -d '{
 
 ```shell
 curl http://localhost:11434/api/chat -d '{
-  "model": "llama2",
+  "model": "llama3",
   "messages": [
     {
       "role": "user",
@@ -609,7 +609,7 @@ curl http://localhost:11434/api/chat -d '{
 
 ```json
 {
-  "model": "registry.ollama.ai/library/llama2:latest",
+  "model": "registry.ollama.ai/library/llama3:latest",
   "created_at": "2023-12-12T14:13:43.416799Z",
   "message": {
     "role": "assistant",
@@ -651,7 +651,7 @@ Create a new model from a `Modelfile`.
 ```shell
 curl http://localhost:11434/api/create -d '{
   "name": "mario",
-  "modelfile": "FROM llama2\nSYSTEM You are mario from Super Mario Bros."
+  "modelfile": "FROM llama3\nSYSTEM You are mario from Super Mario Bros."
 }'
 ```
 
@@ -758,7 +758,7 @@ A single JSON object will be returned.
       }
     },
     {
-      "name": "llama2:latest",
+      "name": "llama3:latest",
       "modified_at": "2023-12-07T09:32:18.757212583-08:00",
       "size": 3825819519,
       "digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
@@ -792,7 +792,7 @@ Show information about a model including details, modelfile, template, parameter
 
 ```shell
 curl http://localhost:11434/api/show -d '{
-  "name": "llama2"
+  "name": "llama3"
 }'
 ```
 
@@ -827,8 +827,8 @@ Copy a model. Creates a model with another name from an existing model.
 
 ```shell
 curl http://localhost:11434/api/copy -d '{
-  "source": "llama2",
-  "destination": "llama2-backup"
+  "source": "llama3",
+  "destination": "llama3-backup"
 }'
 ```
 
@@ -854,7 +854,7 @@ Delete a model and its data.
 
 ```shell
 curl -X DELETE http://localhost:11434/api/delete -d '{
-  "name": "llama2:13b"
+  "name": "llama3:13b"
 }'
 ```
 
@@ -882,7 +882,7 @@ Download a model from the ollama library. Cancelled pulls are resumed from where
 
 ```shell
 curl http://localhost:11434/api/pull -d '{
-  "name": "llama2"
+  "name": "llama3"
 }'
 ```
 

+ 2 - 2
docs/development.md

@@ -51,7 +51,7 @@ Typically the build scripts will auto-detect CUDA, however, if your Linux distro
 or installation approach uses unusual paths, you can specify the location by
 specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
 libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
-set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
+a set of target CUDA architectures by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
 
 Then generate dependencies:
 
@@ -142,4 +142,4 @@ In addition to the common Windows development tools described above, install AMD
 - [AMD HIP](https://www.amd.com/en/developer/resources/rocm-hub/hip-sdk.html)
 - [Strawberry Perl](https://strawberryperl.com/)
 
-Lastly, add `ninja.exe` included with MSVC to the system path (e.g. `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja`).
+Lastly, add `ninja.exe` included with MSVC to the system path (e.g. `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja`).

+ 16 - 6
docs/faq.md

@@ -32,7 +32,7 @@ When using the API, specify the `num_ctx` parameter:
 
 ```
 curl http://localhost:11434/api/generate -d '{
-  "model": "llama2",
+  "model": "llama3",
   "prompt": "Why is the sky blue?",
   "options": {
     "num_ctx": 4096
@@ -88,9 +88,9 @@ On windows, Ollama inherits your user and system environment variables.
 
 3. Edit or create New variable(s) for your user account for `OLLAMA_HOST`, `OLLAMA_MODELS`, etc.
 
-4. Click OK/Apply to save 
+4. Click OK/Apply to save
 
-5. Run `ollama` from a new terminal window 
+5. Run `ollama` from a new terminal window
 
 
 ## How can I expose Ollama on my network?
@@ -140,7 +140,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e
 
 - macOS: `~/.ollama/models`
 - Linux: `/usr/share/ollama/.ollama/models`
-- Windows: `C:\Users\<username>\.ollama\models`
+- Windows: `C:\Users\%username%\.ollama\models`
 
 ### How do I set them to a different location?
 
@@ -221,10 +221,20 @@ The `keep_alive` parameter can be set to:
 
 For example, to preload a model and leave it in memory use:
 ```shell
-curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": -1}'
+curl http://localhost:11434/api/generate -d '{"model": "llama3", "keep_alive": -1}'
 ```
 
 To unload the model and free up memory use:
 ```shell
-curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": 0}'
+curl http://localhost:11434/api/generate -d '{"model": "llama3", "keep_alive": 0}'
 ```
+
+Alternatively, you can change the amount of time all models are loaded into memory by setting the `OLLAMA_KEEP_ALIVE` environment variable when starting the Ollama server. The `OLLAMA_KEEP_ALIVE` variable uses the same parameter types as the `keep_alive` parameter types mentioned above. Refer to section explaining [how to configure the Ollama server](#how-do-i-configure-ollama-server) to correctly set the environment variable.
+
+If you wish to override the `OLLAMA_KEEP_ALIVE` setting, use the `keep_alive` API parameter with the `/api/generate` or `/api/chat` API endpoints.
+
+## How do I manage the maximum number of requests the server can queue
+
+If too many requests are sent to the server, it will respond with a 503 error
+indicating the server is overloaded.  You can adjust how many requests may be
+queue by setting `OLLAMA_MAX_QUEUE`

+ 3 - 1
docs/import.md

@@ -125,7 +125,7 @@ Publishing models is in early alpha. If you'd like to publish your model to shar
 
 1. Create [an account](https://ollama.com/signup)
 2. Copy your Ollama public key:
-  - macOS: `cat ~/.ollama/id_ed25519.pub`
+  - macOS: `cat ~/.ollama/id_ed25519.pub | pbcopy`
   - Windows: `type %USERPROFILE%\.ollama\id_ed25519.pub`
   - Linux: `cat /usr/share/ollama/.ollama/id_ed25519.pub`
 3. Add your public key to your [Ollama account](https://ollama.com/settings/keys)
@@ -136,6 +136,8 @@ Next, copy your model to your username's namespace:
 ollama cp example <your username>/example
 ```
 
+> Note: model names may only contain lowercase letters, digits, and the characters `.`, `-`, and `_`.
+
 Then push the model:
 
 ```

+ 1 - 1
docs/linux.md

@@ -105,7 +105,7 @@ sudo chmod +x /usr/bin/ollama
 To view logs of Ollama running as a startup service, run:
 
 ```bash
-journalctl -u ollama
+journalctl -e -u ollama
 ```
 
 ## Uninstall

+ 17 - 25
docs/modelfile.md

@@ -10,7 +10,7 @@ A model file is the blueprint to create and share models with Ollama.
 - [Examples](#examples)
 - [Instructions](#instructions)
   - [FROM (Required)](#from-required)
-    - [Build from llama2](#build-from-llama2)
+    - [Build from llama3](#build-from-llama3)
     - [Build from a bin file](#build-from-a-bin-file)
   - [PARAMETER](#parameter)
     - [Valid Parameters and Values](#valid-parameters-and-values)
@@ -48,7 +48,7 @@ INSTRUCTION arguments
 An example of a `Modelfile` creating a mario blueprint:
 
 ```modelfile
-FROM llama2
+FROM llama3
 # sets the temperature to 1 [higher is more creative, lower is more coherent]
 PARAMETER temperature 1
 # sets the context window size to 4096, this controls how many tokens the LLM can use as context to generate the next token
@@ -67,33 +67,25 @@ To use this:
 
 More examples are available in the [examples directory](../examples).
 
-### `Modelfile`s in [ollama.com/library][1]
-
-There are two ways to view `Modelfile`s underlying the models in [ollama.com/library][1]:
-
-- Option 1: view a details page from a model's tags page:
-  1.  Go to a particular model's tags (e.g. https://ollama.com/library/llama2/tags)
-  2.  Click on a tag (e.g. https://ollama.com/library/llama2:13b)
-  3.  Scroll down to "Layers"
-      - Note: if the [`FROM` instruction](#from-required) is not present,
-        it means the model was created from a local file
-- Option 2: use `ollama show` to print the `Modelfile` for any local models like so:
+To view the Modelfile of a given model, use the `ollama show --modelfile` command.
 
   ```bash
-  > ollama show --modelfile llama2:13b
+  > ollama show --modelfile llama3
   # Modelfile generated by "ollama show"
   # To build a new Modelfile based on this one, replace the FROM line with:
-  # FROM llama2:13b
+  # FROM llama3:latest
+  FROM /Users/pdevine/.ollama/models/blobs/sha256-00e1317cbf74d901080d7100f57580ba8dd8de57203072dc6f668324ba545f29
+  TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
+
+  {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
 
-  FROM /root/.ollama/models/blobs/sha256:123abc
-  TEMPLATE """[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>>
+  {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
 
-  {{ end }}{{ .Prompt }} [/INST] """
-  SYSTEM """"""
-  PARAMETER stop [INST]
-  PARAMETER stop [/INST]
-  PARAMETER stop <<SYS>>
-  PARAMETER stop <</SYS>>
+  {{ .Response }}<|eot_id|>"""
+  PARAMETER stop "<|start_header_id|>"
+  PARAMETER stop "<|end_header_id|>"
+  PARAMETER stop "<|eot_id|>"
+  PARAMETER stop "<|reserved_special_token"
   ```
 
 ## Instructions
@@ -106,10 +98,10 @@ The `FROM` instruction defines the base model to use when creating a model.
 FROM <model name>:<tag>
 ```
 
-#### Build from llama2
+#### Build from llama3
 
 ```modelfile
-FROM llama2
+FROM llama3
 ```
 
 A list of available base models:

+ 5 - 5
docs/openai.md

@@ -25,7 +25,7 @@ chat_completion = client.chat.completions.create(
             'content': 'Say this is a test',
         }
     ],
-    model='llama2',
+    model='llama3',
 )
 ```
 
@@ -43,7 +43,7 @@ const openai = new OpenAI({
 
 const chatCompletion = await openai.chat.completions.create({
   messages: [{ role: 'user', content: 'Say this is a test' }],
-  model: 'llama2',
+  model: 'llama3',
 })
 ```
 
@@ -53,7 +53,7 @@ const chatCompletion = await openai.chat.completions.create({
 curl http://localhost:11434/v1/chat/completions \
     -H "Content-Type: application/json" \
     -d '{
-        "model": "llama2",
+        "model": "llama3",
         "messages": [
             {
                 "role": "system",
@@ -113,7 +113,7 @@ curl http://localhost:11434/v1/chat/completions \
 Before using a model, pull it locally `ollama pull`:
 
 ```shell
-ollama pull llama2
+ollama pull llama3
 ```
 
 ### Default model names
@@ -121,7 +121,7 @@ ollama pull llama2
 For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
 
 ```
-ollama cp llama2 gpt-3.5-turbo
+ollama cp llama3 gpt-3.5-turbo
 ```
 
 Afterwards, this new model name can be specified the `model` field:

+ 3 - 3
docs/tutorials/langchainjs.md

@@ -15,7 +15,7 @@ import { Ollama } from "langchain/llms/ollama";
 
 const ollama = new Ollama({
   baseUrl: "http://localhost:11434",
-  model: "llama2",
+  model: "llama3",
 });
 
 const answer = await ollama.invoke(`why is the sky blue?`);
@@ -23,10 +23,10 @@ const answer = await ollama.invoke(`why is the sky blue?`);
 console.log(answer);
 ```
 
-That will get us the same thing as if we ran `ollama run llama2 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app.
+That will get us the same thing as if we ran `ollama run llama3 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app.
 
 ```bash
-npm install cheerio 
+npm install cheerio
 ```
 
 ```javascript

+ 7 - 5
docs/tutorials/langchainpy.md

@@ -12,15 +12,17 @@ So let's figure out how we can use **LangChain** with Ollama to ask our question
 
 Let's start by asking a simple question that we can get an answer to from the **Llama2** model using **Ollama**. First, we need to install the **LangChain** package:
 
-`pip install langchain`
+`pip install langchain_community`
 
 Then we can create a model and ask the question:
 
 ```python
-from langchain.llms import Ollama
-ollama = Ollama(base_url='http://localhost:11434',
-model="llama2")
-print(ollama("why is the sky blue"))
+from langchain_community.llms import Ollama
+ollama = Ollama(
+    base_url='http://localhost:11434',
+    model="llama3"
+)
+print(ollama.invoke("why is the sky blue"))
 ```
 
 Notice that we are defining the model and the base URL for Ollama.

+ 6 - 29
docs/tutorials/nvidia-jetson.md

@@ -1,38 +1,15 @@
 # Running Ollama on NVIDIA Jetson Devices
 
-With some minor configuration, Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/). The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack).
+Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/) and should run out of the box with the standard installation instructions. 
 
-NVIDIA Jetson devices are Linux-based embedded AI computers that are purpose-built for AI applications.
-
-Jetsons have an integrated GPU that is wired directly to the memory controller of the machine. For this reason, the `nvidia-smi` command is unrecognized, and Ollama proceeds to operate in "CPU only"
-mode. This can be verified by using a monitoring tool like jtop.
-
-In order to address this, we simply pass the path to the Jetson's pre-installed CUDA libraries into `ollama serve` (while in a tmux session). We then hardcode the num_gpu parameters into a cloned
-version of our target model.
-
-Prerequisites:
-
-- curl
-- tmux
-
-Here are the steps:
+The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack), but should also work on JetPack 6.0.
 
 - Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.com/install.sh | sh`
-- Stop the Ollama service: `sudo systemctl stop ollama`
-- Start Ollama serve in a tmux session called ollama_jetson and reference the CUDA libraries path: `tmux has-session -t ollama_jetson 2>/dev/null || tmux new-session -d -s ollama_jetson 
-'LD_LIBRARY_PATH=/usr/local/cuda/lib64 ollama serve'`
 - Pull the model you want to use (e.g. mistral): `ollama pull mistral`
-- Create a new Modelfile specifically for enabling GPU support on the Jetson: `touch ModelfileMistralJetson`
-- In the ModelfileMistralJetson file, specify the FROM model and the num_gpu PARAMETER as shown below:
-
-```
-FROM mistral
-PARAMETER num_gpu 999
-```
+- Start an interactive session: `ollama run mistral`
 
-- Create a new model from your Modelfile: `ollama create mistral-jetson -f ./ModelfileMistralJetson`
-- Run the new model: `ollama run mistral-jetson`
+And that's it!
 
-If you run a monitoring tool like jtop you should now see that Ollama is using the Jetson's integrated GPU.
+# Running Ollama in Docker
 
-And that's it!
+When running GPU accelerated applications in Docker, it is highly recommended to use [dusty-nv jetson-containers repo](https://github.com/dusty-nv/jetson-containers).

+ 61 - 47
docs/windows.md

@@ -1,47 +1,61 @@
-# Ollama Windows Preview
-
-Welcome to the Ollama Windows preview.
-
-No more WSL required!
-
-Ollama now runs as a native Windows application, including NVIDIA and AMD Radeon GPU support.
-After installing Ollama Windows Preview, Ollama will run in the background and
-the `ollama` command line is available in `cmd`, `powershell` or your favorite
-terminal application. As usual the Ollama [api](./api.md) will be served on
-`http://localhost:11434`.
-
-As this is a preview release, you should expect a few bugs here and there.  If
-you run into a problem you can reach out on
-[Discord](https://discord.gg/ollama), or file an 
-[issue](https://github.com/ollama/ollama/issues).
-Logs will often be helpful in dianosing the problem (see
-[Troubleshooting](#troubleshooting) below)
-
-## System Requirements
-
-* Windows 10 or newer, Home or Pro
-* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card
-* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card
-
-## API Access
-
-Here's a quick example showing API access from `powershell`
-```powershell
-(Invoke-WebRequest -method POST -Body '{"model":"llama2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
-```
-
-## Troubleshooting
-
-While we're in preview, `OLLAMA_DEBUG` is always enabled, which adds
-a "view logs" menu item to the app, and increses logging for the GUI app and
-server.
-
-Ollama on Windows stores files in a few different locations.  You can view them in
-the explorer window by hitting `<cmd>+R` and type in:
-- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
-    - *app.log* contains logs from the GUI application
-    - *server.log* contains the server logs
-    - *upgrade.log* contains log output for upgrades
-- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
-- `explorer %HOMEPATH%\.ollama` contains models and configuration
-- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
+# Ollama Windows Preview
+
+Welcome to the Ollama Windows preview.
+
+No more WSL required!
+
+Ollama now runs as a native Windows application, including NVIDIA and AMD Radeon GPU support.
+After installing Ollama Windows Preview, Ollama will run in the background and
+the `ollama` command line is available in `cmd`, `powershell` or your favorite
+terminal application. As usual the Ollama [api](./api.md) will be served on
+`http://localhost:11434`.
+
+As this is a preview release, you should expect a few bugs here and there.  If
+you run into a problem you can reach out on
+[Discord](https://discord.gg/ollama), or file an
+[issue](https://github.com/ollama/ollama/issues).
+Logs will often be helpful in diagnosing the problem (see
+[Troubleshooting](#troubleshooting) below)
+
+## System Requirements
+
+* Windows 10 or newer, Home or Pro
+* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card
+* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card
+
+## API Access
+
+Here's a quick example showing API access from `powershell`
+```powershell
+(Invoke-WebRequest -method POST -Body '{"model":"llama3", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
+```
+
+## Troubleshooting
+
+While we're in preview, `OLLAMA_DEBUG` is always enabled, which adds
+a "view logs" menu item to the app, and increses logging for the GUI app and
+server.
+
+Ollama on Windows stores files in a few different locations.  You can view them in
+the explorer window by hitting `<cmd>+R` and type in:
+- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
+    - *app.log* contains logs from the GUI application
+    - *server.log* contains the server logs
+    - *upgrade.log* contains log output for upgrades
+- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
+- `explorer %HOMEPATH%\.ollama` contains models and configuration
+- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
+
+
+## Standalone CLI
+
+The easiest way to install Ollama on Windows is to use the `OllamaSetup.exe`
+installer. It installs in your account without requiring Administrator rights.
+We update Ollama regularly to support the latest models, and this installer will
+help you keep up to date.
+
+If you'd like to install or integrate Ollama as a service, a standalone
+`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
+and GPU library dependencies for Nvidia and AMD. This allows for embedding
+Ollama in existing applications, or running it as a system service via `ollama
+serve` with tools such as [NSSM](https://nssm.cc/).

+ 1 - 1
examples/bash-comparemodels/README.md

@@ -2,7 +2,7 @@
 
 When calling `ollama`, you can pass it a file to run all the prompts in the file, one after the other:
 
-`ollama run llama2 < sourcequestions.txt`
+`ollama run llama3 < sourcequestions.txt`
 
 This concept is used in the following example.
 

+ 1 - 0
examples/flyio/.gitignore

@@ -0,0 +1 @@
+fly.toml

+ 67 - 0
examples/flyio/README.md

@@ -0,0 +1,67 @@
+# Deploy Ollama to Fly.io
+
+> Note: this example exposes a public endpoint and does not configure authentication. Use with care.
+
+## Prerequisites
+
+- Ollama: https://ollama.com/download
+- Fly.io account. Sign up for a free account: https://fly.io/app/sign-up
+
+## Steps
+
+1. Login to Fly.io
+
+    ```bash
+    fly auth login
+    ```
+
+1. Create a new Fly app
+
+    ```bash
+    fly launch --name <name> --image ollama/ollama --internal-port 11434 --vm-size shared-cpu-8x --now
+    ```
+
+1. Pull and run `orca-mini:3b`
+
+    ```bash
+    OLLAMA_HOST=https://<name>.fly.dev ollama run orca-mini:3b
+    ```
+
+`shared-cpu-8x` is a free-tier eligible machine type. For better performance, switch to a `performance` or `dedicated` machine type or attach a GPU for hardware acceleration (see below).
+
+## (Optional) Persistent Volume
+
+By default Fly Machines use ephemeral storage which is problematic if you want to use the same model across restarts without pulling it again. Create and attach a persistent volume to store the downloaded models:
+
+1. Create the Fly Volume
+
+    ```bash
+    fly volume create ollama
+    ```
+
+1. Update `fly.toml` and add `[mounts]`
+
+    ```toml
+    [mounts]
+      source = "ollama"
+      destination = "/mnt/ollama/models"
+    ```
+
+1. Update `fly.toml` and add `[env]`
+
+    ```toml
+    [env]
+      OLLAMA_MODELS = "/mnt/ollama/models"
+    ```
+
+1. Deploy your app
+
+    ```bash
+    fly deploy
+    ```
+
+## (Optional) Hardware Acceleration
+
+Fly.io GPU is currently in waitlist. Sign up for the waitlist: https://fly.io/gpu
+
+Once you've been accepted, create the app with the additional flags `--vm-gpu-kind a100-pcie-40gb` or `--vm-gpu-kind a100-pcie-80gb`.

+ 1 - 1
examples/go-chat/main.go

@@ -35,7 +35,7 @@ func main() {
 
 	ctx := context.Background()
 	req := &api.ChatRequest{
-		Model:    "llama2",
+		Model:    "llama3",
 		Messages: messages,
 	}
 

+ 1 - 1
examples/go-http-generate/main.go

@@ -19,7 +19,7 @@ func main() {
 	}
 
 	defer resp.Body.Close()
-	
+
 	responseData, err := io.ReadAll(resp.Body)
 	if err != nil {
 		log.Fatal(err)

+ 14 - 12
examples/kubernetes/README.md

@@ -7,12 +7,24 @@
 
 ## Steps
 
-1. Create the Ollama namespace, daemon set, and service
+1. Create the Ollama namespace, deployment, and service
 
    ```bash
    kubectl apply -f cpu.yaml
    ```
 
+## (Optional) Hardware Acceleration
+
+Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin) which is deployed in Kubernetes in form of daemonset. Follow the link for more details.
+
+Once configured, create a GPU enabled Ollama deployment.
+
+```bash
+kubectl apply -f gpu.yaml
+```
+
+## Test
+
 1. Port forward the Ollama service to connect and use it locally
 
    ```bash
@@ -23,14 +35,4 @@
 
    ```bash
    ollama run orca-mini:3b
-   ```
-
-## (Optional) Hardware Acceleration
-
-Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin). Follow the link for more details.
-
-Once configured, create a GPU enabled Ollama deployment.
-
-```bash
-kubectl apply -f gpu.yaml
-```
+   ```

+ 5 - 5
examples/langchain-python-rag-document/main.py

@@ -40,9 +40,9 @@ while True:
         continue
 
     # Prompt
-    template = """Use the following pieces of context to answer the question at the end. 
-    If you don't know the answer, just say that you don't know, don't try to make up an answer. 
-    Use three sentences maximum and keep the answer as concise as possible. 
+    template = """Use the following pieces of context to answer the question at the end.
+    If you don't know the answer, just say that you don't know, don't try to make up an answer.
+    Use three sentences maximum and keep the answer as concise as possible.
     {context}
     Question: {question}
     Helpful Answer:"""
@@ -51,11 +51,11 @@ while True:
         template=template,
     )
 
-    llm = Ollama(model="llama2:13b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
+    llm = Ollama(model="llama3:8b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
     qa_chain = RetrievalQA.from_chain_type(
         llm,
         retriever=vectorstore.as_retriever(),
         chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
     )
 
-    result = qa_chain({"query": query})
+    result = qa_chain({"query": query})

+ 4 - 4
examples/langchain-python-rag-websummary/main.py

@@ -1,12 +1,12 @@
-from langchain.llms import Ollama
-from langchain.document_loaders import WebBaseLoader
+from langchain_community.llms import Ollama
+from langchain_community.document_loaders import WebBaseLoader
 from langchain.chains.summarize import load_summarize_chain
 
 loader = WebBaseLoader("https://ollama.com/blog/run-llama2-uncensored-locally")
 docs = loader.load()
 
-llm = Ollama(model="llama2")
+llm = Ollama(model="llama3")
 chain = load_summarize_chain(llm, chain_type="stuff")
 
-result = chain.run(docs)
+result = chain.invoke(docs) 
 print(result)

+ 2 - 3
examples/langchain-python-simple/README.md

@@ -4,10 +4,10 @@ This example is a basic "hello world" of using LangChain with Ollama.
 
 ## Running the Example
 
-1. Ensure you have the `llama2` model installed:
+1. Ensure you have the `llama3` model installed:
 
    ```bash
-   ollama pull llama2
+   ollama pull llama3
    ```
 
 2. Install the Python Requirements.
@@ -21,4 +21,3 @@ This example is a basic "hello world" of using LangChain with Ollama.
    ```bash
    python main.py
    ```
-  

+ 1 - 1
examples/langchain-python-simple/main.py

@@ -1,6 +1,6 @@
 from langchain.llms import Ollama
 
 input = input("What is your question?")
-llm = Ollama(model="llama2")
+llm = Ollama(model="llama3")
 res = llm.predict(input)
 print (res)

+ 1 - 1
examples/modelfile-mario/Modelfile

@@ -1,4 +1,4 @@
-FROM llama2
+FROM llama3
 PARAMETER temperature 1
 SYSTEM """
 You are Mario from super mario bros, acting as an assistant.

+ 3 - 3
examples/modelfile-mario/readme.md

@@ -2,12 +2,12 @@
 
 # Example character: Mario
 
-This example shows how to create a basic character using Llama2 as the base model.
+This example shows how to create a basic character using Llama3 as the base model.
 
 To run this example:
 
 1. Download the Modelfile
-2. `ollama pull llama2` to get the base model used in the model file.
+2. `ollama pull llama3` to get the base model used in the model file.
 3. `ollama create NAME -f ./Modelfile`
 4. `ollama run NAME`
 
@@ -18,7 +18,7 @@ Ask it some questions like "Who are you?" or "Is Peach in trouble again?"
 What the model file looks like:
 
 ```
-FROM llama2
+FROM llama3
 PARAMETER temperature 1
 SYSTEM """
 You are Mario from Super Mario Bros, acting as an assistant.

+ 7 - 7
examples/python-json-datagenerator/predefinedschema.py

@@ -2,16 +2,16 @@ import requests
 import json
 import random
 
-model = "llama2"
+model = "llama3"
 template = {
-  "firstName": "", 
-  "lastName": "", 
+  "firstName": "",
+  "lastName": "",
   "address": {
-    "street": "", 
-    "city": "", 
-    "state": "", 
+    "street": "",
+    "city": "",
+    "state": "",
     "zipCode": ""
-  }, 
+  },
   "phoneNumber": ""
 }
 

+ 1 - 1
examples/python-json-datagenerator/randomaddresses.py

@@ -12,7 +12,7 @@ countries = [
     "France",
 ]
 country = random.choice(countries)
-model = "llama2"
+model = "llama3"
 
 prompt = f"generate one realistically believable sample data set of a persons first name, last name, address in {country}, and phone number. Do not use common names. Respond using JSON. Key names should have no backslashes, values should use plain ascii with no special characters."
 

+ 2 - 2
examples/python-json-datagenerator/readme.md

@@ -6,10 +6,10 @@ There are two python scripts in this example. `randomaddresses.py` generates ran
 
 ## Running the Example
 
-1. Ensure you have the `llama2` model installed:
+1. Ensure you have the `llama3` model installed:
 
    ```bash
-   ollama pull llama2
+   ollama pull llama3
    ```
 
 2. Install the Python Requirements.

+ 1 - 1
examples/python-simplechat/client.py

@@ -2,7 +2,7 @@ import json
 import requests
 
 # NOTE: ollama must be running for this to work, start the ollama app or run `ollama serve`
-model = "llama2"  # TODO: update this for whatever model you wish to use
+model = "llama3"  # TODO: update this for whatever model you wish to use
 
 
 def chat(messages):

+ 2 - 2
examples/python-simplechat/readme.md

@@ -4,10 +4,10 @@ The **chat** endpoint is one of two ways to generate text from an LLM with Ollam
 
 ## Running the Example
 
-1. Ensure you have the `llama2` model installed:
+1. Ensure you have the `llama3` model installed:
 
    ```bash
-   ollama pull llama2
+   ollama pull llama3
    ```
 
 2. Install the Python Requirements.

+ 2 - 2
examples/typescript-mentors/README.md

@@ -4,10 +4,10 @@ This example demonstrates how one would create a set of 'mentors' you can have a
 
 ## Usage
 
-1. Add llama2 to have the mentors ask your questions:
+1. Add llama3 to have the mentors ask your questions:
 
    ```bash
-   ollama pull llama2
+   ollama pull llama3
    ```
 
 2. Install prerequisites:

+ 2 - 2
examples/typescript-mentors/character-generator.ts

@@ -15,7 +15,7 @@ async function characterGenerator() {
   ollama.setModel("stablebeluga2:70b-q4_K_M");
   const bio = await ollama.generate(`create a bio of ${character} in a single long paragraph. Instead of saying '${character} is...' or '${character} was...' use language like 'You are...' or 'You were...'. Then create a paragraph describing the speaking mannerisms and style of ${character}. Don't include anything about how ${character} looked or what they sounded like, just focus on the words they said. Instead of saying '${character} would say...' use language like 'You should say...'. If you use quotes, always use single quotes instead of double quotes. If there are any specific words or phrases you used a lot, show how you used them. `);
 
-  const thecontents = `FROM llama2\nSYSTEM """\n${bio.response.replace(/(\r\n|\n|\r)/gm, " ").replace('would', 'should')} All answers to questions should be related back to what you are most known for.\n"""`;
+  const thecontents = `FROM llama3\nSYSTEM """\n${bio.response.replace(/(\r\n|\n|\r)/gm, " ").replace('would', 'should')} All answers to questions should be related back to what you are most known for.\n"""`;
 
   fs.writeFile(path.join(directory, 'Modelfile'), thecontents, (err: any) => {
     if (err) throw err;
@@ -23,4 +23,4 @@ async function characterGenerator() {
   });
 }
 
-characterGenerator();
+characterGenerator();

+ 2 - 2
examples/typescript-simplechat/client.ts

@@ -1,6 +1,6 @@
 import * as readline from "readline";
 
-const model = "llama2";
+const model = "llama3";
 type Message = {
   role: "assistant" | "user" | "system";
   content: string;
@@ -74,4 +74,4 @@ async function main() {
 
 }
 
-main();
+main();

+ 3 - 0
format/bytes.go

@@ -15,6 +15,7 @@ const (
 
 	KibiByte = Byte * 1024
 	MebiByte = KibiByte * 1024
+	GibiByte = MebiByte * 1024
 )
 
 func HumanBytes(b int64) string {
@@ -52,6 +53,8 @@ func HumanBytes(b int64) string {
 
 func HumanBytes2(b uint64) string {
 	switch {
+	case b >= GibiByte:
+		return fmt.Sprintf("%.1f GiB", float64(b)/GibiByte)
 	case b >= MebiByte:
 		return fmt.Sprintf("%.1f MiB", float64(b)/MebiByte)
 	case b >= KibiByte:

+ 14 - 6
format/format.go

@@ -13,12 +13,20 @@ const (
 
 func HumanNumber(b uint64) string {
 	switch {
-	case b > Billion:
-		return fmt.Sprintf("%.0fB", math.Round(float64(b)/Billion))
-	case b > Million:
-		return fmt.Sprintf("%.0fM", math.Round(float64(b)/Million))
-	case b > Thousand:
-		return fmt.Sprintf("%.0fK", math.Round(float64(b)/Thousand))
+	case b >= Billion:
+		number := float64(b) / Billion
+		if number == math.Floor(number) {
+			return fmt.Sprintf("%.0fB", number) // no decimals if whole number
+		}
+		return fmt.Sprintf("%.1fB", number) // one decimal if not a whole number
+	case b >= Million:
+		number := float64(b) / Million
+		if number == math.Floor(number) {
+			return fmt.Sprintf("%.0fM", number) // no decimals if whole number
+		}
+		return fmt.Sprintf("%.2fM", number) // two decimals if not a whole number
+	case b >= Thousand:
+		return fmt.Sprintf("%.0fK", float64(b)/Thousand)
 	default:
 		return fmt.Sprintf("%d", b)
 	}

+ 34 - 0
format/format_test.go

@@ -0,0 +1,34 @@
+package format
+
+import (
+	"testing"
+)
+
+func TestHumanNumber(t *testing.T) {
+
+	type testCase struct {
+		input    uint64
+		expected string
+	}
+
+	testCases := []testCase{
+		{0, "0"},
+		{1000000, "1M"},
+		{125000000, "125M"},
+		{500500000, "500.50M"},
+		{500550000, "500.55M"},
+		{1000000000, "1B"},
+		{2800000000, "2.8B"},
+		{2850000000, "2.9B"},
+		{1000000000000, "1000B"},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.expected, func(t *testing.T) {
+			result := HumanNumber(tc.input)
+			if result != tc.expected {
+				t.Errorf("Expected %s, got %s", tc.expected, result)
+			}
+		})
+	}
+}

+ 58 - 14
gpu/amd_common.go

@@ -7,7 +7,7 @@ import (
 	"log/slog"
 	"os"
 	"path/filepath"
-	"strconv"
+	"runtime"
 	"strings"
 )
 
@@ -35,22 +35,66 @@ func GetSupportedGFX(libDir string) ([]string, error) {
 	return ret, nil
 }
 
-func amdSetVisibleDevices(ids []int, skip map[int]interface{}) {
-	// Set the visible devices if not already set
-	// TODO - does sort order matter?
-	devices := []string{}
-	for i := range ids {
-		if _, skipped := skip[i]; skipped {
+func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
+	ids := []string{}
+	for _, info := range gpuInfo {
+		if info.Library != "rocm" {
+			// TODO shouldn't happen if things are wired correctly...
+			slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
 			continue
 		}
-		devices = append(devices, strconv.Itoa(i))
+		ids = append(ids, info.ID)
 	}
+	return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
+}
 
-	val := strings.Join(devices, ",")
-	err := os.Setenv("HIP_VISIBLE_DEVICES", val)
-	if err != nil {
-		slog.Warn(fmt.Sprintf("failed to set env: %s", err))
-	} else {
-		slog.Info("Setting HIP_VISIBLE_DEVICES=" + val)
+func commonAMDValidateLibDir() (string, error) {
+	// We try to favor system paths first, so that we can wire up the subprocess to use
+	// the system version.  Only use our bundled version if the system version doesn't work
+	// This gives users a more recovery options if versions have subtle problems at runtime
+
+	// Prefer explicit HIP env var
+	hipPath := os.Getenv("HIP_PATH")
+	if hipPath != "" {
+		hipLibDir := filepath.Join(hipPath, "bin")
+		if rocmLibUsable(hipLibDir) {
+			slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
+			return hipLibDir, nil
+		}
+	}
+
+	// Scan the LD_LIBRARY_PATH or PATH
+	pathEnv := "LD_LIBRARY_PATH"
+	if runtime.GOOS == "windows" {
+		pathEnv = "PATH"
+	}
+
+	paths := os.Getenv(pathEnv)
+	for _, path := range filepath.SplitList(paths) {
+		d, err := filepath.Abs(path)
+		if err != nil {
+			continue
+		}
+		if rocmLibUsable(d) {
+			return d, nil
+		}
+	}
+
+	// Well known location(s)
+	for _, path := range RocmStandardLocations {
+		if rocmLibUsable(path) {
+			return path, nil
+		}
+	}
+
+	// Installer payload location if we're running the installed binary
+	exe, err := os.Executable()
+	if err == nil {
+		rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
+		if rocmLibUsable(rocmTargetDir) {
+			slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
+			return rocmTargetDir, nil
+		}
 	}
+	return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
 }

+ 2 - 2
gpu/amd_hip_windows.go

@@ -69,7 +69,7 @@ func NewHipLib() (*HipLib, error) {
 func (hl *HipLib) Release() {
 	err := windows.FreeLibrary(hl.dll)
 	if err != nil {
-		slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err))
+		slog.Warn("failed to unload amdhip64.dll", "error", err)
 	}
 	hl.dll = 0
 }
@@ -98,7 +98,7 @@ func (hl *HipLib) HipGetDeviceCount() int {
 		return 0
 	}
 	if status != hipSuccess {
-		slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err))
+		slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
 	}
 	return count
 }

+ 176 - 287
gpu/amd_linux.go

@@ -11,6 +11,8 @@ import (
 	"slices"
 	"strconv"
 	"strings"
+
+	"github.com/ollama/ollama/format"
 )
 
 // Discovery logic for AMD/ROCm GPUs
@@ -23,26 +25,20 @@ const (
 	// Prefix with the node dir
 	GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
 	GPUUsedMemoryFileGlob  = "mem_banks/*/used_memory"
-	RocmStandardLocation   = "/opt/rocm/lib"
-
-	// TODO find a better way to detect iGPU instead of minimum memory
-	IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU
 )
 
 var (
 	// Used to validate if the given ROCm lib is usable
-	ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
+	ROCmLibGlobs          = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
+	RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"}
 )
 
 // Gather GPU information from the amdgpu driver if any supported GPUs are detected
-// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
-// and the user hasn't already set this variable
-func AMDGetGPUInfo(resp *GpuInfo) {
-	// TODO - DRY this out with windows
+func AMDGetGPUInfo() []GpuInfo {
+	resp := []GpuInfo{}
 	if !AMDDetected() {
-		return
+		return resp
 	}
-	skip := map[int]interface{}{}
 
 	// Opportunistic logging of driver version to aid in troubleshooting
 	ver, err := AMDDriverVersion()
@@ -50,160 +46,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
 		slog.Info("AMD Driver: " + ver)
 	} else {
 		// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
-		slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
+		slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
 	}
 
-	// If the user has specified exactly which GPUs to use, look up their memory
-	visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
-	if visibleDevices != "" {
-		ids := []int{}
-		for _, idStr := range strings.Split(visibleDevices, ",") {
-			id, err := strconv.Atoi(idStr)
-			if err != nil {
-				slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
-			} else {
-				ids = append(ids, id)
-			}
-		}
-		amdProcMemLookup(resp, nil, ids)
-		return
+	// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
+	var visibleDevices []string
+	hipVD := os.Getenv("HIP_VISIBLE_DEVICES")   // zero based index only
+	rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID
+	gpuDO := os.Getenv("GPU_DEVICE_ORDINAL")    // zero based index
+	switch {
+	// TODO is this priorty order right?
+	case hipVD != "":
+		visibleDevices = strings.Split(hipVD, ",")
+	case rocrVD != "":
+		visibleDevices = strings.Split(rocrVD, ",")
+		// TODO - since we don't yet support UUIDs, consider detecting and reporting here
+		// all our test systems show GPU-XX indicating UUID is not supported
+	case gpuDO != "":
+		visibleDevices = strings.Split(gpuDO, ",")
 	}
 
-	// Gather GFX version information from all detected cards
-	gfx := AMDGFXVersions()
-	verStrings := []string{}
-	for i, v := range gfx {
-		verStrings = append(verStrings, v.ToGFXString())
-		if v.Major == 0 {
-			// Silently skip CPUs
-			skip[i] = struct{}{}
+	gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
+	var supported []string
+	libDir := ""
+
+	// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
+	// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
+	matches, _ := filepath.Glob(GPUPropertiesFileGlob)
+	cpuCount := 0
+	for _, match := range matches {
+		slog.Debug("evaluating amdgpu node " + match)
+		fp, err := os.Open(match)
+		if err != nil {
+			slog.Debug("failed to open sysfs node", "file", match, "error", err)
 			continue
 		}
-		if v.Major < 9 {
-			// TODO consider this a build-time setting if we can support 8xx family GPUs
-			slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
-			skip[i] = struct{}{}
+		defer fp.Close()
+		nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
+		if err != nil {
+			slog.Debug("failed to parse node ID", "error", err)
+			continue
 		}
-	}
-	slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
 
-	// Abort if all GPUs are skipped
-	if len(skip) >= len(gfx) {
-		slog.Info("all detected amdgpus are skipped, falling back to CPU")
-		return
-	}
+		scanner := bufio.NewScanner(fp)
+		isCPU := false
+		var major, minor, patch uint64
+		for scanner.Scan() {
+			line := strings.TrimSpace(scanner.Text())
+			// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
+			if strings.HasPrefix(line, "gfx_target_version") {
+				ver := strings.Fields(line)
 
-	// If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
-	libDir, err := AMDValidateLibDir()
-	if err != nil {
-		slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
-		return
-	}
+				// Detect CPUs
+				if len(ver) == 2 && ver[1] == "0" {
+					slog.Debug("detected CPU " + match)
+					isCPU = true
+					break
+				}
 
-	updateLibPath(libDir)
+				if len(ver) != 2 || len(ver[1]) < 5 {
+					slog.Warn("malformed "+match, "gfx_target_version", line)
+					// If this winds up being a CPU, our offsets may be wrong
+					continue
+				}
+				l := len(ver[1])
+				var err1, err2, err3 error
+				patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
+				minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
+				major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
+				if err1 != nil || err2 != nil || err3 != nil {
+					slog.Debug("malformed int " + line)
+					continue
+				}
+			}
 
-	gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
-	if gfxOverride == "" {
-		supported, err := GetSupportedGFX(libDir)
-		if err != nil {
-			slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
-			return
+			// TODO - any other properties we want to extract and record?
+			// vendor_id + device_id -> pci lookup for "Name"
+			// Other metrics that may help us understand relative performance between multiple GPUs
 		}
-		slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
 
-		for i, v := range gfx {
-			if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
-				slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
-				// TODO - consider discrete markdown just for ROCM troubleshooting?
-				slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
-				skip[i] = struct{}{}
-			} else {
-				slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
-			}
+		if isCPU {
+			cpuCount++
+			continue
 		}
-	} else {
-		slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
-	}
 
-	if len(skip) >= len(gfx) {
-		slog.Info("all detected amdgpus are skipped, falling back to CPU")
-		return
-	}
+		// CPUs are always first in the list
+		gpuID := nodeID - cpuCount
 
-	ids := make([]int, len(gfx))
-	i := 0
-	for k := range gfx {
-		ids[i] = k
-		i++
-	}
-	amdProcMemLookup(resp, skip, ids)
-	if resp.memInfo.DeviceCount == 0 {
-		return
-	}
-	if len(skip) > 0 {
-		amdSetVisibleDevices(ids, skip)
-	}
-}
-
-func updateLibPath(libDir string) {
-	ldPaths := []string{}
-	if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
-		ldPaths = strings.Split(val, ":")
-	}
-	for _, d := range ldPaths {
-		if d == libDir {
-			return
-		}
-	}
-	val := strings.Join(append(ldPaths, libDir), ":")
-	slog.Debug("updated lib path", "LD_LIBRARY_PATH", val)
-	os.Setenv("LD_LIBRARY_PATH", val)
-}
-
-// Walk the sysfs nodes for the available GPUs and gather information from them
-// skipping over any devices in the skip map
-func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
-	resp.memInfo.DeviceCount = 0
-	resp.memInfo.TotalMemory = 0
-	resp.memInfo.FreeMemory = 0
-	slog.Debug("discovering VRAM for amdgpu devices")
-	if len(ids) == 0 {
-		entries, err := os.ReadDir(AMDNodesSysfsDir)
-		if err != nil {
-			slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
-			return
-		}
-		for _, node := range entries {
-			if !node.IsDir() {
-				continue
-			}
-			id, err := strconv.Atoi(node.Name())
-			if err != nil {
-				slog.Warn("malformed amdgpu sysfs node id " + node.Name())
-				continue
-			}
-			ids = append(ids, id)
+		// Shouldn't happen, but just in case...
+		if gpuID < 0 {
+			slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
+			return []GpuInfo{}
 		}
-	}
-	slog.Debug(fmt.Sprintf("amdgpu devices %v", ids))
 
-	for _, id := range ids {
-		if _, skipped := skip[id]; skipped {
+		if int(major) < RocmComputeMin {
+			slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%d%x", major, minor, patch), "gpu", gpuID)
 			continue
 		}
+
+		// Look up the memory for the current node
 		totalMemory := uint64(0)
 		usedMemory := uint64(0)
-		// Adjust for sysfs vs HIP ids
-		propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob)
+		propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
 		propFiles, err := filepath.Glob(propGlob)
 		if err != nil {
-			slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
+			slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
 		}
 		// 1 or more memory banks - sum the values of all of them
 		for _, propFile := range propFiles {
 			fp, err := os.Open(propFile)
 			if err != nil {
-				slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
+				slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
 				continue
 			}
 			defer fp.Close()
@@ -226,49 +179,113 @@ func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
 			}
 		}
 		if totalMemory == 0 {
-			slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id))
-			skip[id] = struct{}{}
+			slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
 			continue
 		}
-		if totalMemory < IGPUMemLimit {
-			slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024))
-			skip[id] = struct{}{}
-			continue
-		}
-		usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
+		usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
 		usedFiles, err := filepath.Glob(usedGlob)
 		if err != nil {
-			slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
+			slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
 			continue
 		}
 		for _, usedFile := range usedFiles {
 			fp, err := os.Open(usedFile)
 			if err != nil {
-				slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
+				slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
 				continue
 			}
 			defer fp.Close()
 			data, err := io.ReadAll(fp)
 			if err != nil {
-				slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
+				slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
 				continue
 			}
 			used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
 			if err != nil {
-				slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
+				slog.Warn("malformed used memory", "data", string(data), "error", err)
 				continue
 			}
 			usedMemory += used
 		}
-		slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024))
-		slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory  %dM", id, (totalMemory-usedMemory)/1024/1024))
-		resp.memInfo.DeviceCount++
-		resp.memInfo.TotalMemory += totalMemory
-		resp.memInfo.FreeMemory += (totalMemory - usedMemory)
+
+		// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
+		if totalMemory < IGPUMemLimit {
+			slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
+			continue
+		}
+
+		slog.Info("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
+		slog.Info("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
+		gpuInfo := GpuInfo{
+			Library: "rocm",
+			memInfo: memInfo{
+				TotalMemory: totalMemory,
+				FreeMemory:  (totalMemory - usedMemory),
+			},
+			ID: fmt.Sprintf("%d", gpuID),
+			// Name: not exposed in sysfs directly, would require pci device id lookup
+			Major:         int(major),
+			Minor:         int(minor),
+			Patch:         int(patch),
+			MinimumMemory: rocmMinimumMemory,
+		}
+
+		// If the user wants to filter to a subset of devices, filter out if we aren't a match
+		if len(visibleDevices) > 0 {
+			include := false
+			for _, visible := range visibleDevices {
+				if visible == gpuInfo.ID {
+					include = true
+					break
+				}
+			}
+			if !include {
+				slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices)
+				continue
+			}
+		}
+
+		// Final validation is gfx compatibility - load the library if we haven't already loaded it
+		// even if the user overrides, we still need to validate the library
+		if libDir == "" {
+			libDir, err = AMDValidateLibDir()
+			if err != nil {
+				slog.Warn("unable to verify rocm library, will use cpu", "error", err)
+				return []GpuInfo{}
+			}
+		}
+		gpuInfo.DependencyPath = libDir
+
+		if gfxOverride == "" {
+			// Only load supported list once
+			if len(supported) == 0 {
+				supported, err = GetSupportedGFX(libDir)
+				if err != nil {
+					slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
+					return []GpuInfo{}
+				}
+				slog.Debug("rocm supported GPUs", "types", supported)
+			}
+			gfx := fmt.Sprintf("gfx%d%d%x", gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch)
+			if !slices.Contains[[]string, string](supported, gfx) {
+				slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
+				// TODO - consider discrete markdown just for ROCM troubleshooting?
+				slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
+				continue
+			} else {
+				slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
+			}
+		} else {
+			slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
+		}
+
+		// The GPU has passed all the verification steps and is supported
+		resp = append(resp, gpuInfo)
 	}
-	if resp.memInfo.DeviceCount > 0 {
-		resp.Library = "rocm"
+	if len(resp) == 0 {
+		slog.Info("no compatible amdgpu devices detected")
 	}
+	return resp
 }
 
 // Quick check for AMD driver so we can skip amdgpu discovery if not present
@@ -280,87 +297,24 @@ func AMDDetected() bool {
 		slog.Debug("amdgpu driver not detected " + sysfsDir)
 		return false
 	} else if err != nil {
-		slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
+		slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
 		return false
 	}
 	return true
 }
 
-func setupLink(source, target string) error {
-	if err := os.RemoveAll(target); err != nil {
-		return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
-	}
-	if err := os.Symlink(source, target); err != nil {
-		return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
-	}
-	slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
-	return nil
-}
-
-// Ensure the AMD rocm lib dir is wired up
 // Prefer to use host installed ROCm, as long as it meets our minimum requirements
 // failing that, tell the user how to download it on their own
 func AMDValidateLibDir() (string, error) {
-	// We rely on the rpath compiled into our library to find rocm
-	// so we establish a symlink to wherever we find it on the system
-	// to <payloads>/rocm
-	payloadsDir, err := PayloadsDir()
-	if err != nil {
-		return "", err
-	}
-
-	// If we already have a rocm dependency wired, nothing more to do
-	rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm"))
-	if rocmLibUsable(rocmTargetDir) {
-		return rocmTargetDir, nil
-	}
-
-	// next to the running binary
-	exe, err := os.Executable()
+	libDir, err := commonAMDValidateLibDir()
 	if err == nil {
-		peerDir := filepath.Dir(exe)
-		if rocmLibUsable(peerDir) {
-			slog.Debug("detected ROCM next to ollama executable " + peerDir)
-			return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
-		}
-		peerDir = filepath.Join(filepath.Dir(exe), "rocm")
-		if rocmLibUsable(peerDir) {
-			slog.Debug("detected ROCM next to ollama executable " + peerDir)
-			return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
-		}
+		return libDir, nil
 	}
 
 	// Well known ollama installer path
 	installedRocmDir := "/usr/share/ollama/lib/rocm"
 	if rocmLibUsable(installedRocmDir) {
-		return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir)
-	}
-
-	// Prefer explicit HIP env var
-	hipPath := os.Getenv("HIP_PATH")
-	if hipPath != "" {
-		hipLibDir := filepath.Join(hipPath, "lib")
-		if rocmLibUsable(hipLibDir) {
-			slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
-			return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
-		}
-	}
-
-	// Scan the library path for potential matches
-	ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
-	for _, ldPath := range ldPaths {
-		d, err := filepath.Abs(ldPath)
-		if err != nil {
-			continue
-		}
-		if rocmLibUsable(d) {
-			return rocmTargetDir, setupLink(d, rocmTargetDir)
-		}
-	}
-
-	// Well known location(s)
-	if rocmLibUsable("/opt/rocm/lib") {
-		return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
+		return installedRocmDir, nil
 	}
 
 	// If we still haven't found a usable rocm, the user will have to install it on their own
@@ -384,68 +338,3 @@ func AMDDriverVersion() (string, error) {
 	}
 	return strings.TrimSpace(string(verString)), nil
 }
-
-func AMDGFXVersions() map[int]Version {
-	// The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one
-	// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
-	res := map[int]Version{}
-	matches, _ := filepath.Glob(GPUPropertiesFileGlob)
-	for _, match := range matches {
-		fp, err := os.Open(match)
-		if err != nil {
-			slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
-			continue
-		}
-		defer fp.Close()
-		i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
-		if err != nil {
-			slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
-			continue
-		}
-
-		if i == 0 {
-			// Skipping the CPU
-			continue
-		}
-		// Align with HIP IDs (zero is first GPU, not CPU)
-		i -= 1
-
-		scanner := bufio.NewScanner(fp)
-		for scanner.Scan() {
-			line := strings.TrimSpace(scanner.Text())
-			if strings.HasPrefix(line, "gfx_target_version") {
-				ver := strings.Fields(line)
-				if len(ver) != 2 || len(ver[1]) < 5 {
-					if ver[1] != "0" {
-						slog.Debug("malformed " + line)
-					}
-					res[i] = Version{
-						Major: 0,
-						Minor: 0,
-						Patch: 0,
-					}
-					continue
-				}
-				l := len(ver[1])
-				patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
-				minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
-				major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
-				if err1 != nil || err2 != nil || err3 != nil {
-					slog.Debug("malformed int " + line)
-					continue
-				}
-
-				res[i] = Version{
-					Major: uint(major),
-					Minor: uint(minor),
-					Patch: uint(patch),
-				}
-			}
-		}
-	}
-	return res
-}
-
-func (v Version) ToGFXString() string {
-	return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
-}

+ 85 - 76
gpu/amd_windows.go

@@ -7,11 +7,13 @@ import (
 	"os"
 	"path/filepath"
 	"slices"
+	"strconv"
 	"strings"
+
+	"github.com/ollama/ollama/format"
 )
 
 const (
-	RocmStandardLocation = "C:\\Program Files\\AMD\\ROCm\\5.7\\bin" // TODO glob?
 
 	// TODO  We're lookinng for this exact name to detect iGPUs since hipGetDeviceProperties never reports integrated==true
 	iGPUName = "AMD Radeon(TM) Graphics"
@@ -19,39 +21,36 @@ const (
 
 var (
 	// Used to validate if the given ROCm lib is usable
-	ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
+	ROCmLibGlobs          = []string{"hipblas.dll", "rocblas"}                 // TODO - probably include more coverage of files here...
+	RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
 )
 
-func AMDGetGPUInfo(resp *GpuInfo) {
+func AMDGetGPUInfo() []GpuInfo {
+	resp := []GpuInfo{}
 	hl, err := NewHipLib()
 	if err != nil {
 		slog.Debug(err.Error())
-		return
+		return nil
 	}
 	defer hl.Release()
-	skip := map[int]interface{}{}
-	ids := []int{}
-	resp.memInfo.DeviceCount = 0
-	resp.memInfo.TotalMemory = 0
-	resp.memInfo.FreeMemory = 0
 
 	ver, err := hl.AMDDriverVersion()
 	if err == nil {
 		slog.Info("AMD Driver: " + ver)
 	} else {
 		// For now this is benign, but we may eventually need to fail compatibility checks
-		slog.Debug(fmt.Sprintf("error looking up amd driver version: %s", err))
+		slog.Debug("error looking up amd driver version", "error", err)
 	}
 
-	// Note: the HIP library automatically handles HIP_VISIBLE_DEVICES
+	// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
 	count := hl.HipGetDeviceCount()
 	if count == 0 {
-		return
+		return nil
 	}
 	libDir, err := AMDValidateLibDir()
 	if err != nil {
-		slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
-		return
+		slog.Warn("unable to verify rocm library, will use cpu", "error", err)
+		return nil
 	}
 
 	var supported []string
@@ -59,95 +58,120 @@ func AMDGetGPUInfo(resp *GpuInfo) {
 	if gfxOverride == "" {
 		supported, err = GetSupportedGFX(libDir)
 		if err != nil {
-			slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
-			return
+			slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
+			return nil
 		}
 	} else {
 		slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
 	}
 
-	slog.Info(fmt.Sprintf("detected %d hip devices", count))
+	slog.Info("detected hip devices", "count", count)
+	// TODO how to determine the underlying device ID when visible devices is causing this to subset?
 	for i := 0; i < count; i++ {
-		ids = append(ids, i)
 		err = hl.HipSetDevice(i)
 		if err != nil {
-			slog.Warn(fmt.Sprintf("[%d] %s", i, err))
-			skip[i] = struct{}{}
+			slog.Warn("set device", "id", i, "error", err)
 			continue
 		}
 
 		props, err := hl.HipGetDeviceProperties(i)
 		if err != nil {
-			slog.Warn(fmt.Sprintf("[%d] %s", i, err))
-			skip[i] = struct{}{}
+			slog.Warn("get properties", "id", i, "error", err)
 			continue
 		}
 		n := bytes.IndexByte(props.Name[:], 0)
 		name := string(props.Name[:n])
-		slog.Info(fmt.Sprintf("[%d] Name: %s", i, name))
+		// TODO is UUID actually populated on windows?
+		// Can luid be used on windows for setting visible devices (and is it actually set?)
 		n = bytes.IndexByte(props.GcnArchName[:], 0)
 		gfx := string(props.GcnArchName[:n])
-		slog.Info(fmt.Sprintf("[%d] GcnArchName: %s", i, gfx))
+		slog.Info("hip device", "id", i, "name", name, "gfx", gfx)
+		var major, minor, patch string
+		switch len(gfx) {
+		case 6:
+			major, minor, patch = gfx[3:4], gfx[4:5], gfx[5:]
+		case 7:
+			major, minor, patch = gfx[3:5], gfx[5:6], gfx[6:]
+		}
 		//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY!  Always 0
 		// TODO  Why isn't props.iGPU accurate!?
 		if strings.EqualFold(name, iGPUName) {
-			slog.Info(fmt.Sprintf("iGPU detected [%d] skipping", i))
-			skip[i] = struct{}{}
+			slog.Info("iGPU detected skipping", "id", i)
 			continue
 		}
 		if gfxOverride == "" {
 			if !slices.Contains[[]string, string](supported, gfx) {
-				slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, gfx, libDir, supported))
+				slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
 				// TODO - consider discrete markdown just for ROCM troubleshooting?
 				slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
-				skip[i] = struct{}{}
 				continue
 			} else {
-				slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, gfx))
+				slog.Info("amdgpu is supported", "gpu", i, "gpu_type", gfx)
 			}
 		}
 
-		totalMemory, freeMemory, err := hl.HipMemGetInfo()
+		freeMemory, totalMemory, err := hl.HipMemGetInfo()
 		if err != nil {
-			slog.Warn(fmt.Sprintf("[%d] %s", i, err))
+			slog.Warn("get mem info", "id", i, "error", err)
 			continue
 		}
 
-		// TODO according to docs, freeMem may lie on windows!
-		slog.Info(fmt.Sprintf("[%d] Total Mem: %d", i, totalMemory))
-		slog.Info(fmt.Sprintf("[%d] Free Mem:  %d", i, freeMemory))
-		resp.memInfo.DeviceCount++
-		resp.memInfo.TotalMemory += totalMemory
-		resp.memInfo.FreeMemory += freeMemory
-	}
-	if resp.memInfo.DeviceCount > 0 {
-		resp.Library = "rocm"
-	}
-	// Abort if all GPUs are skipped
-	if len(skip) >= count {
-		slog.Info("all detected amdgpus are skipped, falling back to CPU")
-		return
-	}
-	if len(skip) > 0 {
-		amdSetVisibleDevices(ids, skip)
+		// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
+		if totalMemory < IGPUMemLimit {
+			slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
+			continue
+		}
+
+		// TODO revisit this once ROCm v6 is available on windows.
+		// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
+		slog.Info("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
+		slog.Info("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
+		gpuInfo := GpuInfo{
+			Library: "rocm",
+			memInfo: memInfo{
+				TotalMemory: totalMemory,
+				FreeMemory:  freeMemory,
+			},
+			ID:             fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
+			DependencyPath: libDir,
+			MinimumMemory:  rocmMinimumMemory,
+		}
+		if major != "" {
+			gpuInfo.Major, err = strconv.Atoi(major)
+			if err != nil {
+				slog.Info("failed to parse version", "version", gfx, "error", err)
+			}
+		}
+		if minor != "" {
+			gpuInfo.Minor, err = strconv.Atoi(minor)
+			if err != nil {
+				slog.Info("failed to parse version", "version", gfx, "error", err)
+			}
+		}
+		if patch != "" {
+			// Patch rev is hex; e.g. gfx90a
+			p, err := strconv.ParseInt(patch, 16, 0)
+			if err != nil {
+				slog.Info("failed to parse version", "version", gfx, "error", err)
+			} else {
+				gpuInfo.Patch = int(p)
+			}
+		}
+		if gpuInfo.Major < RocmComputeMin {
+			slog.Warn(fmt.Sprintf("amdgpu [%s] too old gfx%d%d%x", gpuInfo.ID, gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch))
+			continue
+		}
+
+		resp = append(resp, gpuInfo)
 	}
-	UpdatePath(libDir)
+
+	return resp
 }
 
 func AMDValidateLibDir() (string, error) {
-	// On windows non-admins typically can't create links
-	// so instead of trying to rely on rpath and a link in
-	// $LibDir/rocm, we instead rely on setting PATH to point
-	// to the location of the ROCm library
-
-	// Installer payload location if we're running the installed binary
-	exe, err := os.Executable()
+	libDir, err := commonAMDValidateLibDir()
 	if err == nil {
-		rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
-		if rocmLibUsable(rocmTargetDir) {
-			slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
-			return rocmTargetDir, nil
-		}
+		return libDir, nil
 	}
 
 	// Installer payload (if we're running from some other location)
@@ -159,21 +183,6 @@ func AMDValidateLibDir() (string, error) {
 		return rocmTargetDir, nil
 	}
 
-	// Prefer explicit HIP env var
-	hipPath := os.Getenv("HIP_PATH")
-	if hipPath != "" {
-		hipLibDir := filepath.Join(hipPath, "bin")
-		if rocmLibUsable(hipLibDir) {
-			slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
-			return hipLibDir, nil
-		}
-	}
-
-	// Well known location(s)
-	if rocmLibUsable(RocmStandardLocation) {
-		return RocmStandardLocation, nil
-	}
-
 	// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
 	slog.Warn("amdgpu detected, but no compatible rocm library found.  Please install ROCm")
 	return "", fmt.Errorf("no suitable rocm found, falling back to CPU")

+ 15 - 4
gpu/assets.go

@@ -12,6 +12,8 @@ import (
 	"sync"
 	"syscall"
 	"time"
+
+	"github.com/ollama/ollama/server/envconfig"
 )
 
 var (
@@ -24,8 +26,16 @@ func PayloadsDir() (string, error) {
 	defer lock.Unlock()
 	var err error
 	if payloadsDir == "" {
+		runnersDir := envconfig.RunnersDir
+
+		if runnersDir != "" {
+			payloadsDir = runnersDir
+			return payloadsDir, nil
+		}
+
+		// The remainder only applies on non-windows where we still carry payloads in the main executable
 		cleanupTmpDirs()
-		tmpDir := os.Getenv("OLLAMA_TMPDIR")
+		tmpDir := envconfig.TmpDir
 		if tmpDir == "" {
 			tmpDir, err = os.MkdirTemp("", "ollama")
 			if err != nil {
@@ -80,7 +90,7 @@ func cleanupTmpDirs() {
 		}
 		err = os.RemoveAll(d)
 		if err != nil {
-			slog.Debug(fmt.Sprintf("unable to cleanup stale tmpdir %s: %s", d, err))
+			slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err)
 		}
 	}
 }
@@ -88,7 +98,8 @@ func cleanupTmpDirs() {
 func Cleanup() {
 	lock.Lock()
 	defer lock.Unlock()
-	if payloadsDir != "" {
+	runnersDir := envconfig.RunnersDir
+	if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" {
 		// We want to fully clean up the tmpdir parent of the payloads dir
 		tmpDir := filepath.Clean(filepath.Join(payloadsDir, ".."))
 		slog.Debug("cleaning up", "dir", tmpDir)
@@ -120,7 +131,7 @@ func UpdatePath(dir string) {
 			}
 		}
 		newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
-		slog.Info(fmt.Sprintf("Updating PATH to %s", newPath))
+		slog.Info("updating", "PATH", newPath)
 		os.Setenv("PATH", newPath)
 	}
 	// linux and darwin rely on rpath

+ 22 - 0
gpu/cuda_common.go

@@ -0,0 +1,22 @@
+//go:build linux || windows
+
+package gpu
+
+import (
+	"log/slog"
+	"strings"
+)
+
+func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
+	ids := []string{}
+	for _, info := range gpuInfo {
+		if info.Library != "cuda" {
+			// TODO shouldn't happen if things are wired correctly...
+			slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
+			continue
+		}
+		ids = append(ids, info.ID)
+	}
+	return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
+
+}

+ 154 - 142
gpu/gpu.go

@@ -16,22 +16,23 @@ import (
 	"os"
 	"path/filepath"
 	"runtime"
-	"strconv"
 	"strings"
 	"sync"
 	"unsafe"
 
 	"github.com/ollama/ollama/format"
+	"github.com/ollama/ollama/server/envconfig"
 )
 
 type handles struct {
-	nvml   *C.nvml_handle_t
-	cudart *C.cudart_handle_t
+	deviceCount int
+	cudart      *C.cudart_handle_t
+	nvcuda      *C.nvcuda_handle_t
 }
 
 const (
-	cudaMinimumMemory = 457 * format.MebiByte
-	rocmMinimumMemory = 457 * format.MebiByte
+	cudaMinimumMemory = 256 * format.MebiByte
+	rocmMinimumMemory = 256 * format.MebiByte
 )
 
 var gpuMutex sync.Mutex
@@ -39,26 +40,10 @@ var gpuMutex sync.Mutex
 // With our current CUDA compile flags, older than 5.0 will not work properly
 var CudaComputeMin = [2]C.int{5, 0}
 
-// Possible locations for the nvidia-ml library
-var NvmlLinuxGlobs = []string{
-	"/usr/local/cuda/lib64/libnvidia-ml.so*",
-	"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
-	"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
-	"/usr/lib/wsl/lib/libnvidia-ml.so*",
-	"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
-	"/opt/cuda/lib64/libnvidia-ml.so*",
-	"/usr/lib*/libnvidia-ml.so*",
-	"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
-	"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
-	"/usr/local/lib*/libnvidia-ml.so*",
-
-	// TODO: are these stubs ever valid?
-	"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
-}
+var RocmComputeMin = 9
 
-var NvmlWindowsGlobs = []string{
-	"c:\\Windows\\System32\\nvml.dll",
-}
+// TODO find a better way to detect iGPU instead of minimum memory
+const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
 
 var CudartLinuxGlobs = []string{
 	"/usr/local/cuda/lib64/libcudart.so*",
@@ -79,6 +64,22 @@ var CudartWindowsGlobs = []string{
 	"c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
 }
 
+var NvcudaLinuxGlobs = []string{
+	"/usr/local/cuda*/targets/*/lib/libcuda.so*",
+	"/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*",
+	"/usr/lib/*-linux-gnu/libcuda.so*",
+	"/usr/lib/wsl/lib/libcuda.so*",
+	"/usr/lib/wsl/drivers/*/libcuda.so*",
+	"/opt/cuda/lib*/libcuda.so*",
+	"/usr/local/cuda/lib*/libcuda.so*",
+	"/usr/lib*/libcuda.so*",
+	"/usr/local/lib*/libcuda.so*",
+}
+
+var NvcudaWindowsGlobs = []string{
+	"c:\\windows\\system*\\nvcuda.dll",
+}
+
 // Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
 // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
 var CudaTegra string = os.Getenv("JETSON_JETPACK")
@@ -88,61 +89,62 @@ func initGPUHandles() *handles {
 
 	// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
 
-	gpuHandles := &handles{nil, nil}
-	var nvmlMgmtName string
-	var nvmlMgmtPatterns []string
+	gpuHandles := &handles{}
 	var cudartMgmtName string
 	var cudartMgmtPatterns []string
+	var nvcudaMgmtName string
+	var nvcudaMgmtPatterns []string
 
 	tmpDir, _ := PayloadsDir()
 	switch runtime.GOOS {
 	case "windows":
-		nvmlMgmtName = "nvml.dll"
-		nvmlMgmtPatterns = make([]string, len(NvmlWindowsGlobs))
-		copy(nvmlMgmtPatterns, NvmlWindowsGlobs)
 		cudartMgmtName = "cudart64_*.dll"
 		localAppData := os.Getenv("LOCALAPPDATA")
 		cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
 		cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
+		// Aligned with driver, we can't carry as payloads
+		nvcudaMgmtName = "nvcuda.dll"
+		nvcudaMgmtPatterns = NvcudaWindowsGlobs
 	case "linux":
-		nvmlMgmtName = "libnvidia-ml.so"
-		nvmlMgmtPatterns = make([]string, len(NvmlLinuxGlobs))
-		copy(nvmlMgmtPatterns, NvmlLinuxGlobs)
 		cudartMgmtName = "libcudart.so*"
 		if tmpDir != "" {
 			// TODO - add "payloads" for subprocess
 			cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", cudartMgmtName)}
 		}
 		cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
+		// Aligned with driver, we can't carry as payloads
+		nvcudaMgmtName = "libcuda.so*"
+		nvcudaMgmtPatterns = NvcudaLinuxGlobs
 	default:
 		return gpuHandles
 	}
 
-	slog.Info("Detecting GPU type")
-	cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
-	if len(cudartLibPaths) > 0 {
-		cudart := LoadCUDARTMgmt(cudartLibPaths)
-		if cudart != nil {
-			slog.Info("Nvidia GPU detected via cudart")
-			gpuHandles.cudart = cudart
+	slog.Info("Detecting GPUs")
+	nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns)
+	if len(nvcudaLibPaths) > 0 {
+		deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
+		if nvcuda != nil {
+			slog.Info("detected GPUs", "count", deviceCount, "library", libPath)
+			gpuHandles.nvcuda = nvcuda
+			gpuHandles.deviceCount = deviceCount
 			return gpuHandles
 		}
 	}
 
-	// TODO once we build confidence, remove this and the gpu_info_nvml.[ch] files
-	nvmlLibPaths := FindGPULibs(nvmlMgmtName, nvmlMgmtPatterns)
-	if len(nvmlLibPaths) > 0 {
-		nvml := LoadNVMLMgmt(nvmlLibPaths)
-		if nvml != nil {
-			slog.Info("Nvidia GPU detected via nvidia-ml")
-			gpuHandles.nvml = nvml
+	cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
+	if len(cudartLibPaths) > 0 {
+		deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
+		if cudart != nil {
+			slog.Info("detected GPUs", "library", libPath, "count", deviceCount)
+			gpuHandles.cudart = cudart
+			gpuHandles.deviceCount = deviceCount
 			return gpuHandles
 		}
 	}
 	return gpuHandles
 }
 
-func GetGPUInfo() GpuInfo {
+func GetGPUInfo() GpuInfoList {
 	// TODO - consider exploring lspci (and equivalent on windows) to check for
 	// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
 	gpuMutex.Lock()
@@ -150,12 +152,12 @@ func GetGPUInfo() GpuInfo {
 
 	gpuHandles := initGPUHandles()
 	defer func() {
-		if gpuHandles.nvml != nil {
-			C.nvml_release(*gpuHandles.nvml)
-		}
 		if gpuHandles.cudart != nil {
 			C.cudart_release(*gpuHandles.cudart)
 		}
+		if gpuHandles.nvcuda != nil {
+			C.nvcuda_release(*gpuHandles.nvcuda)
+		}
 	}()
 
 	// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
@@ -164,73 +166,75 @@ func GetGPUInfo() GpuInfo {
 		slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
 	}
 
+	// On windows we bundle the nvidia library one level above the runner dir
+	depPath := ""
+	if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
+		depPath = filepath.Dir(envconfig.RunnersDir)
+	}
+
 	var memInfo C.mem_info_t
-	resp := GpuInfo{}
-	if gpuHandles.nvml != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
-		C.nvml_check_vram(*gpuHandles.nvml, &memInfo)
-		if memInfo.err != nil {
-			slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU memory: %s", C.GoString(memInfo.err)))
-			C.free(unsafe.Pointer(memInfo.err))
-		} else if memInfo.count > 0 {
-			// Verify minimum compute capability
-			var cc C.nvml_compute_capability_t
-			C.nvml_compute_capability(*gpuHandles.nvml, &cc)
-			if cc.err != nil {
-				slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU compute capability: %s", C.GoString(cc.err)))
-				C.free(unsafe.Pointer(cc.err))
-			} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
-				slog.Info(fmt.Sprintf("[nvidia-ml] NVML CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
-				resp.Library = "cuda"
-				resp.MinimumMemory = cudaMinimumMemory
-			} else {
-				slog.Info(fmt.Sprintf("[nvidia-ml] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
-			}
+	resp := []GpuInfo{}
+
+	// NVIDIA first
+	for i := 0; i < gpuHandles.deviceCount; i++ {
+		// TODO once we support CPU compilation variants of GPU libraries refine this...
+		if cpuVariant == "" && runtime.GOARCH == "amd64" {
+			continue
+		}
+		gpuInfo := GpuInfo{
+			Library: "cuda",
+		}
+		if gpuHandles.cudart != nil {
+			C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
+		} else {
+			C.nvcuda_check_vram(*gpuHandles.nvcuda, C.int(i), &memInfo)
 		}
-	} else if gpuHandles.cudart != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
-		C.cudart_check_vram(*gpuHandles.cudart, &memInfo)
 		if memInfo.err != nil {
-			slog.Info(fmt.Sprintf("[cudart] error looking up CUDART GPU memory: %s", C.GoString(memInfo.err)))
+			slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
 			C.free(unsafe.Pointer(memInfo.err))
-		} else if memInfo.count > 0 {
-			// Verify minimum compute capability
-			var cc C.cudart_compute_capability_t
-			C.cudart_compute_capability(*gpuHandles.cudart, &cc)
-			if cc.err != nil {
-				slog.Info(fmt.Sprintf("[cudart] error looking up CUDA compute capability: %s", C.GoString(cc.err)))
-				C.free(unsafe.Pointer(cc.err))
-			} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
-				slog.Info(fmt.Sprintf("[cudart] CUDART CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
-				resp.Library = "cuda"
-				resp.MinimumMemory = cudaMinimumMemory
-			} else {
-				slog.Info(fmt.Sprintf("[cudart] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
-			}
+			continue
 		}
-	} else {
-		AMDGetGPUInfo(&resp)
-		if resp.Library != "" {
-			resp.MinimumMemory = rocmMinimumMemory
-			return resp
+		if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
+			slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
+			continue
 		}
+		gpuInfo.TotalMemory = uint64(memInfo.total)
+		gpuInfo.FreeMemory = uint64(memInfo.free)
+		gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
+		gpuInfo.Major = int(memInfo.major)
+		gpuInfo.Minor = int(memInfo.minor)
+		gpuInfo.MinimumMemory = cudaMinimumMemory
+		gpuInfo.DependencyPath = depPath
+
+		// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
+		resp = append(resp, gpuInfo)
 	}
-	if resp.Library == "" {
+
+	// Then AMD
+	resp = append(resp, AMDGetGPUInfo()...)
+
+	if len(resp) == 0 {
 		C.cpu_check_ram(&memInfo)
-		resp.Library = "cpu"
-		resp.Variant = cpuVariant
-	}
-	if memInfo.err != nil {
-		slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
-		C.free(unsafe.Pointer(memInfo.err))
-		return resp
+		if memInfo.err != nil {
+			slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
+			C.free(unsafe.Pointer(memInfo.err))
+			return resp
+		}
+		gpuInfo := GpuInfo{
+			Library: "cpu",
+			Variant: cpuVariant,
+		}
+		gpuInfo.TotalMemory = uint64(memInfo.total)
+		gpuInfo.FreeMemory = uint64(memInfo.free)
+		gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
+
+		resp = append(resp, gpuInfo)
 	}
 
-	resp.DeviceCount = uint32(memInfo.count)
-	resp.FreeMemory = uint64(memInfo.free)
-	resp.TotalMemory = uint64(memInfo.total)
 	return resp
 }
 
-func getCPUMem() (memInfo, error) {
+func GetCPUMem() (memInfo, error) {
 	var ret memInfo
 	var info C.mem_info_t
 	C.cpu_check_ram(&info)
@@ -243,29 +247,12 @@ func getCPUMem() (memInfo, error) {
 	return ret, nil
 }
 
-func CheckVRAM() (uint64, error) {
-	userLimit := os.Getenv("OLLAMA_MAX_VRAM")
-	if userLimit != "" {
-		avail, err := strconv.ParseInt(userLimit, 10, 64)
-		if err != nil {
-			return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
-		}
-		slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
-		return uint64(avail), nil
-	}
-	gpuInfo := GetGPUInfo()
-	if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
-		return gpuInfo.FreeMemory, nil
-	}
-
-	return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
-}
-
-func FindGPULibs(baseLibName string, patterns []string) []string {
+func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
 	// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
 	var ldPaths []string
+	var patterns []string
 	gpuLibPaths := []string{}
-	slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName))
+	slog.Debug("Searching for GPU library", "name", baseLibName)
 
 	switch runtime.GOOS {
 	case "windows":
@@ -283,8 +270,14 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
 		}
 		patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
 	}
-	slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns))
+	patterns = append(patterns, defaultPatterns...)
+	slog.Debug("gpu library search", "globs", patterns)
 	for _, pattern := range patterns {
+
+		// Nvidia PhysX known to return bogus results
+		if strings.Contains(pattern, "PhysX") {
+			slog.Debug("skipping PhysX cuda library path", "path", pattern)
+		}
 		// Ignore glob discovery errors
 		matches, _ := filepath.Glob(pattern)
 		for _, match := range matches {
@@ -311,47 +304,66 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
 			}
 		}
 	}
-	slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths))
+	slog.Debug("discovered GPU libraries", "paths", gpuLibPaths)
 	return gpuLibPaths
 }
 
-func LoadNVMLMgmt(nvmlLibPaths []string) *C.nvml_handle_t {
-	var resp C.nvml_init_resp_t
+func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) {
+	var resp C.cudart_init_resp_t
 	resp.ch.verbose = getVerboseState()
-	for _, libPath := range nvmlLibPaths {
+	for _, libPath := range cudartLibPaths {
 		lib := C.CString(libPath)
 		defer C.free(unsafe.Pointer(lib))
-		C.nvml_init(lib, &resp)
+		C.cudart_init(lib, &resp)
 		if resp.err != nil {
-			slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)))
+			slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err))
 			C.free(unsafe.Pointer(resp.err))
 		} else {
-			return &resp.ch
+			return int(resp.num_devices), &resp.ch, libPath
 		}
 	}
-	return nil
+	return 0, nil, ""
 }
 
-func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
-	var resp C.cudart_init_resp_t
+func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
+	var resp C.nvcuda_init_resp_t
 	resp.ch.verbose = getVerboseState()
-	for _, libPath := range cudartLibPaths {
+	for _, libPath := range nvcudaLibPaths {
 		lib := C.CString(libPath)
 		defer C.free(unsafe.Pointer(lib))
-		C.cudart_init(lib, &resp)
+		C.nvcuda_init(lib, &resp)
 		if resp.err != nil {
-			slog.Info(fmt.Sprintf("Unable to load cudart CUDA management library %s: %s", libPath, C.GoString(resp.err)))
+			slog.Debug("Unable to load nvcuda", "library", libPath, "error", C.GoString(resp.err))
 			C.free(unsafe.Pointer(resp.err))
 		} else {
-			return &resp.ch
+			return int(resp.num_devices), &resp.ch, libPath
 		}
 	}
-	return nil
+	return 0, nil, ""
 }
 
 func getVerboseState() C.uint16_t {
-	if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
+	if envconfig.Debug {
 		return C.uint16_t(1)
 	}
 	return C.uint16_t(0)
 }
+
+// Given the list of GPUs this instantiation is targeted for,
+// figure out the visible devices environment variable
+//
+// If different libraries are detected, the first one is what we use
+func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
+	if len(l) == 0 {
+		return "", ""
+	}
+	switch l[0].Library {
+	case "cuda":
+		return cudaGetVisibleDevicesEnv(l)
+	case "rocm":
+		return rocmGetVisibleDevicesEnv(l)
+	default:
+		slog.Debug("no filter required for library " + l[0].Library)
+		return "", ""
+	}
+}

+ 28 - 33
gpu/gpu_darwin.go

@@ -9,52 +9,47 @@ package gpu
 */
 import "C"
 import (
-	"fmt"
-	"log/slog"
-	"os"
 	"runtime"
-	"strconv"
-)
-
-// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
-func CheckVRAM() (uint64, error) {
-	userLimit := os.Getenv("OLLAMA_MAX_VRAM")
-	if userLimit != "" {
-		avail, err := strconv.ParseInt(userLimit, 10, 64)
-		if err != nil {
-			return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
-		}
-		slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
-		return uint64(avail), nil
-	}
 
-	if runtime.GOARCH == "amd64" {
-		// gpu not supported, this may not be metal
-		return 0, nil
-	}
+	"github.com/ollama/ollama/format"
+)
 
-	return uint64(C.getRecommendedMaxVRAM()), nil
-}
+const (
+	metalMinimumMemory = 384 * format.MebiByte
+)
 
-func GetGPUInfo() GpuInfo {
-	mem, _ := getCPUMem()
+func GetGPUInfo() GpuInfoList {
+	mem, _ := GetCPUMem()
 	if runtime.GOARCH == "amd64" {
-		return GpuInfo{
-			Library: "cpu",
-			Variant: GetCPUVariant(),
-			memInfo: mem,
+		return []GpuInfo{
+			{
+				Library: "cpu",
+				Variant: GetCPUVariant(),
+				memInfo: mem,
+			},
 		}
 	}
-	return GpuInfo{
+	info := GpuInfo{
 		Library: "metal",
-		memInfo: mem,
+		ID:      "0",
 	}
+	info.TotalMemory = uint64(C.getRecommendedMaxVRAM())
+
+	// TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work)
+	info.FreeMemory = info.TotalMemory
+
+	info.MinimumMemory = metalMinimumMemory
+	return []GpuInfo{info}
 }
 
-func getCPUMem() (memInfo, error) {
+func GetCPUMem() (memInfo, error) {
 	return memInfo{
 		TotalMemory: uint64(C.getPhysicalMemory()),
 		FreeMemory:  0,
-		DeviceCount: 1,
 	}, nil
 }
+
+func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
+	// No-op on darwin
+	return "", ""
+}

+ 9 - 4
gpu/gpu_info.h

@@ -38,12 +38,17 @@
 extern "C" {
 #endif
 
+#define GPU_ID_LEN 64
+
 typedef struct mem_info {
+  char *err;  // If non-nill, caller responsible for freeing
+  char gpu_id[GPU_ID_LEN];
   uint64_t total;
   uint64_t free;
-  unsigned int count;
-  int igpu_index; // If >= 0, we detected an integrated GPU to ignore
-  char *err;  // If non-nill, caller responsible for freeing
+
+  // Compute Capability
+  int major; 
+  int minor;
 } mem_info_t;
 
 void cpu_check_ram(mem_info_t *resp);
@@ -52,8 +57,8 @@ void cpu_check_ram(mem_info_t *resp);
 }
 #endif
 
-#include "gpu_info_nvml.h"
 #include "gpu_info_cudart.h"
+#include "gpu_info_nvcuda.h"
 
 #endif  // __GPU_INFO_H__
 #endif  // __APPLE__

+ 6 - 2
gpu/gpu_info_cpu.c

@@ -8,9 +8,11 @@ void cpu_check_ram(mem_info_t *resp) {
   MEMORYSTATUSEX info;
   info.dwLength = sizeof(info);
   if (GlobalMemoryStatusEx(&info) != 0) {
-    resp->count = 1;
     resp->total = info.ullTotalPhys;
     resp->free = info.ullAvailPhys;
+    resp->major = 0;
+    resp->minor = 0;
+    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
   } else {
     resp->err = LOAD_ERR();
   }
@@ -27,9 +29,11 @@ void cpu_check_ram(mem_info_t *resp) {
   if (sysinfo(&info) != 0) {
     resp->err = strdup(strerror(errno));
   } else {
-    resp->count = 1;
     resp->total = info.totalram * info.mem_unit;
     resp->free = info.freeram * info.mem_unit;
+    resp->major = 0;
+    resp->minor = 0;
+    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
   }
   return;
 }

+ 63 - 82
gpu/gpu_info_cudart.c

@@ -6,6 +6,7 @@
 void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
   cudartReturn_t ret;
   resp->err = NULL;
+  resp->num_devices = 0;
   const int buflen = 256;
   char buf[buflen + 1];
   int i;
@@ -21,6 +22,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
       {"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
       {"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
       {"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
+      {"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties},
       {NULL, NULL},
   };
 
@@ -36,13 +38,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
     return;
   }
 
-  // TODO once we've squashed the remaining corner cases remove this log
-  LOG(resp->ch.verbose, "wiring cudart library functions in %s\n", cudart_lib_path);
-  
   for (i = 0; l[i].s != NULL; i++) {
-    // TODO once we've squashed the remaining corner cases remove this log
-    LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
-
     *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
     if (!l[i].p) {
       char *msg = LOAD_ERR();
@@ -63,7 +59,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
     UNLOAD_LIBRARY(resp->ch.handle);
     resp->ch.handle = NULL;
     if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
-      resp->err = strdup("your nvidia driver is too old or missing, please upgrade to run ollama");
+      resp->err = strdup("your nvidia driver is too old or missing.  If you have a CUDA GPU please upgrade to run ollama");
       return;
     }
     snprintf(buf, buflen, "cudart init failure: %d", ret);
@@ -85,110 +81,95 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
     driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
     LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
   }
+
+  ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices);
+  if (ret != CUDART_SUCCESS) {
+    LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret);
+    UNLOAD_LIBRARY(resp->ch.handle);
+    resp->ch.handle = NULL;
+    snprintf(buf, buflen, "unable to get device count: %d", ret);
+    resp->err = strdup(buf);
+    return;
+  }
 }
 
 
-void cudart_check_vram(cudart_handle_t h, mem_info_t *resp) {
+void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
   resp->err = NULL;
   cudartMemory_t memInfo = {0,0,0};
   cudartReturn_t ret;
   const int buflen = 256;
   char buf[buflen + 1];
-  int i;
 
   if (h.handle == NULL) {
     resp->err = strdup("cudart handle isn't initialized");
     return;
   }
 
-  // cudaGetDeviceCount takes int type, resp-> count is uint
-  int deviceCount;
-  ret = (*h.cudaGetDeviceCount)(&deviceCount);
+  ret = (*h.cudaSetDevice)(i);
   if (ret != CUDART_SUCCESS) {
-    snprintf(buf, buflen, "unable to get device count: %d", ret);
+    snprintf(buf, buflen, "cudart device failed to initialize");
     resp->err = strdup(buf);
     return;
-  } else {
-    resp->count = (unsigned int)deviceCount;
   }
 
-  resp->total = 0;
-  resp->free = 0;
-  for (i = 0; i < resp-> count; i++) {  
-    ret = (*h.cudaSetDevice)(i);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "cudart device failed to initialize");
-      resp->err = strdup(buf);
-      return;
+  cudaDeviceProp_t props;
+  ret = (*h.cudaGetDeviceProperties)(&props, i);
+  if (ret != CUDART_SUCCESS) {
+    LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret);
+    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
+    resp->major = 0;
+    resp->minor = 0;
+  } else {
+    int allNull = 1;
+    for (int j = 0; j < 16; j++) {
+      if (props.uuid.bytes[j] != 0) {
+        allNull = 0;
+        break;
+      }
     }
-    ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
-      resp->err = strdup(buf);
-      return;
+    if (allNull != 0) {
+      snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
+    } else {
+      // GPU-d110a105-ac29-1d54-7b49-9c90440f215b
+      snprintf(&resp->gpu_id[0], GPU_ID_LEN,
+          "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
+          props.uuid.bytes[0],
+          props.uuid.bytes[1],
+          props.uuid.bytes[2],
+          props.uuid.bytes[3],
+          props.uuid.bytes[4],
+          props.uuid.bytes[5],
+          props.uuid.bytes[6],
+          props.uuid.bytes[7],
+          props.uuid.bytes[8],
+          props.uuid.bytes[9],
+          props.uuid.bytes[10],
+          props.uuid.bytes[11],
+          props.uuid.bytes[12],
+          props.uuid.bytes[13],
+          props.uuid.bytes[14],
+          props.uuid.bytes[15]
+        );
     }
+    resp->major = props.major;
+    resp->minor = props.minor;
 
-    LOG(h.verbose, "[%d] CUDA totalMem %lu\n", i, memInfo.total);
-    LOG(h.verbose, "[%d] CUDA freeMem %lu\n", i, memInfo.free);
-
-    resp->total += memInfo.total;
-    resp->free += memInfo.free;
-  }
-}
-
-void cudart_compute_capability(cudart_handle_t h, cudart_compute_capability_t *resp) {
-  resp->err = NULL;
-  resp->major = 0;
-  resp->minor = 0;
-  int major = 0;
-  int minor = 0;
-  cudartReturn_t ret;
-  const int buflen = 256;
-  char buf[buflen + 1];
-  int i;
-
-  if (h.handle == NULL) {
-    resp->err = strdup("cudart handle not initialized");
-    return;
+    // TODO add other useful properties from props
   }
-
-  int devices;
-  ret = (*h.cudaGetDeviceCount)(&devices);
+  ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
   if (ret != CUDART_SUCCESS) {
-    snprintf(buf, buflen, "unable to get cudart device count: %d", ret);
+    snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
     resp->err = strdup(buf);
     return;
   }
 
-  for (i = 0; i < devices; i++) {
-    ret = (*h.cudaSetDevice)(i);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "cudart device failed to initialize");
-      resp->err = strdup(buf);
-      return;
-    }
+  resp->total = memInfo.total;
+  resp->free = memInfo.free;
 
-    ret = (*h.cudaDeviceGetAttribute)(&major, cudartDevAttrComputeCapabilityMajor, i);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-    ret = (*h.cudaDeviceGetAttribute)(&minor, cudartDevAttrComputeCapabilityMinor, i);
-    if (ret != CUDART_SUCCESS) {
-      snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-      
-    // Report the lowest major.minor we detect as that limits our compatibility
-    if (resp->major == 0 || resp->major > major ) {
-      resp->major = major;
-      resp->minor = minor;
-    } else if ( resp->major == major && resp->minor > minor ) {
-      resp->minor = minor;
-    }
-  }
+  LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
+  LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
+  LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
 }
 
 void cudart_release(cudart_handle_t h) {

+ 97 - 11
gpu/gpu_info_cudart.h

@@ -6,14 +6,20 @@
 // Just enough typedef's to dlopen/dlsym for memory information
 typedef enum cudartReturn_enum {
   CUDART_SUCCESS = 0,
-  CUDART_UNSUPPORTED = 1,
-  CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
+  CUDART_ERROR_INVALID_VALUE = 1,
+  CUDART_ERROR_MEMORY_ALLOCATION = 2,
+  CUDART_ERROR_INSUFFICIENT_DRIVER = 35,
   // Other values omitted for now...
 } cudartReturn_t;
 
 typedef enum cudartDeviceAttr_enum {
   cudartDevAttrComputeCapabilityMajor = 75,
   cudartDevAttrComputeCapabilityMinor = 76,
+
+  // TODO - not yet wired up but may be useful for Jetson or other
+  // integrated GPU scenarios with shared memory
+  cudaDevAttrIntegrated = 18
+
 } cudartDeviceAttr_t;
 
 typedef void *cudartDevice_t;  // Opaque is sufficient
@@ -28,6 +34,92 @@ typedef struct cudartDriverVersion {
   int minor;
 } cudartDriverVersion_t;
 
+typedef struct cudaUUID {
+    unsigned char bytes[16];
+} cudaUUID_t;
+typedef struct cudaDeviceProp {
+    char         name[256];                  /**< ASCII string identifying device */
+    cudaUUID_t   uuid;                       /**< 16-byte unique identifier */
+    char         luid[8];                    /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */
+    unsigned int luidDeviceNodeMask;         /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */
+    size_t       totalGlobalMem;             /**< Global memory available on device in bytes */
+    size_t       sharedMemPerBlock;          /**< Shared memory available per block in bytes */
+    int          regsPerBlock;               /**< 32-bit registers available per block */
+    int          warpSize;                   /**< Warp size in threads */
+    size_t       memPitch;                   /**< Maximum pitch in bytes allowed by memory copies */
+    int          maxThreadsPerBlock;         /**< Maximum number of threads per block */
+    int          maxThreadsDim[3];           /**< Maximum size of each dimension of a block */
+    int          maxGridSize[3];             /**< Maximum size of each dimension of a grid */
+    int          clockRate;                  /**< Clock frequency in kilohertz */
+    size_t       totalConstMem;              /**< Constant memory available on device in bytes */
+    int          major;                      /**< Major compute capability */
+    int          minor;                      /**< Minor compute capability */
+    size_t       textureAlignment;           /**< Alignment requirement for textures */
+    size_t       texturePitchAlignment;      /**< Pitch alignment requirement for texture references bound to pitched memory */
+    int          deviceOverlap;              /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */
+    int          multiProcessorCount;        /**< Number of multiprocessors on device */
+    int          kernelExecTimeoutEnabled;   /**< Specified whether there is a run time limit on kernels */
+    int          integrated;                 /**< Device is integrated as opposed to discrete */
+    int          canMapHostMemory;           /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */
+    int          computeMode;                /**< Compute mode (See ::cudaComputeMode) */
+    int          maxTexture1D;               /**< Maximum 1D texture size */
+    int          maxTexture1DMipmap;         /**< Maximum 1D mipmapped texture size */
+    int          maxTexture1DLinear;         /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */
+    int          maxTexture2D[2];            /**< Maximum 2D texture dimensions */
+    int          maxTexture2DMipmap[2];      /**< Maximum 2D mipmapped texture dimensions */
+    int          maxTexture2DLinear[3];      /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */
+    int          maxTexture2DGather[2];      /**< Maximum 2D texture dimensions if texture gather operations have to be performed */
+    int          maxTexture3D[3];            /**< Maximum 3D texture dimensions */
+    int          maxTexture3DAlt[3];         /**< Maximum alternate 3D texture dimensions */
+    int          maxTextureCubemap;          /**< Maximum Cubemap texture dimensions */
+    int          maxTexture1DLayered[2];     /**< Maximum 1D layered texture dimensions */
+    int          maxTexture2DLayered[3];     /**< Maximum 2D layered texture dimensions */
+    int          maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */
+    int          maxSurface1D;               /**< Maximum 1D surface size */
+    int          maxSurface2D[2];            /**< Maximum 2D surface dimensions */
+    int          maxSurface3D[3];            /**< Maximum 3D surface dimensions */
+    int          maxSurface1DLayered[2];     /**< Maximum 1D layered surface dimensions */
+    int          maxSurface2DLayered[3];     /**< Maximum 2D layered surface dimensions */
+    int          maxSurfaceCubemap;          /**< Maximum Cubemap surface dimensions */
+    int          maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */
+    size_t       surfaceAlignment;           /**< Alignment requirements for surfaces */
+    int          concurrentKernels;          /**< Device can possibly execute multiple kernels concurrently */
+    int          ECCEnabled;                 /**< Device has ECC support enabled */
+    int          pciBusID;                   /**< PCI bus ID of the device */
+    int          pciDeviceID;                /**< PCI device ID of the device */
+    int          pciDomainID;                /**< PCI domain ID of the device */
+    int          tccDriver;                  /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */
+    int          asyncEngineCount;           /**< Number of asynchronous engines */
+    int          unifiedAddressing;          /**< Device shares a unified address space with the host */
+    int          memoryClockRate;            /**< Peak memory clock frequency in kilohertz */
+    int          memoryBusWidth;             /**< Global memory bus width in bits */
+    int          l2CacheSize;                /**< Size of L2 cache in bytes */
+    int          persistingL2CacheMaxSize;   /**< Device's maximum l2 persisting lines capacity setting in bytes */
+    int          maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */
+    int          streamPrioritiesSupported;  /**< Device supports stream priorities */
+    int          globalL1CacheSupported;     /**< Device supports caching globals in L1 */
+    int          localL1CacheSupported;      /**< Device supports caching locals in L1 */
+    size_t       sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */
+    int          regsPerMultiprocessor;      /**< 32-bit registers available per multiprocessor */
+    int          managedMemory;              /**< Device supports allocating managed memory on this system */
+    int          isMultiGpuBoard;            /**< Device is on a multi-GPU board */
+    int          multiGpuBoardGroupID;       /**< Unique identifier for a group of devices on the same multi-GPU board */
+    int          hostNativeAtomicSupported;  /**< Link between the device and the host supports native atomic operations */
+    int          singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */
+    int          pageableMemoryAccess;       /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */
+    int          concurrentManagedAccess;    /**< Device can coherently access managed memory concurrently with the CPU */
+    int          computePreemptionSupported; /**< Device supports Compute Preemption */
+    int          canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */
+    int          cooperativeLaunch;          /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */
+    int          cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */
+    size_t       sharedMemPerBlockOptin;     /**< Per device maximum shared memory per block usable by special opt in */
+    int          pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */
+    int          directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */
+    int          maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */
+    int          accessPolicyMaxWindowSize;  /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */
+    size_t       reservedSharedMemPerBlock;  /**< Shared memory reserved by CUDA driver per block in bytes */
+  } cudaDeviceProp_t;
+
 typedef struct cudart_handle {
   void *handle;
   uint16_t verbose;
@@ -38,23 +130,17 @@ typedef struct cudart_handle {
   cudartReturn_t (*cudaGetDeviceCount)(int *);
   cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
   cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
+  cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device);
 } cudart_handle_t;
 
 typedef struct cudart_init_resp {
   char *err;  // If err is non-null handle is invalid
   cudart_handle_t ch;
+  int num_devices;
 } cudart_init_resp_t;
 
-typedef struct cudart_compute_capability {
-  char *err;
-  int major;
-  int minor;
-} cudart_compute_capability_t;
-
-
 void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
-void cudart_check_vram(cudart_handle_t ch, mem_info_t *resp);
-void cudart_compute_capability(cudart_handle_t th, cudart_compute_capability_t *cc);
+void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp);
 void cudart_release(cudart_handle_t ch);
 
 #endif  // __GPU_INFO_CUDART_H__

+ 203 - 0
gpu/gpu_info_nvcuda.c

@@ -0,0 +1,203 @@
+#ifndef __APPLE__  // TODO - maybe consider nvidia support on intel macs?
+
+#include <string.h>
+#include "gpu_info_nvcuda.h"
+
+void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
+  CUresult ret;
+  resp->err = NULL;
+  resp->num_devices = 0;
+  const int buflen = 256;
+  char buf[buflen + 1];
+  int i;
+
+  struct lookup {
+    char *s;
+    void **p;
+  } l[] = {
+   
+      {"cuInit", (void *)&resp->ch.cuInit},
+      {"cuDriverGetVersion", (void *)&resp->ch.cuDriverGetVersion},
+      {"cuDeviceGetCount", (void *)&resp->ch.cuDeviceGetCount},
+      {"cuDeviceGet", (void *)&resp->ch.cuDeviceGet},
+      {"cuDeviceGetAttribute", (void *)&resp->ch.cuDeviceGetAttribute},
+      {"cuDeviceGetUuid", (void *)&resp->ch.cuDeviceGetUuid},
+      {"cuCtxCreate_v3", (void *)&resp->ch.cuCtxCreate_v3},
+      {"cuMemGetInfo_v2", (void *)&resp->ch.cuMemGetInfo_v2},
+      {"cuCtxDestroy", (void *)&resp->ch.cuCtxDestroy},
+      {NULL, NULL},
+  };
+
+  resp->ch.handle = LOAD_LIBRARY(nvcuda_lib_path, RTLD_LAZY);
+  if (!resp->ch.handle) {
+    char *msg = LOAD_ERR();
+    LOG(resp->ch.verbose, "library %s load err: %s\n", nvcuda_lib_path, msg);
+    snprintf(buf, buflen,
+            "Unable to load %s library to query for Nvidia GPUs: %s",
+            nvcuda_lib_path, msg);
+    free(msg);
+    resp->err = strdup(buf);
+    return;
+  }
+
+  for (i = 0; l[i].s != NULL; i++) {
+    *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
+    if (!*l[i].p) {
+      char *msg = LOAD_ERR();
+      LOG(resp->ch.verbose, "dlerr: %s\n", msg);
+      UNLOAD_LIBRARY(resp->ch.handle);
+      resp->ch.handle = NULL;
+      snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
+              msg);
+      free(msg);
+      resp->err = strdup(buf);
+      return;
+    }
+  }
+
+  ret = (*resp->ch.cuInit)(0);
+  if (ret != CUDA_SUCCESS) {
+    LOG(resp->ch.verbose, "cuInit err: %d\n", ret);
+    UNLOAD_LIBRARY(resp->ch.handle);
+    resp->ch.handle = NULL;
+    if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
+      resp->err = strdup("your nvidia driver is too old or missing.  If you have a CUDA GPU please upgrade to run ollama");
+      return;
+    }
+    snprintf(buf, buflen, "nvcuda init failure: %d", ret);
+    resp->err = strdup(buf);
+    return;
+  }
+
+  int version = 0;
+  nvcudaDriverVersion_t driverVersion;
+  driverVersion.major = 0;
+  driverVersion.minor = 0;
+
+  // Report driver version if we're in verbose mode, ignore errors
+  ret = (*resp->ch.cuDriverGetVersion)(&version);
+  if (ret != CUDA_SUCCESS) {
+    LOG(resp->ch.verbose, "cuDriverGetVersion failed: %d\n", ret);
+  } else {
+    driverVersion.major = version / 1000;
+    driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
+    LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
+  }
+
+  ret = (*resp->ch.cuDeviceGetCount)(&resp->num_devices);
+  if (ret != CUDA_SUCCESS) {
+    LOG(resp->ch.verbose, "cuDeviceGetCount err: %d\n", ret);
+    UNLOAD_LIBRARY(resp->ch.handle);
+    resp->ch.handle = NULL;
+    snprintf(buf, buflen, "unable to get device count: %d", ret);
+    resp->err = strdup(buf);
+    return;
+  }
+}
+
+const int buflen = 256;
+void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) {
+  resp->err = NULL;
+  nvcudaMemory_t memInfo = {0,0};
+  CUresult ret;
+  CUdevice device = -1;
+  CUcontext ctx = NULL;
+  char buf[buflen + 1];
+  CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
+
+  if (h.handle == NULL) {
+    resp->err = strdup("nvcuda handle isn't initialized");
+    return;
+  }
+
+  ret = (*h.cuDeviceGet)(&device, i);
+  if (ret != CUDA_SUCCESS) {
+    snprintf(buf, buflen, "nvcuda device failed to initialize");
+    resp->err = strdup(buf);
+    return;
+  }
+
+  resp->major = 0;
+  resp->minor = 0;
+  int major = 0;
+  int minor = 0;
+  ret = (*h.cuDeviceGetAttribute)(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device);
+  if (ret != CUDA_SUCCESS) {
+    LOG(h.verbose, "[%d] device major lookup failure: %d\n", i, ret);
+  } else {
+    ret = (*h.cuDeviceGetAttribute)(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device);
+    if (ret != CUDA_SUCCESS) {
+      LOG(h.verbose, "[%d] device minor lookup failure: %d\n", i, ret);
+    } else {
+      resp->minor = minor;  
+      resp->major = major;  
+    }
+  }
+
+  ret = (*h.cuDeviceGetUuid)(&uuid, device);
+  if (ret != CUDA_SUCCESS) {
+    LOG(h.verbose, "[%d] device uuid lookup failure: %d\n", i, ret);
+    snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
+  } else {
+    // GPU-d110a105-ac29-1d54-7b49-9c90440f215b
+    snprintf(&resp->gpu_id[0], GPU_ID_LEN,
+        "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
+        uuid.bytes[0],
+        uuid.bytes[1],
+        uuid.bytes[2],
+        uuid.bytes[3],
+        uuid.bytes[4],
+        uuid.bytes[5],
+        uuid.bytes[6],
+        uuid.bytes[7],
+        uuid.bytes[8],
+        uuid.bytes[9],
+        uuid.bytes[10],
+        uuid.bytes[11],
+        uuid.bytes[12],
+        uuid.bytes[13],
+        uuid.bytes[14],
+        uuid.bytes[15]
+      );
+  }
+
+  // To get memory we have to set (and release) a context
+  ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
+  if (ret != CUDA_SUCCESS) {
+    snprintf(buf, buflen, "nvcuda failed to get primary device context %d", ret);
+    resp->err = strdup(buf);
+    return;
+  }
+
+  ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total);
+  if (ret != CUDA_SUCCESS) {
+    snprintf(buf, buflen, "nvcuda device memory info lookup failure %d", ret);
+    resp->err = strdup(buf);
+    // Best effort on failure...
+    (*h.cuCtxDestroy)(ctx);
+    return;
+  }
+
+  resp->total = memInfo.total;
+  resp->free = memInfo.free;
+
+  LOG(h.verbose, "[%s] CUDA totalMem %lu mb\n", resp->gpu_id, resp->total / 1024 / 1024);
+  LOG(h.verbose, "[%s] CUDA freeMem %lu mb\n", resp->gpu_id, resp->free / 1024 / 1024);
+  LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
+
+  
+
+  ret = (*h.cuCtxDestroy)(ctx);
+  if (ret != CUDA_SUCCESS) {
+    LOG(1, "nvcuda failed to release primary device context %d", ret);
+  }
+}
+
+void nvcuda_release(nvcuda_handle_t h) {
+  LOG(h.verbose, "releasing nvcuda library\n");
+  UNLOAD_LIBRARY(h.handle);
+  // TODO and other context release logic?
+  h.handle = NULL;
+}
+
+#endif  // __APPLE__

+ 71 - 0
gpu/gpu_info_nvcuda.h

@@ -0,0 +1,71 @@
+#ifndef __APPLE__
+#ifndef __GPU_INFO_NVCUDA_H__
+#define __GPU_INFO_NVCUDA_H__
+#include "gpu_info.h"
+
+// Just enough typedef's to dlopen/dlsym for memory information
+typedef enum cudaError_enum {
+  CUDA_SUCCESS = 0,
+  CUDA_ERROR_INVALID_VALUE = 1,
+  CUDA_ERROR_MEMORY_ALLOCATION = 2,
+  CUDA_ERROR_NOT_INITIALIZED = 3,
+  CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
+  // Other values omitted for now...
+} CUresult;
+
+typedef enum CUdevice_attribute_enum {
+  CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75,
+  CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76,
+
+  // TODO - not yet wired up but may be useful for Jetson or other
+  // integrated GPU scenarios with shared memory
+  CU_DEVICE_ATTRIBUTE_INTEGRATED = 18
+
+} CUdevice_attribute;
+
+typedef void *nvcudaDevice_t;  // Opaque is sufficient
+typedef struct nvcudaMemory_st {
+  uint64_t total;
+  uint64_t free;
+} nvcudaMemory_t;
+
+typedef struct nvcudaDriverVersion {
+  int major;
+  int minor;
+} nvcudaDriverVersion_t;
+
+typedef struct CUuuid_st {
+    unsigned char bytes[16];
+} CUuuid;
+
+typedef int CUdevice;
+typedef void* CUcontext;
+
+typedef struct nvcuda_handle {
+  void *handle;
+  uint16_t verbose;
+  CUresult (*cuInit)(unsigned int Flags);
+  CUresult (*cuDriverGetVersion)(int *driverVersion);
+  CUresult (*cuDeviceGetCount)(int *);
+  CUresult (*cuDeviceGet)(CUdevice* device, int ordinal);
+  CUresult (*cuDeviceGetAttribute)(int* pi, CUdevice_attribute attrib, CUdevice dev);
+  CUresult (*cuDeviceGetUuid)(CUuuid* uuid, CUdevice dev); // signature compatible with cuDeviceGetUuid_v2
+
+  // Context specific aspects
+  CUresult (*cuCtxCreate_v3)(CUcontext* pctx, void *params, int len, unsigned int flags, CUdevice dev);
+  CUresult (*cuMemGetInfo_v2)(uint64_t* free, uint64_t* total);
+  CUresult (*cuCtxDestroy)(CUcontext ctx);
+} nvcuda_handle_t;
+
+typedef struct nvcuda_init_resp {
+  char *err;  // If err is non-null handle is invalid
+  nvcuda_handle_t ch;
+  int num_devices;
+} nvcuda_init_resp_t;
+
+void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);
+void nvcuda_check_vram(nvcuda_handle_t ch, int device_id, mem_info_t *resp);
+void nvcuda_release(nvcuda_handle_t ch);
+
+#endif  // __GPU_INFO_NVCUDA_H__
+#endif  // __APPLE__

+ 0 - 221
gpu/gpu_info_nvml.c

@@ -1,221 +0,0 @@
-#ifndef __APPLE__  // TODO - maybe consider nvidia support on intel macs?
-
-#include <string.h>
-
-#include "gpu_info_nvml.h"
-
-void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
-  nvmlReturn_t ret;
-  resp->err = NULL;
-  const int buflen = 256;
-  char buf[buflen + 1];
-  int i;
-
-  struct lookup {
-    char *s;
-    void **p;
-  } l[] = {
-      {"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
-      {"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
-      {"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
-      {"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
-      {"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2},
-      {"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability},
-      {"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion},
-      {"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName},
-      {"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial},
-      {"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion},
-      {"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber},
-      {"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand},
-      {NULL, NULL},
-  };
-
-  resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY);
-  if (!resp->ch.handle) {
-    char *msg = LOAD_ERR();
-    LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg);
-    snprintf(buf, buflen,
-             "Unable to load %s library to query for Nvidia GPUs: %s",
-             nvml_lib_path, msg);
-    free(msg);
-    resp->err = strdup(buf);
-    return;
-  }
-
-  // TODO once we've squashed the remaining corner cases remove this log
-  LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path);
-  
-  for (i = 0; l[i].s != NULL; i++) {
-    // TODO once we've squashed the remaining corner cases remove this log
-    LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
-
-    *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
-    if (!l[i].p) {
-      resp->ch.handle = NULL;
-      char *msg = LOAD_ERR();
-      LOG(resp->ch.verbose, "dlerr: %s\n", msg);
-      UNLOAD_LIBRARY(resp->ch.handle);
-      snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
-               msg);
-      free(msg);
-      resp->err = strdup(buf);
-      return;
-    }
-  }
-
-  ret = (*resp->ch.nvmlInit_v2)();
-  if (ret != NVML_SUCCESS) {
-    LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
-    UNLOAD_LIBRARY(resp->ch.handle);
-    resp->ch.handle = NULL;
-    snprintf(buf, buflen, "nvml vram init failure: %d", ret);
-    resp->err = strdup(buf);
-    return;
-  }
-
-  // Report driver version if we're in verbose mode, ignore errors
-  ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen);
-  if (ret != NVML_SUCCESS) {
-    LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret);
-  } else {
-    LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf);
-  }
-}
-
-void nvml_check_vram(nvml_handle_t h, mem_info_t *resp) {
-  resp->err = NULL;
-  nvmlDevice_t device;
-  nvmlMemory_t memInfo = {0};
-  nvmlReturn_t ret;
-  const int buflen = 256;
-  char buf[buflen + 1];
-  int i;
-
-  if (h.handle == NULL) {
-    resp->err = strdup("nvml handle isn't initialized");
-    return;
-  }
-
-  ret = (*h.nvmlDeviceGetCount_v2)(&resp->count);
-  if (ret != NVML_SUCCESS) {
-    snprintf(buf, buflen, "unable to get device count: %d", ret);
-    resp->err = strdup(buf);
-    return;
-  }
-
-  resp->total = 0;
-  resp->free = 0;
-  for (i = 0; i < resp->count; i++) {
-    ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
-    if (ret != NVML_SUCCESS) {
-      snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-
-    ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
-    if (ret != NVML_SUCCESS) {
-      snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-    if (h.verbose) {
-      nvmlBrandType_t brand = 0;
-      // When in verbose mode, report more information about
-      // the card we discover, but don't fail on error
-      ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
-      }
-      ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
-      }
-      ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
-      }
-      ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
-      }
-      ret = (*h.nvmlDeviceGetBrand)(device, &brand);
-      if (ret != NVML_SUCCESS) {
-        LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
-      } else {
-        LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);
-      }
-    }
-
-    LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total);
-    LOG(h.verbose, "[%d] CUDA freeMem %ld\n", i, memInfo.free);
-
-    resp->total += memInfo.total;
-    resp->free += memInfo.free;
-  }
-}
-
-void nvml_compute_capability(nvml_handle_t h, nvml_compute_capability_t *resp) {
-  resp->err = NULL;
-  resp->major = 0;
-  resp->minor = 0;
-  nvmlDevice_t device;
-  int major = 0;
-  int minor = 0;
-  nvmlReturn_t ret;
-  const int buflen = 256;
-  char buf[buflen + 1];
-  int i;
-
-  if (h.handle == NULL) {
-    resp->err = strdup("nvml handle not initialized");
-    return;
-  }
-
-  unsigned int devices;
-  ret = (*h.nvmlDeviceGetCount_v2)(&devices);
-  if (ret != NVML_SUCCESS) {
-    snprintf(buf, buflen, "unable to get device count: %d", ret);
-    resp->err = strdup(buf);
-    return;
-  }
-
-  for (i = 0; i < devices; i++) {
-    ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
-    if (ret != NVML_SUCCESS) {
-      snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-
-    ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor);
-    if (ret != NVML_SUCCESS) {
-      snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
-      resp->err = strdup(buf);
-      return;
-    }
-    // Report the lowest major.minor we detect as that limits our compatibility
-    if (resp->major == 0 || resp->major > major ) {
-      resp->major = major;
-      resp->minor = minor;
-    } else if ( resp->major == major && resp->minor > minor ) {
-      resp->minor = minor;
-    }
-  }
-}
-
-void nvml_release(nvml_handle_t h) {
-  LOG(h.verbose, "releasing nvml library\n");
-  UNLOAD_LIBRARY(h.handle);
-  h.handle = NULL;
-}
-
-#endif  // __APPLE__

+ 0 - 57
gpu/gpu_info_nvml.h

@@ -1,57 +0,0 @@
-#ifndef __APPLE__
-#ifndef __GPU_INFO_NVML_H__
-#define __GPU_INFO_NVML_H__
-#include "gpu_info.h"
-
-// Just enough typedef's to dlopen/dlsym for memory information
-typedef enum nvmlReturn_enum {
-  NVML_SUCCESS = 0,
-  // Other values omitted for now...
-} nvmlReturn_t;
-typedef void *nvmlDevice_t;  // Opaque is sufficient
-typedef struct nvmlMemory_st {
-  unsigned long long total;
-  unsigned long long free;
-  unsigned long long used;
-} nvmlMemory_t;
-
-typedef enum nvmlBrandType_enum
-{
-    NVML_BRAND_UNKNOWN          = 0,
-} nvmlBrandType_t;
-
-typedef struct nvml_handle {
-  void *handle;
-  uint16_t verbose;
-  nvmlReturn_t (*nvmlInit_v2)(void);
-  nvmlReturn_t (*nvmlShutdown)(void);
-  nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
-  nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
-  nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *);
-  nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor);
-  nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int  length);
-  nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type);
-} nvml_handle_t;
-
-typedef struct nvml_init_resp {
-  char *err;  // If err is non-null handle is invalid
-  nvml_handle_t ch;
-} nvml_init_resp_t;
-
-typedef struct nvml_compute_capability {
-  char *err;
-  int major;
-  int minor;
-} nvml_compute_capability_t;
-
-void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
-void nvml_check_vram(nvml_handle_t ch, mem_info_t *resp);
-void nvml_compute_capability(nvml_handle_t ch, nvml_compute_capability_t *cc);
-void nvml_release(nvml_handle_t ch);
-
-#endif  // __GPU_INFO_NVML_H__
-#endif  // __APPLE__

+ 6 - 13
gpu/gpu_test.go

@@ -9,23 +9,16 @@ import (
 
 func TestBasicGetGPUInfo(t *testing.T) {
 	info := GetGPUInfo()
-	assert.Contains(t, "cuda rocm cpu metal", info.Library)
-
-	switch runtime.GOOS {
-	case "darwin":
-		// TODO - remove this once MacOS returns some size for CPU
-		return
-	case "linux", "windows":
-		assert.Greater(t, info.TotalMemory, uint64(0))
-		assert.Greater(t, info.FreeMemory, uint64(0))
-		assert.Greater(t, info.DeviceCount, uint32(0))
-	default:
-		return
+	assert.Greater(t, len(info), 0)
+	assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
+	if info[0].Library != "cpu" {
+		assert.Greater(t, info[0].TotalMemory, uint64(0))
+		assert.Greater(t, info[0].FreeMemory, uint64(0))
 	}
 }
 
 func TestCPUMemInfo(t *testing.T) {
-	info, err := getCPUMem()
+	info, err := GetCPUMem()
 	assert.NoError(t, err)
 	switch runtime.GOOS {
 	case "darwin":

+ 43 - 6
gpu/types.go

@@ -3,7 +3,6 @@ package gpu
 type memInfo struct {
 	TotalMemory uint64 `json:"total_memory,omitempty"`
 	FreeMemory  uint64 `json:"free_memory,omitempty"`
-	DeviceCount uint32 `json:"device_count,omitempty"`
 }
 
 // Beginning of an `ollama info` command
@@ -17,11 +16,49 @@ type GpuInfo struct {
 	// MinimumMemory represents the minimum memory required to use the GPU
 	MinimumMemory uint64 `json:"-"`
 
-	// TODO add other useful attributes about the card here for discovery information
+	// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
+	DependencyPath string `json:"lib_path,omitempty"`
+
+	// GPU information
+	ID    string `json:"gpu_id"`          // string to use for selection of this specific GPU
+	Name  string `json:"name"`            // user friendly name if available
+	Major int    `json:"major,omitempty"` // Major compatibility version (CC or gfx)
+	Minor int    `json:"minor,omitempty"` // Minor compatibility version (CC or gfx)
+	Patch int    `json:"patch,omitempty"` // Patch compatibility only matters on AMD
+
+	// TODO other performance capability info to help in scheduling decisions
 }
 
-type Version struct {
-	Major uint
-	Minor uint
-	Patch uint
+type GpuInfoList []GpuInfo
+
+// Split up the set of gpu info's by Library and variant
+func (l GpuInfoList) ByLibrary() []GpuInfoList {
+	resp := []GpuInfoList{}
+	libs := []string{}
+	for _, info := range l {
+		found := false
+		requested := info.Library
+		if info.Variant != "" {
+			requested += "_" + info.Variant
+		}
+		for i, lib := range libs {
+			if lib == requested {
+				resp[i] = append(resp[i], info)
+				found = true
+				break
+			}
+		}
+		if !found {
+			libs = append(libs, info.Library)
+			resp = append(resp, []GpuInfo{info})
+		}
+	}
+	return resp
 }
+
+// Sort by Free Space
+type ByFreeMemory []GpuInfo
+
+func (a ByFreeMemory) Len() int           { return len(a) }
+func (a ByFreeMemory) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
+func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }

+ 44 - 2
integration/basic_test.go

@@ -4,11 +4,14 @@ package integration
 
 import (
 	"context"
-	"net/http"
+	"log/slog"
+	"os"
+	"runtime"
 	"testing"
 	"time"
 
 	"github.com/ollama/ollama/api"
+	"github.com/stretchr/testify/require"
 )
 
 func TestOrcaMiniBlueSky(t *testing.T) {
@@ -24,5 +27,44 @@ func TestOrcaMiniBlueSky(t *testing.T) {
 			"seed":        123,
 		},
 	}
-	GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh", "scattering"})
+	GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
+}
+
+func TestUnicodeModelDir(t *testing.T) {
+	// This is only useful for Windows with utf-16 characters, so skip this test for other platforms
+	if runtime.GOOS != "windows" {
+		t.Skip("Unicode test only applicable to windows")
+	}
+	// Only works for local testing
+	if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
+		t.Skip("TestUnicodeModelDir only works for local testing, skipping")
+	}
+
+	modelDir, err := os.MkdirTemp("", "ollama_埃")
+	require.NoError(t, err)
+	defer os.RemoveAll(modelDir)
+	slog.Info("unicode", "OLLAMA_MODELS", modelDir)
+
+	oldModelsDir := os.Getenv("OLLAMA_MODELS")
+	if oldModelsDir == "" {
+		defer os.Unsetenv("OLLAMA_MODELS")
+	} else {
+		defer os.Setenv("OLLAMA_MODELS", oldModelsDir)
+	}
+	err = os.Setenv("OLLAMA_MODELS", modelDir)
+	require.NoError(t, err)
+
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
+	defer cancel()
+
+	req := api.GenerateRequest{
+		Model:  "orca-mini",
+		Prompt: "why is the sky blue?",
+		Stream: &stream,
+		Options: map[string]interface{}{
+			"temperature": 0,
+			"seed":        123,
+		},
+	}
+	GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
 }

+ 225 - 0
integration/concurrency_test.go

@@ -0,0 +1,225 @@
+//go:build integration
+
+package integration
+
+import (
+	"context"
+	"log/slog"
+	"os"
+	"strconv"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/ollama/ollama/api"
+	"github.com/stretchr/testify/require"
+)
+
+func TestMultiModelConcurrency(t *testing.T) {
+	var (
+		req = [2]api.GenerateRequest{
+			{
+				Model:  "orca-mini",
+				Prompt: "why is the ocean blue?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "tinydolphin",
+				Prompt: "what is the origin of the us thanksgiving holiday?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			},
+		}
+		resp = [2][]string{
+			[]string{"sunlight"},
+			[]string{"england", "english", "massachusetts", "pilgrims"},
+		}
+	)
+	var wg sync.WaitGroup
+	wg.Add(len(req))
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
+	defer cancel()
+	for i := 0; i < len(req); i++ {
+		go func(i int) {
+			defer wg.Done()
+			GenerateTestHelper(ctx, t, req[i], resp[i])
+		}(i)
+	}
+	wg.Wait()
+}
+
+func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
+	defer cancel()
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+
+	req, resp := GenerateRequests()
+	// Get the server running (if applicable) warm the model up with a single initial request
+	DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
+
+	var wg sync.WaitGroup
+	wg.Add(len(req))
+	for i := 0; i < len(req); i++ {
+		go func(i int) {
+			defer wg.Done()
+			for j := 0; j < 5; j++ {
+				slog.Info("Starting", "req", i, "iter", j)
+				// On slower GPUs it can take a while to process the 4 concurrent requests
+				// so we allow a much longer initial timeout
+				DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
+			}
+		}(i)
+	}
+	wg.Wait()
+}
+
+// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
+func TestMultiModelStress(t *testing.T) {
+	vram := os.Getenv("OLLAMA_MAX_VRAM")
+	if vram == "" {
+		t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
+	}
+	max, err := strconv.ParseUint(vram, 10, 64)
+	require.NoError(t, err)
+	const MB = uint64(1024 * 1024)
+	type model struct {
+		name string
+		size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
+	}
+
+	smallModels := []model{
+		{
+			name: "orca-mini",
+			size: 2992 * MB,
+		},
+		{
+			name: "phi",
+			size: 2616 * MB,
+		},
+		{
+			name: "gemma:2b",
+			size: 2364 * MB,
+		},
+		{
+			name: "stable-code:3b",
+			size: 2608 * MB,
+		},
+		{
+			name: "starcoder2:3b",
+			size: 2166 * MB,
+		},
+	}
+	mediumModels := []model{
+		{
+			name: "llama2",
+			size: 5118 * MB,
+		},
+		{
+			name: "mistral",
+			size: 4620 * MB,
+		},
+		{
+			name: "orca-mini:7b",
+			size: 5118 * MB,
+		},
+		{
+			name: "dolphin-mistral",
+			size: 4620 * MB,
+		},
+		{
+			name: "gemma:7b",
+			size: 5000 * MB,
+		},
+		// TODO - uncomment this once #3565 is merged and this is rebased on it
+		// {
+		// 	name: "codellama:7b",
+		// 	size: 5118 * MB,
+		// },
+	}
+
+	// These seem to be too slow to be useful...
+	// largeModels := []model{
+	// 	{
+	// 		name: "llama2:13b",
+	// 		size: 7400 * MB,
+	// 	},
+	// 	{
+	// 		name: "codellama:13b",
+	// 		size: 7400 * MB,
+	// 	},
+	// 	{
+	// 		name: "orca-mini:13b",
+	// 		size: 7400 * MB,
+	// 	},
+	// 	{
+	// 		name: "gemma:7b",
+	// 		size: 5000 * MB,
+	// 	},
+	// 	{
+	// 		name: "starcoder2:15b",
+	// 		size: 9100 * MB,
+	// 	},
+	// }
+
+	var chosenModels []model
+	switch {
+	case max < 10000*MB:
+		slog.Info("selecting small models")
+		chosenModels = smallModels
+	// case max < 30000*MB:
+	default:
+		slog.Info("selecting medium models")
+		chosenModels = mediumModels
+		// default:
+		// 	slog.Info("selecting large models")
+		// 	chosenModels = largModels
+	}
+
+	req, resp := GenerateRequests()
+
+	for i := range req {
+		if i > len(chosenModels) {
+			break
+		}
+		req[i].Model = chosenModels[i].name
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
+	defer cancel()
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+
+	// Make sure all the models are pulled before we get started
+	for _, r := range req {
+		require.NoError(t, PullIfMissing(ctx, client, r.Model))
+	}
+
+	var wg sync.WaitGroup
+	consumed := uint64(256 * MB) // Assume some baseline usage
+	for i := 0; i < len(req); i++ {
+		// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
+		if i > 1 && consumed > max {
+			slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
+			break
+		}
+		consumed += chosenModels[i].size
+		slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
+
+		wg.Add(1)
+		go func(i int) {
+			defer wg.Done()
+			for j := 0; j < 3; j++ {
+				slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
+				DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
+			}
+		}(i)
+	}
+	wg.Wait()
+}

+ 1 - 2
integration/context_test.go

@@ -4,7 +4,6 @@ package integration
 
 import (
 	"context"
-	"net/http"
 	"testing"
 	"time"
 
@@ -25,5 +24,5 @@ func TestContextExhaustion(t *testing.T) {
 			"num_ctx":     128,
 		},
 	}
-	GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"})
+	GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
 }

+ 3 - 3
integration/llm_image_test.go

@@ -5,7 +5,6 @@ package integration
 import (
 	"context"
 	"encoding/base64"
-	"net/http"
 	"testing"
 	"time"
 
@@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) {
 		},
 	}
 
-	resp := "the ollamas"
+	// Note: sometimes it returns "the ollamas" sometimes "the ollams"
+	resp := "the ollam"
 	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
 	defer cancel()
-	GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
+	GenerateTestHelper(ctx, t, req, []string{resp})
 }
 
 const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb

+ 1 - 23
integration/llm_test.go

@@ -4,8 +4,6 @@ package integration
 
 import (
 	"context"
-	"net/http"
-	"sync"
 	"testing"
 	"time"
 
@@ -45,25 +43,5 @@ var (
 func TestIntegrationSimpleOrcaMini(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
 	defer cancel()
-	GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0])
+	GenerateTestHelper(ctx, t, req[0], resp[0])
 }
-
-// TODO
-// The server always loads a new runner and closes the old one, which forces serial execution
-// At present this test case fails with concurrency problems.  Eventually we should try to
-// get true concurrency working with n_parallel support in the backend
-func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
-	var wg sync.WaitGroup
-	wg.Add(len(req))
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
-	defer cancel()
-	for i := 0; i < len(req); i++ {
-		go func(i int) {
-			defer wg.Done()
-			GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
-		}(i)
-	}
-	wg.Wait()
-}
-
-// TODO - create a parallel test with 2 different models once we support concurrency

+ 117 - 0
integration/max_queue_test.go

@@ -0,0 +1,117 @@
+//go:build integration
+
+package integration
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"log/slog"
+	"os"
+	"strconv"
+	"strings"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/ollama/ollama/api"
+	"github.com/stretchr/testify/require"
+)
+
+func TestMaxQueue(t *testing.T) {
+	// Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU
+	// Also note that by default Darwin can't sustain > ~128 connections without adjusting limits
+	threadCount := 32
+	mq := os.Getenv("OLLAMA_MAX_QUEUE")
+	if mq != "" {
+		var err error
+		threadCount, err = strconv.Atoi(mq)
+		require.NoError(t, err)
+	} else {
+		os.Setenv("OLLAMA_MAX_QUEUE", fmt.Sprintf("%d", threadCount))
+	}
+
+	req := api.GenerateRequest{
+		Model:  "orca-mini",
+		Prompt: "write a long historical fiction story about christopher columbus.  use at least 10 facts from his actual journey",
+		Options: map[string]interface{}{
+			"seed":        42,
+			"temperature": 0.0,
+		},
+	}
+	resp := []string{"explore", "discover", "ocean"}
+
+	// CPU mode takes much longer at the limit with a large queue setting
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+	defer cancel()
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+
+	require.NoError(t, PullIfMissing(ctx, client, req.Model))
+
+	// Context for the worker threads so we can shut them down
+	// embedCtx, embedCancel := context.WithCancel(ctx)
+	embedCtx := ctx
+
+	var genwg sync.WaitGroup
+	go func() {
+		genwg.Add(1)
+		defer genwg.Done()
+		slog.Info("Starting generate request")
+		DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
+		slog.Info("generate completed")
+	}()
+
+	// Give the generate a chance to get started before we start hammering on embed requests
+	time.Sleep(5 * time.Millisecond)
+
+	threadCount += 10 // Add a few extra to ensure we push the queue past its limit
+	busyCount := 0
+	resetByPeerCount := 0
+	canceledCount := 0
+	succesCount := 0
+	counterMu := sync.Mutex{}
+	var embedwg sync.WaitGroup
+	for i := 0; i < threadCount; i++ {
+		go func(i int) {
+			embedwg.Add(1)
+			defer embedwg.Done()
+			slog.Info("embed started", "id", i)
+			embedReq := api.EmbeddingRequest{
+				Model:   req.Model,
+				Prompt:  req.Prompt,
+				Options: req.Options,
+			}
+			// Fresh client for every request
+			client, _ = GetTestEndpoint()
+
+			resp, genErr := client.Embeddings(embedCtx, &embedReq)
+			counterMu.Lock()
+			defer counterMu.Unlock()
+			switch {
+			case genErr == nil:
+				succesCount++
+				require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
+			case errors.Is(genErr, context.Canceled):
+				canceledCount++
+			case strings.Contains(genErr.Error(), "busy"):
+				busyCount++
+			case strings.Contains(genErr.Error(), "connection reset by peer"):
+				resetByPeerCount++
+			default:
+				require.NoError(t, genErr, "%d request failed", i)
+			}
+
+			slog.Info("embed finished", "id", i)
+		}(i)
+	}
+	genwg.Wait()
+	slog.Info("generate done, waiting for embeds")
+	embedwg.Wait()
+
+	require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
+	require.True(t, busyCount > 0, "no requests hit busy error but some should have")
+	require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")
+
+	slog.Info("embeds completed", "success", succesCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
+}

+ 172 - 99
integration/utils_test.go

@@ -5,13 +5,14 @@ package integration
 import (
 	"bytes"
 	"context"
-	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"log/slog"
 	"math/rand"
 	"net"
 	"net/http"
+	"net/url"
 	"os"
 	"path/filepath"
 	"runtime"
@@ -23,9 +24,13 @@ import (
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/app/lifecycle"
-	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
+func Init() {
+	lifecycle.InitLogging()
+}
+
 func FindPort() string {
 	port := 0
 	if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
@@ -41,7 +46,7 @@ func FindPort() string {
 	return strconv.Itoa(port)
 }
 
-func GetTestEndpoint() (string, string) {
+func GetTestEndpoint() (*api.Client, string) {
 	defaultPort := "11434"
 	ollamaHost := os.Getenv("OLLAMA_HOST")
 
@@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) {
 		port = FindPort()
 	}
 
-	url := fmt.Sprintf("%s:%s", host, port)
-	slog.Info("server connection", "url", url)
-	return scheme, url
+	slog.Info("server connection", "host", host, "port", port)
+
+	return api.NewClient(
+		&url.URL{
+			Scheme: scheme,
+			Host:   net.JoinHostPort(host, port),
+		},
+		http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
 }
 
-// TODO make fanicier, grab logs, etc.
 var serverMutex sync.Mutex
 var serverReady bool
 
-func StartServer(ctx context.Context, ollamaHost string) error {
+func startServer(ctx context.Context, ollamaHost string) error {
 	// Make sure the server has been built
 	CLIName, err := filepath.Abs("../ollama")
 	if err != nil {
@@ -98,7 +107,7 @@ func StartServer(ctx context.Context, ollamaHost string) error {
 
 	if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
 		slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
-		os.Setenv("OLLAMA_HOST", ollamaHost)
+		t.Setenv("OLLAMA_HOST", ollamaHost)
 	}
 
 	slog.Info("starting server", "url", ollamaHost)
@@ -125,67 +134,76 @@ func StartServer(ctx context.Context, ollamaHost string) error {
 	return nil
 }
 
-func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
+func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
 	slog.Info("checking status of model", "model", modelName)
 	showReq := &api.ShowRequest{Name: modelName}
-	requestJSON, err := json.Marshal(showReq)
-	if err != nil {
-		return err
-	}
-
-	req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
-	if err != nil {
-		return err
-	}
 
-	// Make the request with the HTTP client
-	response, err := client.Do(req.WithContext(ctx))
-	if err != nil {
+	showCtx, cancel := context.WithDeadlineCause(
+		ctx,
+		time.Now().Add(5*time.Second),
+		fmt.Errorf("show for existing model %s took too long", modelName),
+	)
+	defer cancel()
+	_, err := client.Show(showCtx, showReq)
+	var statusError api.StatusError
+	switch {
+	case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
+		break
+	case err != nil:
 		return err
-	}
-	defer response.Body.Close()
-	if response.StatusCode == 200 {
+	default:
 		slog.Info("model already present", "model", modelName)
 		return nil
 	}
-	slog.Info("model missing", "status", response.StatusCode)
+	slog.Info("model missing", "model", modelName)
+
+	stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
+	stallTimer := time.NewTimer(stallDuration)
+	fn := func(resp api.ProgressResponse) error {
+		// fmt.Print(".")
+		if !stallTimer.Reset(stallDuration) {
+			return fmt.Errorf("stall was detected, aborting status reporting")
+		}
+		return nil
+	}
 
+	stream := true
 	pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
-	requestJSON, err = json.Marshal(pullReq)
-	if err != nil {
-		return err
-	}
 
-	req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
-	if err != nil {
-		return err
-	}
-	slog.Info("pulling", "model", modelName)
+	var pullError error
 
-	response, err = client.Do(req.WithContext(ctx))
-	if err != nil {
-		return err
-	}
-	defer response.Body.Close()
-	if response.StatusCode != 200 {
-		return fmt.Errorf("failed to pull model") // TODO more details perhaps
+	done := make(chan int)
+	go func() {
+		pullError = client.Pull(ctx, pullReq, fn)
+		done <- 0
+	}()
+
+	select {
+	case <-stallTimer.C:
+		return fmt.Errorf("download stalled")
+	case <-done:
+		return pullError
 	}
-	slog.Info("model pulled", "model", modelName)
-	return nil
 }
 
 var serverProcMutex sync.Mutex
 
-func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
-
-	// TODO maybe stuff in an init routine?
-	lifecycle.InitLogging()
-
-	requestJSON, err := json.Marshal(genReq)
-	if err != nil {
-		t.Fatalf("Error serializing request: %v", err)
+// Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
+// Starts the server if needed
+func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
+	client, testEndpoint := GetTestEndpoint()
+	if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
+		serverProcMutex.Lock()
+		fp, err := os.CreateTemp("", "ollama-server-*.log")
+		if err != nil {
+			t.Fatalf("failed to generate log file: %s", err)
+		}
+		lifecycle.ServerLogFile = fp.Name()
+		fp.Close()
+		require.NoError(t, startServer(ctx, testEndpoint))
 	}
-	defer func() {
+
+	return client, testEndpoint, func() {
 		if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
 			defer serverProcMutex.Unlock()
 			if t.Failed() {
@@ -203,63 +221,118 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
 				os.Stderr.Write(data)
 				slog.Warn("END OF SERVER")
 			}
-			err = os.Remove(lifecycle.ServerLogFile)
+			err := os.Remove(lifecycle.ServerLogFile)
 			if err != nil && !os.IsNotExist(err) {
 				slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
 			}
 		}
-	}()
-	scheme, testEndpoint := GetTestEndpoint()
-
-	if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
-		serverProcMutex.Lock()
-		fp, err := os.CreateTemp("", "ollama-server-*.log")
-		if err != nil {
-			t.Fatalf("failed to generate log file: %s", err)
-		}
-		lifecycle.ServerLogFile = fp.Name()
-		fp.Close()
-		assert.NoError(t, StartServer(ctx, testEndpoint))
 	}
+}
 
-	err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
-	if err != nil {
-		t.Fatalf("Error pulling model: %v", err)
-	}
+func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+	require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
+	DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
+}
 
-	// Make the request and get the response
-	req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
-	if err != nil {
-		t.Fatalf("Error creating request: %v", err)
+func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
+	stallTimer := time.NewTimer(initialTimeout)
+	var buf bytes.Buffer
+	fn := func(response api.GenerateResponse) error {
+		// fmt.Print(".")
+		buf.Write([]byte(response.Response))
+		if !stallTimer.Reset(streamTimeout) {
+			return fmt.Errorf("stall was detected while streaming response, aborting")
+		}
+		return nil
 	}
 
-	// Set the content type for the request
-	req.Header.Set("Content-Type", "application/json")
+	stream := true
+	genReq.Stream = &stream
+	done := make(chan int)
+	var genErr error
+	go func() {
+		genErr = client.Generate(ctx, &genReq, fn)
+		done <- 0
+	}()
 
-	// Make the request with the HTTP client
-	response, err := client.Do(req.WithContext(ctx))
-	if err != nil {
-		t.Fatalf("Error making request: %v", err)
-	}
-	defer response.Body.Close()
-	body, err := io.ReadAll(response.Body)
-	assert.NoError(t, err)
-	assert.Equal(t, response.StatusCode, 200, string(body))
-
-	// Verify the response is valid JSON
-	var payload api.GenerateResponse
-	err = json.Unmarshal(body, &payload)
-	if err != nil {
-		assert.NoError(t, err, body)
+	select {
+	case <-stallTimer.C:
+		if buf.Len() == 0 {
+			t.Errorf("generate never started.  Timed out after :%s", initialTimeout.String())
+		} else {
+			t.Errorf("generate stalled.  Response so far:%s", buf.String())
+		}
+	case <-done:
+		require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
+		// Verify the response contains the expected data
+		response := buf.String()
+		atLeastOne := false
+		for _, resp := range anyResp {
+			if strings.Contains(strings.ToLower(response), resp) {
+				atLeastOne = true
+				break
+			}
+		}
+		require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
+		slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
+	case <-ctx.Done():
+		t.Error("outer test context done while waiting for generate")
 	}
+}
 
-	// Verify the response contains the expected data
-	atLeastOne := false
-	for _, resp := range anyResp {
-		if strings.Contains(strings.ToLower(payload.Response), resp) {
-			atLeastOne = true
-			break
+// Generate a set of requests
+// By default each request uses orca-mini as the model
+func GenerateRequests() ([]api.GenerateRequest, [][]string) {
+	return []api.GenerateRequest{
+			{
+				Model:  "orca-mini",
+				Prompt: "why is the ocean blue?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "orca-mini",
+				Prompt: "why is the color of dirt brown?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "orca-mini",
+				Prompt: "what is the origin of the us thanksgiving holiday?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "orca-mini",
+				Prompt: "what is the origin of independence day?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			}, {
+				Model:  "orca-mini",
+				Prompt: "what is the composition of air?",
+				Stream: &stream,
+				Options: map[string]interface{}{
+					"seed":        42,
+					"temperature": 0.0,
+				},
+			},
+		},
+		[][]string{
+			[]string{"sunlight"},
+			[]string{"soil", "organic", "earth", "black", "tan"},
+			[]string{"england", "english", "massachusetts", "pilgrims"},
+			[]string{"fourth", "july", "declaration", "independence"},
+			[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
 		}
-	}
-	assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
 }

+ 9 - 8
llm/ext_server/server.cpp

@@ -1032,7 +1032,7 @@ struct llama_server_context
             slot.has_next_token = false;
         }
 
-        if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model))
+        if (!slot.cache_tokens.empty() && llama_token_is_eog(model, result.tok))
         {
             slot.stopped_eos = true;
             slot.has_next_token = false;
@@ -1144,12 +1144,15 @@ struct llama_server_context
 
         res.result_json = json
         {
-            {"content",    tkn.text_to_send},
             {"stop",       false},
             {"slot_id",    slot.id},
             {"multimodal", multimodal}
         };
 
+        if (!llama_token_is_eog(model, tkn.tok)) {
+            res.result_json["content"] = tkn.text_to_send;
+        }
+
         if (slot.sparams.n_probs > 0)
         {
             std::vector<completion_token_output> probs_output = {};
@@ -1183,8 +1186,6 @@ struct llama_server_context
             {"model",               params.model_alias},
             {"tokens_predicted",    slot.n_decoded},
             {"tokens_evaluated",    slot.n_prompt_tokens},
-            {"generation_settings", get_formated_generation(slot)},
-            {"prompt",              slot.prompt},
             {"truncated",           slot.truncated},
             {"stopped_eos",         slot.stopped_eos},
             {"stopped_word",        slot.stopped_word},
@@ -2644,18 +2645,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
             if (strncmp(sep, "int:", 4) == 0) {
                 sep += 4;
                 kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
-                kvo.int_value = std::atol(sep);
+                kvo.val_i64 = std::atol(sep);
             } else if (strncmp(sep, "float:", 6) == 0) {
                 sep += 6;
                 kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
-                kvo.float_value = std::atof(sep);
+                kvo.val_f64 = std::atof(sep);
             } else if (strncmp(sep, "bool:", 5) == 0) {
                 sep += 5;
                 kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
                 if (std::strcmp(sep, "true") == 0) {
-                    kvo.bool_value = true;
+                    kvo.val_bool = true;
                 } else if (std::strcmp(sep, "false") == 0) {
-                    kvo.bool_value = false;
+                    kvo.val_bool = false;
                 } else {
                     fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
                     invalid_param = true;

+ 140 - 0
llm/filetype.go

@@ -0,0 +1,140 @@
+package llm
+
+import "fmt"
+
+type fileType uint32
+
+const (
+	fileTypeF32 fileType = iota
+	fileTypeF16
+	fileTypeQ4_0
+	fileTypeQ4_1
+	fileTypeQ4_1_F16
+	fileTypeQ4_2 // unused
+	fileTypeQ4_3 // unused
+	fileTypeQ8_0
+	fileTypeQ5_0
+	fileTypeQ5_1
+	fileTypeQ2_K
+	fileTypeQ3_K_S
+	fileTypeQ3_K_M
+	fileTypeQ3_K_L
+	fileTypeQ4_K_S
+	fileTypeQ4_K_M
+	fileTypeQ5_K_S
+	fileTypeQ5_K_M
+	fileTypeQ6_K
+	fileTypeIQ2_XXS
+	fileTypeIQ2_XS
+	fileTypeQ2_K_S
+	fileTypeQ3_K_XS
+	fileTypeIQ3_XXS
+
+	fileTypeUnknown
+)
+
+func ParseFileType(s string) (fileType, error) {
+	switch s {
+	case "F32":
+		return fileTypeF32, nil
+	case "F16":
+		return fileTypeF16, nil
+	case "Q4_0":
+		return fileTypeQ4_0, nil
+	case "Q4_1":
+		return fileTypeQ4_1, nil
+	case "Q4_1_F16":
+		return fileTypeQ4_1_F16, nil
+	case "Q8_0":
+		return fileTypeQ8_0, nil
+	case "Q5_0":
+		return fileTypeQ5_0, nil
+	case "Q5_1":
+		return fileTypeQ5_1, nil
+	case "Q2_K":
+		return fileTypeQ2_K, nil
+	case "Q3_K_S":
+		return fileTypeQ3_K_S, nil
+	case "Q3_K_M":
+		return fileTypeQ3_K_M, nil
+	case "Q3_K_L":
+		return fileTypeQ3_K_L, nil
+	case "Q4_K_S":
+		return fileTypeQ4_K_S, nil
+	case "Q4_K_M":
+		return fileTypeQ4_K_M, nil
+	case "Q5_K_S":
+		return fileTypeQ5_K_S, nil
+	case "Q5_K_M":
+		return fileTypeQ5_K_M, nil
+	case "Q6_K":
+		return fileTypeQ6_K, nil
+	case "IQ2_XXS":
+		return fileTypeIQ2_XXS, nil
+	case "IQ2_XS":
+		return fileTypeIQ2_XS, nil
+	case "Q2_K_S":
+		return fileTypeQ2_K_S, nil
+	case "Q3_K_XS":
+		return fileTypeQ3_K_XS, nil
+	case "IQ3_XXS":
+		return fileTypeIQ3_XXS, nil
+	default:
+		return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s)
+	}
+}
+
+func (t fileType) String() string {
+	switch t {
+	case fileTypeF32:
+		return "F32"
+	case fileTypeF16:
+		return "F16"
+	case fileTypeQ4_0:
+		return "Q4_0"
+	case fileTypeQ4_1:
+		return "Q4_1"
+	case fileTypeQ4_1_F16:
+		return "Q4_1_F16"
+	case fileTypeQ8_0:
+		return "Q8_0"
+	case fileTypeQ5_0:
+		return "Q5_0"
+	case fileTypeQ5_1:
+		return "Q5_1"
+	case fileTypeQ2_K:
+		return "Q2_K"
+	case fileTypeQ3_K_S:
+		return "Q3_K_S"
+	case fileTypeQ3_K_M:
+		return "Q3_K_M"
+	case fileTypeQ3_K_L:
+		return "Q3_K_L"
+	case fileTypeQ4_K_S:
+		return "Q4_K_S"
+	case fileTypeQ4_K_M:
+		return "Q4_K_M"
+	case fileTypeQ5_K_S:
+		return "Q5_K_S"
+	case fileTypeQ5_K_M:
+		return "Q5_K_M"
+	case fileTypeQ6_K:
+		return "Q6_K"
+	case fileTypeIQ2_XXS:
+		return "IQ2_XXS"
+	case fileTypeIQ2_XS:
+		return "IQ2_XS"
+	case fileTypeQ2_K_S:
+		return "Q2_K_S"
+	case fileTypeQ3_K_XS:
+		return "Q3_K_XS"
+	case fileTypeIQ3_XXS:
+		return "IQ3_XXS"
+	default:
+		return "unknown"
+	}
+}
+
+func (t fileType) Value() uint32 {
+	return uint32(t)
+}

+ 1 - 1
llm/generate/gen_common.sh

@@ -21,7 +21,7 @@ init_vars() {
         # TODO - add additional optimization flags...
         CMAKE_DEFS="-DCMAKE_BUILD_TYPE=Release -DLLAMA_SERVER_VERBOSE=off ${CMAKE_DEFS}"
     fi
-    case $(uname -s) in 
+    case $(uname -s) in
     "Darwin")
         LIB_EXT="dylib"
         WHOLE_ARCHIVE="-Wl,-force_load"

+ 30 - 16
llm/generate/gen_linux.sh

@@ -57,21 +57,21 @@ init_vars
 git_module_setup
 apply_patches
 
+init_vars
+if [ -z "${OLLAMA_SKIP_STATIC_GENERATE}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then
+    # Builds by default, allows skipping, forces build if OLLAMA_CPU_TARGET="static"
+    # Enables optimized Dockerfile builds using a blanket skip and targeted overrides
+    # Static build for linking into the Go binary
+    init_vars
+    CMAKE_TARGETS="--target llama --target ggml"
+    CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
+    BUILD_DIR="../build/linux/${ARCH}_static"
+    echo "Building static library"
+    build
+fi
 
 init_vars
 if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
-
-    if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then
-        # Static build for linking into the Go binary
-        init_vars
-        CMAKE_TARGETS="--target llama --target ggml"
-        CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
-        BUILD_DIR="../build/linux/${ARCH}_static"
-        echo "Building static library"
-        build
-    fi
-
-
     # Users building from source can tune the exact flags we pass to cmake for configuring
     # llama.cpp, and we'll build only 1 CPU variant in that case as the default.
     if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
@@ -165,14 +165,22 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
     fi
     if [ "${ARCH}" == "arm64" ]; then
         echo "ARM CPU detected - disabling unsupported AVX instructions"
-        
+
         # ARM-based CPUs such as M1 and Tegra do not support AVX extensions.
         #
-        # CUDA compute < 6.0 lacks proper FP16 support on ARM. 
-        # Disabling has minimal performance effect while maintaining compatibility. 
+        # CUDA compute < 6.0 lacks proper FP16 support on ARM.
+        # Disabling has minimal performance effect while maintaining compatibility.
         ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off"
     fi
-    CMAKE_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS}"
+    # Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
+    if [ -n "${OLLAMA_CUSTOM_CUDA_DEFS}" ]; then
+        echo "OLLAMA_CUSTOM_CUDA_DEFS=\"${OLLAMA_CUSTOM_CUDA_DEFS}\""
+        CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
+        echo "Building custom CUDA GPU"
+    else
+        CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
+    fi
+    CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
     BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
     EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
     build
@@ -217,6 +225,12 @@ if [ -d "${ROCM_PATH}" ]; then
     fi
     init_vars
     CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
+    # Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
+    if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
+        echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""
+        CMAKE_DEFS="${CMAKE_DEFS} ${OLLAMA_CUSTOM_ROCM_DEFS}"
+        echo "Building custom ROCM GPU"
+    fi
     BUILD_DIR="../build/linux/${ARCH}/rocm${ROCM_VARIANT}"
     EXTRA_LIBS="-L${ROCM_PATH}/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ -Wl,-rpath,\$ORIGIN/../../rocm/ -lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu"
     build

+ 190 - 112
llm/generate/gen_windows.ps1

@@ -26,15 +26,25 @@ function amdGPUs {
     $GPU_LIST -join ';'
 }
 
+
 function init_vars {
-    $script:SRC_DIR = $(resolve-path "..\..\")
-    $script:llamacppDir = "../llama.cpp"
+    if (!$script:SRC_DIR) {
+        $script:SRC_DIR = $(resolve-path "..\..\")
+    }
+    if (!$script:llamacppDir) {
+        $script:llamacppDir = "../llama.cpp"
+    }
+    if (!$script:cmakeTargets) {
+        $script:cmakeTargets = @("ollama_llama_server")
+    }
     $script:cmakeDefs = @(
         "-DBUILD_SHARED_LIBS=on",
         "-DLLAMA_NATIVE=off"
         )
-    $script:cmakeTargets = @("ollama_llama_server")
-    $script:ARCH = "amd64" # arm not yet supported.
+    $script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
+    $script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower()
+    $script:DIST_BASE = "${script:SRC_DIR}\dist\windows-${script:ARCH}\ollama_runners"
+    md "$script:DIST_BASE" -ea 0 > $null
     if ($env:CGO_CFLAGS -contains "-g") {
         $script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on", "-DCMAKE_BUILD_TYPE=RelWithDebInfo")
         $script:config = "RelWithDebInfo"
@@ -55,7 +65,6 @@ function init_vars {
     } else {
         $script:CUDA_LIB_DIR=$env:CUDA_LIB_DIR
     }
-    $script:GZIP=(get-command -ea 'silentlycontinue' gzip).path
     $script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
     if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) {
         $script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
@@ -134,21 +143,18 @@ function sign {
     }
 }
 
-function compress {
-    if ($script:GZIP -eq $null) {
-        write-host "gzip not installed, not compressing files"
-        return
-    }
-    write-host "Compressing binaries..."
+function install {
+    write-host "Installing binaries to dist dir ${script:distDir}"
+    mkdir ${script:distDir} -ErrorAction SilentlyContinue
     $binaries = dir "${script:buildDir}/bin/*.exe"
     foreach ($file in $binaries) {
-        & "$script:GZIP" --best -f $file
+        copy-item -Path $file -Destination ${script:distDir} -Force
     }
 
-    write-host "Compressing dlls..."
+    write-host "Installing dlls to dist dir ${script:distDir}"
     $dlls = dir "${script:buildDir}/bin/*.dll"
     foreach ($file in $dlls) {
-        & "$script:GZIP" --best -f $file
+        copy-item -Path $file -Destination ${script:distDir} -Force
     }
 }
 
@@ -169,123 +175,195 @@ function cleanup {
     }
 }
 
-init_vars
-git_module_setup
-apply_patches
 
 # -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
 # -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
 # -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
 
-$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
-
-if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
-
-# GCC build for direct linking into the Go binary
-init_vars
-# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
-# as we need this to be compiled by gcc for golang to be able to link with itx
-write-host "Checking for MinGW..."
-# error action ensures we exit on failure
-get-command gcc
-get-command mingw32-make
-$script:cmakeTargets = @("llama", "ggml")
-$script:cmakeDefs = @(
-    "-G", "MinGW Makefiles"
-    "-DCMAKE_C_COMPILER=gcc.exe",
-    "-DCMAKE_CXX_COMPILER=g++.exe",
-    "-DBUILD_SHARED_LIBS=off",
-    "-DLLAMA_NATIVE=off",
-    "-DLLAMA_AVX=off",
-    "-DLLAMA_AVX2=off",
-    "-DLLAMA_AVX512=off",
-    "-DLLAMA_F16C=off",
-    "-DLLAMA_FMA=off")
-$script:buildDir="../build/windows/${script:ARCH}_static"
-write-host "Building static library"
-build
 
-# remaining llama.cpp builds use MSVC 
-    init_vars
-    $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
-    $script:buildDir="../build/windows/${script:ARCH}/cpu"
-    write-host "Building LCD CPU"
-    build
-    sign
-    compress
+function build_static() {
+    if ((-not "${env:OLLAMA_SKIP_STATIC_GENERATE}") -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "static"))) {
+        # GCC build for direct linking into the Go binary
+        init_vars
+        # cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
+        # as we need this to be compiled by gcc for golang to be able to link with itx
+        write-host "Checking for MinGW..."
+        # error action ensures we exit on failure
+        get-command gcc
+        get-command mingw32-make
+        $oldTargets = $script:cmakeTargets
+        $script:cmakeTargets = @("llama", "ggml")
+        $script:cmakeDefs = @(
+            "-G", "MinGW Makefiles"
+            "-DCMAKE_C_COMPILER=gcc.exe",
+            "-DCMAKE_CXX_COMPILER=g++.exe",
+            "-DBUILD_SHARED_LIBS=off",
+            "-DLLAMA_NATIVE=off",
+            "-DLLAMA_AVX=off",
+            "-DLLAMA_AVX2=off",
+            "-DLLAMA_AVX512=off",
+            "-DLLAMA_F16C=off",
+            "-DLLAMA_FMA=off")
+        $script:buildDir="../build/windows/${script:ARCH}_static"
+        write-host "Building static library"
+        build
+        $script:cmakeTargets = $oldTargets
+    } else {
+        write-host "Skipping CPU generation step as requested"
+    }
+}
 
-    init_vars
-    $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
-    $script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
-    write-host "Building AVX CPU"
-    build
-    sign
-    compress
+function build_cpu($gen_arch) {
+    if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
+        # remaining llama.cpp builds use MSVC 
+        init_vars
+        $script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
+        $script:buildDir="../build/windows/${script:ARCH}/cpu"
+        $script:distDir="$script:DIST_BASE\cpu"
+        write-host "Building LCD CPU"
+        build
+        sign
+        install
+    } else {
+        write-host "Skipping CPU generation step as requested"
+    }
+}
 
-    init_vars
-    $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
-    $script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
-    write-host "Building AVX2 CPU"
-    build
-    sign
-    compress
-} else {
-    write-host "Skipping CPU generation step as requested"
+function build_cpu_avx() {
+    if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx"))) {
+        init_vars
+        $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
+        $script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
+        $script:distDir="$script:DIST_BASE\cpu_avx"
+        write-host "Building AVX CPU"
+        build
+        sign
+        install
+    } else {
+        write-host "Skipping CPU AVX generation step as requested"
+    }
 }
 
-if ($null -ne $script:CUDA_LIB_DIR) {
-    # Then build cuda as a dynamically loaded library
-    $nvcc = "$script:CUDA_LIB_DIR\nvcc.exe"
-    $script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename
-    if ($null -ne $script:CUDA_VERSION) {
-        $script:CUDA_VARIANT="_"+$script:CUDA_VERSION
+function build_cpu_avx2() {
+    if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx2"))) {
+        init_vars
+        $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
+        $script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
+        $script:distDir="$script:DIST_BASE\cpu_avx2"
+        write-host "Building AVX2 CPU"
+        build
+        sign
+        install
+    } else {
+        write-host "Skipping CPU AVX2 generation step as requested"
     }
-    init_vars
-    $script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
-    $script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
-    build
-    sign
-    compress
 }
 
-if ($null -ne $env:HIP_PATH) {
-    $script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
-    if ($null -ne $script:ROCM_VERSION) {
-        $script:ROCM_VARIANT="_v"+$script:ROCM_VERSION
+function build_cuda() {
+    if ((-not "${env:OLLAMA_SKIP_CUDA_GENERATE}") -and ("${script:CUDA_LIB_DIR}")) {
+        # Then build cuda as a dynamically loaded library
+        $nvcc = "$script:CUDA_LIB_DIR\nvcc.exe"
+        $script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename
+        if ($null -ne $script:CUDA_VERSION) {
+            $script:CUDA_VARIANT="_"+$script:CUDA_VERSION
+        }
+        init_vars
+        $script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
+        $script:distDir="$script:DIST_BASE\cuda$script:CUDA_VARIANT"
+        $script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
+        if ($null -ne $env:OLLAMA_CUSTOM_CUDA_DEFS) {
+            write-host "OLLAMA_CUSTOM_CUDA_DEFS=`"${env:OLLAMA_CUSTOM_CUDA_DEFS}`""
+            $script:cmakeDefs +=@("${env:OLLAMA_CUSTOM_CUDA_DEFS}")
+            write-host "building custom CUDA GPU"
+        }
+        build
+        sign
+        install
+
+        write-host "copying CUDA dependencies to ${script:SRC_DIR}\dist\windows-${script:ARCH}\"
+        cp "${script:CUDA_LIB_DIR}\cudart64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
+        cp "${script:CUDA_LIB_DIR}\cublas64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
+        cp "${script:CUDA_LIB_DIR}\cublasLt64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
+    } else {
+        write-host "Skipping CUDA generation step"
     }
+}
 
-    init_vars
-    $script:buildDir="../build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT"
-    $script:cmakeDefs += @(
-        "-G", "Ninja", 
-        "-DCMAKE_C_COMPILER=clang.exe",
-        "-DCMAKE_CXX_COMPILER=clang++.exe",
-        "-DLLAMA_HIPBLAS=on",
-        "-DHIP_PLATFORM=amd",
-        "-DLLAMA_AVX=on",
-        "-DLLAMA_AVX2=off",
-        "-DCMAKE_POSITION_INDEPENDENT_CODE=on",
-        "-DAMDGPU_TARGETS=$(amdGPUs)",
-        "-DGPU_TARGETS=$(amdGPUs)"
-        )
+function build_rocm() {
+    if ((-not "${env:OLLAMA_SKIP_ROCM_GENERATE}") -and ("${env:HIP_PATH}")) {
+        $script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
+        if ($null -ne $script:ROCM_VERSION) {
+            $script:ROCM_VARIANT="_v"+$script:ROCM_VERSION
+        }
+
+        init_vars
+        $script:buildDir="../build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT"
+        $script:distDir="$script:DIST_BASE\rocm$script:ROCM_VARIANT"
+        $script:cmakeDefs += @(
+            "-G", "Ninja", 
+            "-DCMAKE_C_COMPILER=clang.exe",
+            "-DCMAKE_CXX_COMPILER=clang++.exe",
+            "-DLLAMA_HIPBLAS=on",
+            "-DHIP_PLATFORM=amd",
+            "-DLLAMA_AVX=on",
+            "-DLLAMA_AVX2=off",
+            "-DCMAKE_POSITION_INDEPENDENT_CODE=on",
+            "-DAMDGPU_TARGETS=$(amdGPUs)",
+            "-DGPU_TARGETS=$(amdGPUs)"
+            )
 
-    # Make sure the ROCm binary dir is first in the path
-    $env:PATH="$env:HIP_PATH\bin;$env:PATH"
+        # Make sure the ROCm binary dir is first in the path
+        $env:PATH="$env:HIP_PATH\bin;$env:PATH"
 
-    # We have to clobber the LIB var from the developer shell for clang to work properly
-    $env:LIB=""
+        # We have to clobber the LIB var from the developer shell for clang to work properly
+        $env:LIB=""
+        if ($null -ne $env:OLLAMA_CUSTOM_ROCM_DEFS) {
+            write-host "OLLAMA_CUSTOM_ROCM_DEFS=`"${env:OLLAMA_CUSTOM_ROCM_DEFS}`""
+            $script:cmakeDefs += @("${env:OLLAMA_CUSTOM_ROCM_DEFS}")
+            write-host "building custom ROCM GPU"
+        }
+        write-host "Building ROCm"
+        build
+        # Ninja doesn't prefix with config name
+        ${script:config}=""
+        if ($null -ne $script:DUMPBIN) {
+            & "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | select-string ".dll"
+        }
+        sign
+        install
 
-    write-host "Building ROCm"
-    build
-    # Ninja doesn't prefix with config name
-    ${script:config}=""
-    if ($null -ne $script:DUMPBIN) {
-        & "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | select-string ".dll"
+        # Assumes v5.7, may need adjustments for v6
+        rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
+        md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
+        cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
+        cp "${env:HIP_PATH}\bin\rocblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
+        # amdhip64.dll dependency comes from the driver and must be installed on the host to use AMD GPUs
+        cp "${env:HIP_PATH}\bin\rocblas\library\*" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\"
+    } else {
+        write-host "Skipping ROCm generation step"
     }
-    sign
-    compress
 }
 
+init_vars
+if ($($args.count) -eq 0) {
+    git_module_setup
+    apply_patches
+    build_static
+    if ($script:ARCH -eq "arm64") {
+        build_cpu("ARM64")
+    } else { # amd64
+        build_cpu("x64")
+        build_cpu_avx
+        build_cpu_avx2
+        build_cuda
+        build_rocm
+    }
 
-cleanup
-write-host "`ngo generate completed.  LLM runners: $(get-childitem -path ${script:SRC_DIR}\llm\build\windows\${script:ARCH})"
+    cleanup
+    write-host "`ngo generate completed.  LLM runners: $(get-childitem -path $script:DIST_BASE)"
+} else {
+    for ( $i = 0; $i -lt $args.count; $i++ ) {
+        write-host "performing $($args[$i])"
+        & $($args[$i])
+    } 
+}

+ 28 - 79
llm/ggml.go

@@ -13,82 +13,6 @@ type GGML struct {
 	model
 }
 
-const (
-	fileTypeF32 uint32 = iota
-	fileTypeF16
-	fileTypeQ4_0
-	fileTypeQ4_1
-	fileTypeQ4_1_F16
-	fileTypeQ8_0 uint32 = iota + 2
-	fileTypeQ5_0
-	fileTypeQ5_1
-	fileTypeQ2_K
-	fileTypeQ3_K_S
-	fileTypeQ3_K_M
-	fileTypeQ3_K_L
-	fileTypeQ4_K_S
-	fileTypeQ4_K_M
-	fileTypeQ5_K_S
-	fileTypeQ5_K_M
-	fileTypeQ6_K
-	fileTypeIQ2_XXS
-	fileTypeIQ2_XS
-	fileTypeQ2_K_S
-	fileTypeQ3_K_XS
-	fileTypeIQ3_XXS
-)
-
-func fileType(fileType uint32) string {
-	switch fileType {
-	case fileTypeF32:
-		return "F32"
-	case fileTypeF16:
-		return "F16"
-	case fileTypeQ4_0:
-		return "Q4_0"
-	case fileTypeQ4_1:
-		return "Q4_1"
-	case fileTypeQ4_1_F16:
-		return "Q4_1_F16"
-	case fileTypeQ8_0:
-		return "Q8_0"
-	case fileTypeQ5_0:
-		return "Q5_0"
-	case fileTypeQ5_1:
-		return "Q5_1"
-	case fileTypeQ2_K:
-		return "Q2_K"
-	case fileTypeQ3_K_S:
-		return "Q3_K_S"
-	case fileTypeQ3_K_M:
-		return "Q3_K_M"
-	case fileTypeQ3_K_L:
-		return "Q3_K_L"
-	case fileTypeQ4_K_S:
-		return "Q4_K_S"
-	case fileTypeQ4_K_M:
-		return "Q4_K_M"
-	case fileTypeQ5_K_S:
-		return "Q5_K_S"
-	case fileTypeQ5_K_M:
-		return "Q5_K_M"
-	case fileTypeQ6_K:
-		return "Q6_K"
-	case fileTypeIQ2_XXS:
-		return "IQ2_XXS"
-	case fileTypeIQ2_XS:
-		return "IQ2_XS"
-	case fileTypeQ2_K_S:
-		return "Q2_K_S"
-	case fileTypeQ3_K_XS:
-		return "Q3_K_XS"
-	case fileTypeIQ3_XXS:
-		return "IQ3_XXS"
-	default:
-		return "unknown"
-	}
-}
-
 type model interface {
 	KV() KV
 	Tensors() Tensors
@@ -121,12 +45,12 @@ func (kv KV) ParameterCount() uint64 {
 	return kv.u64("general.parameter_count")
 }
 
-func (kv KV) FileType() string {
+func (kv KV) FileType() fileType {
 	if u64 := kv.u64("general.file_type"); u64 > 0 {
 		return fileType(uint32(u64))
 	}
 
-	return "unknown"
+	return fileTypeUnknown
 }
 
 func (kv KV) BlockCount() uint64 {
@@ -286,6 +210,23 @@ const (
 
 var ErrUnsupportedFormat = errors.New("unsupported model format")
 
+func DetectGGMLType(b []byte) string {
+	switch binary.LittleEndian.Uint32(b[:4]) {
+	case FILE_MAGIC_GGML:
+		return "ggml"
+	case FILE_MAGIC_GGMF:
+		return "ggmf"
+	case FILE_MAGIC_GGJT:
+		return "ggjt"
+	case FILE_MAGIC_GGLA:
+		return "ggla"
+	case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
+		return "gguf"
+	default:
+		return ""
+	}
+}
+
 func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
 	var magic uint32
 	if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
@@ -343,7 +284,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
 		)
 
-		if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok {
+		if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
+			// mixtral 8x22b
+			ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
+			partialOffload = max(
+				3*ffnGateExpsWeight.size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
+				4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch),
+			)
+		} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
+			// mixtral 8x7b
 			ffnGateWeight1 := ffnGateWeight.Shape[1]
 			fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
 			partialOffload = max(

+ 10 - 6
llm/gguf.go

@@ -190,8 +190,6 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
 		llm.kv[k] = v
 	}
 
-	slog.Debug(fmt.Sprintf("general.architecture = %s", llm.kv["general.architecture"]))
-
 	// decode tensors
 	for i := 0; uint64(i) < llm.numTensor(); i++ {
 		name, err := readGGUFString(llm, rs)
@@ -465,11 +463,13 @@ var ggufKVOrder = map[string][]string{
 		"llama.embedding_length",
 		"llama.block_count",
 		"llama.feed_forward_length",
-		"llama.rope.dimension_count",
 		"llama.attention.head_count",
 		"llama.attention.head_count_kv",
 		"llama.attention.layer_norm_rms_epsilon",
 		"llama.rope.freq_base",
+		"llama.rope.dimension_count",
+		"llama.expert_count",
+		"llama.expert_used_count",
 		"gemma.context_length",
 		"gemma.embedding_length",
 		"gemma.block_count",
@@ -577,6 +577,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
 					return err
 				}
 			}
+		default:
+			return fmt.Errorf("improper type for '%s'", k)
 		}
 		if err != nil {
 			return err
@@ -598,9 +600,11 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
 			return err
 		}
 
-		dims := 1
-		if tensor.Shape[1] > 0 {
-			dims = 2
+		dims := 0
+		for cnt := 0; cnt < len(tensor.Shape); cnt++ {
+			if tensor.Shape[cnt] > 0 {
+				dims++
+			}
 		}
 
 		if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil {

+ 1 - 1
llm/llama.cpp

@@ -1 +1 @@
-Subproject commit 7593639ce335e8d7f89aa9a54d616951f273af60
+Subproject commit 952d03dbead16e4dbdd1d3458486340673cc2465

+ 5 - 52
llm/llm.go

@@ -4,6 +4,7 @@ package llm
 // #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
 // #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
 // #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
+// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/libllama.a -static -lstdc++
 // #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
 // #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
 // #include <stdlib.h>
@@ -19,7 +20,7 @@ func SystemInfo() string {
 	return C.GoString(C.llama_print_system_info())
 }
 
-func Quantize(infile, outfile, filetype string) error {
+func Quantize(infile, outfile string, ftype fileType) error {
 	cinfile := C.CString(infile)
 	defer C.free(unsafe.Pointer(cinfile))
 
@@ -28,58 +29,10 @@ func Quantize(infile, outfile, filetype string) error {
 
 	params := C.llama_model_quantize_default_params()
 	params.nthread = -1
+	params.ftype = ftype.Value()
 
-	switch filetype {
-	case "F32":
-		params.ftype = fileTypeF32
-	case "F16":
-		params.ftype = fileTypeF16
-	case "Q4_0":
-		params.ftype = fileTypeQ4_0
-	case "Q4_1":
-		params.ftype = fileTypeQ4_1
-	case "Q4_1_F16":
-		params.ftype = fileTypeQ4_1_F16
-	case "Q8_0":
-		params.ftype = fileTypeQ8_0
-	case "Q5_0":
-		params.ftype = fileTypeQ5_0
-	case "Q5_1":
-		params.ftype = fileTypeQ5_1
-	case "Q2_K":
-		params.ftype = fileTypeQ2_K
-	case "Q3_K_S":
-		params.ftype = fileTypeQ3_K_S
-	case "Q3_K_M":
-		params.ftype = fileTypeQ3_K_M
-	case "Q3_K_L":
-		params.ftype = fileTypeQ3_K_L
-	case "Q4_K_S":
-		params.ftype = fileTypeQ4_K_S
-	case "Q4_K_M":
-		params.ftype = fileTypeQ4_K_M
-	case "Q5_K_S":
-		params.ftype = fileTypeQ5_K_S
-	case "Q5_K_M":
-		params.ftype = fileTypeQ5_K_M
-	case "Q6_K":
-		params.ftype = fileTypeQ6_K
-	case "IQ2_XXS":
-		params.ftype = fileTypeIQ2_XXS
-	case "IQ2_XS":
-		params.ftype = fileTypeIQ2_XS
-	case "Q2_K_S":
-		params.ftype = fileTypeQ2_K_S
-	case "Q3_K_XS":
-		params.ftype = fileTypeQ3_K_XS
-	case "IQ3_XXS":
-		params.ftype = fileTypeIQ3_XXS
-	default:
-		return fmt.Errorf("unknown filetype: %s", filetype)
-	}
-
-	if retval := C.llama_model_quantize(cinfile, coutfile, &params); retval != 0 {
-		return fmt.Errorf("llama_model_quantize: %d", retval)
+	if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
+		return fmt.Errorf("llama_model_quantize: %d", rc)
 	}
 
 	return nil

+ 1 - 1
llm/llm_windows.go

@@ -2,5 +2,5 @@ package llm
 
 import "embed"
 
-//go:embed build/windows/*/*/bin/*
+// unused on windows
 var libEmbed embed.FS

+ 185 - 0
llm/memory.go

@@ -0,0 +1,185 @@
+package llm
+
+import (
+	"fmt"
+	"log/slog"
+
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/format"
+	"github.com/ollama/ollama/gpu"
+	"github.com/ollama/ollama/server/envconfig"
+)
+
+// This algorithm looks for a complete fit to determine if we need to unload other models
+func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
+	var estimatedVRAM uint64
+	if opts.NumCtx > int(ggml.KV().ContextLength()) {
+		slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
+		opts.NumCtx = int(ggml.KV().ContextLength())
+	}
+
+	if opts.NumCtx < 4 {
+		opts.NumCtx = 4
+	}
+
+	// Split up the GPUs by type and try them
+	for _, gpus := range allGpus.ByLibrary() {
+		var layerCount int
+		layerCount, estimatedVRAM, _ = EstimateGPULayers(gpus, ggml, projectors, opts)
+		if opts.NumGPU < 0 {
+			if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
+				return true, estimatedVRAM
+			}
+		} else {
+			if layerCount > 0 && layerCount >= opts.NumGPU {
+				return true, estimatedVRAM
+			}
+		}
+	}
+	return false, estimatedVRAM
+}
+
+// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
+// The GPUs provided must all be the same Library
+func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64, uint64) {
+	var memoryAvailable uint64
+	for _, info := range gpus {
+		memoryAvailable += info.FreeMemory
+	}
+	if envconfig.MaxVRAM > 0 {
+		memoryAvailable = envconfig.MaxVRAM
+	}
+
+	slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))
+
+	// TODO - this is probably wrong, first GPU vs secondaries will have different overheads
+	memoryMinimum := gpus[0].MinimumMemory
+
+	for _, projector := range projectors {
+		memoryMinimum += projectorMemoryRequirements(projector)
+
+		// multimodal models require at least 2048 context
+		opts.NumCtx = max(opts.NumCtx, 2048)
+	}
+
+	// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
+	var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
+
+	graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
+	if graphPartialOffload == 0 {
+		graphPartialOffload = ggml.KV().GQA() * kv / 6
+	}
+
+	if graphFullOffload == 0 {
+		graphFullOffload = graphPartialOffload
+	}
+
+	graphFullOffload *= uint64(len(gpus))
+	graphPartialOffload *= uint64(len(gpus))
+
+	// on metal there's no partial offload overhead
+	if gpus[0].Library == "metal" {
+		graphPartialOffload = graphFullOffload
+	}
+
+	layers := ggml.Tensors().Layers()
+
+	// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
+	memoryRequiredTotal := memoryMinimum + graphFullOffload + layers["blk.0"].size()
+
+	// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
+	memoryRequiredPartial := memoryMinimum + graphPartialOffload + layers["blk.0"].size()
+
+	var memoryLayerOutput uint64
+	if layer, ok := layers["output_norm"]; ok {
+		memoryLayerOutput += layer.size()
+	}
+
+	if layer, ok := layers["output"]; ok {
+		memoryLayerOutput += layer.size()
+	} else if layer, ok := layers["token_embd"]; ok {
+		memoryLayerOutput += layer.size()
+	}
+
+	if gpus[0].Library == "metal" && opts.UseMMap {
+		// memory is preallocated for output tensors
+		memoryRequiredTotal += memoryLayerOutput
+		memoryRequiredPartial += memoryLayerOutput
+	}
+
+	var layerCount int
+	for i := 0; i < int(ggml.KV().BlockCount()); i++ {
+		memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
+
+		// KV is proportional to the number of layers
+		memoryLayer += kv / ggml.KV().BlockCount()
+
+		memoryRequiredTotal += memoryLayer
+		if memoryAvailable > memoryRequiredPartial+memoryLayer {
+			memoryRequiredPartial += memoryLayer
+			layerCount++
+		}
+	}
+
+	if gpus[0].Library != "metal" || !opts.UseMMap {
+		// memory was not preallocated for output tensors
+		memoryRequiredTotal += memoryLayerOutput
+	}
+
+	if memoryAvailable > memoryRequiredTotal {
+		layerCount = int(ggml.KV().BlockCount()) + 1
+		memoryRequiredPartial = memoryRequiredTotal
+	}
+
+	memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
+
+	slog.Info(
+		"offload to gpu",
+		slog.Group(
+			"layers",
+			// actual number of layers offloaded
+			"real", opts.NumGPU,
+			// estimated number of layers that can be offloaded
+			"estimate", layerCount,
+		),
+		slog.Group(
+			"memory",
+			// memory available for offloading
+			"available", format.HumanBytes2(memoryAvailable),
+			slog.Group(
+				"required",
+				// memory required for full offloading
+				"full", format.HumanBytes2(memoryRequiredTotal),
+				// memory required to offload layers.estimate layers
+				"partial", format.HumanBytes2(memoryRequiredPartial),
+				// memory of KV cache
+				"kv", format.HumanBytes2(kv),
+			),
+			slog.Group(
+				"weights",
+				// memory of the weights
+				"total", format.HumanBytes2(memoryWeights),
+				// memory of repeating layers
+				"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
+				// memory of non-repeating layers
+				"nonrepeating", format.HumanBytes2(memoryLayerOutput),
+			),
+			slog.Group(
+				"graph",
+				// memory of graph when fully offloaded
+				"full", format.HumanBytes2(graphFullOffload),
+				// memory of graph when not fully offloaded
+				"partial", format.HumanBytes2(graphPartialOffload),
+			),
+		),
+	)
+	if gpus[0].Library == "cpu" {
+		return 0, 0, memoryRequiredTotal
+	}
+	if memoryRequiredPartial > memoryAvailable {
+		slog.Debug("insufficient VRAM to load any model layers")
+		return 0, 0, memoryRequiredTotal
+	}
+
+	return layerCount, memoryRequiredPartial, memoryRequiredTotal
+}

+ 12 - 0
llm/patches/02-clip-log.diff

@@ -0,0 +1,12 @@
+diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
+index e431c7f7..f077e688 100644
+--- a/examples/llava/clip.cpp
++++ b/examples/llava/clip.cpp
+@@ -3,6 +3,7 @@
+ // I'll gradually clean and extend it
+ // Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
+ #include "clip.h"
++#include "common.h"
+ #include "log.h"
+ #include "ggml.h"
+ #include "ggml-alloc.h"

+ 45 - 0
llm/patches/04-metal.diff

@@ -0,0 +1,45 @@
+diff --git a/ggml-metal.m b/ggml-metal.m
+index 0207b787..b5e9884b 100644
+--- a/ggml-metal.m
++++ b/ggml-metal.m
+@@ -1396,27 +1396,23 @@ static enum ggml_status ggml_metal_graph_compute(
+                         // to the matrix-vector kernel
+                         int ne11_mm_min = 1;
+ 
+-#if 0
+                         // the numbers below are measured on M2 Ultra for 7B and 13B models
+                         // these numbers do not translate to other devices or model sizes
+                         // TODO: need to find a better approach
+-                        if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
+-                            switch (src0t) {
+-                                case GGML_TYPE_F16:  ne11_mm_min = 2;  break;
+-                                case GGML_TYPE_Q8_0: ne11_mm_min = 7;  break;
+-                                case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
+-                                case GGML_TYPE_Q3_K: ne11_mm_min = 7;  break;
+-                                case GGML_TYPE_Q4_0:
+-                                case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
+-                                case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
+-                                case GGML_TYPE_Q5_0:                          // not tested yet
+-                                case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
+-                                case GGML_TYPE_Q5_K: ne11_mm_min = 7;  break;
+-                                case GGML_TYPE_Q6_K: ne11_mm_min = 7;  break;
+-                                default:             ne11_mm_min = 1;  break;
+-                            }
++                        switch (src0t) {
++                            case GGML_TYPE_F16:  ne11_mm_min = 2;  break;
++                            case GGML_TYPE_Q8_0: ne11_mm_min = 7;  break;
++                            case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
++                            case GGML_TYPE_Q3_K: ne11_mm_min = 7;  break;
++                            case GGML_TYPE_Q4_0:
++                            case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
++                            case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
++                            case GGML_TYPE_Q5_0:                          // not tested yet
++                            case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
++                            case GGML_TYPE_Q5_K: ne11_mm_min = 7;  break;
++                            case GGML_TYPE_Q6_K: ne11_mm_min = 7;  break;
++                            default:             ne11_mm_min = 1;  break;
+                         }
+-#endif
+ 
+                         // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+                         // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel

+ 24 - 0
llm/patches/05-clip-fix.diff

@@ -0,0 +1,24 @@
+diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
+index e3c9bcd4..b43f892d 100644
+--- a/examples/llava/clip.cpp
++++ b/examples/llava/clip.cpp
+@@ -573,14 +573,16 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
+     struct ggml_tensor * embeddings = inp;
+     if (ctx->has_class_embedding) {
+         embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
++    }
++    ggml_set_name(embeddings, "embeddings");
++    ggml_set_input(embeddings);
++
++    if (ctx->has_class_embedding) {
+         embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
+                 embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
+         embeddings = ggml_acc(ctx0, embeddings, inp,
+                 embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
+     }
+-    ggml_set_name(embeddings, "embeddings");
+-    ggml_set_input(embeddings);
+-
+ 
+     struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
+     ggml_set_name(positions, "positions");

+ 27 - 7
llm/payload.go

@@ -9,6 +9,7 @@ import (
 	"log/slog"
 	"os"
 	"path/filepath"
+	"runtime"
 	"strings"
 
 	"golang.org/x/exp/slices"
@@ -17,7 +18,7 @@ import (
 	"github.com/ollama/ollama/gpu"
 )
 
-var errPayloadMissing = fmt.Errorf("expected payloads not included in this build of ollama")
+var errPayloadMissing = errors.New("expected payloads not included in this build of ollama")
 
 func Init() error {
 	payloadsDir, err := gpu.PayloadsDir()
@@ -25,13 +26,15 @@ func Init() error {
 		return err
 	}
 
-	slog.Info("extracting embedded files", "dir", payloadsDir)
-	binGlob := "build/*/*/*/bin/*"
+	if runtime.GOOS != "windows" {
+		slog.Info("extracting embedded files", "dir", payloadsDir)
+		binGlob := "build/*/*/*/bin/*"
 
-	// extract server libraries
-	err = extractFiles(payloadsDir, binGlob)
-	if err != nil {
-		return fmt.Errorf("extract binaries: %v", err)
+		// extract server libraries
+		err = extractFiles(payloadsDir, binGlob)
+		if err != nil {
+			return fmt.Errorf("extract binaries: %v", err)
+		}
 	}
 
 	var variants []string
@@ -138,6 +141,23 @@ func serversForGpu(info gpu.GpuInfo) []string {
 	return servers
 }
 
+// Return the optimal server for this CPU architecture
+func serverForCpu() string {
+	if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
+		return "metal"
+	}
+	variant := gpu.GetCPUVariant()
+	availableServers := availableServers()
+	if variant != "" {
+		for cmp := range availableServers {
+			if cmp == "cpu_"+variant {
+				return cmp
+			}
+		}
+	}
+	return "cpu"
+}
+
 // extract extracts the embedded files to the target directory
 func extractFiles(targetDir string, glob string) error {
 	files, err := fs.Glob(libEmbed, glob)

+ 352 - 268
llm/server.go

@@ -21,21 +21,47 @@ import (
 	"strings"
 	"time"
 
+	"golang.org/x/sync/semaphore"
+
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/gpu"
+	"github.com/ollama/ollama/server/envconfig"
 )
 
-// LlamaServer is an instance of the llama.cpp server
-type LlamaServer struct {
+type LlamaServer interface {
+	Ping(ctx context.Context) error
+	WaitUntilRunning(ctx context.Context) error
+	Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
+	Embedding(ctx context.Context, prompt string) ([]float64, error)
+	Tokenize(ctx context.Context, content string) ([]int, error)
+	Detokenize(ctx context.Context, tokens []int) (string, error)
+	Close() error
+	EstimatedVRAM() uint64
+}
+
+// llmServer is an instance of the llama.cpp server
+type llmServer struct {
 	port    int
 	cmd     *exec.Cmd
 	done    chan error // Channel to signal when the process exits
 	status  *StatusWriter
 	options api.Options
+
+	// TODO - this should be broken down by GPU
+	estimatedVRAM  uint64 // Estimated usage of VRAM by the loaded model
+	estimatedTotal uint64 // Total size of model
+	totalLayers    uint64
+	gpuCount       int
+
+	sem *semaphore.Weighted
 }
 
-func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
+func LoadModel(model string) (*GGML, error) {
+	if _, err := os.Stat(model); err != nil {
+		return nil, err
+	}
+
 	f, err := os.Open(model)
 	if err != nil {
 		return nil, err
@@ -43,144 +69,69 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 	defer f.Close()
 
 	ggml, _, err := DecodeGGML(f)
-	if err != nil {
-		return nil, err
-	}
+	return ggml, err
+}
 
+// NewLlamaServer will run a server for the given GPUs
+// The gpu list must be a single family.
+func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
+	var err error
 	if opts.NumCtx > int(ggml.KV().ContextLength()) {
-		slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
-		opts.NumCtx = int(ggml.KV().ContextLength())
+		slog.Warn("requested context length is greater than the model's training context window size", "requested", opts.NumCtx, "training size", ggml.KV().ContextLength())
 	}
 
 	if opts.NumCtx < 4 {
 		opts.NumCtx = 4
 	}
 
-	memoryAvailable, _ := gpu.CheckVRAM()
-	info := gpu.GetGPUInfo()
+	cpuRunner := ""
+	var estimatedVRAM uint64
+	var estimatedTotal uint64
+	var systemMemory uint64
+	gpuCount := len(gpus)
+	if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
 
-	memoryMinimum := info.MinimumMemory
-	for _, projector := range projectors {
-		memoryMinimum += projectorMemoryRequirements(projector)
-
-		// multimodal models require at least 2048 context
-		opts.NumCtx = max(opts.NumCtx, 2048)
-	}
+		// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
 
-	// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
-	var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
-
-	graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
-	if graphPartialOffload == 0 {
-		graphPartialOffload = ggml.KV().GQA() * kv / 6
-	}
-
-	if graphFullOffload == 0 {
-		graphFullOffload = graphPartialOffload
-	}
-
-	graphFullOffload *= uint64(info.DeviceCount)
-	graphPartialOffload *= uint64(info.DeviceCount)
-
-	// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
-	memoryRequiredTotal := memoryMinimum + graphFullOffload
-
-	// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
-	memoryRequiredPartial := memoryMinimum + graphPartialOffload
-
-	if info.Library != "metal" {
-		if memoryRequiredPartial > memoryAvailable {
-			info.Library = "cpu"
-		}
-	}
-
-	var layerCount int
-	layers := ggml.Tensors().Layers()
-	for i := 0; i < int(ggml.KV().BlockCount()); i++ {
-		memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
-
-		// KV is proportional to the number of layers
-		memoryLayer += kv / ggml.KV().BlockCount()
-
-		memoryRequiredTotal += memoryLayer
-		if memoryAvailable > memoryRequiredPartial+memoryLayer {
-			memoryRequiredPartial += memoryLayer
-			layerCount++
+		cpuRunner = serverForCpu()
+		gpuCount = 0
+	} else {
+		if gpus[0].Library == "metal" {
+			memInfo, err := gpu.GetCPUMem()
+			if err != nil {
+				slog.Error("failed to lookup system memory", "error", err)
+			} else {
+				systemMemory = memInfo.TotalMemory
+				slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
+			}
 		}
-	}
-
-	var memoryLayerOutput uint64
-	for k, v := range layers {
-		if !strings.HasPrefix(k, "blk.") {
-			memoryLayerOutput += v.size()
+		var layers int
+		layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
+
+		if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
+			// disable partial offloading when model is greater than total system memory as this
+			// can lead to locking up the system
+			opts.NumGPU = 0
+		} else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
+			opts.NumGPU = layers
 		}
 	}
 
-	memoryRequiredTotal += memoryLayerOutput
-
-	if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory {
-		// disable partial offloading when model is greater than total system memory
-		opts.NumGPU = 0
-	} else if memoryAvailable > memoryRequiredTotal {
-		layerCount = int(ggml.KV().BlockCount()) + 1
-		memoryRequiredPartial = memoryRequiredTotal
-	}
-
-	if opts.NumGPU < 0 {
-		opts.NumGPU = layerCount
-	}
-
-	memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
-
-	slog.Info(
-		"offload to gpu",
-		slog.Group(
-			"layers",
-			// actual number of layers offloaded
-			"real", opts.NumGPU,
-			// estimated number of layers that can be offloaded
-			"estimate", layerCount,
-		),
-		slog.Group(
-			"memory",
-			// memory available for offloading
-			"available", format.HumanBytes2(memoryAvailable),
-			slog.Group(
-				"required",
-				// memory required for full offloading
-				"full", format.HumanBytes2(memoryRequiredTotal),
-				// memory required to offload layers.estimate layers
-				"partial", format.HumanBytes2(memoryRequiredPartial),
-				// memory of KV cache
-				"kv", format.HumanBytes2(kv),
-			),
-			slog.Group(
-				"weights",
-				// memory of the weights
-				"total", format.HumanBytes2(memoryWeights),
-				// memory of repeating layers
-				"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
-				// memory of non-repeating layers
-				"nonrepeating", format.HumanBytes2(memoryLayerOutput),
-			),
-			slog.Group(
-				"graph",
-				// memory of graph when fully offloaded
-				"full", format.HumanBytes2(graphFullOffload),
-				// memory of graph when not fully offloaded
-				"partial", format.HumanBytes2(graphPartialOffload),
-			),
-		),
-	)
+	// Loop through potential servers
+	finalErr := fmt.Errorf("no suitable llama servers found")
 
 	if len(adapters) > 1 {
 		return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
 	}
 
 	availableServers := availableServers()
-	servers := serversForGpu(info)
-
-	demandLib := os.Getenv("OLLAMA_LLM_LIBRARY")
+	var servers []string
+	if cpuRunner != "" {
+		servers = []string{cpuRunner}
+	} else {
+		servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
+	}
+	demandLib := envconfig.LLMLibrary
 	if demandLib != "" {
 		serverPath := availableServers[demandLib]
 		if serverPath == "" {
@@ -188,11 +139,15 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 		} else {
 			slog.Info("user override", "OLLAMA_LLM_LIBRARY", demandLib, "path", serverPath)
 			servers = []string{demandLib}
+			if strings.HasPrefix(demandLib, "cpu") {
+				// Omit the GPU flag to silence the warning
+				opts.NumGPU = -1
+			}
 		}
 	}
 
 	if len(servers) == 0 {
-		return nil, fmt.Errorf("no servers found for %v", info)
+		return nil, fmt.Errorf("no servers found for %v", gpus)
 	}
 
 	params := []string{
@@ -201,7 +156,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 		"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
 		"--embedding",
 	}
-	if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
+	if envconfig.Debug {
 		params = append(params, "--log-format", "json")
 	} else {
 		params = append(params, "--log-disable")
@@ -211,7 +166,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 		params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU))
 	}
 
-	if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
+	if envconfig.Debug {
 		params = append(params, "--verbose")
 	}
 
@@ -249,10 +204,30 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 		params = append(params, "--numa")
 	}
 
-	// Loop through potential servers
-	var finalErr error
+	numParallel := envconfig.NumParallel
+
+	// TODO (jmorganca): multimodal models don't support parallel yet
+	// see https://github.com/ollama/ollama/issues/4165
+	if len(projectors) > 0 {
+		numParallel = 1
+		slog.Warn("multimodal models don't support parallel requests yet")
+	}
+
+	params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
+
 	for i := 0; i < len(servers); i++ {
 		dir := availableServers[servers[i]]
+		if dir == "" {
+			// Shouldn't happen
+			finalErr = fmt.Errorf("[%d] server %s not listed in available servers %v", i, servers[i], availableServers)
+			slog.Error("sever list inconsistent", "error", finalErr)
+			continue
+		}
+
+		if strings.HasPrefix(servers[i], "cpu") {
+			// TODO if we tried a gpu runner first, and it failed, record the error and bubble that back up
+			gpuCount = 0
+		}
 
 		// Find an availableServers  port, retry on each iterration in case the failure was a port conflict race
 		port := 0
@@ -273,12 +248,21 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 		if runtime.GOOS == "windows" {
 			pathEnv = "PATH"
 		}
-		// append the server directory to LD_LIBRARY_PATH/PATH
+		// prepend the server directory to LD_LIBRARY_PATH/PATH
 		libraryPaths := []string{dir}
+
 		if libraryPath, ok := os.LookupEnv(pathEnv); ok {
 			// Append our runner directory to the path
 			// This will favor system libraries over our bundled library dependencies
-			libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...)
+			libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...)
+		}
+
+		// Note: we always put the dependency path first
+		// since this was the exact version we verified for AMD GPUs
+		// and we favor what the user had in their path
+		if gpus[0].DependencyPath != "" {
+			// TODO refine for multi-gpu support
+			libraryPaths = append([]string{gpus[0].DependencyPath}, libraryPaths...)
 		}
 
 		server := filepath.Join(dir, "ollama_llama_server")
@@ -286,21 +270,66 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 			server = server + ".exe"
 		}
 
-		s := &LlamaServer{
-			port:    port,
-			cmd:     exec.Command(server, finalParams...),
-			status:  NewStatusWriter(os.Stderr),
-			options: opts,
+		// Detect tmp cleaners wiping out the file
+		_, err := os.Stat(server)
+		if errors.Is(err, os.ErrNotExist) {
+			slog.Warn("llama server disappeared, reinitializing payloads", "path", server, "error", err)
+			err = Init()
+			if err != nil {
+				slog.Warn("failed to reinitialize payloads", "error", err)
+				return nil, err
+			}
 		}
-		libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
-		slog.Debug(libEnv)
-		s.cmd.Env = append(os.Environ(), libEnv)
+
+		s := &llmServer{
+			port:           port,
+			cmd:            exec.Command(server, finalParams...),
+			status:         NewStatusWriter(os.Stderr),
+			options:        opts,
+			estimatedVRAM:  estimatedVRAM,
+			estimatedTotal: estimatedTotal,
+			sem:            semaphore.NewWeighted(int64(numParallel)),
+			totalLayers:    ggml.KV().BlockCount() + 1,
+			gpuCount:       gpuCount,
+		}
+
+		s.cmd.Env = os.Environ()
 		s.cmd.Stdout = os.Stdout
 		s.cmd.Stderr = s.status
 
+		visibleDevicesEnv, visibleDevicesEnvVal := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
+		pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
+
+		// Update or add the path and visible devices variable with our adjusted version
+		pathNeeded := true
+		devicesNeeded := visibleDevicesEnv != ""
+		for i := range s.cmd.Env {
+			cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
+			if strings.EqualFold(cmp[0], pathEnv) {
+				s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
+				pathNeeded = false
+			} else if devicesNeeded && strings.EqualFold(cmp[0], visibleDevicesEnv) {
+				s.cmd.Env[i] = visibleDevicesEnv + "=" + visibleDevicesEnvVal
+				devicesNeeded = false
+			}
+		}
+		if pathNeeded {
+			s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
+		}
+		if devicesNeeded {
+			s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
+		}
+
 		slog.Info("starting llama server", "cmd", s.cmd.String())
+		// Log at debug as the environment is inherited and might contain sensitive information
+		slog.Debug("subprocess", "environment", s.cmd.Env)
 
 		if err = s.cmd.Start(); err != nil {
+			// Detect permission denied and augment them essage about noexec
+			if errors.Is(err, os.ErrPermission) {
+				finalErr = fmt.Errorf("unable to start server %w.  %s may have noexec set.  Set OLLAMA_TMPDIR for server to a writable executable directory", err, dir)
+				continue
+			}
 			msg := ""
 			if s.status != nil && s.status.LastErrMsg != "" {
 				msg = s.status.LastErrMsg
@@ -310,12 +339,6 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 			continue
 		}
 
-		// reap subprocess when it exits
-		go func() {
-			// Exit status managed via getServerStatus
-			_ = s.cmd.Wait()
-		}()
-
 		return s, nil
 	}
 
@@ -347,12 +370,27 @@ type ServerStatus int
 
 const ( // iota is reset to 0
 	ServerStatusReady ServerStatus = iota
-	ServerStatusNoSlotsAvaialble
+	ServerStatusNoSlotsAvailable
 	ServerStatusLoadingModel
 	ServerStatusNotResponding
 	ServerStatusError
 )
 
+func (s ServerStatus) ToString() string {
+	switch s {
+	case ServerStatusReady:
+		return "llm server ready"
+	case ServerStatusNoSlotsAvailable:
+		return "llm busy - no slots available"
+	case ServerStatusLoadingModel:
+		return "llm server loading model"
+	case ServerStatusNotResponding:
+		return "llm server not responding"
+	default:
+		return "llm server error"
+	}
+}
+
 type ServerStatusResp struct {
 	Status          string `json:"status"`
 	SlotsIdle       int    `json:"slots_idle"`
@@ -360,13 +398,17 @@ type ServerStatusResp struct {
 	Error           string `json:"error"`
 }
 
-func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
+func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
 	// Fail fast if its exited
 	if s.cmd.ProcessState != nil {
 		msg := ""
 		if s.status != nil && s.status.LastErrMsg != "" {
 			msg = s.status.LastErrMsg
 		}
+		if s.cmd.ProcessState.ExitCode() == -1 {
+			// Most likely a signal killed it, log some more details to try to help troubleshoot
+			slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
+		}
 		return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
 	}
 
@@ -399,7 +441,7 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
 	case "ok":
 		return ServerStatusReady, nil
 	case "no slot available":
-		return ServerStatusNoSlotsAvaialble, nil
+		return ServerStatusNoSlotsAvailable, nil
 	case "loading model":
 		return ServerStatusLoadingModel, nil
 	default:
@@ -407,7 +449,30 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
 	}
 }
 
-func (s *LlamaServer) Ping(ctx context.Context) error {
+// getServerStatusRetry will retry if ServerStatusNoSlotsAvailable is received
+func (s *llmServer) getServerStatusRetry(ctx context.Context) (ServerStatus, error) {
+	var retries int
+	for {
+		status, err := s.getServerStatus(ctx)
+		if err != nil {
+			return status, err
+		}
+
+		if status == ServerStatusNoSlotsAvailable {
+			if retries >= 10 {
+				return status, fmt.Errorf("no slots available after %d retries", retries)
+			}
+
+			time.Sleep(5 * time.Millisecond)
+			retries++
+			continue
+		}
+
+		return status, nil
+	}
+}
+
+func (s *llmServer) Ping(ctx context.Context) error {
 	_, err := s.getServerStatus(ctx)
 	if err != nil {
 		slog.Debug("server unhealthy", "error", err)
@@ -416,13 +481,25 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
 	return nil
 }
 
-func (s *LlamaServer) WaitUntilRunning() error {
+func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
 	start := time.Now()
 	expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
 
 	slog.Info("waiting for llama runner to start responding")
 
 	for {
+		select {
+		case <-ctx.Done():
+			slog.Info("context expired before server started")
+			return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
+		case err := <-s.done:
+			msg := ""
+			if s.status != nil && s.status.LastErrMsg != "" {
+				msg = s.status.LastErrMsg
+			}
+			return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
+		default:
+		}
 		ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
 		defer cancel()
 		status, err := s.getServerStatus(ctx)
@@ -487,7 +564,6 @@ ws ::= ([ \t\n] ws)?
 `
 
 const maxBufferSize = 512 * format.KiloByte
-const maxRetries = 3
 
 type ImageData struct {
 	Data []byte `json:"data"`
@@ -524,7 +600,19 @@ type CompletionResponse struct {
 	EvalDuration       time.Duration
 }
 
-func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
+func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
+	if err := s.sem.Acquire(ctx, 1); err != nil {
+		slog.Error("Failed to acquire semaphore", "error", err)
+		return err
+	}
+	defer s.sem.Release(1)
+
+	// only allow maximum 10 "context shifts" to avoid infinite generation
+	if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
+		req.Options.NumPredict = 10 * s.options.NumCtx
+		slog.Debug("setting token limit to 10x num_ctx", "num_ctx", s.options.NumCtx, "num_predict", req.Options.NumPredict)
+	}
+
 	request := map[string]any{
 		"prompt":            req.Prompt,
 		"stream":            true,
@@ -551,11 +639,11 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
 	}
 
 	// Make sure the server is ready
-	status, err := s.getServerStatus(ctx)
+	status, err := s.getServerStatusRetry(ctx)
 	if err != nil {
 		return err
 	} else if status != ServerStatusReady {
-		return fmt.Errorf("unexpected server status: %d", status)
+		return fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 
 	if req.Format == "json" {
@@ -565,133 +653,113 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
 		}
 	}
 
-	retryDelay := 100 * time.Microsecond
-	for retries := 0; retries < maxRetries; retries++ {
-		if retries > 0 {
-			time.Sleep(retryDelay) // wait before retrying
-			retryDelay *= 2        // exponential backoff
-		}
+	// Handling JSON marshaling with special characters unescaped.
+	buffer := &bytes.Buffer{}
+	enc := json.NewEncoder(buffer)
+	enc.SetEscapeHTML(false)
 
-		// Handling JSON marshaling with special characters unescaped.
-		buffer := &bytes.Buffer{}
-		enc := json.NewEncoder(buffer)
-		enc.SetEscapeHTML(false)
+	if err := enc.Encode(request); err != nil {
+		return fmt.Errorf("failed to marshal data: %v", err)
+	}
 
-		if err := enc.Encode(request); err != nil {
-			return fmt.Errorf("failed to marshal data: %v", err)
-		}
+	endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
+	serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
+	if err != nil {
+		return fmt.Errorf("error creating POST request: %v", err)
+	}
+	serverReq.Header.Set("Content-Type", "application/json")
 
-		endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
-		req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
-		if err != nil {
-			return fmt.Errorf("error creating POST request: %v", err)
-		}
-		req.Header.Set("Content-Type", "application/json")
+	res, err := http.DefaultClient.Do(serverReq)
+	if err != nil {
+		return fmt.Errorf("POST predict: %v", err)
+	}
+	defer res.Body.Close()
 
-		resp, err := http.DefaultClient.Do(req)
+	if res.StatusCode >= 400 {
+		bodyBytes, err := io.ReadAll(res.Body)
 		if err != nil {
-			return fmt.Errorf("POST predict: %v", err)
+			return fmt.Errorf("failed reading llm error response: %w", err)
 		}
-		defer resp.Body.Close()
+		log.Printf("llm predict error: %s", bodyBytes)
+		return fmt.Errorf("%s", bodyBytes)
+	}
 
-		if resp.StatusCode >= 400 {
-			bodyBytes, err := io.ReadAll(resp.Body)
-			if err != nil {
-				return fmt.Errorf("failed reading llm error response: %w", err)
+	scanner := bufio.NewScanner(res.Body)
+	buf := make([]byte, 0, maxBufferSize)
+	scanner.Buffer(buf, maxBufferSize)
+
+	// keep track of the last token generated, this is used to abort if the model starts looping
+	var lastToken string
+	var tokenRepeat int
+
+	for scanner.Scan() {
+		select {
+		case <-ctx.Done():
+			// This handles the request cancellation
+			return ctx.Err()
+		default:
+			line := scanner.Bytes()
+			if len(line) == 0 {
+				continue
 			}
-			log.Printf("llm predict error: %s", bodyBytes)
-			return fmt.Errorf("%s", bodyBytes)
-		}
 
-		scanner := bufio.NewScanner(resp.Body)
-		buf := make([]byte, 0, maxBufferSize)
-		scanner.Buffer(buf, maxBufferSize)
+			evt, ok := bytes.CutPrefix(line, []byte("data: "))
+			if !ok {
+				return fmt.Errorf("error parsing llm response stream: %s", line)
+			}
 
-		retryNeeded := false
-		// keep track of the last token generated, this is used to abort if the model starts looping
-		var lastToken string
-		var tokenRepeat int
+			var c completion
+			if err := json.Unmarshal(evt, &c); err != nil {
+				return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
+			}
 
-		for scanner.Scan() {
-			select {
-			case <-ctx.Done():
-				// This handles the request cancellation
-				return ctx.Err()
+			switch {
+			case strings.TrimSpace(c.Content) == lastToken:
+				tokenRepeat++
 			default:
-				line := scanner.Bytes()
-				if len(line) == 0 {
-					continue
-				}
-
-				// try again on slot unavailable
-				if bytes.Contains(line, []byte("slot unavailable")) {
-					retryNeeded = true
-					break
-				}
-
-				evt, ok := bytes.CutPrefix(line, []byte("data: "))
-				if !ok {
-					return fmt.Errorf("error parsing llm response stream: %s", line)
-				}
-
-				var c completion
-				if err := json.Unmarshal(evt, &c); err != nil {
-					return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
-				}
-
-				switch {
-				case strings.TrimSpace(c.Content) == lastToken:
-					tokenRepeat++
-				default:
-					lastToken = strings.TrimSpace(c.Content)
-					tokenRepeat = 0
-				}
-
-				// 30 picked as an arbitrary max token repeat limit, modify as needed
-				if tokenRepeat > 30 {
-					slog.Debug("prediction aborted, token repeat limit reached")
-					return ctx.Err()
-				}
-
-				if c.Content != "" {
-					fn(CompletionResponse{
-						Content: c.Content,
-					})
-				}
-
-				if c.Stop {
-					fn(CompletionResponse{
-						Done:               true,
-						PromptEvalCount:    c.Timings.PromptN,
-						PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
-						EvalCount:          c.Timings.PredictedN,
-						EvalDuration:       parseDurationMs(c.Timings.PredictedMS),
-					})
-					return nil
-				}
+				lastToken = strings.TrimSpace(c.Content)
+				tokenRepeat = 0
+			}
+
+			// 30 picked as an arbitrary max token repeat limit, modify as needed
+			if tokenRepeat > 30 {
+				slog.Debug("prediction aborted, token repeat limit reached")
+				return ctx.Err()
 			}
-		}
 
-		if err := scanner.Err(); err != nil {
-			if strings.Contains(err.Error(), "unexpected EOF") {
-				s.Close()
-				msg := ""
-				if s.status != nil && s.status.LastErrMsg != "" {
-					msg = s.status.LastErrMsg
-				}
+			if c.Content != "" {
+				fn(CompletionResponse{
+					Content: c.Content,
+				})
+			}
 
-				return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
+			if c.Stop {
+				fn(CompletionResponse{
+					Done:               true,
+					PromptEvalCount:    c.Timings.PromptN,
+					PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
+					EvalCount:          c.Timings.PredictedN,
+					EvalDuration:       parseDurationMs(c.Timings.PredictedMS),
+				})
+				return nil
 			}
-			return fmt.Errorf("error reading llm response: %v", err)
 		}
+	}
 
-		if !retryNeeded {
-			return nil // success
+	if err := scanner.Err(); err != nil {
+		if strings.Contains(err.Error(), "unexpected EOF") {
+			s.Close()
+			msg := ""
+			if s.status != nil && s.status.LastErrMsg != "" {
+				msg = s.status.LastErrMsg
+			}
+			return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
 		}
+
+		return fmt.Errorf("error reading llm response: %v", err)
 	}
 
-	// should never reach here ideally
-	return fmt.Errorf("max retries exceeded")
+	return nil
 }
 
 type EmbeddingRequest struct {
@@ -702,13 +770,19 @@ type EmbeddingResponse struct {
 	Embedding []float64 `json:"embedding"`
 }
 
-func (s *LlamaServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
+func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
+	if err := s.sem.Acquire(ctx, 1); err != nil {
+		slog.Error("Failed to acquire semaphore", "error", err)
+		return nil, err
+	}
+	defer s.sem.Release(1)
+
 	// Make sure the server is ready
-	status, err := s.getServerStatus(ctx)
+	status, err := s.getServerStatusRetry(ctx)
 	if err != nil {
 		return nil, err
 	} else if status != ServerStatusReady {
-		return nil, fmt.Errorf("unexpected server status: %d", status)
+		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 
 	data, err := json.Marshal(TokenizeRequest{Content: prompt})
@@ -754,13 +828,13 @@ type TokenizeResponse struct {
 	Tokens []int `json:"tokens"`
 }
 
-func (s *LlamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
+func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
 	// Make sure the server is ready
 	status, err := s.getServerStatus(ctx)
 	if err != nil {
 		return nil, err
-	} else if status != ServerStatusReady {
-		return nil, fmt.Errorf("unexpected server status: %d", status)
+	} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
+		return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 
 	data, err := json.Marshal(TokenizeRequest{Content: content})
@@ -806,13 +880,13 @@ type DetokenizeResponse struct {
 	Content string `json:"content"`
 }
 
-func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
+func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
 	// Make sure the server is ready
 	status, err := s.getServerStatus(ctx)
 	if err != nil {
 		return "", err
-	} else if status != ServerStatusReady {
-		return "", fmt.Errorf("unexpected server status: %d", status)
+	} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
+		return "", fmt.Errorf("unexpected server status: %s", status.ToString())
 	}
 
 	data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
@@ -850,15 +924,25 @@ func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, err
 	return decoded.Content, nil
 }
 
-func (s *LlamaServer) Close() error {
+func (s *llmServer) Close() error {
 	if s.cmd != nil {
 		slog.Debug("stopping llama server")
-		return s.cmd.Process.Kill()
+		if err := s.cmd.Process.Kill(); err != nil {
+			return err
+		}
+
+		_ = s.cmd.Wait()
+
+		slog.Debug("llama server stopped")
 	}
 
 	return nil
 }
 
+func (s *llmServer) EstimatedVRAM() uint64 {
+	return s.estimatedVRAM
+}
+
 func parseDurationMs(ms float64) time.Duration {
 	dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
 	if err != nil {

+ 1 - 1
macapp/src/app.tsx

@@ -19,7 +19,7 @@ export default function () {
   const [step, setStep] = useState<Step>(Step.WELCOME)
   const [commandCopied, setCommandCopied] = useState<boolean>(false)
 
-  const command = 'ollama run llama2'
+  const command = 'ollama run llama3'
 
   return (
     <div className='drag'>

Algúns arquivos non se mostraron porque demasiados arquivos cambiaron neste cambio