utils.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from fastapi import APIRouter, UploadFile, File, BackgroundTasks
  2. from fastapi import Depends, HTTPException, status
  3. from starlette.responses import StreamingResponse, FileResponse
  4. from pydantic import BaseModel
  5. import requests
  6. import os
  7. import aiohttp
  8. import json
  9. from utils.utils import get_admin_user
  10. from utils.misc import calculate_sha256, get_gravatar_url
  11. from config import OLLAMA_BASE_URLS, DATA_DIR, UPLOAD_DIR
  12. from constants import ERROR_MESSAGES
  13. router = APIRouter()
  14. class UploadBlobForm(BaseModel):
  15. filename: str
  16. from urllib.parse import urlparse
  17. def parse_huggingface_url(hf_url):
  18. try:
  19. # Parse the URL
  20. parsed_url = urlparse(hf_url)
  21. # Get the path and split it into components
  22. path_components = parsed_url.path.split("/")
  23. # Extract the desired output
  24. user_repo = "/".join(path_components[1:3])
  25. model_file = path_components[-1]
  26. return model_file
  27. except ValueError:
  28. return None
  29. async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024):
  30. done = False
  31. if os.path.exists(file_path):
  32. current_size = os.path.getsize(file_path)
  33. else:
  34. current_size = 0
  35. headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
  36. timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
  37. async with aiohttp.ClientSession(timeout=timeout) as session:
  38. async with session.get(url, headers=headers) as response:
  39. total_size = int(response.headers.get("content-length", 0)) + current_size
  40. with open(file_path, "ab+") as file:
  41. async for data in response.content.iter_chunked(chunk_size):
  42. current_size += len(data)
  43. file.write(data)
  44. done = current_size == total_size
  45. progress = round((current_size / total_size) * 100, 2)
  46. yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
  47. if done:
  48. file.seek(0)
  49. hashed = calculate_sha256(file)
  50. file.seek(0)
  51. url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}"
  52. response = requests.post(url, data=file)
  53. if response.ok:
  54. res = {
  55. "done": done,
  56. "blob": f"sha256:{hashed}",
  57. "name": file_name,
  58. }
  59. os.remove(file_path)
  60. yield f"data: {json.dumps(res)}\n\n"
  61. else:
  62. raise "Ollama: Could not create blob, Please try again."
  63. @router.get("/download")
  64. async def download(
  65. url: str,
  66. ):
  67. # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
  68. file_name = parse_huggingface_url(url)
  69. if file_name:
  70. file_path = f"{UPLOAD_DIR}/{file_name}"
  71. return StreamingResponse(
  72. download_file_stream(url, file_path, file_name),
  73. media_type="text/event-stream",
  74. )
  75. else:
  76. return None
  77. @router.post("/upload")
  78. def upload(file: UploadFile = File(...)):
  79. file_path = f"{UPLOAD_DIR}/{file.filename}"
  80. # Save file in chunks
  81. with open(file_path, "wb+") as f:
  82. for chunk in file.file:
  83. f.write(chunk)
  84. def file_process_stream():
  85. total_size = os.path.getsize(file_path)
  86. chunk_size = 1024 * 1024
  87. try:
  88. with open(file_path, "rb") as f:
  89. total = 0
  90. done = False
  91. while not done:
  92. chunk = f.read(chunk_size)
  93. if not chunk:
  94. done = True
  95. continue
  96. total += len(chunk)
  97. progress = round((total / total_size) * 100, 2)
  98. res = {
  99. "progress": progress,
  100. "total": total_size,
  101. "completed": total,
  102. }
  103. yield f"data: {json.dumps(res)}\n\n"
  104. if done:
  105. f.seek(0)
  106. hashed = calculate_sha256(f)
  107. f.seek(0)
  108. url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}"
  109. response = requests.post(url, data=f)
  110. if response.ok:
  111. res = {
  112. "done": done,
  113. "blob": f"sha256:{hashed}",
  114. "name": file.filename,
  115. }
  116. os.remove(file_path)
  117. yield f"data: {json.dumps(res)}\n\n"
  118. else:
  119. raise Exception(
  120. "Ollama: Could not create blob, Please try again."
  121. )
  122. except Exception as e:
  123. res = {"error": str(e)}
  124. yield f"data: {json.dumps(res)}\n\n"
  125. return StreamingResponse(file_process_stream(), media_type="text/event-stream")
  126. @router.get("/gravatar")
  127. async def get_gravatar(
  128. email: str,
  129. ):
  130. return get_gravatar_url(email)
  131. @router.get("/db/download")
  132. async def download_db(user=Depends(get_admin_user)):
  133. return FileResponse(
  134. f"{DATA_DIR}/webui.db",
  135. media_type="application/octet-stream",
  136. filename="webui.db",
  137. )