main.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. import inspect
  2. import json
  3. import logging
  4. import time
  5. from typing import AsyncGenerator, Generator, Iterator
  6. from open_webui.apps.socket.main import get_event_call, get_event_emitter
  7. from open_webui.apps.webui.models.functions import Functions
  8. from open_webui.apps.webui.models.models import Models
  9. from open_webui.apps.webui.routers import (
  10. auths,
  11. chats,
  12. folders,
  13. configs,
  14. groups,
  15. files,
  16. functions,
  17. memories,
  18. models,
  19. knowledge,
  20. prompts,
  21. evaluations,
  22. tools,
  23. users,
  24. utils,
  25. )
  26. from open_webui.apps.webui.utils import load_function_module_by_id
  27. from open_webui.config import (
  28. ADMIN_EMAIL,
  29. CORS_ALLOW_ORIGIN,
  30. DEFAULT_MODELS,
  31. DEFAULT_PROMPT_SUGGESTIONS,
  32. DEFAULT_USER_ROLE,
  33. ENABLE_COMMUNITY_SHARING,
  34. ENABLE_LOGIN_FORM,
  35. ENABLE_MESSAGE_RATING,
  36. ENABLE_SIGNUP,
  37. ENABLE_API_KEY,
  38. ENABLE_EVALUATION_ARENA_MODELS,
  39. EVALUATION_ARENA_MODELS,
  40. DEFAULT_ARENA_MODEL,
  41. JWT_EXPIRES_IN,
  42. ENABLE_OAUTH_ROLE_MANAGEMENT,
  43. OAUTH_ROLES_CLAIM,
  44. OAUTH_EMAIL_CLAIM,
  45. OAUTH_PICTURE_CLAIM,
  46. OAUTH_USERNAME_CLAIM,
  47. OAUTH_ALLOWED_ROLES,
  48. OAUTH_ADMIN_ROLES,
  49. SHOW_ADMIN_DETAILS,
  50. USER_PERMISSIONS,
  51. WEBHOOK_URL,
  52. WEBUI_AUTH,
  53. WEBUI_BANNERS,
  54. ENABLE_LDAP,
  55. LDAP_SERVER_LABEL,
  56. LDAP_SERVER_HOST,
  57. LDAP_SERVER_PORT,
  58. LDAP_ATTRIBUTE_FOR_USERNAME,
  59. LDAP_SEARCH_FILTERS,
  60. LDAP_SEARCH_BASE,
  61. LDAP_APP_DN,
  62. LDAP_APP_PASSWORD,
  63. LDAP_USE_TLS,
  64. LDAP_CA_CERT_FILE,
  65. LDAP_CIPHERS,
  66. AppConfig,
  67. )
  68. from open_webui.env import (
  69. ENV,
  70. SRC_LOG_LEVELS,
  71. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  72. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  73. )
  74. from fastapi import FastAPI
  75. from fastapi.middleware.cors import CORSMiddleware
  76. from fastapi.responses import StreamingResponse
  77. from pydantic import BaseModel
  78. from open_webui.utils.misc import (
  79. openai_chat_chunk_message_template,
  80. openai_chat_completion_message_template,
  81. )
  82. from open_webui.utils.payload import (
  83. apply_model_params_to_body_openai,
  84. apply_model_system_prompt_to_body,
  85. )
  86. from open_webui.utils.tools import get_tools
  87. app = FastAPI(
  88. docs_url="/docs" if ENV == "dev" else None,
  89. openapi_url="/openapi.json" if ENV == "dev" else None,
  90. redoc_url=None,
  91. )
  92. log = logging.getLogger(__name__)
  93. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  94. app.state.config = AppConfig()
  95. app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
  96. app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
  97. app.state.config.ENABLE_API_KEY = ENABLE_API_KEY
  98. app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
  99. app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  100. app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
  101. app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
  102. app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
  103. app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
  104. app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
  105. app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
  106. app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
  107. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  108. app.state.config.BANNERS = WEBUI_BANNERS
  109. app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
  110. app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
  111. app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS
  112. app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS
  113. app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
  114. app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
  115. app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
  116. app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
  117. app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
  118. app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
  119. app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
  120. app.state.config.ENABLE_LDAP = ENABLE_LDAP
  121. app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL
  122. app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST
  123. app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT
  124. app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME
  125. app.state.config.LDAP_APP_DN = LDAP_APP_DN
  126. app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD
  127. app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE
  128. app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS
  129. app.state.config.LDAP_USE_TLS = LDAP_USE_TLS
  130. app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
  131. app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
  132. app.state.TOOLS = {}
  133. app.state.FUNCTIONS = {}
  134. app.add_middleware(
  135. CORSMiddleware,
  136. allow_origins=CORS_ALLOW_ORIGIN,
  137. allow_credentials=True,
  138. allow_methods=["*"],
  139. allow_headers=["*"],
  140. )
  141. app.include_router(configs.router, prefix="/configs", tags=["configs"])
  142. app.include_router(auths.router, prefix="/auths", tags=["auths"])
  143. app.include_router(users.router, prefix="/users", tags=["users"])
  144. app.include_router(chats.router, prefix="/chats", tags=["chats"])
  145. app.include_router(models.router, prefix="/models", tags=["models"])
  146. app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
  147. app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
  148. app.include_router(tools.router, prefix="/tools", tags=["tools"])
  149. app.include_router(memories.router, prefix="/memories", tags=["memories"])
  150. app.include_router(folders.router, prefix="/folders", tags=["folders"])
  151. app.include_router(groups.router, prefix="/groups", tags=["groups"])
  152. app.include_router(files.router, prefix="/files", tags=["files"])
  153. app.include_router(functions.router, prefix="/functions", tags=["functions"])
  154. app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
  155. app.include_router(utils.router, prefix="/utils", tags=["utils"])
  156. @app.get("/")
  157. async def get_status():
  158. return {
  159. "status": True,
  160. "auth": WEBUI_AUTH,
  161. "default_models": app.state.config.DEFAULT_MODELS,
  162. "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  163. }
  164. async def get_all_models():
  165. models = []
  166. pipe_models = await get_pipe_models()
  167. models = models + pipe_models
  168. if app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
  169. arena_models = []
  170. if len(app.state.config.EVALUATION_ARENA_MODELS) > 0:
  171. arena_models = [
  172. {
  173. "id": model["id"],
  174. "name": model["name"],
  175. "info": {
  176. "meta": model["meta"],
  177. },
  178. "object": "model",
  179. "created": int(time.time()),
  180. "owned_by": "arena",
  181. "arena": True,
  182. }
  183. for model in app.state.config.EVALUATION_ARENA_MODELS
  184. ]
  185. else:
  186. # Add default arena model
  187. arena_models = [
  188. {
  189. "id": DEFAULT_ARENA_MODEL["id"],
  190. "name": DEFAULT_ARENA_MODEL["name"],
  191. "info": {
  192. "meta": DEFAULT_ARENA_MODEL["meta"],
  193. },
  194. "object": "model",
  195. "created": int(time.time()),
  196. "owned_by": "arena",
  197. "arena": True,
  198. }
  199. ]
  200. models = models + arena_models
  201. return models
  202. def get_function_module(pipe_id: str):
  203. # Check if function is already loaded
  204. if pipe_id not in app.state.FUNCTIONS:
  205. function_module, _, _ = load_function_module_by_id(pipe_id)
  206. app.state.FUNCTIONS[pipe_id] = function_module
  207. else:
  208. function_module = app.state.FUNCTIONS[pipe_id]
  209. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  210. valves = Functions.get_function_valves_by_id(pipe_id)
  211. function_module.valves = function_module.Valves(**(valves if valves else {}))
  212. return function_module
  213. async def get_pipe_models():
  214. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  215. pipe_models = []
  216. for pipe in pipes:
  217. function_module = get_function_module(pipe.id)
  218. # Check if function is a manifold
  219. if hasattr(function_module, "pipes"):
  220. sub_pipes = []
  221. # Check if pipes is a function or a list
  222. try:
  223. if callable(function_module.pipes):
  224. sub_pipes = function_module.pipes()
  225. else:
  226. sub_pipes = function_module.pipes
  227. except Exception as e:
  228. log.exception(e)
  229. sub_pipes = []
  230. log.debug(f"get_pipe_models: function '{pipe.id}' is a manifold of {sub_pipes}")
  231. for p in sub_pipes:
  232. sub_pipe_id = f'{pipe.id}.{p["id"]}'
  233. sub_pipe_name = p["name"]
  234. if hasattr(function_module, "name"):
  235. sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
  236. pipe_flag = {"type": pipe.type}
  237. pipe_models.append(
  238. {
  239. "id": sub_pipe_id,
  240. "name": sub_pipe_name,
  241. "object": "model",
  242. "created": pipe.created_at,
  243. "owned_by": "openai",
  244. "pipe": pipe_flag,
  245. }
  246. )
  247. else:
  248. pipe_flag = {"type": "pipe"}
  249. log.debug(f"get_pipe_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}")
  250. pipe_models.append(
  251. {
  252. "id": pipe.id,
  253. "name": pipe.name,
  254. "object": "model",
  255. "created": pipe.created_at,
  256. "owned_by": "openai",
  257. "pipe": pipe_flag,
  258. }
  259. )
  260. return pipe_models
  261. async def execute_pipe(pipe, params):
  262. if inspect.iscoroutinefunction(pipe):
  263. return await pipe(**params)
  264. else:
  265. return pipe(**params)
  266. async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
  267. if isinstance(res, str):
  268. return res
  269. if isinstance(res, Generator):
  270. return "".join(map(str, res))
  271. if isinstance(res, AsyncGenerator):
  272. return "".join([str(stream) async for stream in res])
  273. def process_line(form_data: dict, line):
  274. if isinstance(line, BaseModel):
  275. line = line.model_dump_json()
  276. line = f"data: {line}"
  277. if isinstance(line, dict):
  278. line = f"data: {json.dumps(line)}"
  279. try:
  280. line = line.decode("utf-8")
  281. except Exception:
  282. pass
  283. if line.startswith("data:"):
  284. return f"{line}\n\n"
  285. else:
  286. line = openai_chat_chunk_message_template(form_data["model"], line)
  287. return f"data: {json.dumps(line)}\n\n"
  288. def get_pipe_id(form_data: dict) -> str:
  289. pipe_id = form_data["model"]
  290. if "." in pipe_id:
  291. pipe_id, _ = pipe_id.split(".", 1)
  292. return pipe_id
  293. def get_function_params(function_module, form_data, user, extra_params=None):
  294. if extra_params is None:
  295. extra_params = {}
  296. pipe_id = get_pipe_id(form_data)
  297. # Get the signature of the function
  298. sig = inspect.signature(function_module.pipe)
  299. params = {"body": form_data} | {
  300. k: v for k, v in extra_params.items() if k in sig.parameters
  301. }
  302. if "__user__" in params and hasattr(function_module, "UserValves"):
  303. user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  304. try:
  305. params["__user__"]["valves"] = function_module.UserValves(**user_valves)
  306. except Exception as e:
  307. log.exception(e)
  308. params["__user__"]["valves"] = function_module.UserValves()
  309. return params
  310. async def generate_function_chat_completion(form_data, user, models: dict = {}):
  311. model_id = form_data.get("model")
  312. model_info = Models.get_model_by_id(model_id)
  313. metadata = form_data.pop("metadata", {})
  314. files = metadata.get("files", [])
  315. tool_ids = metadata.get("tool_ids", [])
  316. # Check if tool_ids is None
  317. if tool_ids is None:
  318. tool_ids = []
  319. __event_emitter__ = None
  320. __event_call__ = None
  321. __task__ = None
  322. __task_body__ = None
  323. if metadata:
  324. if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
  325. __event_emitter__ = get_event_emitter(metadata)
  326. __event_call__ = get_event_call(metadata)
  327. __task__ = metadata.get("task", None)
  328. __task_body__ = metadata.get("task_body", None)
  329. extra_params = {
  330. "__event_emitter__": __event_emitter__,
  331. "__event_call__": __event_call__,
  332. "__task__": __task__,
  333. "__task_body__": __task_body__,
  334. "__files__": files,
  335. "__user__": {
  336. "id": user.id,
  337. "email": user.email,
  338. "name": user.name,
  339. "role": user.role,
  340. },
  341. "__metadata__": metadata,
  342. }
  343. extra_params["__tools__"] = get_tools(
  344. app,
  345. tool_ids,
  346. user,
  347. {
  348. **extra_params,
  349. "__model__": models.get(form_data["model"], None),
  350. "__messages__": form_data["messages"],
  351. "__files__": files,
  352. },
  353. )
  354. if model_info:
  355. if model_info.base_model_id:
  356. form_data["model"] = model_info.base_model_id
  357. params = model_info.params.model_dump()
  358. form_data = apply_model_params_to_body_openai(params, form_data)
  359. form_data = apply_model_system_prompt_to_body(params, form_data, user)
  360. pipe_id = get_pipe_id(form_data)
  361. function_module = get_function_module(pipe_id)
  362. pipe = function_module.pipe
  363. params = get_function_params(function_module, form_data, user, extra_params)
  364. if form_data.get("stream", False):
  365. async def stream_content():
  366. try:
  367. res = await execute_pipe(pipe, params)
  368. # Directly return if the response is a StreamingResponse
  369. if isinstance(res, StreamingResponse):
  370. async for data in res.body_iterator:
  371. yield data
  372. return
  373. if isinstance(res, dict):
  374. yield f"data: {json.dumps(res)}\n\n"
  375. return
  376. except Exception as e:
  377. log.error(f"Error: {e}")
  378. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  379. return
  380. if isinstance(res, str):
  381. message = openai_chat_chunk_message_template(form_data["model"], res)
  382. yield f"data: {json.dumps(message)}\n\n"
  383. if isinstance(res, Iterator):
  384. for line in res:
  385. yield process_line(form_data, line)
  386. if isinstance(res, AsyncGenerator):
  387. async for line in res:
  388. yield process_line(form_data, line)
  389. if isinstance(res, str) or isinstance(res, Generator):
  390. finish_message = openai_chat_chunk_message_template(
  391. form_data["model"], ""
  392. )
  393. finish_message["choices"][0]["finish_reason"] = "stop"
  394. yield f"data: {json.dumps(finish_message)}\n\n"
  395. yield "data: [DONE]"
  396. return StreamingResponse(stream_content(), media_type="text/event-stream")
  397. else:
  398. try:
  399. res = await execute_pipe(pipe, params)
  400. except Exception as e:
  401. log.error(f"Error: {e}")
  402. return {"error": {"detail": str(e)}}
  403. if isinstance(res, StreamingResponse) or isinstance(res, dict):
  404. return res
  405. if isinstance(res, BaseModel):
  406. return res.model_dump()
  407. message = await get_message_content(res)
  408. return openai_chat_completion_message_template(form_data["model"], message)