ollama.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428
  1. # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
  2. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
  3. # least connections, or least response time for better resource utilization and performance optimization.
  4. import asyncio
  5. import json
  6. import logging
  7. import os
  8. import random
  9. import re
  10. import time
  11. from typing import Optional, Union
  12. from urllib.parse import urlparse
  13. import aiohttp
  14. from aiocache import cached
  15. import requests
  16. from fastapi import (
  17. Depends,
  18. FastAPI,
  19. File,
  20. HTTPException,
  21. Request,
  22. UploadFile,
  23. APIRouter,
  24. )
  25. from fastapi.middleware.cors import CORSMiddleware
  26. from fastapi.responses import StreamingResponse
  27. from pydantic import BaseModel, ConfigDict
  28. from starlette.background import BackgroundTask
  29. from open_webui.models.models import Models
  30. from open_webui.utils.misc import (
  31. calculate_sha256,
  32. )
  33. from open_webui.utils.payload import (
  34. apply_model_params_to_body_ollama,
  35. apply_model_params_to_body_openai,
  36. apply_model_system_prompt_to_body,
  37. )
  38. from open_webui.utils.auth import get_admin_user, get_verified_user
  39. from open_webui.utils.access_control import has_access
  40. from open_webui.config import (
  41. UPLOAD_DIR,
  42. )
  43. from open_webui.env import (
  44. ENV,
  45. SRC_LOG_LEVELS,
  46. AIOHTTP_CLIENT_TIMEOUT,
  47. AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
  48. BYPASS_MODEL_ACCESS_CONTROL,
  49. )
  50. from open_webui.constants import ERROR_MESSAGES
  51. log = logging.getLogger(__name__)
  52. log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
  53. ##########################################
  54. #
  55. # Utility functions
  56. #
  57. ##########################################
  58. async def send_get_request(url, key=None):
  59. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
  60. try:
  61. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  62. async with session.get(
  63. url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
  64. ) as response:
  65. return await response.json()
  66. except Exception as e:
  67. # Handle connection error here
  68. log.error(f"Connection error: {e}")
  69. return None
  70. async def send_post_request(
  71. url: str,
  72. payload: Union[str, bytes],
  73. stream: bool = True,
  74. key: Optional[str] = None,
  75. content_type: Optional[str] = None,
  76. ):
  77. async def cleanup_response(
  78. response: Optional[aiohttp.ClientResponse],
  79. session: Optional[aiohttp.ClientSession],
  80. ):
  81. if response:
  82. response.close()
  83. if session:
  84. await session.close()
  85. r = None
  86. try:
  87. session = aiohttp.ClientSession(
  88. trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  89. )
  90. r = await session.post(
  91. url,
  92. data=payload,
  93. headers={
  94. "Content-Type": "application/json",
  95. **({"Authorization": f"Bearer {key}"} if key else {}),
  96. },
  97. )
  98. r.raise_for_status()
  99. if stream:
  100. response_headers = dict(r.headers)
  101. if content_type:
  102. response_headers["Content-Type"] = content_type
  103. return StreamingResponse(
  104. r.content,
  105. status_code=r.status,
  106. headers=response_headers,
  107. background=BackgroundTask(
  108. cleanup_response, response=r, session=session
  109. ),
  110. )
  111. else:
  112. res = await r.json()
  113. await cleanup_response(r, session)
  114. return res
  115. except Exception as e:
  116. detail = None
  117. if r is not None:
  118. try:
  119. res = await r.json()
  120. if "error" in res:
  121. detail = f"Ollama: {res.get('error', 'Unknown error')}"
  122. except Exception:
  123. detail = f"Ollama: {e}"
  124. raise HTTPException(
  125. status_code=r.status if r else 500,
  126. detail=detail if detail else "Open WebUI: Server Connection Error",
  127. )
  128. def get_api_key(url, configs):
  129. parsed_url = urlparse(url)
  130. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  131. return configs.get(base_url, {}).get("key", None)
  132. ##########################################
  133. #
  134. # API routes
  135. #
  136. ##########################################
  137. router = APIRouter()
  138. @router.head("/")
  139. @router.get("/")
  140. async def get_status():
  141. return {"status": True}
  142. class ConnectionVerificationForm(BaseModel):
  143. url: str
  144. key: Optional[str] = None
  145. @router.post("/verify")
  146. async def verify_connection(
  147. form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
  148. ):
  149. url = form_data.url
  150. key = form_data.key
  151. async with aiohttp.ClientSession(
  152. timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
  153. ) as session:
  154. try:
  155. async with session.get(
  156. f"{url}/api/version",
  157. headers={**({"Authorization": f"Bearer {key}"} if key else {})},
  158. ) as r:
  159. if r.status != 200:
  160. detail = f"HTTP Error: {r.status}"
  161. res = await r.json()
  162. if "error" in res:
  163. detail = f"External Error: {res['error']}"
  164. raise Exception(detail)
  165. data = await r.json()
  166. return data
  167. except aiohttp.ClientError as e:
  168. log.exception(f"Client error: {str(e)}")
  169. raise HTTPException(
  170. status_code=500, detail="Open WebUI: Server Connection Error"
  171. )
  172. except Exception as e:
  173. log.exception(f"Unexpected error: {e}")
  174. error_detail = f"Unexpected error: {str(e)}"
  175. raise HTTPException(status_code=500, detail=error_detail)
  176. @router.get("/config")
  177. async def get_config(request: Request, user=Depends(get_admin_user)):
  178. return {
  179. "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
  180. "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
  181. "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
  182. }
  183. class OllamaConfigForm(BaseModel):
  184. ENABLE_OLLAMA_API: Optional[bool] = None
  185. OLLAMA_BASE_URLS: list[str]
  186. OLLAMA_API_CONFIGS: dict
  187. @router.post("/config/update")
  188. async def update_config(
  189. request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user)
  190. ):
  191. request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
  192. request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
  193. request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
  194. # Remove any extra configs
  195. config_urls = request.app.state.config.OLLAMA_API_CONFIGS.keys()
  196. for url in list(request.app.state.config.OLLAMA_BASE_URLS):
  197. if url not in config_urls:
  198. request.app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
  199. return {
  200. "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
  201. "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
  202. "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
  203. }
  204. @cached(ttl=3)
  205. async def get_all_models(request: Request):
  206. log.info("get_all_models()")
  207. if request.app.state.config.ENABLE_OLLAMA_API:
  208. request_tasks = []
  209. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
  210. if url not in request.app.state.config.OLLAMA_API_CONFIGS:
  211. request_tasks.append(send_get_request(f"{url}/api/tags"))
  212. else:
  213. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
  214. enable = api_config.get("enable", True)
  215. key = api_config.get("key", None)
  216. if enable:
  217. request_tasks.append(send_get_request(f"{url}/api/tags", key))
  218. else:
  219. request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
  220. responses = await asyncio.gather(*request_tasks)
  221. for idx, response in enumerate(responses):
  222. if response:
  223. url = request.app.state.config.OLLAMA_BASE_URLS[idx]
  224. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
  225. prefix_id = api_config.get("prefix_id", None)
  226. model_ids = api_config.get("model_ids", [])
  227. if len(model_ids) != 0 and "models" in response:
  228. response["models"] = list(
  229. filter(
  230. lambda model: model["model"] in model_ids,
  231. response["models"],
  232. )
  233. )
  234. if prefix_id:
  235. for model in response.get("models", []):
  236. model["model"] = f"{prefix_id}.{model['model']}"
  237. def merge_models_lists(model_lists):
  238. merged_models = {}
  239. for idx, model_list in enumerate(model_lists):
  240. if model_list is not None:
  241. for model in model_list:
  242. id = model["model"]
  243. if id not in merged_models:
  244. model["urls"] = [idx]
  245. merged_models[id] = model
  246. else:
  247. merged_models[id]["urls"].append(idx)
  248. return list(merged_models.values())
  249. models = {
  250. "models": merge_models_lists(
  251. map(
  252. lambda response: response.get("models", []) if response else None,
  253. responses,
  254. )
  255. )
  256. }
  257. else:
  258. models = {"models": []}
  259. return models
  260. async def get_filtered_models(models, user):
  261. # Filter models based on user access control
  262. filtered_models = []
  263. for model in models.get("models", []):
  264. model_info = Models.get_model_by_id(model["model"])
  265. if model_info:
  266. if user.id == model_info.user_id or has_access(
  267. user.id, type="read", access_control=model_info.access_control
  268. ):
  269. filtered_models.append(model)
  270. return filtered_models
  271. @router.get("/api/tags")
  272. @router.get("/api/tags/{url_idx}")
  273. async def get_ollama_tags(
  274. request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
  275. ):
  276. models = []
  277. if url_idx is None:
  278. models = await get_all_models()
  279. else:
  280. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  281. key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
  282. r = None
  283. try:
  284. r = requests.request(
  285. method="GET",
  286. url=f"{url}/api/tags",
  287. headers={**({"Authorization": f"Bearer {key}"} if key else {})},
  288. )
  289. r.raise_for_status()
  290. models = r.json()
  291. except Exception as e:
  292. log.exception(e)
  293. detail = None
  294. if r is not None:
  295. try:
  296. res = r.json()
  297. if "error" in res:
  298. detail = f"Ollama: {res['error']}"
  299. except Exception:
  300. detail = f"Ollama: {e}"
  301. raise HTTPException(
  302. status_code=r.status_code if r else 500,
  303. detail=detail if detail else "Open WebUI: Server Connection Error",
  304. )
  305. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  306. models["models"] = get_filtered_models(models, user)
  307. return models
  308. @router.get("/api/version")
  309. @router.get("/api/version/{url_idx}")
  310. async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
  311. if request.app.state.config.ENABLE_OLLAMA_API:
  312. if url_idx is None:
  313. # returns lowest version
  314. request_tasks = [
  315. send_get_request(
  316. f"{url}/api/version",
  317. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
  318. "key", None
  319. ),
  320. )
  321. for url in request.app.state.config.OLLAMA_BASE_URLS
  322. ]
  323. responses = await asyncio.gather(*request_tasks)
  324. responses = list(filter(lambda x: x is not None, responses))
  325. if len(responses) > 0:
  326. lowest_version = min(
  327. responses,
  328. key=lambda x: tuple(
  329. map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
  330. ),
  331. )
  332. return {"version": lowest_version["version"]}
  333. else:
  334. raise HTTPException(
  335. status_code=500,
  336. detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
  337. )
  338. else:
  339. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  340. r = None
  341. try:
  342. r = requests.request(method="GET", url=f"{url}/api/version")
  343. r.raise_for_status()
  344. return r.json()
  345. except Exception as e:
  346. log.exception(e)
  347. detail = None
  348. if r is not None:
  349. try:
  350. res = r.json()
  351. if "error" in res:
  352. detail = f"Ollama: {res['error']}"
  353. except Exception:
  354. detail = f"Ollama: {e}"
  355. raise HTTPException(
  356. status_code=r.status_code if r else 500,
  357. detail=detail if detail else "Open WebUI: Server Connection Error",
  358. )
  359. else:
  360. return {"version": False}
  361. @router.get("/api/ps")
  362. async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
  363. """
  364. List models that are currently loaded into Ollama memory, and which node they are loaded on.
  365. """
  366. if request.app.state.config.ENABLE_OLLAMA_API:
  367. request_tasks = [
  368. send_get_request(
  369. f"{url}/api/ps",
  370. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
  371. "key", None
  372. ),
  373. )
  374. for url in request.app.state.config.OLLAMA_BASE_URLS
  375. ]
  376. responses = await asyncio.gather(*request_tasks)
  377. return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
  378. else:
  379. return {}
  380. class ModelNameForm(BaseModel):
  381. name: str
  382. @router.post("/api/pull")
  383. @router.post("/api/pull/{url_idx}")
  384. async def pull_model(
  385. request: Request,
  386. form_data: ModelNameForm,
  387. url_idx: int = 0,
  388. user=Depends(get_admin_user),
  389. ):
  390. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  391. log.info(f"url: {url}")
  392. # Admin should be able to pull models from any source
  393. payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
  394. return await send_post_request(
  395. url=f"{url}/api/pull",
  396. payload=json.dumps(payload),
  397. key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
  398. )
  399. class PushModelForm(BaseModel):
  400. name: str
  401. insecure: Optional[bool] = None
  402. stream: Optional[bool] = None
  403. @router.delete("/api/push")
  404. @router.delete("/api/push/{url_idx}")
  405. async def push_model(
  406. request: Request,
  407. form_data: PushModelForm,
  408. url_idx: Optional[int] = None,
  409. user=Depends(get_admin_user),
  410. ):
  411. if url_idx is None:
  412. await get_all_models(request)
  413. models = request.app.state.OLLAMA_MODELS
  414. if form_data.name in models:
  415. url_idx = models[form_data.name]["urls"][0]
  416. else:
  417. raise HTTPException(
  418. status_code=400,
  419. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  420. )
  421. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  422. log.debug(f"url: {url}")
  423. return await send_post_request(
  424. url=f"{url}/api/push",
  425. payload=form_data.model_dump_json(exclude_none=True).encode(),
  426. key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
  427. )
  428. class CreateModelForm(BaseModel):
  429. name: str
  430. modelfile: Optional[str] = None
  431. stream: Optional[bool] = None
  432. path: Optional[str] = None
  433. @router.post("/api/create")
  434. @router.post("/api/create/{url_idx}")
  435. async def create_model(
  436. request: Request,
  437. form_data: CreateModelForm,
  438. url_idx: int = 0,
  439. user=Depends(get_admin_user),
  440. ):
  441. log.debug(f"form_data: {form_data}")
  442. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  443. return await send_post_request(
  444. url=f"{url}/api/create",
  445. payload=form_data.model_dump_json(exclude_none=True).encode(),
  446. key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
  447. )
  448. class CopyModelForm(BaseModel):
  449. source: str
  450. destination: str
  451. @router.post("/api/copy")
  452. @router.post("/api/copy/{url_idx}")
  453. async def copy_model(
  454. request: Request,
  455. form_data: CopyModelForm,
  456. url_idx: Optional[int] = None,
  457. user=Depends(get_admin_user),
  458. ):
  459. if url_idx is None:
  460. await get_all_models()
  461. models = request.app.state.OLLAMA_MODELS
  462. if form_data.source in models:
  463. url_idx = models[form_data.source]["urls"][0]
  464. else:
  465. raise HTTPException(
  466. status_code=400,
  467. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
  468. )
  469. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  470. key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
  471. try:
  472. r = requests.request(
  473. method="POST",
  474. url=f"{url}/api/copy",
  475. headers={
  476. "Content-Type": "application/json",
  477. **({"Authorization": f"Bearer {key}"} if key else {}),
  478. },
  479. data=form_data.model_dump_json(exclude_none=True).encode(),
  480. )
  481. r.raise_for_status()
  482. log.debug(f"r.text: {r.text}")
  483. return True
  484. except Exception as e:
  485. log.exception(e)
  486. detail = None
  487. if r is not None:
  488. try:
  489. res = r.json()
  490. if "error" in res:
  491. detail = f"Ollama: {res['error']}"
  492. except Exception:
  493. detail = f"Ollama: {e}"
  494. raise HTTPException(
  495. status_code=r.status_code if r else 500,
  496. detail=detail if detail else "Open WebUI: Server Connection Error",
  497. )
  498. @router.delete("/api/delete")
  499. @router.delete("/api/delete/{url_idx}")
  500. async def delete_model(
  501. request: Request,
  502. form_data: ModelNameForm,
  503. url_idx: Optional[int] = None,
  504. user=Depends(get_admin_user),
  505. ):
  506. if url_idx is None:
  507. await get_all_models()
  508. models = request.app.state.OLLAMA_MODELS
  509. if form_data.name in models:
  510. url_idx = models[form_data.name]["urls"][0]
  511. else:
  512. raise HTTPException(
  513. status_code=400,
  514. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  515. )
  516. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  517. key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
  518. try:
  519. r = requests.request(
  520. method="DELETE",
  521. url=f"{url}/api/delete",
  522. data=form_data.model_dump_json(exclude_none=True).encode(),
  523. headers={
  524. "Content-Type": "application/json",
  525. **({"Authorization": f"Bearer {key}"} if key else {}),
  526. },
  527. )
  528. r.raise_for_status()
  529. log.debug(f"r.text: {r.text}")
  530. return True
  531. except Exception as e:
  532. log.exception(e)
  533. detail = None
  534. if r is not None:
  535. try:
  536. res = r.json()
  537. if "error" in res:
  538. detail = f"Ollama: {res['error']}"
  539. except Exception:
  540. detail = f"Ollama: {e}"
  541. raise HTTPException(
  542. status_code=r.status_code if r else 500,
  543. detail=detail if detail else "Open WebUI: Server Connection Error",
  544. )
  545. @router.post("/api/show")
  546. async def show_model_info(
  547. request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
  548. ):
  549. await get_all_models()
  550. models = request.app.state.OLLAMA_MODELS
  551. if form_data.name not in models:
  552. raise HTTPException(
  553. status_code=400,
  554. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  555. )
  556. url_idx = random.choice(models[form_data.name]["urls"])
  557. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  558. key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
  559. try:
  560. r = requests.request(
  561. method="POST",
  562. url=f"{url}/api/show",
  563. headers={
  564. "Content-Type": "application/json",
  565. **({"Authorization": f"Bearer {key}"} if key else {}),
  566. },
  567. data=form_data.model_dump_json(exclude_none=True).encode(),
  568. )
  569. r.raise_for_status()
  570. return r.json()
  571. except Exception as e:
  572. log.exception(e)
  573. detail = None
  574. if r is not None:
  575. try:
  576. res = r.json()
  577. if "error" in res:
  578. detail = f"Ollama: {res['error']}"
  579. except Exception:
  580. detail = f"Ollama: {e}"
  581. raise HTTPException(
  582. status_code=r.status_code if r else 500,
  583. detail=detail if detail else "Open WebUI: Server Connection Error",
  584. )
  585. class GenerateEmbedForm(BaseModel):
  586. model: str
  587. input: list[str] | str
  588. truncate: Optional[bool] = None
  589. options: Optional[dict] = None
  590. keep_alive: Optional[Union[int, str]] = None
  591. @router.post("/api/embed")
  592. @router.post("/api/embed/{url_idx}")
  593. async def embed(
  594. request: Request,
  595. form_data: GenerateEmbedForm,
  596. url_idx: Optional[int] = None,
  597. user=Depends(get_verified_user),
  598. ):
  599. log.info(f"generate_ollama_batch_embeddings {form_data}")
  600. if url_idx is None:
  601. await get_all_models()
  602. models = request.app.state.OLLAMA_MODELS
  603. model = form_data.model
  604. if ":" not in model:
  605. model = f"{model}:latest"
  606. if model in models:
  607. url_idx = random.choice(models[model]["urls"])
  608. else:
  609. raise HTTPException(
  610. status_code=400,
  611. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  612. )
  613. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  614. key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
  615. try:
  616. r = requests.request(
  617. method="POST",
  618. url=f"{url}/api/embed",
  619. headers={
  620. "Content-Type": "application/json",
  621. **({"Authorization": f"Bearer {key}"} if key else {}),
  622. },
  623. data=form_data.model_dump_json(exclude_none=True).encode(),
  624. )
  625. r.raise_for_status()
  626. data = r.json()
  627. return data
  628. except Exception as e:
  629. log.exception(e)
  630. detail = None
  631. if r is not None:
  632. try:
  633. res = r.json()
  634. if "error" in res:
  635. detail = f"Ollama: {res['error']}"
  636. except Exception:
  637. detail = f"Ollama: {e}"
  638. raise HTTPException(
  639. status_code=r.status_code if r else 500,
  640. detail=detail if detail else "Open WebUI: Server Connection Error",
  641. )
  642. class GenerateEmbeddingsForm(BaseModel):
  643. model: str
  644. prompt: str
  645. options: Optional[dict] = None
  646. keep_alive: Optional[Union[int, str]] = None
  647. @router.post("/api/embeddings")
  648. @router.post("/api/embeddings/{url_idx}")
  649. async def embeddings(
  650. request: Request,
  651. form_data: GenerateEmbeddingsForm,
  652. url_idx: Optional[int] = None,
  653. user=Depends(get_verified_user),
  654. ):
  655. log.info(f"generate_ollama_embeddings {form_data}")
  656. if url_idx is None:
  657. await get_all_models()
  658. models = request.app.state.OLLAMA_MODELS
  659. model = form_data.model
  660. if ":" not in model:
  661. model = f"{model}:latest"
  662. if model in models:
  663. url_idx = random.choice(models[model]["urls"])
  664. else:
  665. raise HTTPException(
  666. status_code=400,
  667. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  668. )
  669. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  670. key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
  671. try:
  672. r = requests.request(
  673. method="POST",
  674. url=f"{url}/api/embeddings",
  675. headers={
  676. "Content-Type": "application/json",
  677. **({"Authorization": f"Bearer {key}"} if key else {}),
  678. },
  679. data=form_data.model_dump_json(exclude_none=True).encode(),
  680. )
  681. r.raise_for_status()
  682. data = r.json()
  683. return data
  684. except Exception as e:
  685. log.exception(e)
  686. detail = None
  687. if r is not None:
  688. try:
  689. res = r.json()
  690. if "error" in res:
  691. detail = f"Ollama: {res['error']}"
  692. except Exception:
  693. detail = f"Ollama: {e}"
  694. raise HTTPException(
  695. status_code=r.status_code if r else 500,
  696. detail=detail if detail else "Open WebUI: Server Connection Error",
  697. )
  698. class GenerateCompletionForm(BaseModel):
  699. model: str
  700. prompt: str
  701. suffix: Optional[str] = None
  702. images: Optional[list[str]] = None
  703. format: Optional[str] = None
  704. options: Optional[dict] = None
  705. system: Optional[str] = None
  706. template: Optional[str] = None
  707. context: Optional[list[int]] = None
  708. stream: Optional[bool] = True
  709. raw: Optional[bool] = None
  710. keep_alive: Optional[Union[int, str]] = None
  711. @router.post("/api/generate")
  712. @router.post("/api/generate/{url_idx}")
  713. async def generate_completion(
  714. request: Request,
  715. form_data: GenerateCompletionForm,
  716. url_idx: Optional[int] = None,
  717. user=Depends(get_verified_user),
  718. ):
  719. if url_idx is None:
  720. model_list = await get_all_models()
  721. models = {model["model"]: model for model in model_list["models"]}
  722. model = form_data.model
  723. if ":" not in model:
  724. model = f"{model}:latest"
  725. if model in models:
  726. url_idx = random.choice(models[model]["urls"])
  727. else:
  728. raise HTTPException(
  729. status_code=400,
  730. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  731. )
  732. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  733. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
  734. prefix_id = api_config.get("prefix_id", None)
  735. if prefix_id:
  736. form_data.model = form_data.model.replace(f"{prefix_id}.", "")
  737. return await send_post_request(
  738. url=f"{url}/api/generate",
  739. payload=form_data.model_dump_json(exclude_none=True).encode(),
  740. key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
  741. )
  742. class ChatMessage(BaseModel):
  743. role: str
  744. content: str
  745. images: Optional[list[str]] = None
  746. class GenerateChatCompletionForm(BaseModel):
  747. model: str
  748. messages: list[ChatMessage]
  749. format: Optional[str] = None
  750. options: Optional[dict] = None
  751. template: Optional[str] = None
  752. stream: Optional[bool] = True
  753. keep_alive: Optional[Union[int, str]] = None
  754. async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
  755. if url_idx is None:
  756. models = request.app.state.OLLAMA_MODELS
  757. if model not in models:
  758. raise HTTPException(
  759. status_code=400,
  760. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
  761. )
  762. url_idx = random.choice(models[model].get("urls", []))
  763. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  764. return url
  765. @router.post("/api/chat")
  766. @router.post("/api/chat/{url_idx}")
  767. async def generate_chat_completion(
  768. request: Request,
  769. form_data: GenerateChatCompletionForm,
  770. url_idx: Optional[int] = None,
  771. user=Depends(get_verified_user),
  772. bypass_filter: Optional[bool] = False,
  773. ):
  774. if BYPASS_MODEL_ACCESS_CONTROL:
  775. bypass_filter = True
  776. payload = {**form_data.model_dump(exclude_none=True)}
  777. if "metadata" in payload:
  778. del payload["metadata"]
  779. model_id = payload["model"]
  780. model_info = Models.get_model_by_id(model_id)
  781. if model_info:
  782. if model_info.base_model_id:
  783. payload["model"] = model_info.base_model_id
  784. params = model_info.params.model_dump()
  785. if params:
  786. if payload.get("options") is None:
  787. payload["options"] = {}
  788. payload["options"] = apply_model_params_to_body_ollama(
  789. params, payload["options"]
  790. )
  791. payload = apply_model_system_prompt_to_body(params, payload, user)
  792. # Check if user has access to the model
  793. if not bypass_filter and user.role == "user":
  794. if not (
  795. user.id == model_info.user_id
  796. or has_access(
  797. user.id, type="read", access_control=model_info.access_control
  798. )
  799. ):
  800. raise HTTPException(
  801. status_code=403,
  802. detail="Model not found",
  803. )
  804. elif not bypass_filter:
  805. if user.role != "admin":
  806. raise HTTPException(
  807. status_code=403,
  808. detail="Model not found",
  809. )
  810. if ":" not in payload["model"]:
  811. payload["model"] = f"{payload['model']}:latest"
  812. url = await get_ollama_url(request, payload["model"], url_idx)
  813. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
  814. prefix_id = api_config.get("prefix_id", None)
  815. if prefix_id:
  816. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  817. return await send_post_request(
  818. url=f"{url}/api/chat",
  819. payload=json.dumps(payload),
  820. stream=form_data.stream,
  821. key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
  822. content_type="application/x-ndjson",
  823. )
  824. # TODO: we should update this part once Ollama supports other types
  825. class OpenAIChatMessageContent(BaseModel):
  826. type: str
  827. model_config = ConfigDict(extra="allow")
  828. class OpenAIChatMessage(BaseModel):
  829. role: str
  830. content: Union[str, list[OpenAIChatMessageContent]]
  831. model_config = ConfigDict(extra="allow")
  832. class OpenAIChatCompletionForm(BaseModel):
  833. model: str
  834. messages: list[OpenAIChatMessage]
  835. model_config = ConfigDict(extra="allow")
  836. class OpenAICompletionForm(BaseModel):
  837. model: str
  838. prompt: str
  839. model_config = ConfigDict(extra="allow")
  840. @router.post("/v1/completions")
  841. @router.post("/v1/completions/{url_idx}")
  842. async def generate_openai_completion(
  843. request: Request,
  844. form_data: dict,
  845. url_idx: Optional[int] = None,
  846. user=Depends(get_verified_user),
  847. ):
  848. try:
  849. form_data = OpenAICompletionForm(**form_data)
  850. except Exception as e:
  851. log.exception(e)
  852. raise HTTPException(
  853. status_code=400,
  854. detail=str(e),
  855. )
  856. payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
  857. if "metadata" in payload:
  858. del payload["metadata"]
  859. model_id = form_data.model
  860. if ":" not in model_id:
  861. model_id = f"{model_id}:latest"
  862. model_info = Models.get_model_by_id(model_id)
  863. if model_info:
  864. if model_info.base_model_id:
  865. payload["model"] = model_info.base_model_id
  866. params = model_info.params.model_dump()
  867. if params:
  868. payload = apply_model_params_to_body_openai(params, payload)
  869. # Check if user has access to the model
  870. if user.role == "user":
  871. if not (
  872. user.id == model_info.user_id
  873. or has_access(
  874. user.id, type="read", access_control=model_info.access_control
  875. )
  876. ):
  877. raise HTTPException(
  878. status_code=403,
  879. detail="Model not found",
  880. )
  881. else:
  882. if user.role != "admin":
  883. raise HTTPException(
  884. status_code=403,
  885. detail="Model not found",
  886. )
  887. if ":" not in payload["model"]:
  888. payload["model"] = f"{payload['model']}:latest"
  889. url = await get_ollama_url(request, payload["model"], url_idx)
  890. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
  891. prefix_id = api_config.get("prefix_id", None)
  892. if prefix_id:
  893. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  894. return await send_post_request(
  895. url=f"{url}/v1/completions",
  896. payload=json.dumps(payload),
  897. stream=payload.get("stream", False),
  898. key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
  899. )
  900. @router.post("/v1/chat/completions")
  901. @router.post("/v1/chat/completions/{url_idx}")
  902. async def generate_openai_chat_completion(
  903. request: Request,
  904. form_data: dict,
  905. url_idx: Optional[int] = None,
  906. user=Depends(get_verified_user),
  907. ):
  908. try:
  909. completion_form = OpenAIChatCompletionForm(**form_data)
  910. except Exception as e:
  911. log.exception(e)
  912. raise HTTPException(
  913. status_code=400,
  914. detail=str(e),
  915. )
  916. payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
  917. if "metadata" in payload:
  918. del payload["metadata"]
  919. model_id = completion_form.model
  920. if ":" not in model_id:
  921. model_id = f"{model_id}:latest"
  922. model_info = Models.get_model_by_id(model_id)
  923. if model_info:
  924. if model_info.base_model_id:
  925. payload["model"] = model_info.base_model_id
  926. params = model_info.params.model_dump()
  927. if params:
  928. payload = apply_model_params_to_body_openai(params, payload)
  929. payload = apply_model_system_prompt_to_body(params, payload, user)
  930. # Check if user has access to the model
  931. if user.role == "user":
  932. if not (
  933. user.id == model_info.user_id
  934. or has_access(
  935. user.id, type="read", access_control=model_info.access_control
  936. )
  937. ):
  938. raise HTTPException(
  939. status_code=403,
  940. detail="Model not found",
  941. )
  942. else:
  943. if user.role != "admin":
  944. raise HTTPException(
  945. status_code=403,
  946. detail="Model not found",
  947. )
  948. if ":" not in payload["model"]:
  949. payload["model"] = f"{payload['model']}:latest"
  950. url = await get_ollama_url(request, payload["model"], url_idx)
  951. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
  952. prefix_id = api_config.get("prefix_id", None)
  953. if prefix_id:
  954. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  955. return await send_post_request(
  956. url=f"{url}/v1/chat/completions",
  957. payload=json.dumps(payload),
  958. stream=payload.get("stream", False),
  959. key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
  960. )
  961. @router.get("/v1/models")
  962. @router.get("/v1/models/{url_idx}")
  963. async def get_openai_models(
  964. request: Request,
  965. url_idx: Optional[int] = None,
  966. user=Depends(get_verified_user),
  967. ):
  968. models = []
  969. if url_idx is None:
  970. model_list = await get_all_models()
  971. models = [
  972. {
  973. "id": model["model"],
  974. "object": "model",
  975. "created": int(time.time()),
  976. "owned_by": "openai",
  977. }
  978. for model in model_list["models"]
  979. ]
  980. else:
  981. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  982. try:
  983. r = requests.request(method="GET", url=f"{url}/api/tags")
  984. r.raise_for_status()
  985. model_list = r.json()
  986. models = [
  987. {
  988. "id": model["model"],
  989. "object": "model",
  990. "created": int(time.time()),
  991. "owned_by": "openai",
  992. }
  993. for model in models["models"]
  994. ]
  995. except Exception as e:
  996. log.exception(e)
  997. error_detail = "Open WebUI: Server Connection Error"
  998. if r is not None:
  999. try:
  1000. res = r.json()
  1001. if "error" in res:
  1002. error_detail = f"Ollama: {res['error']}"
  1003. except Exception:
  1004. error_detail = f"Ollama: {e}"
  1005. raise HTTPException(
  1006. status_code=r.status_code if r else 500,
  1007. detail=error_detail,
  1008. )
  1009. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  1010. # Filter models based on user access control
  1011. filtered_models = []
  1012. for model in models:
  1013. model_info = Models.get_model_by_id(model["id"])
  1014. if model_info:
  1015. if user.id == model_info.user_id or has_access(
  1016. user.id, type="read", access_control=model_info.access_control
  1017. ):
  1018. filtered_models.append(model)
  1019. models = filtered_models
  1020. return {
  1021. "data": models,
  1022. "object": "list",
  1023. }
  1024. class UrlForm(BaseModel):
  1025. url: str
  1026. class UploadBlobForm(BaseModel):
  1027. filename: str
  1028. def parse_huggingface_url(hf_url):
  1029. try:
  1030. # Parse the URL
  1031. parsed_url = urlparse(hf_url)
  1032. # Get the path and split it into components
  1033. path_components = parsed_url.path.split("/")
  1034. # Extract the desired output
  1035. model_file = path_components[-1]
  1036. return model_file
  1037. except ValueError:
  1038. return None
  1039. async def download_file_stream(
  1040. ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
  1041. ):
  1042. done = False
  1043. if os.path.exists(file_path):
  1044. current_size = os.path.getsize(file_path)
  1045. else:
  1046. current_size = 0
  1047. headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
  1048. timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
  1049. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  1050. async with session.get(file_url, headers=headers) as response:
  1051. total_size = int(response.headers.get("content-length", 0)) + current_size
  1052. with open(file_path, "ab+") as file:
  1053. async for data in response.content.iter_chunked(chunk_size):
  1054. current_size += len(data)
  1055. file.write(data)
  1056. done = current_size == total_size
  1057. progress = round((current_size / total_size) * 100, 2)
  1058. yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
  1059. if done:
  1060. file.seek(0)
  1061. hashed = calculate_sha256(file)
  1062. file.seek(0)
  1063. url = f"{ollama_url}/api/blobs/sha256:{hashed}"
  1064. response = requests.post(url, data=file)
  1065. if response.ok:
  1066. res = {
  1067. "done": done,
  1068. "blob": f"sha256:{hashed}",
  1069. "name": file_name,
  1070. }
  1071. os.remove(file_path)
  1072. yield f"data: {json.dumps(res)}\n\n"
  1073. else:
  1074. raise "Ollama: Could not create blob, Please try again."
  1075. # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
  1076. @router.post("/models/download")
  1077. @router.post("/models/download/{url_idx}")
  1078. async def download_model(
  1079. request: Request,
  1080. form_data: UrlForm,
  1081. url_idx: Optional[int] = None,
  1082. user=Depends(get_admin_user),
  1083. ):
  1084. allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
  1085. if not any(form_data.url.startswith(host) for host in allowed_hosts):
  1086. raise HTTPException(
  1087. status_code=400,
  1088. detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
  1089. )
  1090. if url_idx is None:
  1091. url_idx = 0
  1092. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1093. file_name = parse_huggingface_url(form_data.url)
  1094. if file_name:
  1095. file_path = f"{UPLOAD_DIR}/{file_name}"
  1096. return StreamingResponse(
  1097. download_file_stream(url, form_data.url, file_path, file_name),
  1098. )
  1099. else:
  1100. return None
  1101. @router.post("/models/upload")
  1102. @router.post("/models/upload/{url_idx}")
  1103. def upload_model(
  1104. request: Request,
  1105. file: UploadFile = File(...),
  1106. url_idx: Optional[int] = None,
  1107. user=Depends(get_admin_user),
  1108. ):
  1109. if url_idx is None:
  1110. url_idx = 0
  1111. ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1112. file_path = f"{UPLOAD_DIR}/{file.filename}"
  1113. # Save file in chunks
  1114. with open(file_path, "wb+") as f:
  1115. for chunk in file.file:
  1116. f.write(chunk)
  1117. def file_process_stream():
  1118. nonlocal ollama_url
  1119. total_size = os.path.getsize(file_path)
  1120. chunk_size = 1024 * 1024
  1121. try:
  1122. with open(file_path, "rb") as f:
  1123. total = 0
  1124. done = False
  1125. while not done:
  1126. chunk = f.read(chunk_size)
  1127. if not chunk:
  1128. done = True
  1129. continue
  1130. total += len(chunk)
  1131. progress = round((total / total_size) * 100, 2)
  1132. res = {
  1133. "progress": progress,
  1134. "total": total_size,
  1135. "completed": total,
  1136. }
  1137. yield f"data: {json.dumps(res)}\n\n"
  1138. if done:
  1139. f.seek(0)
  1140. hashed = calculate_sha256(f)
  1141. f.seek(0)
  1142. url = f"{ollama_url}/api/blobs/sha256:{hashed}"
  1143. response = requests.post(url, data=f)
  1144. if response.ok:
  1145. res = {
  1146. "done": done,
  1147. "blob": f"sha256:{hashed}",
  1148. "name": file.filename,
  1149. }
  1150. os.remove(file_path)
  1151. yield f"data: {json.dumps(res)}\n\n"
  1152. else:
  1153. raise Exception(
  1154. "Ollama: Could not create blob, Please try again."
  1155. )
  1156. except Exception as e:
  1157. res = {"error": str(e)}
  1158. yield f"data: {json.dumps(res)}\n\n"
  1159. return StreamingResponse(file_process_stream(), media_type="text/event-stream")