utils.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from fastapi import APIRouter, UploadFile, File, BackgroundTasks
  2. from fastapi import Depends, HTTPException, status
  3. from starlette.responses import StreamingResponse
  4. from pydantic import BaseModel
  5. import requests
  6. import os
  7. import aiohttp
  8. import json
  9. from utils.misc import calculate_sha256
  10. from config import OLLAMA_API_BASE_URL
  11. router = APIRouter()
  12. class UploadBlobForm(BaseModel):
  13. filename: str
  14. from urllib.parse import urlparse
  15. def parse_huggingface_url(hf_url):
  16. try:
  17. # Parse the URL
  18. parsed_url = urlparse(hf_url)
  19. # Get the path and split it into components
  20. path_components = parsed_url.path.split("/")
  21. # Extract the desired output
  22. user_repo = "/".join(path_components[1:3])
  23. model_file = path_components[-1]
  24. return model_file
  25. except ValueError:
  26. return None
  27. async def download_file_stream(url,
  28. file_path,
  29. file_name,
  30. chunk_size=1024 * 1024):
  31. done = False
  32. if os.path.exists(file_path):
  33. current_size = os.path.getsize(file_path)
  34. else:
  35. current_size = 0
  36. headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
  37. timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
  38. async with aiohttp.ClientSession(timeout=timeout) as session:
  39. async with session.get(url, headers=headers) as response:
  40. total_size = int(response.headers.get("content-length",
  41. 0)) + current_size
  42. with open(file_path, "ab+") as file:
  43. async for data in response.content.iter_chunked(chunk_size):
  44. current_size += len(data)
  45. file.write(data)
  46. done = current_size == total_size
  47. progress = round((current_size / total_size) * 100, 2)
  48. yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
  49. if done:
  50. file.seek(0)
  51. hashed = calculate_sha256(file)
  52. file.seek(0)
  53. url = f"{OLLAMA_API_BASE_URL}/blobs/sha256:{hashed}"
  54. response = requests.post(url, data=file)
  55. if response.ok:
  56. res = {
  57. "done": done,
  58. "blob": f"sha256:{hashed}",
  59. "name": file_name,
  60. }
  61. os.remove(file_path)
  62. yield f"data: {json.dumps(res)}\n\n"
  63. else:
  64. raise "Ollama: Could not create blob, Please try again."
  65. @router.get("/download")
  66. async def download(url: str, ):
  67. # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
  68. file_name = parse_huggingface_url(url)
  69. if file_name:
  70. os.makedirs("./uploads", exist_ok=True)
  71. file_path = os.path.join("./uploads", f"{file_name}")
  72. return StreamingResponse(
  73. download_file_stream(url, file_path, file_name),
  74. media_type="text/event-stream",
  75. )
  76. else:
  77. return None
  78. @router.post("/upload")
  79. async def upload(file: UploadFile = File(...)):
  80. os.makedirs("./uploads", exist_ok=True)
  81. file_path = os.path.join("./uploads", file.filename)
  82. async def file_write_stream():
  83. total = 0
  84. total_size = file.size
  85. chunk_size = 1024 * 1024
  86. done = False
  87. try:
  88. with open(file_path, "wb+") as f:
  89. while True:
  90. chunk = file.file.read(chunk_size)
  91. if not chunk:
  92. break
  93. f.write(chunk)
  94. total += len(chunk)
  95. done = total_size == total
  96. progress = round((total / total_size) * 100, 2)
  97. res = {
  98. "progress": progress,
  99. "total": total_size,
  100. "completed": total,
  101. }
  102. yield f"data: {json.dumps(res)}\n\n"
  103. if done:
  104. f.seek(0)
  105. hashed = calculate_sha256(f)
  106. f.seek(0)
  107. url = f"{OLLAMA_API_BASE_URL}/blobs/sha256:{hashed}"
  108. response = requests.post(url, data=f)
  109. if response.ok:
  110. res = {
  111. "done": done,
  112. "blob": f"sha256:{hashed}",
  113. "name": file.filename,
  114. }
  115. os.remove(file_path)
  116. yield f"data: {json.dumps(res)}\n\n"
  117. else:
  118. raise "Ollama: Could not create blob, Please try again."
  119. except Exception as e:
  120. res = {"error": str(e)}
  121. yield f"data: {json.dumps(res)}\n\n"
  122. return StreamingResponse(file_write_stream(),
  123. media_type="text/event-stream")