main.py 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280
  1. from fastapi import (
  2. FastAPI,
  3. Request,
  4. Response,
  5. HTTPException,
  6. Depends,
  7. status,
  8. UploadFile,
  9. File,
  10. BackgroundTasks,
  11. )
  12. from fastapi.middleware.cors import CORSMiddleware
  13. from fastapi.responses import StreamingResponse
  14. from fastapi.concurrency import run_in_threadpool
  15. from pydantic import BaseModel, ConfigDict
  16. import os
  17. import re
  18. import copy
  19. import random
  20. import requests
  21. import json
  22. import uuid
  23. import aiohttp
  24. import asyncio
  25. import logging
  26. import time
  27. from urllib.parse import urlparse
  28. from typing import Optional, List, Union
  29. from starlette.background import BackgroundTask
  30. from apps.webui.models.models import Models
  31. from apps.webui.models.users import Users
  32. from constants import ERROR_MESSAGES
  33. from utils.utils import (
  34. decode_token,
  35. get_current_user,
  36. get_verified_user,
  37. get_admin_user,
  38. )
  39. from utils.task import prompt_template
  40. from config import (
  41. SRC_LOG_LEVELS,
  42. OLLAMA_BASE_URLS,
  43. ENABLE_OLLAMA_API,
  44. AIOHTTP_CLIENT_TIMEOUT,
  45. ENABLE_MODEL_FILTER,
  46. MODEL_FILTER_LIST,
  47. UPLOAD_DIR,
  48. AppConfig,
  49. )
  50. from utils.misc import calculate_sha256, add_or_update_system_message
  51. log = logging.getLogger(__name__)
  52. log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
  53. app = FastAPI()
  54. app.add_middleware(
  55. CORSMiddleware,
  56. allow_origins=["*"],
  57. allow_credentials=True,
  58. allow_methods=["*"],
  59. allow_headers=["*"],
  60. )
  61. app.state.config = AppConfig()
  62. app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
  63. app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
  64. app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
  65. app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
  66. app.state.MODELS = {}
  67. # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
  68. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
  69. # least connections, or least response time for better resource utilization and performance optimization.
  70. @app.middleware("http")
  71. async def check_url(request: Request, call_next):
  72. if len(app.state.MODELS) == 0:
  73. await get_all_models()
  74. else:
  75. pass
  76. response = await call_next(request)
  77. return response
  78. @app.head("/")
  79. @app.get("/")
  80. async def get_status():
  81. return {"status": True}
  82. @app.get("/config")
  83. async def get_config(user=Depends(get_admin_user)):
  84. return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
  85. class OllamaConfigForm(BaseModel):
  86. enable_ollama_api: Optional[bool] = None
  87. @app.post("/config/update")
  88. async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
  89. app.state.config.ENABLE_OLLAMA_API = form_data.enable_ollama_api
  90. return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
  91. @app.get("/urls")
  92. async def get_ollama_api_urls(user=Depends(get_admin_user)):
  93. return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
  94. class UrlUpdateForm(BaseModel):
  95. urls: List[str]
  96. @app.post("/urls/update")
  97. async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
  98. app.state.config.OLLAMA_BASE_URLS = form_data.urls
  99. log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}")
  100. return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
  101. async def fetch_url(url):
  102. timeout = aiohttp.ClientTimeout(total=5)
  103. try:
  104. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  105. async with session.get(url) as response:
  106. return await response.json()
  107. except Exception as e:
  108. # Handle connection error here
  109. log.error(f"Connection error: {e}")
  110. return None
  111. async def cleanup_response(
  112. response: Optional[aiohttp.ClientResponse],
  113. session: Optional[aiohttp.ClientSession],
  114. ):
  115. if response:
  116. response.close()
  117. if session:
  118. await session.close()
  119. async def post_streaming_url(url: str, payload: str, stream: bool = True):
  120. r = None
  121. try:
  122. session = aiohttp.ClientSession(
  123. trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  124. )
  125. r = await session.post(url, data=payload)
  126. r.raise_for_status()
  127. if stream:
  128. return StreamingResponse(
  129. r.content,
  130. status_code=r.status,
  131. headers=dict(r.headers),
  132. background=BackgroundTask(
  133. cleanup_response, response=r, session=session
  134. ),
  135. )
  136. else:
  137. res = await r.json()
  138. await cleanup_response(r, session)
  139. return res
  140. except Exception as e:
  141. error_detail = "Open WebUI: Server Connection Error"
  142. if r is not None:
  143. try:
  144. res = await r.json()
  145. if "error" in res:
  146. error_detail = f"Ollama: {res['error']}"
  147. except:
  148. error_detail = f"Ollama: {e}"
  149. raise HTTPException(
  150. status_code=r.status if r else 500,
  151. detail=error_detail,
  152. )
  153. def merge_models_lists(model_lists):
  154. merged_models = {}
  155. for idx, model_list in enumerate(model_lists):
  156. if model_list is not None:
  157. for model in model_list:
  158. digest = model["digest"]
  159. if digest not in merged_models:
  160. model["urls"] = [idx]
  161. merged_models[digest] = model
  162. else:
  163. merged_models[digest]["urls"].append(idx)
  164. return list(merged_models.values())
  165. async def get_all_models():
  166. log.info("get_all_models()")
  167. if app.state.config.ENABLE_OLLAMA_API:
  168. tasks = [
  169. fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS
  170. ]
  171. responses = await asyncio.gather(*tasks)
  172. models = {
  173. "models": merge_models_lists(
  174. map(
  175. lambda response: response["models"] if response else None, responses
  176. )
  177. )
  178. }
  179. else:
  180. models = {"models": []}
  181. app.state.MODELS = {model["model"]: model for model in models["models"]}
  182. return models
  183. @app.get("/api/tags")
  184. @app.get("/api/tags/{url_idx}")
  185. async def get_ollama_tags(
  186. url_idx: Optional[int] = None, user=Depends(get_verified_user)
  187. ):
  188. if url_idx == None:
  189. models = await get_all_models()
  190. if app.state.config.ENABLE_MODEL_FILTER:
  191. if user.role == "user":
  192. models["models"] = list(
  193. filter(
  194. lambda model: model["name"]
  195. in app.state.config.MODEL_FILTER_LIST,
  196. models["models"],
  197. )
  198. )
  199. return models
  200. return models
  201. else:
  202. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  203. r = None
  204. try:
  205. r = requests.request(method="GET", url=f"{url}/api/tags")
  206. r.raise_for_status()
  207. return r.json()
  208. except Exception as e:
  209. log.exception(e)
  210. error_detail = "Open WebUI: Server Connection Error"
  211. if r is not None:
  212. try:
  213. res = r.json()
  214. if "error" in res:
  215. error_detail = f"Ollama: {res['error']}"
  216. except:
  217. error_detail = f"Ollama: {e}"
  218. raise HTTPException(
  219. status_code=r.status_code if r else 500,
  220. detail=error_detail,
  221. )
  222. @app.get("/api/version")
  223. @app.get("/api/version/{url_idx}")
  224. async def get_ollama_versions(url_idx: Optional[int] = None):
  225. if app.state.config.ENABLE_OLLAMA_API:
  226. if url_idx == None:
  227. # returns lowest version
  228. tasks = [
  229. fetch_url(f"{url}/api/version")
  230. for url in app.state.config.OLLAMA_BASE_URLS
  231. ]
  232. responses = await asyncio.gather(*tasks)
  233. responses = list(filter(lambda x: x is not None, responses))
  234. if len(responses) > 0:
  235. lowest_version = min(
  236. responses,
  237. key=lambda x: tuple(
  238. map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
  239. ),
  240. )
  241. return {"version": lowest_version["version"]}
  242. else:
  243. raise HTTPException(
  244. status_code=500,
  245. detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
  246. )
  247. else:
  248. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  249. r = None
  250. try:
  251. r = requests.request(method="GET", url=f"{url}/api/version")
  252. r.raise_for_status()
  253. return r.json()
  254. except Exception as e:
  255. log.exception(e)
  256. error_detail = "Open WebUI: Server Connection Error"
  257. if r is not None:
  258. try:
  259. res = r.json()
  260. if "error" in res:
  261. error_detail = f"Ollama: {res['error']}"
  262. except:
  263. error_detail = f"Ollama: {e}"
  264. raise HTTPException(
  265. status_code=r.status_code if r else 500,
  266. detail=error_detail,
  267. )
  268. else:
  269. return {"version": False}
  270. class ModelNameForm(BaseModel):
  271. name: str
  272. @app.post("/api/pull")
  273. @app.post("/api/pull/{url_idx}")
  274. async def pull_model(
  275. form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
  276. ):
  277. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  278. log.info(f"url: {url}")
  279. r = None
  280. # Admin should be able to pull models from any source
  281. payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
  282. return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
  283. class PushModelForm(BaseModel):
  284. name: str
  285. insecure: Optional[bool] = None
  286. stream: Optional[bool] = None
  287. @app.delete("/api/push")
  288. @app.delete("/api/push/{url_idx}")
  289. async def push_model(
  290. form_data: PushModelForm,
  291. url_idx: Optional[int] = None,
  292. user=Depends(get_admin_user),
  293. ):
  294. if url_idx == None:
  295. if form_data.name in app.state.MODELS:
  296. url_idx = app.state.MODELS[form_data.name]["urls"][0]
  297. else:
  298. raise HTTPException(
  299. status_code=400,
  300. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  301. )
  302. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  303. log.debug(f"url: {url}")
  304. return await post_streaming_url(
  305. f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
  306. )
  307. class CreateModelForm(BaseModel):
  308. name: str
  309. modelfile: Optional[str] = None
  310. stream: Optional[bool] = None
  311. path: Optional[str] = None
  312. @app.post("/api/create")
  313. @app.post("/api/create/{url_idx}")
  314. async def create_model(
  315. form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
  316. ):
  317. log.debug(f"form_data: {form_data}")
  318. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  319. log.info(f"url: {url}")
  320. return await post_streaming_url(
  321. f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
  322. )
  323. class CopyModelForm(BaseModel):
  324. source: str
  325. destination: str
  326. @app.post("/api/copy")
  327. @app.post("/api/copy/{url_idx}")
  328. async def copy_model(
  329. form_data: CopyModelForm,
  330. url_idx: Optional[int] = None,
  331. user=Depends(get_admin_user),
  332. ):
  333. if url_idx == None:
  334. if form_data.source in app.state.MODELS:
  335. url_idx = app.state.MODELS[form_data.source]["urls"][0]
  336. else:
  337. raise HTTPException(
  338. status_code=400,
  339. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
  340. )
  341. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  342. log.info(f"url: {url}")
  343. try:
  344. r = requests.request(
  345. method="POST",
  346. url=f"{url}/api/copy",
  347. data=form_data.model_dump_json(exclude_none=True).encode(),
  348. )
  349. r.raise_for_status()
  350. log.debug(f"r.text: {r.text}")
  351. return True
  352. except Exception as e:
  353. log.exception(e)
  354. error_detail = "Open WebUI: Server Connection Error"
  355. if r is not None:
  356. try:
  357. res = r.json()
  358. if "error" in res:
  359. error_detail = f"Ollama: {res['error']}"
  360. except:
  361. error_detail = f"Ollama: {e}"
  362. raise HTTPException(
  363. status_code=r.status_code if r else 500,
  364. detail=error_detail,
  365. )
  366. @app.delete("/api/delete")
  367. @app.delete("/api/delete/{url_idx}")
  368. async def delete_model(
  369. form_data: ModelNameForm,
  370. url_idx: Optional[int] = None,
  371. user=Depends(get_admin_user),
  372. ):
  373. if url_idx == None:
  374. if form_data.name in app.state.MODELS:
  375. url_idx = app.state.MODELS[form_data.name]["urls"][0]
  376. else:
  377. raise HTTPException(
  378. status_code=400,
  379. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  380. )
  381. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  382. log.info(f"url: {url}")
  383. try:
  384. r = requests.request(
  385. method="DELETE",
  386. url=f"{url}/api/delete",
  387. data=form_data.model_dump_json(exclude_none=True).encode(),
  388. )
  389. r.raise_for_status()
  390. log.debug(f"r.text: {r.text}")
  391. return True
  392. except Exception as e:
  393. log.exception(e)
  394. error_detail = "Open WebUI: Server Connection Error"
  395. if r is not None:
  396. try:
  397. res = r.json()
  398. if "error" in res:
  399. error_detail = f"Ollama: {res['error']}"
  400. except:
  401. error_detail = f"Ollama: {e}"
  402. raise HTTPException(
  403. status_code=r.status_code if r else 500,
  404. detail=error_detail,
  405. )
  406. @app.post("/api/show")
  407. async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
  408. if form_data.name not in app.state.MODELS:
  409. raise HTTPException(
  410. status_code=400,
  411. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  412. )
  413. url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
  414. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  415. log.info(f"url: {url}")
  416. try:
  417. r = requests.request(
  418. method="POST",
  419. url=f"{url}/api/show",
  420. data=form_data.model_dump_json(exclude_none=True).encode(),
  421. )
  422. r.raise_for_status()
  423. return r.json()
  424. except Exception as e:
  425. log.exception(e)
  426. error_detail = "Open WebUI: Server Connection Error"
  427. if r is not None:
  428. try:
  429. res = r.json()
  430. if "error" in res:
  431. error_detail = f"Ollama: {res['error']}"
  432. except:
  433. error_detail = f"Ollama: {e}"
  434. raise HTTPException(
  435. status_code=r.status_code if r else 500,
  436. detail=error_detail,
  437. )
  438. class GenerateEmbeddingsForm(BaseModel):
  439. model: str
  440. prompt: str
  441. options: Optional[dict] = None
  442. keep_alive: Optional[Union[int, str]] = None
  443. @app.post("/api/embeddings")
  444. @app.post("/api/embeddings/{url_idx}")
  445. async def generate_embeddings(
  446. form_data: GenerateEmbeddingsForm,
  447. url_idx: Optional[int] = None,
  448. user=Depends(get_verified_user),
  449. ):
  450. if url_idx == None:
  451. model = form_data.model
  452. if ":" not in model:
  453. model = f"{model}:latest"
  454. if model in app.state.MODELS:
  455. url_idx = random.choice(app.state.MODELS[model]["urls"])
  456. else:
  457. raise HTTPException(
  458. status_code=400,
  459. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  460. )
  461. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  462. log.info(f"url: {url}")
  463. try:
  464. r = requests.request(
  465. method="POST",
  466. url=f"{url}/api/embeddings",
  467. data=form_data.model_dump_json(exclude_none=True).encode(),
  468. )
  469. r.raise_for_status()
  470. return r.json()
  471. except Exception as e:
  472. log.exception(e)
  473. error_detail = "Open WebUI: Server Connection Error"
  474. if r is not None:
  475. try:
  476. res = r.json()
  477. if "error" in res:
  478. error_detail = f"Ollama: {res['error']}"
  479. except:
  480. error_detail = f"Ollama: {e}"
  481. raise HTTPException(
  482. status_code=r.status_code if r else 500,
  483. detail=error_detail,
  484. )
  485. def generate_ollama_embeddings(
  486. form_data: GenerateEmbeddingsForm,
  487. url_idx: Optional[int] = None,
  488. ):
  489. log.info(f"generate_ollama_embeddings {form_data}")
  490. if url_idx == None:
  491. model = form_data.model
  492. if ":" not in model:
  493. model = f"{model}:latest"
  494. if model in app.state.MODELS:
  495. url_idx = random.choice(app.state.MODELS[model]["urls"])
  496. else:
  497. raise HTTPException(
  498. status_code=400,
  499. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  500. )
  501. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  502. log.info(f"url: {url}")
  503. try:
  504. r = requests.request(
  505. method="POST",
  506. url=f"{url}/api/embeddings",
  507. data=form_data.model_dump_json(exclude_none=True).encode(),
  508. )
  509. r.raise_for_status()
  510. data = r.json()
  511. log.info(f"generate_ollama_embeddings {data}")
  512. if "embedding" in data:
  513. return data["embedding"]
  514. else:
  515. raise "Something went wrong :/"
  516. except Exception as e:
  517. log.exception(e)
  518. error_detail = "Open WebUI: Server Connection Error"
  519. if r is not None:
  520. try:
  521. res = r.json()
  522. if "error" in res:
  523. error_detail = f"Ollama: {res['error']}"
  524. except:
  525. error_detail = f"Ollama: {e}"
  526. raise error_detail
  527. class GenerateCompletionForm(BaseModel):
  528. model: str
  529. prompt: str
  530. images: Optional[List[str]] = None
  531. format: Optional[str] = None
  532. options: Optional[dict] = None
  533. system: Optional[str] = None
  534. template: Optional[str] = None
  535. context: Optional[str] = None
  536. stream: Optional[bool] = True
  537. raw: Optional[bool] = None
  538. keep_alive: Optional[Union[int, str]] = None
  539. @app.post("/api/generate")
  540. @app.post("/api/generate/{url_idx}")
  541. async def generate_completion(
  542. form_data: GenerateCompletionForm,
  543. url_idx: Optional[int] = None,
  544. user=Depends(get_verified_user),
  545. ):
  546. if url_idx == None:
  547. model = form_data.model
  548. if ":" not in model:
  549. model = f"{model}:latest"
  550. if model in app.state.MODELS:
  551. url_idx = random.choice(app.state.MODELS[model]["urls"])
  552. else:
  553. raise HTTPException(
  554. status_code=400,
  555. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  556. )
  557. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  558. log.info(f"url: {url}")
  559. return await post_streaming_url(
  560. f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
  561. )
  562. class ChatMessage(BaseModel):
  563. role: str
  564. content: str
  565. images: Optional[List[str]] = None
  566. class GenerateChatCompletionForm(BaseModel):
  567. model: str
  568. messages: List[ChatMessage]
  569. format: Optional[str] = None
  570. options: Optional[dict] = None
  571. template: Optional[str] = None
  572. stream: Optional[bool] = None
  573. keep_alive: Optional[Union[int, str]] = None
  574. @app.post("/api/chat")
  575. @app.post("/api/chat/{url_idx}")
  576. async def generate_chat_completion(
  577. form_data: GenerateChatCompletionForm,
  578. url_idx: Optional[int] = None,
  579. user=Depends(get_verified_user),
  580. ):
  581. log.debug(
  582. "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
  583. form_data.model_dump_json(exclude_none=True).encode()
  584. )
  585. )
  586. payload = {
  587. **form_data.model_dump(exclude_none=True, exclude=["metadata"]),
  588. }
  589. if "metadata" in payload:
  590. del payload["metadata"]
  591. model_id = form_data.model
  592. model_info = Models.get_model_by_id(model_id)
  593. if model_info:
  594. if model_info.base_model_id:
  595. payload["model"] = model_info.base_model_id
  596. model_info.params = model_info.params.model_dump()
  597. if model_info.params:
  598. if payload.get("options") is None:
  599. payload["options"] = {}
  600. if (
  601. model_info.params.get("mirostat", None)
  602. and payload["options"].get("mirostat") is None
  603. ):
  604. payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
  605. if (
  606. model_info.params.get("mirostat_eta", None)
  607. and payload["options"].get("mirostat_eta") is None
  608. ):
  609. payload["options"]["mirostat_eta"] = model_info.params.get(
  610. "mirostat_eta", None
  611. )
  612. if (
  613. model_info.params.get("mirostat_tau", None)
  614. and payload["options"].get("mirostat_tau") is None
  615. ):
  616. payload["options"]["mirostat_tau"] = model_info.params.get(
  617. "mirostat_tau", None
  618. )
  619. if (
  620. model_info.params.get("num_ctx", None)
  621. and payload["options"].get("num_ctx") is None
  622. ):
  623. payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
  624. if (
  625. model_info.params.get("num_batch", None)
  626. and payload["options"].get("num_batch") is None
  627. ):
  628. payload["options"]["num_batch"] = model_info.params.get(
  629. "num_batch", None
  630. )
  631. if (
  632. model_info.params.get("num_keep", None)
  633. and payload["options"].get("num_keep") is None
  634. ):
  635. payload["options"]["num_keep"] = model_info.params.get("num_keep", None)
  636. if (
  637. model_info.params.get("repeat_last_n", None)
  638. and payload["options"].get("repeat_last_n") is None
  639. ):
  640. payload["options"]["repeat_last_n"] = model_info.params.get(
  641. "repeat_last_n", None
  642. )
  643. if (
  644. model_info.params.get("frequency_penalty", None)
  645. and payload["options"].get("frequency_penalty") is None
  646. ):
  647. payload["options"]["repeat_penalty"] = model_info.params.get(
  648. "frequency_penalty", None
  649. )
  650. if (
  651. model_info.params.get("temperature", None)
  652. and payload["options"].get("temperature") is None
  653. ):
  654. payload["options"]["temperature"] = model_info.params.get(
  655. "temperature", None
  656. )
  657. if (
  658. model_info.params.get("seed", None)
  659. and payload["options"].get("seed") is None
  660. ):
  661. payload["options"]["seed"] = model_info.params.get("seed", None)
  662. if (
  663. model_info.params.get("stop", None)
  664. and payload["options"].get("stop") is None
  665. ):
  666. payload["options"]["stop"] = (
  667. [
  668. bytes(stop, "utf-8").decode("unicode_escape")
  669. for stop in model_info.params["stop"]
  670. ]
  671. if model_info.params.get("stop", None)
  672. else None
  673. )
  674. if (
  675. model_info.params.get("tfs_z", None)
  676. and payload["options"].get("tfs_z") is None
  677. ):
  678. payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
  679. if (
  680. model_info.params.get("max_tokens", None)
  681. and payload["options"].get("max_tokens") is None
  682. ):
  683. payload["options"]["num_predict"] = model_info.params.get(
  684. "max_tokens", None
  685. )
  686. if (
  687. model_info.params.get("top_k", None)
  688. and payload["options"].get("top_k") is None
  689. ):
  690. payload["options"]["top_k"] = model_info.params.get("top_k", None)
  691. if (
  692. model_info.params.get("top_p", None)
  693. and payload["options"].get("top_p") is None
  694. ):
  695. payload["options"]["top_p"] = model_info.params.get("top_p", None)
  696. if (
  697. model_info.params.get("use_mmap", None)
  698. and payload["options"].get("use_mmap") is None
  699. ):
  700. payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None)
  701. if (
  702. model_info.params.get("use_mlock", None)
  703. and payload["options"].get("use_mlock") is None
  704. ):
  705. payload["options"]["use_mlock"] = model_info.params.get(
  706. "use_mlock", None
  707. )
  708. if (
  709. model_info.params.get("num_thread", None)
  710. and payload["options"].get("num_thread") is None
  711. ):
  712. payload["options"]["num_thread"] = model_info.params.get(
  713. "num_thread", None
  714. )
  715. system = model_info.params.get("system", None)
  716. if system:
  717. system = prompt_template(
  718. system,
  719. **(
  720. {
  721. "user_name": user.name,
  722. "user_location": (
  723. user.info.get("location") if user.info else None
  724. ),
  725. }
  726. if user
  727. else {}
  728. ),
  729. )
  730. if payload.get("messages"):
  731. payload["messages"] = add_or_update_system_message(
  732. system, payload["messages"]
  733. )
  734. if url_idx == None:
  735. if ":" not in payload["model"]:
  736. payload["model"] = f"{payload['model']}:latest"
  737. if payload["model"] in app.state.MODELS:
  738. url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
  739. else:
  740. raise HTTPException(
  741. status_code=400,
  742. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  743. )
  744. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  745. log.info(f"url: {url}")
  746. log.debug(payload)
  747. return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))
  748. # TODO: we should update this part once Ollama supports other types
  749. class OpenAIChatMessageContent(BaseModel):
  750. type: str
  751. model_config = ConfigDict(extra="allow")
  752. class OpenAIChatMessage(BaseModel):
  753. role: str
  754. content: Union[str, OpenAIChatMessageContent]
  755. model_config = ConfigDict(extra="allow")
  756. class OpenAIChatCompletionForm(BaseModel):
  757. model: str
  758. messages: List[OpenAIChatMessage]
  759. model_config = ConfigDict(extra="allow")
  760. @app.post("/v1/chat/completions")
  761. @app.post("/v1/chat/completions/{url_idx}")
  762. async def generate_openai_chat_completion(
  763. form_data: dict,
  764. url_idx: Optional[int] = None,
  765. user=Depends(get_verified_user),
  766. ):
  767. form_data = OpenAIChatCompletionForm(**form_data)
  768. payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
  769. if "metadata" in payload:
  770. del payload["metadata"]
  771. model_id = form_data.model
  772. model_info = Models.get_model_by_id(model_id)
  773. if model_info:
  774. if model_info.base_model_id:
  775. payload["model"] = model_info.base_model_id
  776. model_info.params = model_info.params.model_dump()
  777. if model_info.params:
  778. payload["temperature"] = model_info.params.get("temperature", None)
  779. payload["top_p"] = model_info.params.get("top_p", None)
  780. payload["max_tokens"] = model_info.params.get("max_tokens", None)
  781. payload["frequency_penalty"] = model_info.params.get(
  782. "frequency_penalty", None
  783. )
  784. payload["seed"] = model_info.params.get("seed", None)
  785. payload["stop"] = (
  786. [
  787. bytes(stop, "utf-8").decode("unicode_escape")
  788. for stop in model_info.params["stop"]
  789. ]
  790. if model_info.params.get("stop", None)
  791. else None
  792. )
  793. system = model_info.params.get("system", None)
  794. if system:
  795. system = prompt_template(
  796. system,
  797. **(
  798. {
  799. "user_name": user.name,
  800. "user_location": (
  801. user.info.get("location") if user.info else None
  802. ),
  803. }
  804. if user
  805. else {}
  806. ),
  807. )
  808. # Check if the payload already has a system message
  809. # If not, add a system message to the payload
  810. if payload.get("messages"):
  811. for message in payload["messages"]:
  812. if message.get("role") == "system":
  813. message["content"] = system + message["content"]
  814. break
  815. else:
  816. payload["messages"].insert(
  817. 0,
  818. {
  819. "role": "system",
  820. "content": system,
  821. },
  822. )
  823. if url_idx == None:
  824. if ":" not in payload["model"]:
  825. payload["model"] = f"{payload['model']}:latest"
  826. if payload["model"] in app.state.MODELS:
  827. url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
  828. else:
  829. raise HTTPException(
  830. status_code=400,
  831. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  832. )
  833. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  834. log.info(f"url: {url}")
  835. return await post_streaming_url(
  836. f"{url}/v1/chat/completions",
  837. json.dumps(payload),
  838. stream=payload.get("stream", False),
  839. )
  840. @app.get("/v1/models")
  841. @app.get("/v1/models/{url_idx}")
  842. async def get_openai_models(
  843. url_idx: Optional[int] = None,
  844. user=Depends(get_verified_user),
  845. ):
  846. if url_idx == None:
  847. models = await get_all_models()
  848. if app.state.config.ENABLE_MODEL_FILTER:
  849. if user.role == "user":
  850. models["models"] = list(
  851. filter(
  852. lambda model: model["name"]
  853. in app.state.config.MODEL_FILTER_LIST,
  854. models["models"],
  855. )
  856. )
  857. return {
  858. "data": [
  859. {
  860. "id": model["model"],
  861. "object": "model",
  862. "created": int(time.time()),
  863. "owned_by": "openai",
  864. }
  865. for model in models["models"]
  866. ],
  867. "object": "list",
  868. }
  869. else:
  870. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  871. try:
  872. r = requests.request(method="GET", url=f"{url}/api/tags")
  873. r.raise_for_status()
  874. models = r.json()
  875. return {
  876. "data": [
  877. {
  878. "id": model["model"],
  879. "object": "model",
  880. "created": int(time.time()),
  881. "owned_by": "openai",
  882. }
  883. for model in models["models"]
  884. ],
  885. "object": "list",
  886. }
  887. except Exception as e:
  888. log.exception(e)
  889. error_detail = "Open WebUI: Server Connection Error"
  890. if r is not None:
  891. try:
  892. res = r.json()
  893. if "error" in res:
  894. error_detail = f"Ollama: {res['error']}"
  895. except:
  896. error_detail = f"Ollama: {e}"
  897. raise HTTPException(
  898. status_code=r.status_code if r else 500,
  899. detail=error_detail,
  900. )
  901. class UrlForm(BaseModel):
  902. url: str
  903. class UploadBlobForm(BaseModel):
  904. filename: str
  905. def parse_huggingface_url(hf_url):
  906. try:
  907. # Parse the URL
  908. parsed_url = urlparse(hf_url)
  909. # Get the path and split it into components
  910. path_components = parsed_url.path.split("/")
  911. # Extract the desired output
  912. user_repo = "/".join(path_components[1:3])
  913. model_file = path_components[-1]
  914. return model_file
  915. except ValueError:
  916. return None
  917. async def download_file_stream(
  918. ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
  919. ):
  920. done = False
  921. if os.path.exists(file_path):
  922. current_size = os.path.getsize(file_path)
  923. else:
  924. current_size = 0
  925. headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
  926. timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
  927. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  928. async with session.get(file_url, headers=headers) as response:
  929. total_size = int(response.headers.get("content-length", 0)) + current_size
  930. with open(file_path, "ab+") as file:
  931. async for data in response.content.iter_chunked(chunk_size):
  932. current_size += len(data)
  933. file.write(data)
  934. done = current_size == total_size
  935. progress = round((current_size / total_size) * 100, 2)
  936. yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
  937. if done:
  938. file.seek(0)
  939. hashed = calculate_sha256(file)
  940. file.seek(0)
  941. url = f"{ollama_url}/api/blobs/sha256:{hashed}"
  942. response = requests.post(url, data=file)
  943. if response.ok:
  944. res = {
  945. "done": done,
  946. "blob": f"sha256:{hashed}",
  947. "name": file_name,
  948. }
  949. os.remove(file_path)
  950. yield f"data: {json.dumps(res)}\n\n"
  951. else:
  952. raise "Ollama: Could not create blob, Please try again."
  953. # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
  954. @app.post("/models/download")
  955. @app.post("/models/download/{url_idx}")
  956. async def download_model(
  957. form_data: UrlForm,
  958. url_idx: Optional[int] = None,
  959. user=Depends(get_admin_user),
  960. ):
  961. allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
  962. if not any(form_data.url.startswith(host) for host in allowed_hosts):
  963. raise HTTPException(
  964. status_code=400,
  965. detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
  966. )
  967. if url_idx == None:
  968. url_idx = 0
  969. url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  970. file_name = parse_huggingface_url(form_data.url)
  971. if file_name:
  972. file_path = f"{UPLOAD_DIR}/{file_name}"
  973. return StreamingResponse(
  974. download_file_stream(url, form_data.url, file_path, file_name),
  975. )
  976. else:
  977. return None
  978. @app.post("/models/upload")
  979. @app.post("/models/upload/{url_idx}")
  980. def upload_model(
  981. file: UploadFile = File(...),
  982. url_idx: Optional[int] = None,
  983. user=Depends(get_admin_user),
  984. ):
  985. if url_idx == None:
  986. url_idx = 0
  987. ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
  988. file_path = f"{UPLOAD_DIR}/{file.filename}"
  989. # Save file in chunks
  990. with open(file_path, "wb+") as f:
  991. for chunk in file.file:
  992. f.write(chunk)
  993. def file_process_stream():
  994. nonlocal ollama_url
  995. total_size = os.path.getsize(file_path)
  996. chunk_size = 1024 * 1024
  997. try:
  998. with open(file_path, "rb") as f:
  999. total = 0
  1000. done = False
  1001. while not done:
  1002. chunk = f.read(chunk_size)
  1003. if not chunk:
  1004. done = True
  1005. continue
  1006. total += len(chunk)
  1007. progress = round((total / total_size) * 100, 2)
  1008. res = {
  1009. "progress": progress,
  1010. "total": total_size,
  1011. "completed": total,
  1012. }
  1013. yield f"data: {json.dumps(res)}\n\n"
  1014. if done:
  1015. f.seek(0)
  1016. hashed = calculate_sha256(f)
  1017. f.seek(0)
  1018. url = f"{ollama_url}/api/blobs/sha256:{hashed}"
  1019. response = requests.post(url, data=f)
  1020. if response.ok:
  1021. res = {
  1022. "done": done,
  1023. "blob": f"sha256:{hashed}",
  1024. "name": file.filename,
  1025. }
  1026. os.remove(file_path)
  1027. yield f"data: {json.dumps(res)}\n\n"
  1028. else:
  1029. raise Exception(
  1030. "Ollama: Could not create blob, Please try again."
  1031. )
  1032. except Exception as e:
  1033. res = {"error": str(e)}
  1034. yield f"data: {json.dumps(res)}\n\n"
  1035. return StreamingResponse(file_process_stream(), media_type="text/event-stream")