ollama.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431
  1. import asyncio
  2. import json
  3. import logging
  4. import os
  5. import random
  6. import re
  7. import time
  8. from typing import Optional, Union
  9. from urllib.parse import urlparse
  10. import aiohttp
  11. from aiocache import cached
  12. import requests
  13. from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
  14. from fastapi.middleware.cors import CORSMiddleware
  15. from fastapi.responses import StreamingResponse
  16. from pydantic import BaseModel, ConfigDict
  17. from starlette.background import BackgroundTask
  18. from open_webui.models.models import Models
  19. from open_webui.config import (
  20. UPLOAD_DIR,
  21. )
  22. from open_webui.env import (
  23. AIOHTTP_CLIENT_TIMEOUT,
  24. AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
  25. BYPASS_MODEL_ACCESS_CONTROL,
  26. )
  27. from open_webui.constants import ERROR_MESSAGES
  28. from open_webui.env import ENV, SRC_LOG_LEVELS
  29. from open_webui.utils.misc import (
  30. calculate_sha256,
  31. )
  32. from open_webui.utils.payload import (
  33. apply_model_params_to_body_ollama,
  34. apply_model_params_to_body_openai,
  35. apply_model_system_prompt_to_body,
  36. )
  37. from open_webui.utils.auth import get_admin_user, get_verified_user
  38. from open_webui.utils.access_control import has_access
  39. log = logging.getLogger(__name__)
  40. log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
  41. # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
  42. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
  43. # least connections, or least response time for better resource utilization and performance optimization.
  44. @app.head("/")
  45. @app.get("/")
  46. async def get_status():
  47. return {"status": True}
  48. class ConnectionVerificationForm(BaseModel):
  49. url: str
  50. key: Optional[str] = None
  51. @app.post("/verify")
  52. async def verify_connection(
  53. form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
  54. ):
  55. url = form_data.url
  56. key = form_data.key
  57. headers = {}
  58. if key:
  59. headers["Authorization"] = f"Bearer {key}"
  60. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
  61. async with aiohttp.ClientSession(timeout=timeout) as session:
  62. try:
  63. async with session.get(f"{url}/api/version", headers=headers) as r:
  64. if r.status != 200:
  65. # Extract response error details if available
  66. error_detail = f"HTTP Error: {r.status}"
  67. res = await r.json()
  68. if "error" in res:
  69. error_detail = f"External Error: {res['error']}"
  70. raise Exception(error_detail)
  71. response_data = await r.json()
  72. return response_data
  73. except aiohttp.ClientError as e:
  74. # ClientError covers all aiohttp requests issues
  75. log.exception(f"Client error: {str(e)}")
  76. # Handle aiohttp-specific connection issues, timeout etc.
  77. raise HTTPException(
  78. status_code=500, detail="Open WebUI: Server Connection Error"
  79. )
  80. except Exception as e:
  81. log.exception(f"Unexpected error: {e}")
  82. # Generic error handler in case parsing JSON or other steps fail
  83. error_detail = f"Unexpected error: {str(e)}"
  84. raise HTTPException(status_code=500, detail=error_detail)
  85. @app.get("/config")
  86. async def get_config(user=Depends(get_admin_user)):
  87. return {
  88. "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
  89. "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
  90. "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
  91. }
  92. class OllamaConfigForm(BaseModel):
  93. ENABLE_OLLAMA_API: Optional[bool] = None
  94. OLLAMA_BASE_URLS: list[str]
  95. OLLAMA_API_CONFIGS: dict
  96. @app.post("/config/update")
  97. async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
  98. app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
  99. app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
  100. app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
  101. # Remove any extra configs
  102. config_urls = app.state.config.OLLAMA_API_CONFIGS.keys()
  103. for url in list(app.state.config.OLLAMA_BASE_URLS):
  104. if url not in config_urls:
  105. app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
  106. return {
  107. "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
  108. "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
  109. "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
  110. }
  111. async def aiohttp_get(url, key=None):
  112. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
  113. try:
  114. headers = {"Authorization": f"Bearer {key}"} if key else {}
  115. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  116. async with session.get(url, headers=headers) as response:
  117. return await response.json()
  118. except Exception as e:
  119. # Handle connection error here
  120. log.error(f"Connection error: {e}")
  121. return None
  122. async def cleanup_response(
  123. response: Optional[aiohttp.ClientResponse],
  124. session: Optional[aiohttp.ClientSession],
  125. ):
  126. if response:
  127. response.close()
  128. if session:
  129. await session.close()
  130. async def post_streaming_url(
  131. url: str, payload: Union[str, bytes], stream: bool = True, content_type=None
  132. ):
  133. r = None
  134. try:
  135. session = aiohttp.ClientSession(
  136. trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  137. )
  138. parsed_url = urlparse(url)
  139. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  140. api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
  141. key = api_config.get("key", None)
  142. headers = {"Content-Type": "application/json"}
  143. if key:
  144. headers["Authorization"] = f"Bearer {key}"
  145. r = await session.post(
  146. url,
  147. data=payload,
  148. headers=headers,
  149. )
  150. r.raise_for_status()
  151. if stream:
  152. response_headers = dict(r.headers)
  153. if content_type:
  154. response_headers["Content-Type"] = content_type
  155. return StreamingResponse(
  156. r.content,
  157. status_code=r.status,
  158. headers=response_headers,
  159. background=BackgroundTask(
  160. cleanup_response, response=r, session=session
  161. ),
  162. )
  163. else:
  164. res = await r.json()
  165. await cleanup_response(r, session)
  166. return res
  167. except Exception as e:
  168. error_detail = "Open WebUI: Server Connection Error"
  169. if r is not None:
  170. try:
  171. res = await r.json()
  172. if "error" in res:
  173. error_detail = f"Ollama: {res['error']}"
  174. except Exception:
  175. error_detail = f"Ollama: {e}"
  176. raise HTTPException(
  177. status_code=r.status if r else 500,
  178. detail=error_detail,
  179. )
  180. def merge_models_lists(model_lists):
  181. merged_models = {}
  182. for idx, model_list in enumerate(model_lists):
  183. if model_list is not None:
  184. for model in model_list:
  185. id = model["model"]
  186. if id not in merged_models:
  187. model["urls"] = [idx]
  188. merged_models[id] = model
  189. else:
  190. merged_models[id]["urls"].append(idx)
  191. return list(merged_models.values())
  192. @cached(ttl=3)
  193. async def get_all_models():
  194. log.info("get_all_models()")
  195. if app.state.config.ENABLE_OLLAMA_API:
  196. tasks = []
  197. for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS):
  198. if url not in app.state.config.OLLAMA_API_CONFIGS:
  199. tasks.append(aiohttp_get(f"{url}/api/tags"))
  200. else:
  201. api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
  202. enable = api_config.get("enable", True)
  203. key = api_config.get("key", None)
  204. if enable:
  205. tasks.append(aiohttp_get(f"{url}/api/tags", key))
  206. else:
  207. tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
  208. responses = await asyncio.gather(*tasks)
  209. for idx, response in enumerate(responses):
  210. if response:
  211. url = app.state.config.OLLAMA_BASE_URLS[idx]
  212. api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
  213. prefix_id = api_config.get("prefix_id", None)
  214. model_ids = api_config.get("model_ids", [])
  215. if len(model_ids) != 0 and "models" in response:
  216. response["models"] = list(
  217. filter(
  218. lambda model: model["model"] in model_ids,
  219. response["models"],
  220. )
  221. )
  222. if prefix_id:
  223. for model in response.get("models", []):
  224. model["model"] = f"{prefix_id}.{model['model']}"
  225. models = {
  226. "models": merge_models_lists(
  227. map(
  228. lambda response: response.get("models", []) if response else None,
  229. responses,
  230. )
  231. )
  232. }
  233. else:
  234. models = {"models": []}
  235. return models
  236. @app.get("/api/tags")
  237. @app.get("/api/tags/{url_idx}")
  238. async def get_ollama_tags(
  239. url_idx: Optional[int] = None, user=Depends(get_verified_user)
  240. ):
  241. models = []
  242. if url_idx is None:
  243. models = await get_all_models()
  244. else:
  245. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  246. parsed_url = urlparse(url)
  247. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  248. api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
  249. key = api_config.get("key", None)
  250. headers = {}
  251. if key:
  252. headers["Authorization"] = f"Bearer {key}"
  253. r = None
  254. try:
  255. r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers)
  256. r.raise_for_status()
  257. models = r.json()
  258. except Exception as e:
  259. log.exception(e)
  260. error_detail = "Open WebUI: Server Connection Error"
  261. if r is not None:
  262. try:
  263. res = r.json()
  264. if "error" in res:
  265. error_detail = f"Ollama: {res['error']}"
  266. except Exception:
  267. error_detail = f"Ollama: {e}"
  268. raise HTTPException(
  269. status_code=r.status_code if r else 500,
  270. detail=error_detail,
  271. )
  272. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  273. # Filter models based on user access control
  274. filtered_models = []
  275. for model in models.get("models", []):
  276. model_info = Models.get_model_by_id(model["model"])
  277. if model_info:
  278. if user.id == model_info.user_id or has_access(
  279. user.id, type="read", access_control=model_info.access_control
  280. ):
  281. filtered_models.append(model)
  282. models["models"] = filtered_models
  283. return models
  284. @app.get("/api/version")
  285. @app.get("/api/version/{url_idx}")
  286. async def get_ollama_versions(url_idx: Optional[int] = None):
  287. if app.state.config.ENABLE_OLLAMA_API:
  288. if url_idx is None:
  289. # returns lowest version
  290. tasks = [
  291. aiohttp_get(
  292. f"{url}/api/version",
  293. app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
  294. )
  295. for url in app.state.config.OLLAMA_BASE_URLS
  296. ]
  297. responses = await asyncio.gather(*tasks)
  298. responses = list(filter(lambda x: x is not None, responses))
  299. if len(responses) > 0:
  300. lowest_version = min(
  301. responses,
  302. key=lambda x: tuple(
  303. map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
  304. ),
  305. )
  306. return {"version": lowest_version["version"]}
  307. else:
  308. raise HTTPException(
  309. status_code=500,
  310. detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
  311. )
  312. else:
  313. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  314. r = None
  315. try:
  316. r = requests.request(method="GET", url=f"{url}/api/version")
  317. r.raise_for_status()
  318. return r.json()
  319. except Exception as e:
  320. log.exception(e)
  321. error_detail = "Open WebUI: Server Connection Error"
  322. if r is not None:
  323. try:
  324. res = r.json()
  325. if "error" in res:
  326. error_detail = f"Ollama: {res['error']}"
  327. except Exception:
  328. error_detail = f"Ollama: {e}"
  329. raise HTTPException(
  330. status_code=r.status_code if r else 500,
  331. detail=error_detail,
  332. )
  333. else:
  334. return {"version": False}
  335. @app.get("/api/ps")
  336. async def get_ollama_loaded_models(user=Depends(get_verified_user)):
  337. """
  338. List models that are currently loaded into Ollama memory, and which node they are loaded on.
  339. """
  340. if app.state.config.ENABLE_OLLAMA_API:
  341. tasks = [
  342. aiohttp_get(
  343. f"{url}/api/ps",
  344. app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
  345. )
  346. for url in app.state.config.OLLAMA_BASE_URLS
  347. ]
  348. responses = await asyncio.gather(*tasks)
  349. return dict(zip(app.state.config.OLLAMA_BASE_URLS, responses))
  350. else:
  351. return {}
  352. class ModelNameForm(BaseModel):
  353. name: str
  354. @app.post("/api/pull")
  355. @app.post("/api/pull/{url_idx}")
  356. async def pull_model(
  357. form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
  358. ):
  359. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  360. log.info(f"url: {url}")
  361. # Admin should be able to pull models from any source
  362. payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
  363. return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
  364. class PushModelForm(BaseModel):
  365. name: str
  366. insecure: Optional[bool] = None
  367. stream: Optional[bool] = None
  368. @app.delete("/api/push")
  369. @app.delete("/api/push/{url_idx}")
  370. async def push_model(
  371. form_data: PushModelForm,
  372. url_idx: Optional[int] = None,
  373. user=Depends(get_admin_user),
  374. ):
  375. if url_idx is None:
  376. model_list = await get_all_models()
  377. models = {model["model"]: model for model in model_list["models"]}
  378. if form_data.name in models:
  379. url_idx = models[form_data.name]["urls"][0]
  380. else:
  381. raise HTTPException(
  382. status_code=400,
  383. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  384. )
  385. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  386. log.debug(f"url: {url}")
  387. return await post_streaming_url(
  388. f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
  389. )
  390. class CreateModelForm(BaseModel):
  391. name: str
  392. modelfile: Optional[str] = None
  393. stream: Optional[bool] = None
  394. path: Optional[str] = None
  395. @app.post("/api/create")
  396. @app.post("/api/create/{url_idx}")
  397. async def create_model(
  398. form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
  399. ):
  400. log.debug(f"form_data: {form_data}")
  401. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  402. log.info(f"url: {url}")
  403. return await post_streaming_url(
  404. f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
  405. )
  406. class CopyModelForm(BaseModel):
  407. source: str
  408. destination: str
  409. @app.post("/api/copy")
  410. @app.post("/api/copy/{url_idx}")
  411. async def copy_model(
  412. form_data: CopyModelForm,
  413. url_idx: Optional[int] = None,
  414. user=Depends(get_admin_user),
  415. ):
  416. if url_idx is None:
  417. model_list = await get_all_models()
  418. models = {model["model"]: model for model in model_list["models"]}
  419. if form_data.source in models:
  420. url_idx = models[form_data.source]["urls"][0]
  421. else:
  422. raise HTTPException(
  423. status_code=400,
  424. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
  425. )
  426. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  427. log.info(f"url: {url}")
  428. parsed_url = urlparse(url)
  429. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  430. api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
  431. key = api_config.get("key", None)
  432. headers = {"Content-Type": "application/json"}
  433. if key:
  434. headers["Authorization"] = f"Bearer {key}"
  435. r = requests.request(
  436. method="POST",
  437. url=f"{url}/api/copy",
  438. headers=headers,
  439. data=form_data.model_dump_json(exclude_none=True).encode(),
  440. )
  441. try:
  442. r.raise_for_status()
  443. log.debug(f"r.text: {r.text}")
  444. return True
  445. except Exception as e:
  446. log.exception(e)
  447. error_detail = "Open WebUI: Server Connection Error"
  448. if r is not None:
  449. try:
  450. res = r.json()
  451. if "error" in res:
  452. error_detail = f"Ollama: {res['error']}"
  453. except Exception:
  454. error_detail = f"Ollama: {e}"
  455. raise HTTPException(
  456. status_code=r.status_code if r else 500,
  457. detail=error_detail,
  458. )
  459. @app.delete("/api/delete")
  460. @app.delete("/api/delete/{url_idx}")
  461. async def delete_model(
  462. form_data: ModelNameForm,
  463. url_idx: Optional[int] = None,
  464. user=Depends(get_admin_user),
  465. ):
  466. if url_idx is None:
  467. model_list = await get_all_models()
  468. models = {model["model"]: model for model in model_list["models"]}
  469. if form_data.name in models:
  470. url_idx = models[form_data.name]["urls"][0]
  471. else:
  472. raise HTTPException(
  473. status_code=400,
  474. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  475. )
  476. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  477. log.info(f"url: {url}")
  478. parsed_url = urlparse(url)
  479. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  480. api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
  481. key = api_config.get("key", None)
  482. headers = {"Content-Type": "application/json"}
  483. if key:
  484. headers["Authorization"] = f"Bearer {key}"
  485. r = requests.request(
  486. method="DELETE",
  487. url=f"{url}/api/delete",
  488. data=form_data.model_dump_json(exclude_none=True).encode(),
  489. headers=headers,
  490. )
  491. try:
  492. r.raise_for_status()
  493. log.debug(f"r.text: {r.text}")
  494. return True
  495. except Exception as e:
  496. log.exception(e)
  497. error_detail = "Open WebUI: Server Connection Error"
  498. if r is not None:
  499. try:
  500. res = r.json()
  501. if "error" in res:
  502. error_detail = f"Ollama: {res['error']}"
  503. except Exception:
  504. error_detail = f"Ollama: {e}"
  505. raise HTTPException(
  506. status_code=r.status_code if r else 500,
  507. detail=error_detail,
  508. )
  509. @app.post("/api/show")
  510. async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
  511. model_list = await get_all_models()
  512. models = {model["model"]: model for model in model_list["models"]}
  513. if form_data.name not in models:
  514. raise HTTPException(
  515. status_code=400,
  516. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  517. )
  518. url_idx = random.choice(models[form_data.name]["urls"])
  519. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  520. log.info(f"url: {url}")
  521. parsed_url = urlparse(url)
  522. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  523. api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
  524. key = api_config.get("key", None)
  525. headers = {"Content-Type": "application/json"}
  526. if key:
  527. headers["Authorization"] = f"Bearer {key}"
  528. r = requests.request(
  529. method="POST",
  530. url=f"{url}/api/show",
  531. headers=headers,
  532. data=form_data.model_dump_json(exclude_none=True).encode(),
  533. )
  534. try:
  535. r.raise_for_status()
  536. return r.json()
  537. except Exception as e:
  538. log.exception(e)
  539. error_detail = "Open WebUI: Server Connection Error"
  540. if r is not None:
  541. try:
  542. res = r.json()
  543. if "error" in res:
  544. error_detail = f"Ollama: {res['error']}"
  545. except Exception:
  546. error_detail = f"Ollama: {e}"
  547. raise HTTPException(
  548. status_code=r.status_code if r else 500,
  549. detail=error_detail,
  550. )
  551. class GenerateEmbeddingsForm(BaseModel):
  552. model: str
  553. prompt: str
  554. options: Optional[dict] = None
  555. keep_alive: Optional[Union[int, str]] = None
  556. class GenerateEmbedForm(BaseModel):
  557. model: str
  558. input: list[str] | str
  559. truncate: Optional[bool] = None
  560. options: Optional[dict] = None
  561. keep_alive: Optional[Union[int, str]] = None
  562. @app.post("/api/embed")
  563. @app.post("/api/embed/{url_idx}")
  564. async def generate_embeddings(
  565. form_data: GenerateEmbedForm,
  566. url_idx: Optional[int] = None,
  567. user=Depends(get_verified_user),
  568. ):
  569. return await generate_ollama_batch_embeddings(form_data, url_idx)
  570. @app.post("/api/embeddings")
  571. @app.post("/api/embeddings/{url_idx}")
  572. async def generate_embeddings(
  573. form_data: GenerateEmbeddingsForm,
  574. url_idx: Optional[int] = None,
  575. user=Depends(get_verified_user),
  576. ):
  577. return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
  578. async def generate_ollama_embeddings(
  579. form_data: GenerateEmbeddingsForm,
  580. url_idx: Optional[int] = None,
  581. ):
  582. log.info(f"generate_ollama_embeddings {form_data}")
  583. if url_idx is None:
  584. model_list = await get_all_models()
  585. models = {model["model"]: model for model in model_list["models"]}
  586. model = form_data.model
  587. if ":" not in model:
  588. model = f"{model}:latest"
  589. if model in models:
  590. url_idx = random.choice(models[model]["urls"])
  591. else:
  592. raise HTTPException(
  593. status_code=400,
  594. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  595. )
  596. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  597. log.info(f"url: {url}")
  598. parsed_url = urlparse(url)
  599. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  600. api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
  601. key = api_config.get("key", None)
  602. headers = {"Content-Type": "application/json"}
  603. if key:
  604. headers["Authorization"] = f"Bearer {key}"
  605. r = requests.request(
  606. method="POST",
  607. url=f"{url}/api/embeddings",
  608. headers=headers,
  609. data=form_data.model_dump_json(exclude_none=True).encode(),
  610. )
  611. try:
  612. r.raise_for_status()
  613. data = r.json()
  614. log.info(f"generate_ollama_embeddings {data}")
  615. if "embedding" in data:
  616. return data
  617. else:
  618. raise Exception("Something went wrong :/")
  619. except Exception as e:
  620. log.exception(e)
  621. error_detail = "Open WebUI: Server Connection Error"
  622. if r is not None:
  623. try:
  624. res = r.json()
  625. if "error" in res:
  626. error_detail = f"Ollama: {res['error']}"
  627. except Exception:
  628. error_detail = f"Ollama: {e}"
  629. raise HTTPException(
  630. status_code=r.status_code if r else 500,
  631. detail=error_detail,
  632. )
  633. async def generate_ollama_batch_embeddings(
  634. form_data: GenerateEmbedForm,
  635. url_idx: Optional[int] = None,
  636. ):
  637. log.info(f"generate_ollama_batch_embeddings {form_data}")
  638. if url_idx is None:
  639. model_list = await get_all_models()
  640. models = {model["model"]: model for model in model_list["models"]}
  641. model = form_data.model
  642. if ":" not in model:
  643. model = f"{model}:latest"
  644. if model in models:
  645. url_idx = random.choice(models[model]["urls"])
  646. else:
  647. raise HTTPException(
  648. status_code=400,
  649. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  650. )
  651. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  652. log.info(f"url: {url}")
  653. parsed_url = urlparse(url)
  654. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  655. api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
  656. key = api_config.get("key", None)
  657. headers = {"Content-Type": "application/json"}
  658. if key:
  659. headers["Authorization"] = f"Bearer {key}"
  660. r = requests.request(
  661. method="POST",
  662. url=f"{url}/api/embed",
  663. headers=headers,
  664. data=form_data.model_dump_json(exclude_none=True).encode(),
  665. )
  666. try:
  667. r.raise_for_status()
  668. data = r.json()
  669. log.info(f"generate_ollama_batch_embeddings {data}")
  670. if "embeddings" in data:
  671. return data
  672. else:
  673. raise Exception("Something went wrong :/")
  674. except Exception as e:
  675. log.exception(e)
  676. error_detail = "Open WebUI: Server Connection Error"
  677. if r is not None:
  678. try:
  679. res = r.json()
  680. if "error" in res:
  681. error_detail = f"Ollama: {res['error']}"
  682. except Exception:
  683. error_detail = f"Ollama: {e}"
  684. raise Exception(error_detail)
  685. class GenerateCompletionForm(BaseModel):
  686. model: str
  687. prompt: str
  688. suffix: Optional[str] = None
  689. images: Optional[list[str]] = None
  690. format: Optional[str] = None
  691. options: Optional[dict] = None
  692. system: Optional[str] = None
  693. template: Optional[str] = None
  694. context: Optional[list[int]] = None
  695. stream: Optional[bool] = True
  696. raw: Optional[bool] = None
  697. keep_alive: Optional[Union[int, str]] = None
  698. @app.post("/api/generate")
  699. @app.post("/api/generate/{url_idx}")
  700. async def generate_completion(
  701. form_data: GenerateCompletionForm,
  702. url_idx: Optional[int] = None,
  703. user=Depends(get_verified_user),
  704. ):
  705. if url_idx is None:
  706. model_list = await get_all_models()
  707. models = {model["model"]: model for model in model_list["models"]}
  708. model = form_data.model
  709. if ":" not in model:
  710. model = f"{model}:latest"
  711. if model in models:
  712. url_idx = random.choice(models[model]["urls"])
  713. else:
  714. raise HTTPException(
  715. status_code=400,
  716. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  717. )
  718. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  719. api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
  720. prefix_id = api_config.get("prefix_id", None)
  721. if prefix_id:
  722. form_data.model = form_data.model.replace(f"{prefix_id}.", "")
  723. log.info(f"url: {url}")
  724. return await post_streaming_url(
  725. f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
  726. )
  727. class ChatMessage(BaseModel):
  728. role: str
  729. content: str
  730. images: Optional[list[str]] = None
  731. class GenerateChatCompletionForm(BaseModel):
  732. model: str
  733. messages: list[ChatMessage]
  734. format: Optional[str] = None
  735. options: Optional[dict] = None
  736. template: Optional[str] = None
  737. stream: Optional[bool] = True
  738. keep_alive: Optional[Union[int, str]] = None
  739. async def get_ollama_url(url_idx: Optional[int], model: str):
  740. if url_idx is None:
  741. model_list = await get_all_models()
  742. models = {model["model"]: model for model in model_list["models"]}
  743. if model not in models:
  744. raise HTTPException(
  745. status_code=400,
  746. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
  747. )
  748. url_idx = random.choice(models[model]["urls"])
  749. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  750. return url
  751. @app.post("/api/chat")
  752. @app.post("/api/chat/{url_idx}")
  753. async def generate_chat_completion(
  754. form_data: GenerateChatCompletionForm,
  755. url_idx: Optional[int] = None,
  756. user=Depends(get_verified_user),
  757. bypass_filter: Optional[bool] = False,
  758. ):
  759. if BYPASS_MODEL_ACCESS_CONTROL:
  760. bypass_filter = True
  761. payload = {**form_data.model_dump(exclude_none=True)}
  762. log.debug(f"generate_chat_completion() - 1.payload = {payload}")
  763. if "metadata" in payload:
  764. del payload["metadata"]
  765. model_id = payload["model"]
  766. model_info = Models.get_model_by_id(model_id)
  767. if model_info:
  768. if model_info.base_model_id:
  769. payload["model"] = model_info.base_model_id
  770. params = model_info.params.model_dump()
  771. if params:
  772. if payload.get("options") is None:
  773. payload["options"] = {}
  774. payload["options"] = apply_model_params_to_body_ollama(
  775. params, payload["options"]
  776. )
  777. payload = apply_model_system_prompt_to_body(params, payload, user)
  778. # Check if user has access to the model
  779. if not bypass_filter and user.role == "user":
  780. if not (
  781. user.id == model_info.user_id
  782. or has_access(
  783. user.id, type="read", access_control=model_info.access_control
  784. )
  785. ):
  786. raise HTTPException(
  787. status_code=403,
  788. detail="Model not found",
  789. )
  790. elif not bypass_filter:
  791. if user.role != "admin":
  792. raise HTTPException(
  793. status_code=403,
  794. detail="Model not found",
  795. )
  796. if ":" not in payload["model"]:
  797. payload["model"] = f"{payload['model']}:latest"
  798. url = await get_ollama_url(url_idx, payload["model"])
  799. log.info(f"url: {url}")
  800. log.debug(f"generate_chat_completion() - 2.payload = {payload}")
  801. parsed_url = urlparse(url)
  802. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  803. api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
  804. prefix_id = api_config.get("prefix_id", None)
  805. if prefix_id:
  806. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  807. return await post_streaming_url(
  808. f"{url}/api/chat",
  809. json.dumps(payload),
  810. stream=form_data.stream,
  811. content_type="application/x-ndjson",
  812. )
  813. # TODO: we should update this part once Ollama supports other types
  814. class OpenAIChatMessageContent(BaseModel):
  815. type: str
  816. model_config = ConfigDict(extra="allow")
  817. class OpenAIChatMessage(BaseModel):
  818. role: str
  819. content: Union[str, list[OpenAIChatMessageContent]]
  820. model_config = ConfigDict(extra="allow")
  821. class OpenAIChatCompletionForm(BaseModel):
  822. model: str
  823. messages: list[OpenAIChatMessage]
  824. model_config = ConfigDict(extra="allow")
  825. class OpenAICompletionForm(BaseModel):
  826. model: str
  827. prompt: str
  828. model_config = ConfigDict(extra="allow")
  829. @app.post("/v1/completions")
  830. @app.post("/v1/completions/{url_idx}")
  831. async def generate_openai_completion(
  832. form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user)
  833. ):
  834. try:
  835. form_data = OpenAICompletionForm(**form_data)
  836. except Exception as e:
  837. log.exception(e)
  838. raise HTTPException(
  839. status_code=400,
  840. detail=str(e),
  841. )
  842. payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
  843. if "metadata" in payload:
  844. del payload["metadata"]
  845. model_id = form_data.model
  846. if ":" not in model_id:
  847. model_id = f"{model_id}:latest"
  848. model_info = Models.get_model_by_id(model_id)
  849. if model_info:
  850. if model_info.base_model_id:
  851. payload["model"] = model_info.base_model_id
  852. params = model_info.params.model_dump()
  853. if params:
  854. payload = apply_model_params_to_body_openai(params, payload)
  855. # Check if user has access to the model
  856. if user.role == "user":
  857. if not (
  858. user.id == model_info.user_id
  859. or has_access(
  860. user.id, type="read", access_control=model_info.access_control
  861. )
  862. ):
  863. raise HTTPException(
  864. status_code=403,
  865. detail="Model not found",
  866. )
  867. else:
  868. if user.role != "admin":
  869. raise HTTPException(
  870. status_code=403,
  871. detail="Model not found",
  872. )
  873. if ":" not in payload["model"]:
  874. payload["model"] = f"{payload['model']}:latest"
  875. url = await get_ollama_url(url_idx, payload["model"])
  876. log.info(f"url: {url}")
  877. api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
  878. prefix_id = api_config.get("prefix_id", None)
  879. if prefix_id:
  880. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  881. return await post_streaming_url(
  882. f"{url}/v1/completions",
  883. json.dumps(payload),
  884. stream=payload.get("stream", False),
  885. )
  886. @app.post("/v1/chat/completions")
  887. @app.post("/v1/chat/completions/{url_idx}")
  888. async def generate_openai_chat_completion(
  889. form_data: dict,
  890. url_idx: Optional[int] = None,
  891. user=Depends(get_verified_user),
  892. ):
  893. try:
  894. completion_form = OpenAIChatCompletionForm(**form_data)
  895. except Exception as e:
  896. log.exception(e)
  897. raise HTTPException(
  898. status_code=400,
  899. detail=str(e),
  900. )
  901. payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
  902. if "metadata" in payload:
  903. del payload["metadata"]
  904. model_id = completion_form.model
  905. if ":" not in model_id:
  906. model_id = f"{model_id}:latest"
  907. model_info = Models.get_model_by_id(model_id)
  908. if model_info:
  909. if model_info.base_model_id:
  910. payload["model"] = model_info.base_model_id
  911. params = model_info.params.model_dump()
  912. if params:
  913. payload = apply_model_params_to_body_openai(params, payload)
  914. payload = apply_model_system_prompt_to_body(params, payload, user)
  915. # Check if user has access to the model
  916. if user.role == "user":
  917. if not (
  918. user.id == model_info.user_id
  919. or has_access(
  920. user.id, type="read", access_control=model_info.access_control
  921. )
  922. ):
  923. raise HTTPException(
  924. status_code=403,
  925. detail="Model not found",
  926. )
  927. else:
  928. if user.role != "admin":
  929. raise HTTPException(
  930. status_code=403,
  931. detail="Model not found",
  932. )
  933. if ":" not in payload["model"]:
  934. payload["model"] = f"{payload['model']}:latest"
  935. url = await get_ollama_url(url_idx, payload["model"])
  936. log.info(f"url: {url}")
  937. api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
  938. prefix_id = api_config.get("prefix_id", None)
  939. if prefix_id:
  940. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  941. return await post_streaming_url(
  942. f"{url}/v1/chat/completions",
  943. json.dumps(payload),
  944. stream=payload.get("stream", False),
  945. )
  946. @app.get("/v1/models")
  947. @app.get("/v1/models/{url_idx}")
  948. async def get_openai_models(
  949. url_idx: Optional[int] = None,
  950. user=Depends(get_verified_user),
  951. ):
  952. models = []
  953. if url_idx is None:
  954. model_list = await get_all_models()
  955. models = [
  956. {
  957. "id": model["model"],
  958. "object": "model",
  959. "created": int(time.time()),
  960. "owned_by": "openai",
  961. }
  962. for model in model_list["models"]
  963. ]
  964. else:
  965. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  966. try:
  967. r = requests.request(method="GET", url=f"{url}/api/tags")
  968. r.raise_for_status()
  969. model_list = r.json()
  970. models = [
  971. {
  972. "id": model["model"],
  973. "object": "model",
  974. "created": int(time.time()),
  975. "owned_by": "openai",
  976. }
  977. for model in models["models"]
  978. ]
  979. except Exception as e:
  980. log.exception(e)
  981. error_detail = "Open WebUI: Server Connection Error"
  982. if r is not None:
  983. try:
  984. res = r.json()
  985. if "error" in res:
  986. error_detail = f"Ollama: {res['error']}"
  987. except Exception:
  988. error_detail = f"Ollama: {e}"
  989. raise HTTPException(
  990. status_code=r.status_code if r else 500,
  991. detail=error_detail,
  992. )
  993. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  994. # Filter models based on user access control
  995. filtered_models = []
  996. for model in models:
  997. model_info = Models.get_model_by_id(model["id"])
  998. if model_info:
  999. if user.id == model_info.user_id or has_access(
  1000. user.id, type="read", access_control=model_info.access_control
  1001. ):
  1002. filtered_models.append(model)
  1003. models = filtered_models
  1004. return {
  1005. "data": models,
  1006. "object": "list",
  1007. }
  1008. class UrlForm(BaseModel):
  1009. url: str
  1010. class UploadBlobForm(BaseModel):
  1011. filename: str
  1012. def parse_huggingface_url(hf_url):
  1013. try:
  1014. # Parse the URL
  1015. parsed_url = urlparse(hf_url)
  1016. # Get the path and split it into components
  1017. path_components = parsed_url.path.split("/")
  1018. # Extract the desired output
  1019. model_file = path_components[-1]
  1020. return model_file
  1021. except ValueError:
  1022. return None
  1023. async def download_file_stream(
  1024. ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
  1025. ):
  1026. done = False
  1027. if os.path.exists(file_path):
  1028. current_size = os.path.getsize(file_path)
  1029. else:
  1030. current_size = 0
  1031. headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
  1032. timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
  1033. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  1034. async with session.get(file_url, headers=headers) as response:
  1035. total_size = int(response.headers.get("content-length", 0)) + current_size
  1036. with open(file_path, "ab+") as file:
  1037. async for data in response.content.iter_chunked(chunk_size):
  1038. current_size += len(data)
  1039. file.write(data)
  1040. done = current_size == total_size
  1041. progress = round((current_size / total_size) * 100, 2)
  1042. yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
  1043. if done:
  1044. file.seek(0)
  1045. hashed = calculate_sha256(file)
  1046. file.seek(0)
  1047. url = f"{ollama_url}/api/blobs/sha256:{hashed}"
  1048. response = requests.post(url, data=file)
  1049. if response.ok:
  1050. res = {
  1051. "done": done,
  1052. "blob": f"sha256:{hashed}",
  1053. "name": file_name,
  1054. }
  1055. os.remove(file_path)
  1056. yield f"data: {json.dumps(res)}\n\n"
  1057. else:
  1058. raise "Ollama: Could not create blob, Please try again."
  1059. # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
  1060. @app.post("/models/download")
  1061. @app.post("/models/download/{url_idx}")
  1062. async def download_model(
  1063. form_data: UrlForm,
  1064. url_idx: Optional[int] = None,
  1065. user=Depends(get_admin_user),
  1066. ):
  1067. allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
  1068. if not any(form_data.url.startswith(host) for host in allowed_hosts):
  1069. raise HTTPException(
  1070. status_code=400,
  1071. detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
  1072. )
  1073. if url_idx is None:
  1074. url_idx = 0
  1075. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  1076. file_name = parse_huggingface_url(form_data.url)
  1077. if file_name:
  1078. file_path = f"{UPLOAD_DIR}/{file_name}"
  1079. return StreamingResponse(
  1080. download_file_stream(url, form_data.url, file_path, file_name),
  1081. )
  1082. else:
  1083. return None
  1084. @app.post("/models/upload")
  1085. @app.post("/models/upload/{url_idx}")
  1086. def upload_model(
  1087. file: UploadFile = File(...),
  1088. url_idx: Optional[int] = None,
  1089. user=Depends(get_admin_user),
  1090. ):
  1091. if url_idx is None:
  1092. url_idx = 0
  1093. ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  1094. file_path = f"{UPLOAD_DIR}/{file.filename}"
  1095. # Save file in chunks
  1096. with open(file_path, "wb+") as f:
  1097. for chunk in file.file:
  1098. f.write(chunk)
  1099. def file_process_stream():
  1100. nonlocal ollama_url
  1101. total_size = os.path.getsize(file_path)
  1102. chunk_size = 1024 * 1024
  1103. try:
  1104. with open(file_path, "rb") as f:
  1105. total = 0
  1106. done = False
  1107. while not done:
  1108. chunk = f.read(chunk_size)
  1109. if not chunk:
  1110. done = True
  1111. continue
  1112. total += len(chunk)
  1113. progress = round((total / total_size) * 100, 2)
  1114. res = {
  1115. "progress": progress,
  1116. "total": total_size,
  1117. "completed": total,
  1118. }
  1119. yield f"data: {json.dumps(res)}\n\n"
  1120. if done:
  1121. f.seek(0)
  1122. hashed = calculate_sha256(f)
  1123. f.seek(0)
  1124. url = f"{ollama_url}/api/blobs/sha256:{hashed}"
  1125. response = requests.post(url, data=f)
  1126. if response.ok:
  1127. res = {
  1128. "done": done,
  1129. "blob": f"sha256:{hashed}",
  1130. "name": file.filename,
  1131. }
  1132. os.remove(file_path)
  1133. yield f"data: {json.dumps(res)}\n\n"
  1134. else:
  1135. raise Exception(
  1136. "Ollama: Could not create blob, Please try again."
  1137. )
  1138. except Exception as e:
  1139. res = {"error": str(e)}
  1140. yield f"data: {json.dumps(res)}\n\n"
  1141. return StreamingResponse(file_process_stream(), media_type="text/event-stream")