Przeglądaj źródła

pull from remote

Bruce MacDonald 1 rok temu
rodzic
commit
52beb0a99e
5 zmienionych plików z 168 dodań i 11 usunięć
  1. 2 2
      README.md
  2. 8 0
      ollama/cmd/cli.py
  3. 73 6
      ollama/model.py
  4. 83 3
      poetry.lock
  5. 2 0
      pyproject.toml

+ 2 - 2
README.md

@@ -93,8 +93,6 @@ Unload a model
 ollama.unload("model")
 ```
 
-## Cooming Soon
-
 ### `ollama.pull(model)`
 
 Download a model
@@ -103,6 +101,8 @@ Download a model
 ollama.pull("huggingface.co/thebloke/llama-7b-ggml")
 ```
 
+## Cooming Soon
+
 ### `ollama.search("query")`
 
 Search for compatible models that Ollama can run

+ 8 - 0
ollama/cmd/cli.py

@@ -27,6 +27,10 @@ def main():
     add_parser.add_argument("model")
     add_parser.set_defaults(fn=add)
 
+    pull_parser = subparsers.add_parser("pull")
+    pull_parser.add_argument("remote")
+    pull_parser.set_defaults(fn=pull)
+
     args = parser.parse_args()
     args = vars(args)
 
@@ -55,3 +59,7 @@ def generate(*args, **kwargs):
 
 def add(model, models_home):
     os.rename(model, Path(models_home) / Path(model).name)
+
+
+def pull(*args, **kwargs):
+    model.pull(*args, **kwargs)

+ 73 - 6
ollama/model.py

@@ -1,9 +1,76 @@
-from os import walk, path
+import os
+import requests
+from urllib.parse import urlsplit, urlunsplit
+from tqdm import tqdm
 
 
-def models(models_home='.', *args, **kwargs):
-    for root, _, files in walk(models_home):
+def models(models_home=".", *args, **kwargs):
+    for root, _, files in os.walk(models_home):
         for file in files:
-            base, ext = path.splitext(file)
-            if ext == '.bin':
-                yield base, path.join(root, file)
+            base, ext = os.path.splitext(file)
+            if ext == ".bin":
+                yield base, os.path.join(root, file)
+
+
+def pull(remote, models_home=".", *args, **kwargs):
+    if not (remote.startswith("http://") or remote.startswith("https://")):
+        remote = f"https://{remote}"
+
+    parts = urlsplit(remote)
+    path_parts = parts.path.split("/tree/")
+
+    if len(path_parts) == 1:
+        model = path_parts[0]
+        branch = "main"
+    else:
+        model, branch = path_parts
+
+    model = model.strip("/")
+
+    # Reconstruct the URL
+    new_url = urlunsplit(
+        (
+            "https",
+            parts.netloc,
+            f"/api/models/{model}/tree/{branch}",
+            parts.query,
+            parts.fragment,
+        )
+    )
+
+    print(f"Fetching model from {new_url}")
+
+    response = requests.get(new_url)
+    response.raise_for_status()  # Raises stored HTTPError, if one occurred
+
+    json_response = response.json()
+
+    for file_info in json_response:
+        if file_info.get("type") == "file" and file_info.get("path").endswith(".bin"):
+            f_path = file_info.get("path")
+            download_url = f"https://huggingface.co/{model}/resolve/{branch}/{f_path}"
+            local_filename = os.path.join(
+                models_home, os.path.basename(file_info.get("path"))
+            )
+
+            if os.path.exists(local_filename):
+                # TODO: check if the file is the same
+                break
+
+            response = requests.get(download_url, stream=True)
+            response.raise_for_status()  # Raises stored HTTPError, if one occurred
+
+            total_size = int(response.headers.get("content-length", 0))
+
+            with open(local_filename, "wb") as file, tqdm(
+                desc=local_filename,
+                total=total_size,
+                unit="iB",
+                unit_scale=True,
+                unit_divisor=1024,
+            ) as bar:
+                for data in response.iter_content(chunk_size=1024):
+                    size = file.write(data)
+                    bar.update(size)
+
+            break  # Stop after downloading the first .bin file

+ 83 - 3
poetry.lock

@@ -165,11 +165,22 @@ docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-
 tests = ["attrs[tests-no-zope]", "zope-interface"]
 tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
 
+[[package]]
+name = "certifi"
+version = "2023.5.7"
+description = "Python package for providing Mozilla's CA Bundle."
+optional = false
+python-versions = ">=3.6"
+files = [
+    {file = "certifi-2023.5.7-py3-none-any.whl", hash = "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"},
+    {file = "certifi-2023.5.7.tar.gz", hash = "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7"},
+]
+
 [[package]]
 name = "charset-normalizer"
 version = "3.1.0"
 description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
-optional = true
+optional = false
 python-versions = ">=3.7.0"
 files = [
     {file = "charset-normalizer-3.1.0.tar.gz", hash = "sha256:34e0a2f9c370eb95597aae63bf85eb5e96826d81e3dcf88b8886012906f509b5"},
@@ -249,6 +260,17 @@ files = [
     {file = "charset_normalizer-3.1.0-py3-none-any.whl", hash = "sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d"},
 ]
 
+[[package]]
+name = "colorama"
+version = "0.4.6"
+description = "Cross-platform colored terminal text."
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
+files = [
+    {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
+    {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
+]
+
 [[package]]
 name = "diskcache"
 version = "5.6.1"
@@ -374,7 +396,7 @@ files = [
 name = "idna"
 version = "3.4"
 description = "Internationalized Domain Names in Applications (IDNA)"
-optional = true
+optional = false
 python-versions = ">=3.5"
 files = [
     {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"},
@@ -680,6 +702,27 @@ test = ["coverage", "flaky", "matplotlib", "numpy", "pandas", "pylint (>=2.5.0,<
 websockets = ["websockets (>=10.3)"]
 yapf = ["whatthepatch (>=1.0.2,<2.0.0)", "yapf (>=0.33.0)"]
 
+[[package]]
+name = "requests"
+version = "2.31.0"
+description = "Python HTTP for Humans."
+optional = false
+python-versions = ">=3.7"
+files = [
+    {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"},
+    {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"},
+]
+
+[package.dependencies]
+certifi = ">=2017.4.17"
+charset-normalizer = ">=2,<4"
+idna = ">=2.5,<4"
+urllib3 = ">=1.21.1,<3"
+
+[package.extras]
+socks = ["PySocks (>=1.5.6,!=1.5.7)"]
+use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
+
 [[package]]
 name = "setuptools"
 version = "68.0.0"
@@ -696,6 +739,26 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-g
 testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
 testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
 
+[[package]]
+name = "tqdm"
+version = "4.65.0"
+description = "Fast, Extensible Progress Meter"
+optional = false
+python-versions = ">=3.7"
+files = [
+    {file = "tqdm-4.65.0-py3-none-any.whl", hash = "sha256:c4f53a17fe37e132815abceec022631be8ffe1b9381c2e6e30aa70edc99e9671"},
+    {file = "tqdm-4.65.0.tar.gz", hash = "sha256:1871fb68a86b8fb3b59ca4cdd3dcccbc7e6d613eeed31f4c332531977b89beb5"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "platform_system == \"Windows\""}
+
+[package.extras]
+dev = ["py-make (>=0.1.0)", "twine", "wheel"]
+notebook = ["ipywidgets (>=6)"]
+slack = ["slack-sdk"]
+telegram = ["requests"]
+
 [[package]]
 name = "typing-extensions"
 version = "4.6.3"
@@ -777,6 +840,23 @@ files = [
     {file = "ujson-5.8.0.tar.gz", hash = "sha256:78e318def4ade898a461b3d92a79f9441e7e0e4d2ad5419abed4336d702c7425"},
 ]
 
+[[package]]
+name = "urllib3"
+version = "2.0.3"
+description = "HTTP library with thread-safe connection pooling, file post, and more."
+optional = false
+python-versions = ">=3.7"
+files = [
+    {file = "urllib3-2.0.3-py3-none-any.whl", hash = "sha256:48e7fafa40319d358848e1bc6809b208340fafe2096f1725d05d67443d0483d1"},
+    {file = "urllib3-2.0.3.tar.gz", hash = "sha256:bee28b5e56addb8226c96f7f13ac28cb4c301dd5ea8a6ca179c0b9835e032825"},
+]
+
+[package.extras]
+brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
+secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"]
+socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
+zstd = ["zstandard (>=0.18.0)"]
+
 [[package]]
 name = "yarl"
 version = "1.9.2"
@@ -870,4 +950,4 @@ server = ["aiohttp", "aiohttp-cors"]
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.11"
-content-hash = "c649ffbb8b8045d35831f9fb09a6f099b8e940abd85a020c3dfd24173b2582d8"
+content-hash = "ba168754266c6c46b2136207415a5b3a879c957e53e924cab1e64267849ceb90"

+ 2 - 0
pyproject.toml

@@ -13,6 +13,8 @@ llama-cpp-python = "^0.1.66"
 
 aiohttp = {version = "^3.8.4", optional = true}
 aiohttp-cors = {version = "^0.7.0", optional = true}
+requests = "^2.31.0"
+tqdm = "^4.65.0"
 
 [tool.poetry.extras]
 server = ["aiohttp", "aiohttp_cors"]