main.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963
  1. from fastapi import FastAPI, Request, Response, HTTPException, Depends, status
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from fastapi.responses import StreamingResponse
  4. from fastapi.concurrency import run_in_threadpool
  5. from pydantic import BaseModel, ConfigDict
  6. import random
  7. import requests
  8. import json
  9. import uuid
  10. import aiohttp
  11. import asyncio
  12. from apps.web.models.users import Users
  13. from constants import ERROR_MESSAGES
  14. from utils.utils import decode_token, get_current_user, get_admin_user
  15. from config import OLLAMA_BASE_URL, WEBUI_AUTH
  16. from typing import Optional, List, Union
  17. app = FastAPI()
  18. app.add_middleware(
  19. CORSMiddleware,
  20. allow_origins=["*"],
  21. allow_credentials=True,
  22. allow_methods=["*"],
  23. allow_headers=["*"],
  24. )
  25. app.state.OLLAMA_BASE_URL = OLLAMA_BASE_URL
  26. app.state.OLLAMA_BASE_URLS = [OLLAMA_BASE_URL]
  27. app.state.MODELS = {}
  28. REQUEST_POOL = []
  29. # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
  30. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
  31. # least connections, or least response time for better resource utilization and performance optimization.
  32. @app.middleware("http")
  33. async def check_url(request: Request, call_next):
  34. if len(app.state.MODELS) == 0:
  35. await get_all_models()
  36. else:
  37. pass
  38. response = await call_next(request)
  39. return response
  40. @app.get("/urls")
  41. async def get_ollama_api_urls(user=Depends(get_admin_user)):
  42. return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
  43. class UrlUpdateForm(BaseModel):
  44. urls: List[str]
  45. @app.post("/urls/update")
  46. async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
  47. app.state.OLLAMA_BASE_URLS = form_data.urls
  48. print(app.state.OLLAMA_BASE_URLS)
  49. return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
  50. @app.get("/cancel/{request_id}")
  51. async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
  52. if user:
  53. if request_id in REQUEST_POOL:
  54. REQUEST_POOL.remove(request_id)
  55. return True
  56. else:
  57. raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
  58. async def fetch_url(url):
  59. try:
  60. async with aiohttp.ClientSession() as session:
  61. async with session.get(url) as response:
  62. return await response.json()
  63. except Exception as e:
  64. # Handle connection error here
  65. print(f"Connection error: {e}")
  66. return None
  67. def merge_models_lists(model_lists):
  68. merged_models = {}
  69. for idx, model_list in enumerate(model_lists):
  70. for model in model_list:
  71. digest = model["digest"]
  72. if digest not in merged_models:
  73. model["urls"] = [idx]
  74. merged_models[digest] = model
  75. else:
  76. merged_models[digest]["urls"].append(idx)
  77. return list(merged_models.values())
  78. # user=Depends(get_current_user)
  79. async def get_all_models():
  80. print("get_all_models")
  81. tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS]
  82. responses = await asyncio.gather(*tasks)
  83. responses = list(filter(lambda x: x is not None, responses))
  84. models = {
  85. "models": merge_models_lists(
  86. map(lambda response: response["models"], responses)
  87. )
  88. }
  89. app.state.MODELS = {model["model"]: model for model in models["models"]}
  90. return models
  91. @app.get("/api/tags")
  92. @app.get("/api/tags/{url_idx}")
  93. async def get_ollama_tags(
  94. url_idx: Optional[int] = None, user=Depends(get_current_user)
  95. ):
  96. if url_idx == None:
  97. return await get_all_models()
  98. else:
  99. url = app.state.OLLAMA_BASE_URLS[url_idx]
  100. try:
  101. r = requests.request(method="GET", url=f"{url}/api/tags")
  102. r.raise_for_status()
  103. return r.json()
  104. except Exception as e:
  105. print(e)
  106. error_detail = "Open WebUI: Server Connection Error"
  107. if r is not None:
  108. try:
  109. res = r.json()
  110. if "error" in res:
  111. error_detail = f"Ollama: {res['error']}"
  112. except:
  113. error_detail = f"Ollama: {e}"
  114. raise HTTPException(
  115. status_code=r.status_code if r else 500,
  116. detail=error_detail,
  117. )
  118. @app.get("/api/version")
  119. @app.get("/api/version/{url_idx}")
  120. async def get_ollama_versions(url_idx: Optional[int] = None):
  121. if url_idx == None:
  122. # returns lowest version
  123. tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS]
  124. responses = await asyncio.gather(*tasks)
  125. responses = list(filter(lambda x: x is not None, responses))
  126. lowest_version = min(
  127. responses, key=lambda x: tuple(map(int, x["version"].split(".")))
  128. )
  129. return {"version": lowest_version["version"]}
  130. else:
  131. url = app.state.OLLAMA_BASE_URLS[url_idx]
  132. try:
  133. r = requests.request(method="GET", url=f"{url}/api/version")
  134. r.raise_for_status()
  135. return r.json()
  136. except Exception as e:
  137. print(e)
  138. error_detail = "Open WebUI: Server Connection Error"
  139. if r is not None:
  140. try:
  141. res = r.json()
  142. if "error" in res:
  143. error_detail = f"Ollama: {res['error']}"
  144. except:
  145. error_detail = f"Ollama: {e}"
  146. raise HTTPException(
  147. status_code=r.status_code if r else 500,
  148. detail=error_detail,
  149. )
  150. class ModelNameForm(BaseModel):
  151. name: str
  152. @app.post("/api/pull")
  153. @app.post("/api/pull/{url_idx}")
  154. async def pull_model(
  155. form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
  156. ):
  157. url = app.state.OLLAMA_BASE_URLS[url_idx]
  158. print(url)
  159. r = None
  160. def get_request():
  161. nonlocal url
  162. nonlocal r
  163. try:
  164. def stream_content():
  165. for chunk in r.iter_content(chunk_size=8192):
  166. yield chunk
  167. r = requests.request(
  168. method="POST",
  169. url=f"{url}/api/pull",
  170. data=form_data.model_dump_json(exclude_none=True),
  171. stream=True,
  172. )
  173. r.raise_for_status()
  174. return StreamingResponse(
  175. stream_content(),
  176. status_code=r.status_code,
  177. headers=dict(r.headers),
  178. )
  179. except Exception as e:
  180. raise e
  181. try:
  182. return await run_in_threadpool(get_request)
  183. except Exception as e:
  184. print(e)
  185. error_detail = "Open WebUI: Server Connection Error"
  186. if r is not None:
  187. try:
  188. res = r.json()
  189. if "error" in res:
  190. error_detail = f"Ollama: {res['error']}"
  191. except:
  192. error_detail = f"Ollama: {e}"
  193. raise HTTPException(
  194. status_code=r.status_code if r else 500,
  195. detail=error_detail,
  196. )
  197. class PushModelForm(BaseModel):
  198. name: str
  199. insecure: Optional[bool] = None
  200. stream: Optional[bool] = None
  201. @app.delete("/api/push")
  202. @app.delete("/api/push/{url_idx}")
  203. async def push_model(
  204. form_data: PushModelForm,
  205. url_idx: Optional[int] = None,
  206. user=Depends(get_admin_user),
  207. ):
  208. if url_idx == None:
  209. if form_data.name in app.state.MODELS:
  210. url_idx = app.state.MODELS[form_data.name]["urls"][0]
  211. else:
  212. raise HTTPException(
  213. status_code=400,
  214. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  215. )
  216. url = app.state.OLLAMA_BASE_URLS[url_idx]
  217. r = None
  218. def get_request():
  219. nonlocal url
  220. nonlocal r
  221. try:
  222. def stream_content():
  223. for chunk in r.iter_content(chunk_size=8192):
  224. yield chunk
  225. r = requests.request(
  226. method="POST",
  227. url=f"{url}/api/push",
  228. data=form_data.model_dump_json(exclude_none=True),
  229. )
  230. r.raise_for_status()
  231. return StreamingResponse(
  232. stream_content(),
  233. status_code=r.status_code,
  234. headers=dict(r.headers),
  235. )
  236. except Exception as e:
  237. raise e
  238. try:
  239. return await run_in_threadpool(get_request)
  240. except Exception as e:
  241. print(e)
  242. error_detail = "Open WebUI: Server Connection Error"
  243. if r is not None:
  244. try:
  245. res = r.json()
  246. if "error" in res:
  247. error_detail = f"Ollama: {res['error']}"
  248. except:
  249. error_detail = f"Ollama: {e}"
  250. raise HTTPException(
  251. status_code=r.status_code if r else 500,
  252. detail=error_detail,
  253. )
  254. class CreateModelForm(BaseModel):
  255. name: str
  256. modelfile: Optional[str] = None
  257. stream: Optional[bool] = None
  258. path: Optional[str] = None
  259. @app.post("/api/create")
  260. @app.post("/api/create/{url_idx}")
  261. async def create_model(
  262. form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
  263. ):
  264. print(form_data)
  265. url = app.state.OLLAMA_BASE_URLS[url_idx]
  266. r = None
  267. def get_request():
  268. nonlocal url
  269. nonlocal r
  270. try:
  271. def stream_content():
  272. for chunk in r.iter_content(chunk_size=8192):
  273. yield chunk
  274. r = requests.request(
  275. method="POST",
  276. url=f"{url}/api/create",
  277. data=form_data.model_dump_json(exclude_none=True),
  278. stream=True,
  279. )
  280. r.raise_for_status()
  281. print(r)
  282. return StreamingResponse(
  283. stream_content(),
  284. status_code=r.status_code,
  285. headers=dict(r.headers),
  286. )
  287. except Exception as e:
  288. raise e
  289. try:
  290. return await run_in_threadpool(get_request)
  291. except Exception as e:
  292. print(e)
  293. error_detail = "Open WebUI: Server Connection Error"
  294. if r is not None:
  295. try:
  296. res = r.json()
  297. if "error" in res:
  298. error_detail = f"Ollama: {res['error']}"
  299. except:
  300. error_detail = f"Ollama: {e}"
  301. raise HTTPException(
  302. status_code=r.status_code if r else 500,
  303. detail=error_detail,
  304. )
  305. class CopyModelForm(BaseModel):
  306. source: str
  307. destination: str
  308. @app.post("/api/copy")
  309. @app.post("/api/copy/{url_idx}")
  310. async def copy_model(
  311. form_data: CopyModelForm,
  312. url_idx: Optional[int] = None,
  313. user=Depends(get_admin_user),
  314. ):
  315. if url_idx == None:
  316. if form_data.source in app.state.MODELS:
  317. url_idx = app.state.MODELS[form_data.source]["urls"][0]
  318. else:
  319. raise HTTPException(
  320. status_code=400,
  321. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
  322. )
  323. url = app.state.OLLAMA_BASE_URLS[url_idx]
  324. try:
  325. r = requests.request(
  326. method="POST",
  327. url=f"{url}/api/copy",
  328. data=form_data.model_dump_json(exclude_none=True),
  329. )
  330. r.raise_for_status()
  331. print(r.text)
  332. return True
  333. except Exception as e:
  334. print(e)
  335. error_detail = "Open WebUI: Server Connection Error"
  336. if r is not None:
  337. try:
  338. res = r.json()
  339. if "error" in res:
  340. error_detail = f"Ollama: {res['error']}"
  341. except:
  342. error_detail = f"Ollama: {e}"
  343. raise HTTPException(
  344. status_code=r.status_code if r else 500,
  345. detail=error_detail,
  346. )
  347. @app.delete("/api/delete")
  348. @app.delete("/api/delete/{url_idx}")
  349. async def delete_model(
  350. form_data: ModelNameForm,
  351. url_idx: Optional[int] = None,
  352. user=Depends(get_admin_user),
  353. ):
  354. if url_idx == None:
  355. if form_data.name in app.state.MODELS:
  356. url_idx = app.state.MODELS[form_data.name]["urls"][0]
  357. else:
  358. raise HTTPException(
  359. status_code=400,
  360. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  361. )
  362. url = app.state.OLLAMA_BASE_URLS[url_idx]
  363. print(url)
  364. try:
  365. r = requests.request(
  366. method="DELETE",
  367. url=f"{url}/api/delete",
  368. data=form_data.model_dump_json(exclude_none=True),
  369. )
  370. r.raise_for_status()
  371. print(r.text)
  372. return True
  373. except Exception as e:
  374. print(e)
  375. error_detail = "Open WebUI: Server Connection Error"
  376. if r is not None:
  377. try:
  378. res = r.json()
  379. if "error" in res:
  380. error_detail = f"Ollama: {res['error']}"
  381. except:
  382. error_detail = f"Ollama: {e}"
  383. raise HTTPException(
  384. status_code=r.status_code if r else 500,
  385. detail=error_detail,
  386. )
  387. @app.post("/api/show")
  388. async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_user)):
  389. if form_data.name not in app.state.MODELS:
  390. raise HTTPException(
  391. status_code=400,
  392. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  393. )
  394. url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
  395. url = app.state.OLLAMA_BASE_URLS[url_idx]
  396. try:
  397. r = requests.request(
  398. method="POST",
  399. url=f"{url}/api/show",
  400. data=form_data.model_dump_json(exclude_none=True),
  401. )
  402. r.raise_for_status()
  403. return r.json()
  404. except Exception as e:
  405. print(e)
  406. error_detail = "Open WebUI: Server Connection Error"
  407. if r is not None:
  408. try:
  409. res = r.json()
  410. if "error" in res:
  411. error_detail = f"Ollama: {res['error']}"
  412. except:
  413. error_detail = f"Ollama: {e}"
  414. raise HTTPException(
  415. status_code=r.status_code if r else 500,
  416. detail=error_detail,
  417. )
  418. class GenerateEmbeddingsForm(BaseModel):
  419. model: str
  420. prompt: str
  421. options: Optional[dict] = None
  422. keep_alive: Optional[Union[int, str]] = None
  423. @app.post("/api/embeddings")
  424. @app.post("/api/embeddings/{url_idx}")
  425. async def generate_embeddings(
  426. form_data: GenerateEmbeddingsForm,
  427. url_idx: Optional[int] = None,
  428. user=Depends(get_current_user),
  429. ):
  430. if url_idx == None:
  431. if form_data.model in app.state.MODELS:
  432. url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
  433. else:
  434. raise HTTPException(
  435. status_code=400,
  436. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  437. )
  438. url = app.state.OLLAMA_BASE_URLS[url_idx]
  439. try:
  440. r = requests.request(
  441. method="POST",
  442. url=f"{url}/api/embeddings",
  443. data=form_data.model_dump_json(exclude_none=True),
  444. )
  445. r.raise_for_status()
  446. return r.json()
  447. except Exception as e:
  448. print(e)
  449. error_detail = "Open WebUI: Server Connection Error"
  450. if r is not None:
  451. try:
  452. res = r.json()
  453. if "error" in res:
  454. error_detail = f"Ollama: {res['error']}"
  455. except:
  456. error_detail = f"Ollama: {e}"
  457. raise HTTPException(
  458. status_code=r.status_code if r else 500,
  459. detail=error_detail,
  460. )
  461. class GenerateCompletionForm(BaseModel):
  462. model: str
  463. prompt: str
  464. images: Optional[List[str]] = None
  465. format: Optional[str] = None
  466. options: Optional[dict] = None
  467. system: Optional[str] = None
  468. template: Optional[str] = None
  469. context: Optional[str] = None
  470. stream: Optional[bool] = True
  471. raw: Optional[bool] = None
  472. keep_alive: Optional[Union[int, str]] = None
  473. @app.post("/api/generate")
  474. @app.post("/api/generate/{url_idx}")
  475. async def generate_completion(
  476. form_data: GenerateCompletionForm,
  477. url_idx: Optional[int] = None,
  478. user=Depends(get_current_user),
  479. ):
  480. if url_idx == None:
  481. if form_data.model in app.state.MODELS:
  482. url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
  483. else:
  484. raise HTTPException(
  485. status_code=400,
  486. detail="error_detail",
  487. )
  488. url = app.state.OLLAMA_BASE_URLS[url_idx]
  489. r = None
  490. def get_request():
  491. nonlocal form_data
  492. nonlocal r
  493. request_id = str(uuid.uuid4())
  494. try:
  495. REQUEST_POOL.append(request_id)
  496. def stream_content():
  497. try:
  498. if form_data.stream:
  499. yield json.dumps({"id": request_id, "done": False}) + "\n"
  500. for chunk in r.iter_content(chunk_size=8192):
  501. if request_id in REQUEST_POOL:
  502. yield chunk
  503. else:
  504. print("User: canceled request")
  505. break
  506. finally:
  507. if hasattr(r, "close"):
  508. r.close()
  509. if request_id in REQUEST_POOL:
  510. REQUEST_POOL.remove(request_id)
  511. r = requests.request(
  512. method="POST",
  513. url=f"{url}/api/generate",
  514. data=form_data.model_dump_json(exclude_none=True),
  515. stream=True,
  516. )
  517. r.raise_for_status()
  518. return StreamingResponse(
  519. stream_content(),
  520. status_code=r.status_code,
  521. headers=dict(r.headers),
  522. )
  523. except Exception as e:
  524. raise e
  525. try:
  526. return await run_in_threadpool(get_request)
  527. except Exception as e:
  528. error_detail = "Open WebUI: Server Connection Error"
  529. if r is not None:
  530. try:
  531. res = r.json()
  532. if "error" in res:
  533. error_detail = f"Ollama: {res['error']}"
  534. except:
  535. error_detail = f"Ollama: {e}"
  536. raise HTTPException(
  537. status_code=r.status_code if r else 500,
  538. detail=error_detail,
  539. )
  540. class ChatMessage(BaseModel):
  541. role: str
  542. content: str
  543. images: Optional[List[str]] = None
  544. class GenerateChatCompletionForm(BaseModel):
  545. model: str
  546. messages: List[ChatMessage]
  547. format: Optional[str] = None
  548. options: Optional[dict] = None
  549. template: Optional[str] = None
  550. stream: Optional[bool] = True
  551. keep_alive: Optional[Union[int, str]] = None
  552. @app.post("/api/chat")
  553. @app.post("/api/chat/{url_idx}")
  554. async def generate_chat_completion(
  555. form_data: GenerateChatCompletionForm,
  556. url_idx: Optional[int] = None,
  557. user=Depends(get_current_user),
  558. ):
  559. if url_idx == None:
  560. if form_data.model in app.state.MODELS:
  561. url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
  562. else:
  563. raise HTTPException(
  564. status_code=400,
  565. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  566. )
  567. url = app.state.OLLAMA_BASE_URLS[url_idx]
  568. r = None
  569. print(form_data.model_dump_json(exclude_none=True))
  570. def get_request():
  571. nonlocal form_data
  572. nonlocal r
  573. request_id = str(uuid.uuid4())
  574. try:
  575. REQUEST_POOL.append(request_id)
  576. def stream_content():
  577. try:
  578. if form_data.stream:
  579. yield json.dumps({"id": request_id, "done": False}) + "\n"
  580. for chunk in r.iter_content(chunk_size=8192):
  581. if request_id in REQUEST_POOL:
  582. yield chunk
  583. else:
  584. print("User: canceled request")
  585. break
  586. finally:
  587. if hasattr(r, "close"):
  588. r.close()
  589. if request_id in REQUEST_POOL:
  590. REQUEST_POOL.remove(request_id)
  591. r = requests.request(
  592. method="POST",
  593. url=f"{url}/api/chat",
  594. data=form_data.model_dump_json(exclude_none=True),
  595. stream=True,
  596. )
  597. r.raise_for_status()
  598. return StreamingResponse(
  599. stream_content(),
  600. status_code=r.status_code,
  601. headers=dict(r.headers),
  602. )
  603. except Exception as e:
  604. raise e
  605. try:
  606. return await run_in_threadpool(get_request)
  607. except Exception as e:
  608. error_detail = "Open WebUI: Server Connection Error"
  609. if r is not None:
  610. try:
  611. res = r.json()
  612. if "error" in res:
  613. error_detail = f"Ollama: {res['error']}"
  614. except:
  615. error_detail = f"Ollama: {e}"
  616. raise HTTPException(
  617. status_code=r.status_code if r else 500,
  618. detail=error_detail,
  619. )
  620. # TODO: we should update this part once Ollama supports other types
  621. class OpenAIChatMessage(BaseModel):
  622. role: str
  623. content: str
  624. model_config = ConfigDict(extra="allow")
  625. class OpenAIChatCompletionForm(BaseModel):
  626. model: str
  627. messages: List[OpenAIChatMessage]
  628. model_config = ConfigDict(extra="allow")
  629. @app.post("/v1/chat/completions")
  630. @app.post("/v1/chat/completions/{url_idx}")
  631. async def generate_openai_chat_completion(
  632. form_data: OpenAIChatCompletionForm,
  633. url_idx: Optional[int] = None,
  634. user=Depends(get_current_user),
  635. ):
  636. if url_idx == None:
  637. if form_data.model in app.state.MODELS:
  638. url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
  639. else:
  640. raise HTTPException(
  641. status_code=400,
  642. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  643. )
  644. url = app.state.OLLAMA_BASE_URLS[url_idx]
  645. r = None
  646. def get_request():
  647. nonlocal form_data
  648. nonlocal r
  649. request_id = str(uuid.uuid4())
  650. try:
  651. REQUEST_POOL.append(request_id)
  652. def stream_content():
  653. try:
  654. if form_data.stream:
  655. yield json.dumps(
  656. {"request_id": request_id, "done": False}
  657. ) + "\n"
  658. for chunk in r.iter_content(chunk_size=8192):
  659. if request_id in REQUEST_POOL:
  660. yield chunk
  661. else:
  662. print("User: canceled request")
  663. break
  664. finally:
  665. if hasattr(r, "close"):
  666. r.close()
  667. if request_id in REQUEST_POOL:
  668. REQUEST_POOL.remove(request_id)
  669. r = requests.request(
  670. method="POST",
  671. url=f"{url}/v1/chat/completions",
  672. data=form_data.model_dump_json(exclude_none=True),
  673. stream=True,
  674. )
  675. r.raise_for_status()
  676. return StreamingResponse(
  677. stream_content(),
  678. status_code=r.status_code,
  679. headers=dict(r.headers),
  680. )
  681. except Exception as e:
  682. raise e
  683. try:
  684. return await run_in_threadpool(get_request)
  685. except Exception as e:
  686. error_detail = "Open WebUI: Server Connection Error"
  687. if r is not None:
  688. try:
  689. res = r.json()
  690. if "error" in res:
  691. error_detail = f"Ollama: {res['error']}"
  692. except:
  693. error_detail = f"Ollama: {e}"
  694. raise HTTPException(
  695. status_code=r.status_code if r else 500,
  696. detail=error_detail,
  697. )
  698. @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
  699. async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)):
  700. url = app.state.OLLAMA_BASE_URLS[0]
  701. target_url = f"{url}/{path}"
  702. body = await request.body()
  703. headers = dict(request.headers)
  704. if user.role in ["user", "admin"]:
  705. if path in ["pull", "delete", "push", "copy", "create"]:
  706. if user.role != "admin":
  707. raise HTTPException(
  708. status_code=status.HTTP_401_UNAUTHORIZED,
  709. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  710. )
  711. else:
  712. raise HTTPException(
  713. status_code=status.HTTP_401_UNAUTHORIZED,
  714. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  715. )
  716. headers.pop("host", None)
  717. headers.pop("authorization", None)
  718. headers.pop("origin", None)
  719. headers.pop("referer", None)
  720. r = None
  721. def get_request():
  722. nonlocal r
  723. request_id = str(uuid.uuid4())
  724. try:
  725. REQUEST_POOL.append(request_id)
  726. def stream_content():
  727. try:
  728. if path == "generate":
  729. data = json.loads(body.decode("utf-8"))
  730. if not ("stream" in data and data["stream"] == False):
  731. yield json.dumps({"id": request_id, "done": False}) + "\n"
  732. elif path == "chat":
  733. yield json.dumps({"id": request_id, "done": False}) + "\n"
  734. for chunk in r.iter_content(chunk_size=8192):
  735. if request_id in REQUEST_POOL:
  736. yield chunk
  737. else:
  738. print("User: canceled request")
  739. break
  740. finally:
  741. if hasattr(r, "close"):
  742. r.close()
  743. if request_id in REQUEST_POOL:
  744. REQUEST_POOL.remove(request_id)
  745. r = requests.request(
  746. method=request.method,
  747. url=target_url,
  748. data=body,
  749. headers=headers,
  750. stream=True,
  751. )
  752. r.raise_for_status()
  753. # r.close()
  754. return StreamingResponse(
  755. stream_content(),
  756. status_code=r.status_code,
  757. headers=dict(r.headers),
  758. )
  759. except Exception as e:
  760. raise e
  761. try:
  762. return await run_in_threadpool(get_request)
  763. except Exception as e:
  764. error_detail = "Open WebUI: Server Connection Error"
  765. if r is not None:
  766. try:
  767. res = r.json()
  768. if "error" in res:
  769. error_detail = f"Ollama: {res['error']}"
  770. except:
  771. error_detail = f"Ollama: {e}"
  772. raise HTTPException(
  773. status_code=r.status_code if r else 500,
  774. detail=error_detail,
  775. )