main.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972
  1. from fastapi import FastAPI, Request, Response, HTTPException, Depends, status
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from fastapi.responses import StreamingResponse
  4. from fastapi.concurrency import run_in_threadpool
  5. from pydantic import BaseModel, ConfigDict
  6. import random
  7. import requests
  8. import json
  9. import uuid
  10. import aiohttp
  11. import asyncio
  12. from apps.web.models.users import Users
  13. from constants import ERROR_MESSAGES
  14. from utils.utils import decode_token, get_current_user, get_admin_user
  15. from config import OLLAMA_BASE_URL, WEBUI_AUTH
  16. from typing import Optional, List, Union
  17. app = FastAPI()
  18. app.add_middleware(
  19. CORSMiddleware,
  20. allow_origins=["*"],
  21. allow_credentials=True,
  22. allow_methods=["*"],
  23. allow_headers=["*"],
  24. )
  25. app.state.OLLAMA_BASE_URL = OLLAMA_BASE_URL
  26. app.state.OLLAMA_BASE_URLS = [OLLAMA_BASE_URL]
  27. app.state.MODELS = {}
  28. REQUEST_POOL = []
  29. # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
  30. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
  31. # least connections, or least response time for better resource utilization and performance optimization.
  32. @app.middleware("http")
  33. async def check_url(request: Request, call_next):
  34. if len(app.state.MODELS) == 0:
  35. await get_all_models()
  36. else:
  37. pass
  38. response = await call_next(request)
  39. return response
  40. @app.get("/urls")
  41. async def get_ollama_api_urls(user=Depends(get_admin_user)):
  42. return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
  43. class UrlUpdateForm(BaseModel):
  44. urls: List[str]
  45. @app.post("/urls/update")
  46. async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
  47. app.state.OLLAMA_BASE_URLS = form_data.urls
  48. print(app.state.OLLAMA_BASE_URLS)
  49. return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
  50. @app.get("/cancel/{request_id}")
  51. async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
  52. if user:
  53. if request_id in REQUEST_POOL:
  54. REQUEST_POOL.remove(request_id)
  55. return True
  56. else:
  57. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  58. async def fetch_url(url):
  59. try:
  60. async with aiohttp.ClientSession() as session:
  61. async with session.get(url) as response:
  62. return await response.json()
  63. except Exception as e:
  64. # Handle connection error here
  65. print(f"Connection error: {e}")
  66. return None
  67. def merge_models_lists(model_lists):
  68. merged_models = {}
  69. for idx, model_list in enumerate(model_lists):
  70. for model in model_list:
  71. digest = model["digest"]
  72. if digest not in merged_models:
  73. model["urls"] = [idx]
  74. merged_models[digest] = model
  75. else:
  76. merged_models[digest]["urls"].append(idx)
  77. return list(merged_models.values())
  78. # user=Depends(get_current_user)
  79. async def get_all_models():
  80. print("get_all_models")
  81. tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS]
  82. responses = await asyncio.gather(*tasks)
  83. responses = list(filter(lambda x: x is not None, responses))
  84. models = {
  85. "models": merge_models_lists(
  86. map(lambda response: response["models"], responses)
  87. )
  88. }
  89. app.state.MODELS = {model["model"]: model for model in models["models"]}
  90. return models
  91. @app.get("/api/tags")
  92. @app.get("/api/tags/{url_idx}")
  93. async def get_ollama_tags(
  94. url_idx: Optional[int] = None, user=Depends(get_current_user)
  95. ):
  96. if url_idx == None:
  97. return await get_all_models()
  98. else:
  99. url = app.state.OLLAMA_BASE_URLS[url_idx]
  100. try:
  101. r = requests.request(method="GET", url=f"{url}/api/tags")
  102. r.raise_for_status()
  103. return r.json()
  104. except Exception as e:
  105. print(e)
  106. error_detail = "Open WebUI: Server Connection Error"
  107. if r is not None:
  108. try:
  109. res = r.json()
  110. if "error" in res:
  111. error_detail = f"Ollama: {res['error']}"
  112. except:
  113. error_detail = f"Ollama: {e}"
  114. raise HTTPException(
  115. status_code=r.status_code if r else 500,
  116. detail=error_detail,
  117. )
  118. @app.get("/api/version")
  119. @app.get("/api/version/{url_idx}")
  120. async def get_ollama_versions(url_idx: Optional[int] = None):
  121. if url_idx == None:
  122. # returns lowest version
  123. tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS]
  124. responses = await asyncio.gather(*tasks)
  125. responses = list(filter(lambda x: x is not None, responses))
  126. lowest_version = min(
  127. responses, key=lambda x: tuple(map(int, x["version"].split(".")))
  128. )
  129. return {"version": lowest_version["version"]}
  130. else:
  131. url = app.state.OLLAMA_BASE_URLS[url_idx]
  132. try:
  133. r = requests.request(method="GET", url=f"{url}/api/version")
  134. r.raise_for_status()
  135. return r.json()
  136. except Exception as e:
  137. print(e)
  138. error_detail = "Open WebUI: Server Connection Error"
  139. if r is not None:
  140. try:
  141. res = r.json()
  142. if "error" in res:
  143. error_detail = f"Ollama: {res['error']}"
  144. except:
  145. error_detail = f"Ollama: {e}"
  146. raise HTTPException(
  147. status_code=r.status_code if r else 500,
  148. detail=error_detail,
  149. )
  150. class ModelNameForm(BaseModel):
  151. name: str
  152. @app.post("/api/pull")
  153. @app.post("/api/pull/{url_idx}")
  154. async def pull_model(
  155. form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
  156. ):
  157. url = app.state.OLLAMA_BASE_URLS[url_idx]
  158. print(url)
  159. r = None
  160. def get_request():
  161. nonlocal url
  162. nonlocal r
  163. try:
  164. def stream_content():
  165. for chunk in r.iter_content(chunk_size=8192):
  166. yield chunk
  167. r = requests.request(
  168. method="POST",
  169. url=f"{url}/api/pull",
  170. data=form_data.model_dump_json(exclude_none=True),
  171. stream=True,
  172. )
  173. r.raise_for_status()
  174. return StreamingResponse(
  175. stream_content(),
  176. status_code=r.status_code,
  177. headers=dict(r.headers),
  178. )
  179. except Exception as e:
  180. raise e
  181. try:
  182. return await run_in_threadpool(get_request)
  183. except Exception as e:
  184. print(e)
  185. error_detail = "Open WebUI: Server Connection Error"
  186. if r is not None:
  187. try:
  188. res = r.json()
  189. if "error" in res:
  190. error_detail = f"Ollama: {res['error']}"
  191. except:
  192. error_detail = f"Ollama: {e}"
  193. raise HTTPException(
  194. status_code=r.status_code if r else 500,
  195. detail=error_detail,
  196. )
  197. class PushModelForm(BaseModel):
  198. name: str
  199. insecure: Optional[bool] = None
  200. stream: Optional[bool] = None
  201. @app.delete("/api/push")
  202. @app.delete("/api/push/{url_idx}")
  203. async def push_model(
  204. form_data: PushModelForm,
  205. url_idx: Optional[int] = None,
  206. user=Depends(get_admin_user),
  207. ):
  208. if url_idx == None:
  209. if form_data.name in app.state.MODELS:
  210. url_idx = app.state.MODELS[form_data.name]["urls"][0]
  211. else:
  212. raise HTTPException(
  213. status_code=400,
  214. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  215. )
  216. url = app.state.OLLAMA_BASE_URLS[url_idx]
  217. print(url)
  218. r = None
  219. def get_request():
  220. nonlocal url
  221. nonlocal r
  222. try:
  223. def stream_content():
  224. for chunk in r.iter_content(chunk_size=8192):
  225. yield chunk
  226. r = requests.request(
  227. method="POST",
  228. url=f"{url}/api/push",
  229. data=form_data.model_dump_json(exclude_none=True),
  230. )
  231. r.raise_for_status()
  232. return StreamingResponse(
  233. stream_content(),
  234. status_code=r.status_code,
  235. headers=dict(r.headers),
  236. )
  237. except Exception as e:
  238. raise e
  239. try:
  240. return await run_in_threadpool(get_request)
  241. except Exception as e:
  242. print(e)
  243. error_detail = "Open WebUI: Server Connection Error"
  244. if r is not None:
  245. try:
  246. res = r.json()
  247. if "error" in res:
  248. error_detail = f"Ollama: {res['error']}"
  249. except:
  250. error_detail = f"Ollama: {e}"
  251. raise HTTPException(
  252. status_code=r.status_code if r else 500,
  253. detail=error_detail,
  254. )
  255. class CreateModelForm(BaseModel):
  256. name: str
  257. modelfile: Optional[str] = None
  258. stream: Optional[bool] = None
  259. path: Optional[str] = None
  260. @app.post("/api/create")
  261. @app.post("/api/create/{url_idx}")
  262. async def create_model(
  263. form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
  264. ):
  265. print(form_data)
  266. url = app.state.OLLAMA_BASE_URLS[url_idx]
  267. print(url)
  268. r = None
  269. def get_request():
  270. nonlocal url
  271. nonlocal r
  272. try:
  273. def stream_content():
  274. for chunk in r.iter_content(chunk_size=8192):
  275. yield chunk
  276. r = requests.request(
  277. method="POST",
  278. url=f"{url}/api/create",
  279. data=form_data.model_dump_json(exclude_none=True),
  280. stream=True,
  281. )
  282. r.raise_for_status()
  283. print(r)
  284. return StreamingResponse(
  285. stream_content(),
  286. status_code=r.status_code,
  287. headers=dict(r.headers),
  288. )
  289. except Exception as e:
  290. raise e
  291. try:
  292. return await run_in_threadpool(get_request)
  293. except Exception as e:
  294. print(e)
  295. error_detail = "Open WebUI: Server Connection Error"
  296. if r is not None:
  297. try:
  298. res = r.json()
  299. if "error" in res:
  300. error_detail = f"Ollama: {res['error']}"
  301. except:
  302. error_detail = f"Ollama: {e}"
  303. raise HTTPException(
  304. status_code=r.status_code if r else 500,
  305. detail=error_detail,
  306. )
  307. class CopyModelForm(BaseModel):
  308. source: str
  309. destination: str
  310. @app.post("/api/copy")
  311. @app.post("/api/copy/{url_idx}")
  312. async def copy_model(
  313. form_data: CopyModelForm,
  314. url_idx: Optional[int] = None,
  315. user=Depends(get_admin_user),
  316. ):
  317. if url_idx == None:
  318. if form_data.source in app.state.MODELS:
  319. url_idx = app.state.MODELS[form_data.source]["urls"][0]
  320. else:
  321. raise HTTPException(
  322. status_code=400,
  323. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
  324. )
  325. url = app.state.OLLAMA_BASE_URLS[url_idx]
  326. print(url)
  327. try:
  328. r = requests.request(
  329. method="POST",
  330. url=f"{url}/api/copy",
  331. data=form_data.model_dump_json(exclude_none=True),
  332. )
  333. r.raise_for_status()
  334. print(r.text)
  335. return True
  336. except Exception as e:
  337. print(e)
  338. error_detail = "Open WebUI: Server Connection Error"
  339. if r is not None:
  340. try:
  341. res = r.json()
  342. if "error" in res:
  343. error_detail = f"Ollama: {res['error']}"
  344. except:
  345. error_detail = f"Ollama: {e}"
  346. raise HTTPException(
  347. status_code=r.status_code if r else 500,
  348. detail=error_detail,
  349. )
  350. @app.delete("/api/delete")
  351. @app.delete("/api/delete/{url_idx}")
  352. async def delete_model(
  353. form_data: ModelNameForm,
  354. url_idx: Optional[int] = None,
  355. user=Depends(get_admin_user),
  356. ):
  357. if url_idx == None:
  358. if form_data.name in app.state.MODELS:
  359. url_idx = app.state.MODELS[form_data.name]["urls"][0]
  360. else:
  361. raise HTTPException(
  362. status_code=400,
  363. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  364. )
  365. url = app.state.OLLAMA_BASE_URLS[url_idx]
  366. print(url)
  367. try:
  368. r = requests.request(
  369. method="DELETE",
  370. url=f"{url}/api/delete",
  371. data=form_data.model_dump_json(exclude_none=True),
  372. )
  373. r.raise_for_status()
  374. print(r.text)
  375. return True
  376. except Exception as e:
  377. print(e)
  378. error_detail = "Open WebUI: Server Connection Error"
  379. if r is not None:
  380. try:
  381. res = r.json()
  382. if "error" in res:
  383. error_detail = f"Ollama: {res['error']}"
  384. except:
  385. error_detail = f"Ollama: {e}"
  386. raise HTTPException(
  387. status_code=r.status_code if r else 500,
  388. detail=error_detail,
  389. )
  390. @app.post("/api/show")
  391. async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_user)):
  392. if form_data.name not in app.state.MODELS:
  393. raise HTTPException(
  394. status_code=400,
  395. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  396. )
  397. url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
  398. url = app.state.OLLAMA_BASE_URLS[url_idx]
  399. print(url)
  400. try:
  401. r = requests.request(
  402. method="POST",
  403. url=f"{url}/api/show",
  404. data=form_data.model_dump_json(exclude_none=True),
  405. )
  406. r.raise_for_status()
  407. return r.json()
  408. except Exception as e:
  409. print(e)
  410. error_detail = "Open WebUI: Server Connection Error"
  411. if r is not None:
  412. try:
  413. res = r.json()
  414. if "error" in res:
  415. error_detail = f"Ollama: {res['error']}"
  416. except:
  417. error_detail = f"Ollama: {e}"
  418. raise HTTPException(
  419. status_code=r.status_code if r else 500,
  420. detail=error_detail,
  421. )
  422. class GenerateEmbeddingsForm(BaseModel):
  423. model: str
  424. prompt: str
  425. options: Optional[dict] = None
  426. keep_alive: Optional[Union[int, str]] = None
  427. @app.post("/api/embeddings")
  428. @app.post("/api/embeddings/{url_idx}")
  429. async def generate_embeddings(
  430. form_data: GenerateEmbeddingsForm,
  431. url_idx: Optional[int] = None,
  432. user=Depends(get_current_user),
  433. ):
  434. if url_idx == None:
  435. if form_data.model in app.state.MODELS:
  436. url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
  437. else:
  438. raise HTTPException(
  439. status_code=400,
  440. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  441. )
  442. url = app.state.OLLAMA_BASE_URLS[url_idx]
  443. print(url)
  444. try:
  445. r = requests.request(
  446. method="POST",
  447. url=f"{url}/api/embeddings",
  448. data=form_data.model_dump_json(exclude_none=True),
  449. )
  450. r.raise_for_status()
  451. return r.json()
  452. except Exception as e:
  453. print(e)
  454. error_detail = "Open WebUI: Server Connection Error"
  455. if r is not None:
  456. try:
  457. res = r.json()
  458. if "error" in res:
  459. error_detail = f"Ollama: {res['error']}"
  460. except:
  461. error_detail = f"Ollama: {e}"
  462. raise HTTPException(
  463. status_code=r.status_code if r else 500,
  464. detail=error_detail,
  465. )
  466. class GenerateCompletionForm(BaseModel):
  467. model: str
  468. prompt: str
  469. images: Optional[List[str]] = None
  470. format: Optional[str] = None
  471. options: Optional[dict] = None
  472. system: Optional[str] = None
  473. template: Optional[str] = None
  474. context: Optional[str] = None
  475. stream: Optional[bool] = True
  476. raw: Optional[bool] = None
  477. keep_alive: Optional[Union[int, str]] = None
  478. @app.post("/api/generate")
  479. @app.post("/api/generate/{url_idx}")
  480. async def generate_completion(
  481. form_data: GenerateCompletionForm,
  482. url_idx: Optional[int] = None,
  483. user=Depends(get_current_user),
  484. ):
  485. if url_idx == None:
  486. if form_data.model in app.state.MODELS:
  487. url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
  488. else:
  489. raise HTTPException(
  490. status_code=400,
  491. detail="error_detail",
  492. )
  493. url = app.state.OLLAMA_BASE_URLS[url_idx]
  494. print(url)
  495. r = None
  496. def get_request():
  497. nonlocal form_data
  498. nonlocal r
  499. request_id = str(uuid.uuid4())
  500. try:
  501. REQUEST_POOL.append(request_id)
  502. def stream_content():
  503. try:
  504. if form_data.stream:
  505. yield json.dumps({"id": request_id, "done": False}) + "\n"
  506. for chunk in r.iter_content(chunk_size=8192):
  507. if request_id in REQUEST_POOL:
  508. yield chunk
  509. else:
  510. print("User: canceled request")
  511. break
  512. finally:
  513. if hasattr(r, "close"):
  514. r.close()
  515. if request_id in REQUEST_POOL:
  516. REQUEST_POOL.remove(request_id)
  517. r = requests.request(
  518. method="POST",
  519. url=f"{url}/api/generate",
  520. data=form_data.model_dump_json(exclude_none=True),
  521. stream=True,
  522. )
  523. r.raise_for_status()
  524. return StreamingResponse(
  525. stream_content(),
  526. status_code=r.status_code,
  527. headers=dict(r.headers),
  528. )
  529. except Exception as e:
  530. raise e
  531. try:
  532. return await run_in_threadpool(get_request)
  533. except Exception as e:
  534. error_detail = "Open WebUI: Server Connection Error"
  535. if r is not None:
  536. try:
  537. res = r.json()
  538. if "error" in res:
  539. error_detail = f"Ollama: {res['error']}"
  540. except:
  541. error_detail = f"Ollama: {e}"
  542. raise HTTPException(
  543. status_code=r.status_code if r else 500,
  544. detail=error_detail,
  545. )
  546. class ChatMessage(BaseModel):
  547. role: str
  548. content: str
  549. images: Optional[List[str]] = None
  550. class GenerateChatCompletionForm(BaseModel):
  551. model: str
  552. messages: List[ChatMessage]
  553. format: Optional[str] = None
  554. options: Optional[dict] = None
  555. template: Optional[str] = None
  556. stream: Optional[bool] = True
  557. keep_alive: Optional[Union[int, str]] = None
  558. @app.post("/api/chat")
  559. @app.post("/api/chat/{url_idx}")
  560. async def generate_chat_completion(
  561. form_data: GenerateChatCompletionForm,
  562. url_idx: Optional[int] = None,
  563. user=Depends(get_current_user),
  564. ):
  565. if url_idx == None:
  566. if form_data.model in app.state.MODELS:
  567. url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
  568. else:
  569. raise HTTPException(
  570. status_code=400,
  571. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  572. )
  573. url = app.state.OLLAMA_BASE_URLS[url_idx]
  574. print(url)
  575. r = None
  576. print(form_data.model_dump_json(exclude_none=True))
  577. def get_request():
  578. nonlocal form_data
  579. nonlocal r
  580. request_id = str(uuid.uuid4())
  581. try:
  582. REQUEST_POOL.append(request_id)
  583. def stream_content():
  584. try:
  585. if form_data.stream:
  586. yield json.dumps({"id": request_id, "done": False}) + "\n"
  587. for chunk in r.iter_content(chunk_size=8192):
  588. if request_id in REQUEST_POOL:
  589. yield chunk
  590. else:
  591. print("User: canceled request")
  592. break
  593. finally:
  594. if hasattr(r, "close"):
  595. r.close()
  596. if request_id in REQUEST_POOL:
  597. REQUEST_POOL.remove(request_id)
  598. r = requests.request(
  599. method="POST",
  600. url=f"{url}/api/chat",
  601. data=form_data.model_dump_json(exclude_none=True),
  602. stream=True,
  603. )
  604. r.raise_for_status()
  605. return StreamingResponse(
  606. stream_content(),
  607. status_code=r.status_code,
  608. headers=dict(r.headers),
  609. )
  610. except Exception as e:
  611. raise e
  612. try:
  613. return await run_in_threadpool(get_request)
  614. except Exception as e:
  615. error_detail = "Open WebUI: Server Connection Error"
  616. if r is not None:
  617. try:
  618. res = r.json()
  619. if "error" in res:
  620. error_detail = f"Ollama: {res['error']}"
  621. except:
  622. error_detail = f"Ollama: {e}"
  623. raise HTTPException(
  624. status_code=r.status_code if r else 500,
  625. detail=error_detail,
  626. )
  627. # TODO: we should update this part once Ollama supports other types
  628. class OpenAIChatMessage(BaseModel):
  629. role: str
  630. content: str
  631. model_config = ConfigDict(extra="allow")
  632. class OpenAIChatCompletionForm(BaseModel):
  633. model: str
  634. messages: List[OpenAIChatMessage]
  635. model_config = ConfigDict(extra="allow")
  636. @app.post("/v1/chat/completions")
  637. @app.post("/v1/chat/completions/{url_idx}")
  638. async def generate_openai_chat_completion(
  639. form_data: OpenAIChatCompletionForm,
  640. url_idx: Optional[int] = None,
  641. user=Depends(get_current_user),
  642. ):
  643. if url_idx == None:
  644. if form_data.model in app.state.MODELS:
  645. url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
  646. else:
  647. raise HTTPException(
  648. status_code=400,
  649. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  650. )
  651. url = app.state.OLLAMA_BASE_URLS[url_idx]
  652. print(url)
  653. r = None
  654. def get_request():
  655. nonlocal form_data
  656. nonlocal r
  657. request_id = str(uuid.uuid4())
  658. try:
  659. REQUEST_POOL.append(request_id)
  660. def stream_content():
  661. try:
  662. if form_data.stream:
  663. yield json.dumps(
  664. {"request_id": request_id, "done": False}
  665. ) + "\n"
  666. for chunk in r.iter_content(chunk_size=8192):
  667. if request_id in REQUEST_POOL:
  668. yield chunk
  669. else:
  670. print("User: canceled request")
  671. break
  672. finally:
  673. if hasattr(r, "close"):
  674. r.close()
  675. if request_id in REQUEST_POOL:
  676. REQUEST_POOL.remove(request_id)
  677. r = requests.request(
  678. method="POST",
  679. url=f"{url}/v1/chat/completions",
  680. data=form_data.model_dump_json(exclude_none=True),
  681. stream=True,
  682. )
  683. r.raise_for_status()
  684. return StreamingResponse(
  685. stream_content(),
  686. status_code=r.status_code,
  687. headers=dict(r.headers),
  688. )
  689. except Exception as e:
  690. raise e
  691. try:
  692. return await run_in_threadpool(get_request)
  693. except Exception as e:
  694. error_detail = "Open WebUI: Server Connection Error"
  695. if r is not None:
  696. try:
  697. res = r.json()
  698. if "error" in res:
  699. error_detail = f"Ollama: {res['error']}"
  700. except:
  701. error_detail = f"Ollama: {e}"
  702. raise HTTPException(
  703. status_code=r.status_code if r else 500,
  704. detail=error_detail,
  705. )
  706. @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
  707. async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)):
  708. url = app.state.OLLAMA_BASE_URLS[0]
  709. target_url = f"{url}/{path}"
  710. body = await request.body()
  711. headers = dict(request.headers)
  712. if user.role in ["user", "admin"]:
  713. if path in ["pull", "delete", "push", "copy", "create"]:
  714. if user.role != "admin":
  715. raise HTTPException(
  716. status_code=status.HTTP_401_UNAUTHORIZED,
  717. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  718. )
  719. else:
  720. raise HTTPException(
  721. status_code=status.HTTP_401_UNAUTHORIZED,
  722. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  723. )
  724. headers.pop("host", None)
  725. headers.pop("authorization", None)
  726. headers.pop("origin", None)
  727. headers.pop("referer", None)
  728. r = None
  729. def get_request():
  730. nonlocal r
  731. request_id = str(uuid.uuid4())
  732. try:
  733. REQUEST_POOL.append(request_id)
  734. def stream_content():
  735. try:
  736. if path == "generate":
  737. data = json.loads(body.decode("utf-8"))
  738. if not ("stream" in data and data["stream"] == False):
  739. yield json.dumps({"id": request_id, "done": False}) + "\n"
  740. elif path == "chat":
  741. yield json.dumps({"id": request_id, "done": False}) + "\n"
  742. for chunk in r.iter_content(chunk_size=8192):
  743. if request_id in REQUEST_POOL:
  744. yield chunk
  745. else:
  746. print("User: canceled request")
  747. break
  748. finally:
  749. if hasattr(r, "close"):
  750. r.close()
  751. if request_id in REQUEST_POOL:
  752. REQUEST_POOL.remove(request_id)
  753. r = requests.request(
  754. method=request.method,
  755. url=target_url,
  756. data=body,
  757. headers=headers,
  758. stream=True,
  759. )
  760. r.raise_for_status()
  761. # r.close()
  762. return StreamingResponse(
  763. stream_content(),
  764. status_code=r.status_code,
  765. headers=dict(r.headers),
  766. )
  767. except Exception as e:
  768. raise e
  769. try:
  770. return await run_in_threadpool(get_request)
  771. except Exception as e:
  772. error_detail = "Open WebUI: Server Connection Error"
  773. if r is not None:
  774. try:
  775. res = r.json()
  776. if "error" in res:
  777. error_detail = f"Ollama: {res['error']}"
  778. except:
  779. error_detail = f"Ollama: {e}"
  780. raise HTTPException(
  781. status_code=r.status_code if r else 500,
  782. detail=error_detail,
  783. )