ollama.py 51 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634
  1. # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
  2. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
  3. # least connections, or least response time for better resource utilization and performance optimization.
  4. import asyncio
  5. import json
  6. import logging
  7. import os
  8. import random
  9. import re
  10. import time
  11. from typing import Optional, Union
  12. from urllib.parse import urlparse
  13. import aiohttp
  14. from aiocache import cached
  15. import requests
  16. from open_webui.models.users import UserModel
  17. from open_webui.env import (
  18. ENABLE_FORWARD_USER_INFO_HEADERS,
  19. )
  20. from fastapi import (
  21. Depends,
  22. FastAPI,
  23. File,
  24. HTTPException,
  25. Request,
  26. UploadFile,
  27. APIRouter,
  28. )
  29. from fastapi.middleware.cors import CORSMiddleware
  30. from fastapi.responses import StreamingResponse
  31. from pydantic import BaseModel, ConfigDict, validator
  32. from starlette.background import BackgroundTask
  33. from open_webui.models.models import Models
  34. from open_webui.utils.misc import (
  35. calculate_sha256,
  36. )
  37. from open_webui.utils.payload import (
  38. apply_model_params_to_body_ollama,
  39. apply_model_params_to_body_openai,
  40. apply_model_system_prompt_to_body,
  41. )
  42. from open_webui.utils.auth import get_admin_user, get_verified_user
  43. from open_webui.utils.access_control import has_access
  44. from open_webui.config import (
  45. UPLOAD_DIR,
  46. )
  47. from open_webui.env import (
  48. ENV,
  49. SRC_LOG_LEVELS,
  50. AIOHTTP_CLIENT_TIMEOUT,
  51. AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
  52. BYPASS_MODEL_ACCESS_CONTROL,
  53. )
  54. from open_webui.constants import ERROR_MESSAGES
  55. log = logging.getLogger(__name__)
  56. log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
  57. ##########################################
  58. #
  59. # Utility functions
  60. #
  61. ##########################################
  62. async def send_get_request(url, key=None, user: UserModel = None):
  63. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
  64. try:
  65. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  66. async with session.get(
  67. url,
  68. headers={
  69. "Content-Type": "application/json",
  70. **({"Authorization": f"Bearer {key}"} if key else {}),
  71. **(
  72. {
  73. "X-OpenWebUI-User-Name": user.name,
  74. "X-OpenWebUI-User-Id": user.id,
  75. "X-OpenWebUI-User-Email": user.email,
  76. "X-OpenWebUI-User-Role": user.role,
  77. }
  78. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  79. else {}
  80. ),
  81. },
  82. ) as response:
  83. return await response.json()
  84. except Exception as e:
  85. # Handle connection error here
  86. log.error(f"Connection error: {e}")
  87. return None
  88. async def cleanup_response(
  89. response: Optional[aiohttp.ClientResponse],
  90. session: Optional[aiohttp.ClientSession],
  91. ):
  92. if response:
  93. response.close()
  94. if session:
  95. await session.close()
  96. async def send_post_request(
  97. url: str,
  98. payload: Union[str, bytes],
  99. stream: bool = True,
  100. key: Optional[str] = None,
  101. content_type: Optional[str] = None,
  102. user: UserModel = None,
  103. ):
  104. r = None
  105. try:
  106. session = aiohttp.ClientSession(
  107. trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  108. )
  109. r = await session.post(
  110. url,
  111. data=payload,
  112. headers={
  113. "Content-Type": "application/json",
  114. **({"Authorization": f"Bearer {key}"} if key else {}),
  115. **(
  116. {
  117. "X-OpenWebUI-User-Name": user.name,
  118. "X-OpenWebUI-User-Id": user.id,
  119. "X-OpenWebUI-User-Email": user.email,
  120. "X-OpenWebUI-User-Role": user.role,
  121. }
  122. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  123. else {}
  124. ),
  125. },
  126. )
  127. r.raise_for_status()
  128. if stream:
  129. response_headers = dict(r.headers)
  130. if content_type:
  131. response_headers["Content-Type"] = content_type
  132. return StreamingResponse(
  133. r.content,
  134. status_code=r.status,
  135. headers=response_headers,
  136. background=BackgroundTask(
  137. cleanup_response, response=r, session=session
  138. ),
  139. )
  140. else:
  141. res = await r.json()
  142. await cleanup_response(r, session)
  143. return res
  144. except Exception as e:
  145. detail = None
  146. if r is not None:
  147. try:
  148. res = await r.json()
  149. if "error" in res:
  150. detail = f"Ollama: {res.get('error', 'Unknown error')}"
  151. except Exception:
  152. detail = f"Ollama: {e}"
  153. raise HTTPException(
  154. status_code=r.status if r else 500,
  155. detail=detail if detail else "Open WebUI: Server Connection Error",
  156. )
  157. def get_api_key(idx, url, configs):
  158. parsed_url = urlparse(url)
  159. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  160. return configs.get(str(idx), configs.get(base_url, {})).get(
  161. "key", None
  162. ) # Legacy support
  163. ##########################################
  164. #
  165. # API routes
  166. #
  167. ##########################################
  168. router = APIRouter()
  169. @router.head("/")
  170. @router.get("/")
  171. async def get_status():
  172. return {"status": True}
  173. class ConnectionVerificationForm(BaseModel):
  174. url: str
  175. key: Optional[str] = None
  176. @router.post("/verify")
  177. async def verify_connection(
  178. form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
  179. ):
  180. url = form_data.url
  181. key = form_data.key
  182. async with aiohttp.ClientSession(
  183. timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
  184. ) as session:
  185. try:
  186. async with session.get(
  187. f"{url}/api/version",
  188. headers={
  189. **({"Authorization": f"Bearer {key}"} if key else {}),
  190. **(
  191. {
  192. "X-OpenWebUI-User-Name": user.name,
  193. "X-OpenWebUI-User-Id": user.id,
  194. "X-OpenWebUI-User-Email": user.email,
  195. "X-OpenWebUI-User-Role": user.role,
  196. }
  197. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  198. else {}
  199. ),
  200. },
  201. ) as r:
  202. if r.status != 200:
  203. detail = f"HTTP Error: {r.status}"
  204. res = await r.json()
  205. if "error" in res:
  206. detail = f"External Error: {res['error']}"
  207. raise Exception(detail)
  208. data = await r.json()
  209. return data
  210. except aiohttp.ClientError as e:
  211. log.exception(f"Client error: {str(e)}")
  212. raise HTTPException(
  213. status_code=500, detail="Open WebUI: Server Connection Error"
  214. )
  215. except Exception as e:
  216. log.exception(f"Unexpected error: {e}")
  217. error_detail = f"Unexpected error: {str(e)}"
  218. raise HTTPException(status_code=500, detail=error_detail)
  219. @router.get("/config")
  220. async def get_config(request: Request, user=Depends(get_admin_user)):
  221. return {
  222. "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
  223. "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
  224. "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
  225. }
  226. class OllamaConfigForm(BaseModel):
  227. ENABLE_OLLAMA_API: Optional[bool] = None
  228. OLLAMA_BASE_URLS: list[str]
  229. OLLAMA_API_CONFIGS: dict
  230. @router.post("/config/update")
  231. async def update_config(
  232. request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user)
  233. ):
  234. request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
  235. request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
  236. request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
  237. # Remove the API configs that are not in the API URLS
  238. keys = list(map(str, range(len(request.app.state.config.OLLAMA_BASE_URLS))))
  239. request.app.state.config.OLLAMA_API_CONFIGS = {
  240. key: value
  241. for key, value in request.app.state.config.OLLAMA_API_CONFIGS.items()
  242. if key in keys
  243. }
  244. return {
  245. "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
  246. "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
  247. "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
  248. }
  249. @cached(ttl=3)
  250. async def get_all_models(request: Request, user: UserModel = None):
  251. log.info("get_all_models()")
  252. if request.app.state.config.ENABLE_OLLAMA_API:
  253. request_tasks = []
  254. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
  255. if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
  256. url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
  257. ):
  258. request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
  259. else:
  260. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  261. str(idx),
  262. request.app.state.config.OLLAMA_API_CONFIGS.get(
  263. url, {}
  264. ), # Legacy support
  265. )
  266. enable = api_config.get("enable", True)
  267. key = api_config.get("key", None)
  268. if enable:
  269. request_tasks.append(
  270. send_get_request(f"{url}/api/tags", key, user=user)
  271. )
  272. else:
  273. request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
  274. responses = await asyncio.gather(*request_tasks)
  275. for idx, response in enumerate(responses):
  276. if response:
  277. url = request.app.state.config.OLLAMA_BASE_URLS[idx]
  278. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  279. str(idx),
  280. request.app.state.config.OLLAMA_API_CONFIGS.get(
  281. url, {}
  282. ), # Legacy support
  283. )
  284. prefix_id = api_config.get("prefix_id", None)
  285. model_ids = api_config.get("model_ids", [])
  286. if len(model_ids) != 0 and "models" in response:
  287. response["models"] = list(
  288. filter(
  289. lambda model: model["model"] in model_ids,
  290. response["models"],
  291. )
  292. )
  293. if prefix_id:
  294. for model in response.get("models", []):
  295. model["model"] = f"{prefix_id}.{model['model']}"
  296. def merge_models_lists(model_lists):
  297. merged_models = {}
  298. for idx, model_list in enumerate(model_lists):
  299. if model_list is not None:
  300. for model in model_list:
  301. id = model["model"]
  302. if id not in merged_models:
  303. model["urls"] = [idx]
  304. merged_models[id] = model
  305. else:
  306. merged_models[id]["urls"].append(idx)
  307. return list(merged_models.values())
  308. models = {
  309. "models": merge_models_lists(
  310. map(
  311. lambda response: response.get("models", []) if response else None,
  312. responses,
  313. )
  314. )
  315. }
  316. else:
  317. models = {"models": []}
  318. request.app.state.OLLAMA_MODELS = {
  319. model["model"]: model for model in models["models"]
  320. }
  321. return models
  322. async def get_filtered_models(models, user):
  323. # Filter models based on user access control
  324. filtered_models = []
  325. for model in models.get("models", []):
  326. model_info = Models.get_model_by_id(model["model"])
  327. if model_info:
  328. if user.id == model_info.user_id or has_access(
  329. user.id, type="read", access_control=model_info.access_control
  330. ):
  331. filtered_models.append(model)
  332. return filtered_models
  333. @router.get("/api/tags")
  334. @router.get("/api/tags/{url_idx}")
  335. async def get_ollama_tags(
  336. request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
  337. ):
  338. models = []
  339. if url_idx is None:
  340. models = await get_all_models(request, user=user)
  341. else:
  342. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  343. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  344. r = None
  345. try:
  346. r = requests.request(
  347. method="GET",
  348. url=f"{url}/api/tags",
  349. headers={
  350. **({"Authorization": f"Bearer {key}"} if key else {}),
  351. **(
  352. {
  353. "X-OpenWebUI-User-Name": user.name,
  354. "X-OpenWebUI-User-Id": user.id,
  355. "X-OpenWebUI-User-Email": user.email,
  356. "X-OpenWebUI-User-Role": user.role,
  357. }
  358. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  359. else {}
  360. ),
  361. },
  362. )
  363. r.raise_for_status()
  364. models = r.json()
  365. except Exception as e:
  366. log.exception(e)
  367. detail = None
  368. if r is not None:
  369. try:
  370. res = r.json()
  371. if "error" in res:
  372. detail = f"Ollama: {res['error']}"
  373. except Exception:
  374. detail = f"Ollama: {e}"
  375. raise HTTPException(
  376. status_code=r.status_code if r else 500,
  377. detail=detail if detail else "Open WebUI: Server Connection Error",
  378. )
  379. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  380. models["models"] = await get_filtered_models(models, user)
  381. return models
  382. @router.get("/api/version")
  383. @router.get("/api/version/{url_idx}")
  384. async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
  385. if request.app.state.config.ENABLE_OLLAMA_API:
  386. if url_idx is None:
  387. # returns lowest version
  388. request_tasks = [
  389. send_get_request(
  390. f"{url}/api/version",
  391. request.app.state.config.OLLAMA_API_CONFIGS.get(
  392. str(idx),
  393. request.app.state.config.OLLAMA_API_CONFIGS.get(
  394. url, {}
  395. ), # Legacy support
  396. ).get("key", None),
  397. )
  398. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
  399. ]
  400. responses = await asyncio.gather(*request_tasks)
  401. responses = list(filter(lambda x: x is not None, responses))
  402. if len(responses) > 0:
  403. lowest_version = min(
  404. responses,
  405. key=lambda x: tuple(
  406. map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
  407. ),
  408. )
  409. return {"version": lowest_version["version"]}
  410. else:
  411. raise HTTPException(
  412. status_code=500,
  413. detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
  414. )
  415. else:
  416. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  417. r = None
  418. try:
  419. r = requests.request(method="GET", url=f"{url}/api/version")
  420. r.raise_for_status()
  421. return r.json()
  422. except Exception as e:
  423. log.exception(e)
  424. detail = None
  425. if r is not None:
  426. try:
  427. res = r.json()
  428. if "error" in res:
  429. detail = f"Ollama: {res['error']}"
  430. except Exception:
  431. detail = f"Ollama: {e}"
  432. raise HTTPException(
  433. status_code=r.status_code if r else 500,
  434. detail=detail if detail else "Open WebUI: Server Connection Error",
  435. )
  436. else:
  437. return {"version": False}
  438. @router.get("/api/ps")
  439. async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
  440. """
  441. List models that are currently loaded into Ollama memory, and which node they are loaded on.
  442. """
  443. if request.app.state.config.ENABLE_OLLAMA_API:
  444. request_tasks = [
  445. send_get_request(
  446. f"{url}/api/ps",
  447. request.app.state.config.OLLAMA_API_CONFIGS.get(
  448. str(idx),
  449. request.app.state.config.OLLAMA_API_CONFIGS.get(
  450. url, {}
  451. ), # Legacy support
  452. ).get("key", None),
  453. user=user,
  454. )
  455. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
  456. ]
  457. responses = await asyncio.gather(*request_tasks)
  458. return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
  459. else:
  460. return {}
  461. class ModelNameForm(BaseModel):
  462. name: str
  463. @router.post("/api/pull")
  464. @router.post("/api/pull/{url_idx}")
  465. async def pull_model(
  466. request: Request,
  467. form_data: ModelNameForm,
  468. url_idx: int = 0,
  469. user=Depends(get_admin_user),
  470. ):
  471. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  472. log.info(f"url: {url}")
  473. # Admin should be able to pull models from any source
  474. payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
  475. return await send_post_request(
  476. url=f"{url}/api/pull",
  477. payload=json.dumps(payload),
  478. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  479. user=user,
  480. )
  481. class PushModelForm(BaseModel):
  482. name: str
  483. insecure: Optional[bool] = None
  484. stream: Optional[bool] = None
  485. @router.delete("/api/push")
  486. @router.delete("/api/push/{url_idx}")
  487. async def push_model(
  488. request: Request,
  489. form_data: PushModelForm,
  490. url_idx: Optional[int] = None,
  491. user=Depends(get_admin_user),
  492. ):
  493. if url_idx is None:
  494. await get_all_models(request, user=user)
  495. models = request.app.state.OLLAMA_MODELS
  496. if form_data.name in models:
  497. url_idx = models[form_data.name]["urls"][0]
  498. else:
  499. raise HTTPException(
  500. status_code=400,
  501. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  502. )
  503. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  504. log.debug(f"url: {url}")
  505. return await send_post_request(
  506. url=f"{url}/api/push",
  507. payload=form_data.model_dump_json(exclude_none=True).encode(),
  508. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  509. user=user,
  510. )
  511. class CreateModelForm(BaseModel):
  512. model: Optional[str] = None
  513. stream: Optional[bool] = None
  514. path: Optional[str] = None
  515. model_config = ConfigDict(extra="allow")
  516. @router.post("/api/create")
  517. @router.post("/api/create/{url_idx}")
  518. async def create_model(
  519. request: Request,
  520. form_data: CreateModelForm,
  521. url_idx: int = 0,
  522. user=Depends(get_admin_user),
  523. ):
  524. log.debug(f"form_data: {form_data}")
  525. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  526. return await send_post_request(
  527. url=f"{url}/api/create",
  528. payload=form_data.model_dump_json(exclude_none=True).encode(),
  529. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  530. user=user,
  531. )
  532. class CopyModelForm(BaseModel):
  533. source: str
  534. destination: str
  535. @router.post("/api/copy")
  536. @router.post("/api/copy/{url_idx}")
  537. async def copy_model(
  538. request: Request,
  539. form_data: CopyModelForm,
  540. url_idx: Optional[int] = None,
  541. user=Depends(get_admin_user),
  542. ):
  543. if url_idx is None:
  544. await get_all_models(request, user=user)
  545. models = request.app.state.OLLAMA_MODELS
  546. if form_data.source in models:
  547. url_idx = models[form_data.source]["urls"][0]
  548. else:
  549. raise HTTPException(
  550. status_code=400,
  551. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
  552. )
  553. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  554. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  555. try:
  556. r = requests.request(
  557. method="POST",
  558. url=f"{url}/api/copy",
  559. headers={
  560. "Content-Type": "application/json",
  561. **({"Authorization": f"Bearer {key}"} if key else {}),
  562. **(
  563. {
  564. "X-OpenWebUI-User-Name": user.name,
  565. "X-OpenWebUI-User-Id": user.id,
  566. "X-OpenWebUI-User-Email": user.email,
  567. "X-OpenWebUI-User-Role": user.role,
  568. }
  569. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  570. else {}
  571. ),
  572. },
  573. data=form_data.model_dump_json(exclude_none=True).encode(),
  574. )
  575. r.raise_for_status()
  576. log.debug(f"r.text: {r.text}")
  577. return True
  578. except Exception as e:
  579. log.exception(e)
  580. detail = None
  581. if r is not None:
  582. try:
  583. res = r.json()
  584. if "error" in res:
  585. detail = f"Ollama: {res['error']}"
  586. except Exception:
  587. detail = f"Ollama: {e}"
  588. raise HTTPException(
  589. status_code=r.status_code if r else 500,
  590. detail=detail if detail else "Open WebUI: Server Connection Error",
  591. )
  592. @router.delete("/api/delete")
  593. @router.delete("/api/delete/{url_idx}")
  594. async def delete_model(
  595. request: Request,
  596. form_data: ModelNameForm,
  597. url_idx: Optional[int] = None,
  598. user=Depends(get_admin_user),
  599. ):
  600. if url_idx is None:
  601. await get_all_models(request, user=user)
  602. models = request.app.state.OLLAMA_MODELS
  603. if form_data.name in models:
  604. url_idx = models[form_data.name]["urls"][0]
  605. else:
  606. raise HTTPException(
  607. status_code=400,
  608. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  609. )
  610. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  611. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  612. try:
  613. r = requests.request(
  614. method="DELETE",
  615. url=f"{url}/api/delete",
  616. data=form_data.model_dump_json(exclude_none=True).encode(),
  617. headers={
  618. "Content-Type": "application/json",
  619. **({"Authorization": f"Bearer {key}"} if key else {}),
  620. **(
  621. {
  622. "X-OpenWebUI-User-Name": user.name,
  623. "X-OpenWebUI-User-Id": user.id,
  624. "X-OpenWebUI-User-Email": user.email,
  625. "X-OpenWebUI-User-Role": user.role,
  626. }
  627. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  628. else {}
  629. ),
  630. },
  631. )
  632. r.raise_for_status()
  633. log.debug(f"r.text: {r.text}")
  634. return True
  635. except Exception as e:
  636. log.exception(e)
  637. detail = None
  638. if r is not None:
  639. try:
  640. res = r.json()
  641. if "error" in res:
  642. detail = f"Ollama: {res['error']}"
  643. except Exception:
  644. detail = f"Ollama: {e}"
  645. raise HTTPException(
  646. status_code=r.status_code if r else 500,
  647. detail=detail if detail else "Open WebUI: Server Connection Error",
  648. )
  649. @router.post("/api/show")
  650. async def show_model_info(
  651. request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
  652. ):
  653. await get_all_models(request, user=user)
  654. models = request.app.state.OLLAMA_MODELS
  655. if form_data.name not in models:
  656. raise HTTPException(
  657. status_code=400,
  658. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  659. )
  660. url_idx = random.choice(models[form_data.name]["urls"])
  661. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  662. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  663. try:
  664. r = requests.request(
  665. method="POST",
  666. url=f"{url}/api/show",
  667. headers={
  668. "Content-Type": "application/json",
  669. **({"Authorization": f"Bearer {key}"} if key else {}),
  670. **(
  671. {
  672. "X-OpenWebUI-User-Name": user.name,
  673. "X-OpenWebUI-User-Id": user.id,
  674. "X-OpenWebUI-User-Email": user.email,
  675. "X-OpenWebUI-User-Role": user.role,
  676. }
  677. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  678. else {}
  679. ),
  680. },
  681. data=form_data.model_dump_json(exclude_none=True).encode(),
  682. )
  683. r.raise_for_status()
  684. return r.json()
  685. except Exception as e:
  686. log.exception(e)
  687. detail = None
  688. if r is not None:
  689. try:
  690. res = r.json()
  691. if "error" in res:
  692. detail = f"Ollama: {res['error']}"
  693. except Exception:
  694. detail = f"Ollama: {e}"
  695. raise HTTPException(
  696. status_code=r.status_code if r else 500,
  697. detail=detail if detail else "Open WebUI: Server Connection Error",
  698. )
  699. class GenerateEmbedForm(BaseModel):
  700. model: str
  701. input: list[str] | str
  702. truncate: Optional[bool] = None
  703. options: Optional[dict] = None
  704. keep_alive: Optional[Union[int, str]] = None
  705. @router.post("/api/embed")
  706. @router.post("/api/embed/{url_idx}")
  707. async def embed(
  708. request: Request,
  709. form_data: GenerateEmbedForm,
  710. url_idx: Optional[int] = None,
  711. user=Depends(get_verified_user),
  712. ):
  713. log.info(f"generate_ollama_batch_embeddings {form_data}")
  714. if url_idx is None:
  715. await get_all_models(request, user=user)
  716. models = request.app.state.OLLAMA_MODELS
  717. model = form_data.model
  718. if ":" not in model:
  719. model = f"{model}:latest"
  720. if model in models:
  721. url_idx = random.choice(models[model]["urls"])
  722. else:
  723. raise HTTPException(
  724. status_code=400,
  725. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  726. )
  727. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  728. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  729. try:
  730. r = requests.request(
  731. method="POST",
  732. url=f"{url}/api/embed",
  733. headers={
  734. "Content-Type": "application/json",
  735. **({"Authorization": f"Bearer {key}"} if key else {}),
  736. **(
  737. {
  738. "X-OpenWebUI-User-Name": user.name,
  739. "X-OpenWebUI-User-Id": user.id,
  740. "X-OpenWebUI-User-Email": user.email,
  741. "X-OpenWebUI-User-Role": user.role,
  742. }
  743. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  744. else {}
  745. ),
  746. },
  747. data=form_data.model_dump_json(exclude_none=True).encode(),
  748. )
  749. r.raise_for_status()
  750. data = r.json()
  751. return data
  752. except Exception as e:
  753. log.exception(e)
  754. detail = None
  755. if r is not None:
  756. try:
  757. res = r.json()
  758. if "error" in res:
  759. detail = f"Ollama: {res['error']}"
  760. except Exception:
  761. detail = f"Ollama: {e}"
  762. raise HTTPException(
  763. status_code=r.status_code if r else 500,
  764. detail=detail if detail else "Open WebUI: Server Connection Error",
  765. )
  766. class GenerateEmbeddingsForm(BaseModel):
  767. model: str
  768. prompt: str
  769. options: Optional[dict] = None
  770. keep_alive: Optional[Union[int, str]] = None
  771. @router.post("/api/embeddings")
  772. @router.post("/api/embeddings/{url_idx}")
  773. async def embeddings(
  774. request: Request,
  775. form_data: GenerateEmbeddingsForm,
  776. url_idx: Optional[int] = None,
  777. user=Depends(get_verified_user),
  778. ):
  779. log.info(f"generate_ollama_embeddings {form_data}")
  780. if url_idx is None:
  781. await get_all_models(request, user=user)
  782. models = request.app.state.OLLAMA_MODELS
  783. model = form_data.model
  784. if ":" not in model:
  785. model = f"{model}:latest"
  786. if model in models:
  787. url_idx = random.choice(models[model]["urls"])
  788. else:
  789. raise HTTPException(
  790. status_code=400,
  791. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  792. )
  793. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  794. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  795. try:
  796. r = requests.request(
  797. method="POST",
  798. url=f"{url}/api/embeddings",
  799. headers={
  800. "Content-Type": "application/json",
  801. **({"Authorization": f"Bearer {key}"} if key else {}),
  802. **(
  803. {
  804. "X-OpenWebUI-User-Name": user.name,
  805. "X-OpenWebUI-User-Id": user.id,
  806. "X-OpenWebUI-User-Email": user.email,
  807. "X-OpenWebUI-User-Role": user.role,
  808. }
  809. if ENABLE_FORWARD_USER_INFO_HEADERS and user
  810. else {}
  811. ),
  812. },
  813. data=form_data.model_dump_json(exclude_none=True).encode(),
  814. )
  815. r.raise_for_status()
  816. data = r.json()
  817. return data
  818. except Exception as e:
  819. log.exception(e)
  820. detail = None
  821. if r is not None:
  822. try:
  823. res = r.json()
  824. if "error" in res:
  825. detail = f"Ollama: {res['error']}"
  826. except Exception:
  827. detail = f"Ollama: {e}"
  828. raise HTTPException(
  829. status_code=r.status_code if r else 500,
  830. detail=detail if detail else "Open WebUI: Server Connection Error",
  831. )
  832. class GenerateCompletionForm(BaseModel):
  833. model: str
  834. prompt: str
  835. suffix: Optional[str] = None
  836. images: Optional[list[str]] = None
  837. format: Optional[str] = None
  838. options: Optional[dict] = None
  839. system: Optional[str] = None
  840. template: Optional[str] = None
  841. context: Optional[list[int]] = None
  842. stream: Optional[bool] = True
  843. raw: Optional[bool] = None
  844. keep_alive: Optional[Union[int, str]] = None
  845. @router.post("/api/generate")
  846. @router.post("/api/generate/{url_idx}")
  847. async def generate_completion(
  848. request: Request,
  849. form_data: GenerateCompletionForm,
  850. url_idx: Optional[int] = None,
  851. user=Depends(get_verified_user),
  852. ):
  853. if url_idx is None:
  854. await get_all_models(request, user=user)
  855. models = request.app.state.OLLAMA_MODELS
  856. model = form_data.model
  857. if ":" not in model:
  858. model = f"{model}:latest"
  859. if model in models:
  860. url_idx = random.choice(models[model]["urls"])
  861. else:
  862. raise HTTPException(
  863. status_code=400,
  864. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  865. )
  866. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  867. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  868. str(url_idx),
  869. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  870. )
  871. prefix_id = api_config.get("prefix_id", None)
  872. if prefix_id:
  873. form_data.model = form_data.model.replace(f"{prefix_id}.", "")
  874. return await send_post_request(
  875. url=f"{url}/api/generate",
  876. payload=form_data.model_dump_json(exclude_none=True).encode(),
  877. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  878. user=user,
  879. )
  880. class ChatMessage(BaseModel):
  881. role: str
  882. content: Optional[str] = None
  883. tool_calls: Optional[list[dict]] = None
  884. images: Optional[list[str]] = None
  885. @validator("content", pre=True)
  886. @classmethod
  887. def check_at_least_one_field(cls, field_value, values, **kwargs):
  888. # Raise an error if both 'content' and 'tool_calls' are None
  889. if field_value is None and (
  890. "tool_calls" not in values or values["tool_calls"] is None
  891. ):
  892. raise ValueError(
  893. "At least one of 'content' or 'tool_calls' must be provided"
  894. )
  895. return field_value
  896. class GenerateChatCompletionForm(BaseModel):
  897. model: str
  898. messages: list[ChatMessage]
  899. format: Optional[Union[dict, str]] = None
  900. options: Optional[dict] = None
  901. template: Optional[str] = None
  902. stream: Optional[bool] = True
  903. keep_alive: Optional[Union[int, str]] = None
  904. tools: Optional[list[dict]] = None
  905. async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
  906. if url_idx is None:
  907. models = request.app.state.OLLAMA_MODELS
  908. if model not in models:
  909. raise HTTPException(
  910. status_code=400,
  911. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
  912. )
  913. url_idx = random.choice(models[model].get("urls", []))
  914. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  915. return url, url_idx
  916. @router.post("/api/chat")
  917. @router.post("/api/chat/{url_idx}")
  918. async def generate_chat_completion(
  919. request: Request,
  920. form_data: dict,
  921. url_idx: Optional[int] = None,
  922. user=Depends(get_verified_user),
  923. bypass_filter: Optional[bool] = False,
  924. ):
  925. if BYPASS_MODEL_ACCESS_CONTROL:
  926. bypass_filter = True
  927. metadata = form_data.pop("metadata", None)
  928. try:
  929. form_data = GenerateChatCompletionForm(**form_data)
  930. except Exception as e:
  931. log.exception(e)
  932. raise HTTPException(
  933. status_code=400,
  934. detail=str(e),
  935. )
  936. payload = {**form_data.model_dump(exclude_none=True)}
  937. if "metadata" in payload:
  938. del payload["metadata"]
  939. model_id = payload["model"]
  940. model_info = Models.get_model_by_id(model_id)
  941. if model_info:
  942. if model_info.base_model_id:
  943. payload["model"] = model_info.base_model_id
  944. params = model_info.params.model_dump()
  945. if params:
  946. if payload.get("options") is None:
  947. payload["options"] = {}
  948. payload["options"] = apply_model_params_to_body_ollama(
  949. params, payload["options"]
  950. )
  951. payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
  952. # Check if user has access to the model
  953. if not bypass_filter and user.role == "user":
  954. if not (
  955. user.id == model_info.user_id
  956. or has_access(
  957. user.id, type="read", access_control=model_info.access_control
  958. )
  959. ):
  960. raise HTTPException(
  961. status_code=403,
  962. detail="Model not found",
  963. )
  964. elif not bypass_filter:
  965. if user.role != "admin":
  966. raise HTTPException(
  967. status_code=403,
  968. detail="Model not found",
  969. )
  970. if ":" not in payload["model"]:
  971. payload["model"] = f"{payload['model']}:latest"
  972. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  973. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  974. str(url_idx),
  975. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  976. )
  977. prefix_id = api_config.get("prefix_id", None)
  978. if prefix_id:
  979. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  980. return await send_post_request(
  981. url=f"{url}/api/chat",
  982. payload=json.dumps(payload),
  983. stream=form_data.stream,
  984. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  985. content_type="application/x-ndjson",
  986. user=user,
  987. )
  988. # TODO: we should update this part once Ollama supports other types
  989. class OpenAIChatMessageContent(BaseModel):
  990. type: str
  991. model_config = ConfigDict(extra="allow")
  992. class OpenAIChatMessage(BaseModel):
  993. role: str
  994. content: Union[str, list[OpenAIChatMessageContent]]
  995. model_config = ConfigDict(extra="allow")
  996. class OpenAIChatCompletionForm(BaseModel):
  997. model: str
  998. messages: list[OpenAIChatMessage]
  999. model_config = ConfigDict(extra="allow")
  1000. class OpenAICompletionForm(BaseModel):
  1001. model: str
  1002. prompt: str
  1003. model_config = ConfigDict(extra="allow")
  1004. @router.post("/v1/completions")
  1005. @router.post("/v1/completions/{url_idx}")
  1006. async def generate_openai_completion(
  1007. request: Request,
  1008. form_data: dict,
  1009. url_idx: Optional[int] = None,
  1010. user=Depends(get_verified_user),
  1011. ):
  1012. try:
  1013. form_data = OpenAICompletionForm(**form_data)
  1014. except Exception as e:
  1015. log.exception(e)
  1016. raise HTTPException(
  1017. status_code=400,
  1018. detail=str(e),
  1019. )
  1020. payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
  1021. if "metadata" in payload:
  1022. del payload["metadata"]
  1023. model_id = form_data.model
  1024. if ":" not in model_id:
  1025. model_id = f"{model_id}:latest"
  1026. model_info = Models.get_model_by_id(model_id)
  1027. if model_info:
  1028. if model_info.base_model_id:
  1029. payload["model"] = model_info.base_model_id
  1030. params = model_info.params.model_dump()
  1031. if params:
  1032. payload = apply_model_params_to_body_openai(params, payload)
  1033. # Check if user has access to the model
  1034. if user.role == "user":
  1035. if not (
  1036. user.id == model_info.user_id
  1037. or has_access(
  1038. user.id, type="read", access_control=model_info.access_control
  1039. )
  1040. ):
  1041. raise HTTPException(
  1042. status_code=403,
  1043. detail="Model not found",
  1044. )
  1045. else:
  1046. if user.role != "admin":
  1047. raise HTTPException(
  1048. status_code=403,
  1049. detail="Model not found",
  1050. )
  1051. if ":" not in payload["model"]:
  1052. payload["model"] = f"{payload['model']}:latest"
  1053. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  1054. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  1055. str(url_idx),
  1056. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1057. )
  1058. prefix_id = api_config.get("prefix_id", None)
  1059. if prefix_id:
  1060. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  1061. return await send_post_request(
  1062. url=f"{url}/v1/completions",
  1063. payload=json.dumps(payload),
  1064. stream=payload.get("stream", False),
  1065. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1066. user=user,
  1067. )
  1068. @router.post("/v1/chat/completions")
  1069. @router.post("/v1/chat/completions/{url_idx}")
  1070. async def generate_openai_chat_completion(
  1071. request: Request,
  1072. form_data: dict,
  1073. url_idx: Optional[int] = None,
  1074. user=Depends(get_verified_user),
  1075. ):
  1076. metadata = form_data.pop("metadata", None)
  1077. try:
  1078. completion_form = OpenAIChatCompletionForm(**form_data)
  1079. except Exception as e:
  1080. log.exception(e)
  1081. raise HTTPException(
  1082. status_code=400,
  1083. detail=str(e),
  1084. )
  1085. payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
  1086. if "metadata" in payload:
  1087. del payload["metadata"]
  1088. model_id = completion_form.model
  1089. if ":" not in model_id:
  1090. model_id = f"{model_id}:latest"
  1091. model_info = Models.get_model_by_id(model_id)
  1092. if model_info:
  1093. if model_info.base_model_id:
  1094. payload["model"] = model_info.base_model_id
  1095. params = model_info.params.model_dump()
  1096. if params:
  1097. payload = apply_model_params_to_body_openai(params, payload)
  1098. payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
  1099. # Check if user has access to the model
  1100. if user.role == "user":
  1101. if not (
  1102. user.id == model_info.user_id
  1103. or has_access(
  1104. user.id, type="read", access_control=model_info.access_control
  1105. )
  1106. ):
  1107. raise HTTPException(
  1108. status_code=403,
  1109. detail="Model not found",
  1110. )
  1111. else:
  1112. if user.role != "admin":
  1113. raise HTTPException(
  1114. status_code=403,
  1115. detail="Model not found",
  1116. )
  1117. if ":" not in payload["model"]:
  1118. payload["model"] = f"{payload['model']}:latest"
  1119. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  1120. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  1121. str(url_idx),
  1122. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1123. )
  1124. prefix_id = api_config.get("prefix_id", None)
  1125. if prefix_id:
  1126. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  1127. return await send_post_request(
  1128. url=f"{url}/v1/chat/completions",
  1129. payload=json.dumps(payload),
  1130. stream=payload.get("stream", False),
  1131. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1132. user=user,
  1133. )
  1134. @router.get("/v1/models")
  1135. @router.get("/v1/models/{url_idx}")
  1136. async def get_openai_models(
  1137. request: Request,
  1138. url_idx: Optional[int] = None,
  1139. user=Depends(get_verified_user),
  1140. ):
  1141. models = []
  1142. if url_idx is None:
  1143. model_list = await get_all_models(request, user=user)
  1144. models = [
  1145. {
  1146. "id": model["model"],
  1147. "object": "model",
  1148. "created": int(time.time()),
  1149. "owned_by": "openai",
  1150. }
  1151. for model in model_list["models"]
  1152. ]
  1153. else:
  1154. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1155. try:
  1156. r = requests.request(method="GET", url=f"{url}/api/tags")
  1157. r.raise_for_status()
  1158. model_list = r.json()
  1159. models = [
  1160. {
  1161. "id": model["model"],
  1162. "object": "model",
  1163. "created": int(time.time()),
  1164. "owned_by": "openai",
  1165. }
  1166. for model in models["models"]
  1167. ]
  1168. except Exception as e:
  1169. log.exception(e)
  1170. error_detail = "Open WebUI: Server Connection Error"
  1171. if r is not None:
  1172. try:
  1173. res = r.json()
  1174. if "error" in res:
  1175. error_detail = f"Ollama: {res['error']}"
  1176. except Exception:
  1177. error_detail = f"Ollama: {e}"
  1178. raise HTTPException(
  1179. status_code=r.status_code if r else 500,
  1180. detail=error_detail,
  1181. )
  1182. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  1183. # Filter models based on user access control
  1184. filtered_models = []
  1185. for model in models:
  1186. model_info = Models.get_model_by_id(model["id"])
  1187. if model_info:
  1188. if user.id == model_info.user_id or has_access(
  1189. user.id, type="read", access_control=model_info.access_control
  1190. ):
  1191. filtered_models.append(model)
  1192. models = filtered_models
  1193. return {
  1194. "data": models,
  1195. "object": "list",
  1196. }
  1197. class UrlForm(BaseModel):
  1198. url: str
  1199. class UploadBlobForm(BaseModel):
  1200. filename: str
  1201. def parse_huggingface_url(hf_url):
  1202. try:
  1203. # Parse the URL
  1204. parsed_url = urlparse(hf_url)
  1205. # Get the path and split it into components
  1206. path_components = parsed_url.path.split("/")
  1207. # Extract the desired output
  1208. model_file = path_components[-1]
  1209. return model_file
  1210. except ValueError:
  1211. return None
  1212. async def download_file_stream(
  1213. ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
  1214. ):
  1215. done = False
  1216. if os.path.exists(file_path):
  1217. current_size = os.path.getsize(file_path)
  1218. else:
  1219. current_size = 0
  1220. headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
  1221. timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
  1222. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  1223. async with session.get(file_url, headers=headers) as response:
  1224. total_size = int(response.headers.get("content-length", 0)) + current_size
  1225. with open(file_path, "ab+") as file:
  1226. async for data in response.content.iter_chunked(chunk_size):
  1227. current_size += len(data)
  1228. file.write(data)
  1229. done = current_size == total_size
  1230. progress = round((current_size / total_size) * 100, 2)
  1231. yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
  1232. if done:
  1233. file.seek(0)
  1234. hashed = calculate_sha256(file)
  1235. file.seek(0)
  1236. url = f"{ollama_url}/api/blobs/sha256:{hashed}"
  1237. response = requests.post(url, data=file)
  1238. if response.ok:
  1239. res = {
  1240. "done": done,
  1241. "blob": f"sha256:{hashed}",
  1242. "name": file_name,
  1243. }
  1244. os.remove(file_path)
  1245. yield f"data: {json.dumps(res)}\n\n"
  1246. else:
  1247. raise "Ollama: Could not create blob, Please try again."
  1248. # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
  1249. @router.post("/models/download")
  1250. @router.post("/models/download/{url_idx}")
  1251. async def download_model(
  1252. request: Request,
  1253. form_data: UrlForm,
  1254. url_idx: Optional[int] = None,
  1255. user=Depends(get_admin_user),
  1256. ):
  1257. allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
  1258. if not any(form_data.url.startswith(host) for host in allowed_hosts):
  1259. raise HTTPException(
  1260. status_code=400,
  1261. detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
  1262. )
  1263. if url_idx is None:
  1264. url_idx = 0
  1265. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1266. file_name = parse_huggingface_url(form_data.url)
  1267. if file_name:
  1268. file_path = f"{UPLOAD_DIR}/{file_name}"
  1269. return StreamingResponse(
  1270. download_file_stream(url, form_data.url, file_path, file_name),
  1271. )
  1272. else:
  1273. return None
  1274. # TODO: Progress bar does not reflect size & duration of upload.
  1275. @router.post("/models/upload")
  1276. @router.post("/models/upload/{url_idx}")
  1277. async def upload_model(
  1278. request: Request,
  1279. file: UploadFile = File(...),
  1280. url_idx: Optional[int] = None,
  1281. user=Depends(get_admin_user),
  1282. ):
  1283. if url_idx is None:
  1284. url_idx = 0
  1285. ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1286. file_path = os.path.join(UPLOAD_DIR, file.filename)
  1287. os.makedirs(UPLOAD_DIR, exist_ok=True)
  1288. # --- P1: save file locally ---
  1289. chunk_size = 1024 * 1024 * 2 # 2 MB chunks
  1290. with open(file_path, "wb") as out_f:
  1291. while True:
  1292. chunk = file.file.read(chunk_size)
  1293. # log.info(f"Chunk: {str(chunk)}") # DEBUG
  1294. if not chunk:
  1295. break
  1296. out_f.write(chunk)
  1297. async def file_process_stream():
  1298. nonlocal ollama_url
  1299. total_size = os.path.getsize(file_path)
  1300. log.info(f"Total Model Size: {str(total_size)}") # DEBUG
  1301. # --- P2: SSE progress + calculate sha256 hash ---
  1302. file_hash = calculate_sha256(file_path, chunk_size)
  1303. log.info(f"Model Hash: {str(file_hash)}") # DEBUG
  1304. try:
  1305. with open(file_path, "rb") as f:
  1306. bytes_read = 0
  1307. while chunk := f.read(chunk_size):
  1308. bytes_read += len(chunk)
  1309. progress = round(bytes_read / total_size * 100, 2)
  1310. data_msg = {
  1311. "progress": progress,
  1312. "total": total_size,
  1313. "completed": bytes_read,
  1314. }
  1315. yield f"data: {json.dumps(data_msg)}\n\n"
  1316. # --- P3: Upload to ollama /api/blobs ---
  1317. with open(file_path, "rb") as f:
  1318. url = f"{ollama_url}/api/blobs/sha256:{file_hash}"
  1319. response = requests.post(url, data=f)
  1320. if response.ok:
  1321. log.info(f"Uploaded to /api/blobs") # DEBUG
  1322. # Remove local file
  1323. os.remove(file_path)
  1324. # Create model in ollama
  1325. model_name, ext = os.path.splitext(file.filename)
  1326. log.info(f"Created Model: {model_name}") # DEBUG
  1327. create_payload = {
  1328. "model": model_name,
  1329. # Reference the file by its original name => the uploaded blob's digest
  1330. "files": {file.filename: f"sha256:{file_hash}"},
  1331. }
  1332. log.info(f"Model Payload: {create_payload}") # DEBUG
  1333. # Call ollama /api/create
  1334. # https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
  1335. create_resp = requests.post(
  1336. url=f"{ollama_url}/api/create",
  1337. headers={"Content-Type": "application/json"},
  1338. data=json.dumps(create_payload),
  1339. )
  1340. if create_resp.ok:
  1341. log.info(f"API SUCCESS!") # DEBUG
  1342. done_msg = {
  1343. "done": True,
  1344. "blob": f"sha256:{file_hash}",
  1345. "name": file.filename,
  1346. "model_created": model_name,
  1347. }
  1348. yield f"data: {json.dumps(done_msg)}\n\n"
  1349. else:
  1350. raise Exception(
  1351. f"Failed to create model in Ollama. {create_resp.text}"
  1352. )
  1353. else:
  1354. raise Exception("Ollama: Could not create blob, Please try again.")
  1355. except Exception as e:
  1356. res = {"error": str(e)}
  1357. yield f"data: {json.dumps(res)}\n\n"
  1358. return StreamingResponse(file_process_stream(), media_type="text/event-stream")