utils.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. from fastapi import APIRouter, UploadFile, File, BackgroundTasks
  2. from fastapi import Depends, HTTPException, status
  3. from starlette.responses import StreamingResponse
  4. from pydantic import BaseModel
  5. from utils.misc import calculate_sha256
  6. import requests
  7. import os
  8. import asyncio
  9. import json
  10. from config import OLLAMA_API_BASE_URL
  11. router = APIRouter()
  12. class UploadBlobForm(BaseModel):
  13. filename: str
  14. from urllib.parse import urlparse
  15. def parse_huggingface_url(hf_url):
  16. # Parse the URL
  17. parsed_url = urlparse(hf_url)
  18. # Get the path and split it into components
  19. path_components = parsed_url.path.split("/")
  20. # Extract the desired output
  21. user_repo = "/".join(path_components[1:3])
  22. model_file = path_components[-1]
  23. return [user_repo, model_file]
  24. def download_file_stream(url, file_path, chunk_size=1024 * 1024):
  25. done = False
  26. if os.path.exists(file_path):
  27. current_size = os.path.getsize(file_path)
  28. else:
  29. current_size = 0
  30. headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
  31. with requests.get(url, headers=headers, stream=True) as response:
  32. total_size = int(response.headers.get("content-length", 0)) + current_size
  33. with open(file_path, "ab") as file:
  34. for data in response.iter_content(chunk_size=chunk_size):
  35. current_size += len(data)
  36. file.write(data)
  37. done = current_size == total_size
  38. progress = round((current_size / total_size) * 100, 2)
  39. yield f'data: {{"progress": {progress}, "current": {current_size}, "total": {total_size}}}\n\n'
  40. @router.get("/download")
  41. async def download(
  42. url: str = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf",
  43. ):
  44. user_repo, model_file = parse_huggingface_url(url)
  45. os.makedirs("./uploads", exist_ok=True)
  46. file_path = os.path.join("./uploads", f"{model_file}")
  47. return StreamingResponse(
  48. download_file_stream(url, file_path), media_type="text/event-stream"
  49. )
  50. @router.post("/upload")
  51. async def upload(file: UploadFile = File(...)):
  52. os.makedirs("./uploads", exist_ok=True)
  53. file_path = os.path.join("./uploads", file.filename)
  54. async def file_write_stream():
  55. total = 0
  56. total_size = file.size
  57. chunk_size = 1024 * 1024
  58. done = False
  59. try:
  60. with open(file_path, "wb+") as f:
  61. while True:
  62. chunk = file.file.read(chunk_size)
  63. if not chunk:
  64. break
  65. f.write(chunk)
  66. total += len(chunk)
  67. done = total_size == total
  68. res = {
  69. "total": total_size,
  70. "uploaded": total,
  71. }
  72. yield f"data: {json.dumps(res)}\n\n"
  73. if done:
  74. f.seek(0)
  75. hashed = calculate_sha256(f)
  76. f.seek(0)
  77. url = f"{OLLAMA_API_BASE_URL}/blobs/sha256:{hashed}"
  78. response = requests.post(url, data=f)
  79. if response.ok:
  80. res = {
  81. "done": done,
  82. "blob": f"sha256:{hashed}",
  83. }
  84. os.remove(file_path)
  85. yield f"data: {json.dumps(res)}\n\n"
  86. else:
  87. raise "Ollama: Could not create blob, Please try again."
  88. except Exception as e:
  89. res = {"error": str(e)}
  90. yield f"data: {json.dumps(res)}\n\n"
  91. return StreamingResponse(file_write_stream(), media_type="text/event-stream")