webui.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. import inspect
  2. import json
  3. import logging
  4. import time
  5. from typing import AsyncGenerator, Generator, Iterator
  6. from open_webui.apps.socket.main import get_event_call, get_event_emitter
  7. from open_webui.models.functions import Functions
  8. from open_webui.models.models import Models
  9. from open_webui.routers import (
  10. auths,
  11. chats,
  12. folders,
  13. configs,
  14. groups,
  15. files,
  16. functions,
  17. memories,
  18. models,
  19. knowledge,
  20. prompts,
  21. evaluations,
  22. tools,
  23. users,
  24. utils,
  25. )
  26. from backend.open_webui.utils.plugin import load_function_module_by_id
  27. from open_webui.config import (
  28. ADMIN_EMAIL,
  29. CORS_ALLOW_ORIGIN,
  30. DEFAULT_MODELS,
  31. DEFAULT_PROMPT_SUGGESTIONS,
  32. DEFAULT_USER_ROLE,
  33. MODEL_ORDER_LIST,
  34. ENABLE_COMMUNITY_SHARING,
  35. ENABLE_LOGIN_FORM,
  36. ENABLE_MESSAGE_RATING,
  37. ENABLE_SIGNUP,
  38. ENABLE_API_KEY,
  39. ENABLE_EVALUATION_ARENA_MODELS,
  40. EVALUATION_ARENA_MODELS,
  41. DEFAULT_ARENA_MODEL,
  42. JWT_EXPIRES_IN,
  43. ENABLE_OAUTH_ROLE_MANAGEMENT,
  44. OAUTH_ROLES_CLAIM,
  45. OAUTH_EMAIL_CLAIM,
  46. OAUTH_PICTURE_CLAIM,
  47. OAUTH_USERNAME_CLAIM,
  48. OAUTH_ALLOWED_ROLES,
  49. OAUTH_ADMIN_ROLES,
  50. SHOW_ADMIN_DETAILS,
  51. USER_PERMISSIONS,
  52. WEBHOOK_URL,
  53. WEBUI_AUTH,
  54. WEBUI_BANNERS,
  55. ENABLE_LDAP,
  56. LDAP_SERVER_LABEL,
  57. LDAP_SERVER_HOST,
  58. LDAP_SERVER_PORT,
  59. LDAP_ATTRIBUTE_FOR_USERNAME,
  60. LDAP_SEARCH_FILTERS,
  61. LDAP_SEARCH_BASE,
  62. LDAP_APP_DN,
  63. LDAP_APP_PASSWORD,
  64. LDAP_USE_TLS,
  65. LDAP_CA_CERT_FILE,
  66. LDAP_CIPHERS,
  67. AppConfig,
  68. )
  69. from open_webui.env import (
  70. ENV,
  71. SRC_LOG_LEVELS,
  72. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  73. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  74. )
  75. from fastapi import FastAPI
  76. from fastapi.middleware.cors import CORSMiddleware
  77. from fastapi.responses import StreamingResponse
  78. from pydantic import BaseModel
  79. from open_webui.utils.misc import (
  80. openai_chat_chunk_message_template,
  81. openai_chat_completion_message_template,
  82. )
  83. from open_webui.utils.payload import (
  84. apply_model_params_to_body_openai,
  85. apply_model_system_prompt_to_body,
  86. )
  87. from open_webui.utils.tools import get_tools
  88. log = logging.getLogger(__name__)
  89. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  90. @app.get("/")
  91. async def get_status():
  92. return {
  93. "status": True,
  94. "auth": WEBUI_AUTH,
  95. "default_models": app.state.config.DEFAULT_MODELS,
  96. "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  97. }
  98. async def get_all_models():
  99. models = []
  100. pipe_models = await get_pipe_models()
  101. models = models + pipe_models
  102. if app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
  103. arena_models = []
  104. if len(app.state.config.EVALUATION_ARENA_MODELS) > 0:
  105. arena_models = [
  106. {
  107. "id": model["id"],
  108. "name": model["name"],
  109. "info": {
  110. "meta": model["meta"],
  111. },
  112. "object": "model",
  113. "created": int(time.time()),
  114. "owned_by": "arena",
  115. "arena": True,
  116. }
  117. for model in app.state.config.EVALUATION_ARENA_MODELS
  118. ]
  119. else:
  120. # Add default arena model
  121. arena_models = [
  122. {
  123. "id": DEFAULT_ARENA_MODEL["id"],
  124. "name": DEFAULT_ARENA_MODEL["name"],
  125. "info": {
  126. "meta": DEFAULT_ARENA_MODEL["meta"],
  127. },
  128. "object": "model",
  129. "created": int(time.time()),
  130. "owned_by": "arena",
  131. "arena": True,
  132. }
  133. ]
  134. models = models + arena_models
  135. return models
  136. def get_function_module(pipe_id: str):
  137. # Check if function is already loaded
  138. if pipe_id not in app.state.FUNCTIONS:
  139. function_module, _, _ = load_function_module_by_id(pipe_id)
  140. app.state.FUNCTIONS[pipe_id] = function_module
  141. else:
  142. function_module = app.state.FUNCTIONS[pipe_id]
  143. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  144. valves = Functions.get_function_valves_by_id(pipe_id)
  145. function_module.valves = function_module.Valves(**(valves if valves else {}))
  146. return function_module
  147. async def get_pipe_models():
  148. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  149. pipe_models = []
  150. for pipe in pipes:
  151. function_module = get_function_module(pipe.id)
  152. # Check if function is a manifold
  153. if hasattr(function_module, "pipes"):
  154. sub_pipes = []
  155. # Check if pipes is a function or a list
  156. try:
  157. if callable(function_module.pipes):
  158. sub_pipes = function_module.pipes()
  159. else:
  160. sub_pipes = function_module.pipes
  161. except Exception as e:
  162. log.exception(e)
  163. sub_pipes = []
  164. log.debug(
  165. f"get_pipe_models: function '{pipe.id}' is a manifold of {sub_pipes}"
  166. )
  167. for p in sub_pipes:
  168. sub_pipe_id = f'{pipe.id}.{p["id"]}'
  169. sub_pipe_name = p["name"]
  170. if hasattr(function_module, "name"):
  171. sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
  172. pipe_flag = {"type": pipe.type}
  173. pipe_models.append(
  174. {
  175. "id": sub_pipe_id,
  176. "name": sub_pipe_name,
  177. "object": "model",
  178. "created": pipe.created_at,
  179. "owned_by": "openai",
  180. "pipe": pipe_flag,
  181. }
  182. )
  183. else:
  184. pipe_flag = {"type": "pipe"}
  185. log.debug(
  186. f"get_pipe_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
  187. )
  188. pipe_models.append(
  189. {
  190. "id": pipe.id,
  191. "name": pipe.name,
  192. "object": "model",
  193. "created": pipe.created_at,
  194. "owned_by": "openai",
  195. "pipe": pipe_flag,
  196. }
  197. )
  198. return pipe_models
  199. async def execute_pipe(pipe, params):
  200. if inspect.iscoroutinefunction(pipe):
  201. return await pipe(**params)
  202. else:
  203. return pipe(**params)
  204. async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
  205. if isinstance(res, str):
  206. return res
  207. if isinstance(res, Generator):
  208. return "".join(map(str, res))
  209. if isinstance(res, AsyncGenerator):
  210. return "".join([str(stream) async for stream in res])
  211. def process_line(form_data: dict, line):
  212. if isinstance(line, BaseModel):
  213. line = line.model_dump_json()
  214. line = f"data: {line}"
  215. if isinstance(line, dict):
  216. line = f"data: {json.dumps(line)}"
  217. try:
  218. line = line.decode("utf-8")
  219. except Exception:
  220. pass
  221. if line.startswith("data:"):
  222. return f"{line}\n\n"
  223. else:
  224. line = openai_chat_chunk_message_template(form_data["model"], line)
  225. return f"data: {json.dumps(line)}\n\n"
  226. def get_pipe_id(form_data: dict) -> str:
  227. pipe_id = form_data["model"]
  228. if "." in pipe_id:
  229. pipe_id, _ = pipe_id.split(".", 1)
  230. return pipe_id
  231. def get_function_params(function_module, form_data, user, extra_params=None):
  232. if extra_params is None:
  233. extra_params = {}
  234. pipe_id = get_pipe_id(form_data)
  235. # Get the signature of the function
  236. sig = inspect.signature(function_module.pipe)
  237. params = {"body": form_data} | {
  238. k: v for k, v in extra_params.items() if k in sig.parameters
  239. }
  240. if "__user__" in params and hasattr(function_module, "UserValves"):
  241. user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  242. try:
  243. params["__user__"]["valves"] = function_module.UserValves(**user_valves)
  244. except Exception as e:
  245. log.exception(e)
  246. params["__user__"]["valves"] = function_module.UserValves()
  247. return params
  248. async def generate_function_chat_completion(form_data, user, models: dict = {}):
  249. model_id = form_data.get("model")
  250. model_info = Models.get_model_by_id(model_id)
  251. metadata = form_data.pop("metadata", {})
  252. files = metadata.get("files", [])
  253. tool_ids = metadata.get("tool_ids", [])
  254. # Check if tool_ids is None
  255. if tool_ids is None:
  256. tool_ids = []
  257. __event_emitter__ = None
  258. __event_call__ = None
  259. __task__ = None
  260. __task_body__ = None
  261. if metadata:
  262. if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
  263. __event_emitter__ = get_event_emitter(metadata)
  264. __event_call__ = get_event_call(metadata)
  265. __task__ = metadata.get("task", None)
  266. __task_body__ = metadata.get("task_body", None)
  267. extra_params = {
  268. "__event_emitter__": __event_emitter__,
  269. "__event_call__": __event_call__,
  270. "__task__": __task__,
  271. "__task_body__": __task_body__,
  272. "__files__": files,
  273. "__user__": {
  274. "id": user.id,
  275. "email": user.email,
  276. "name": user.name,
  277. "role": user.role,
  278. },
  279. "__metadata__": metadata,
  280. }
  281. extra_params["__tools__"] = get_tools(
  282. app,
  283. tool_ids,
  284. user,
  285. {
  286. **extra_params,
  287. "__model__": models.get(form_data["model"], None),
  288. "__messages__": form_data["messages"],
  289. "__files__": files,
  290. },
  291. )
  292. if model_info:
  293. if model_info.base_model_id:
  294. form_data["model"] = model_info.base_model_id
  295. params = model_info.params.model_dump()
  296. form_data = apply_model_params_to_body_openai(params, form_data)
  297. form_data = apply_model_system_prompt_to_body(params, form_data, user)
  298. pipe_id = get_pipe_id(form_data)
  299. function_module = get_function_module(pipe_id)
  300. pipe = function_module.pipe
  301. params = get_function_params(function_module, form_data, user, extra_params)
  302. if form_data.get("stream", False):
  303. async def stream_content():
  304. try:
  305. res = await execute_pipe(pipe, params)
  306. # Directly return if the response is a StreamingResponse
  307. if isinstance(res, StreamingResponse):
  308. async for data in res.body_iterator:
  309. yield data
  310. return
  311. if isinstance(res, dict):
  312. yield f"data: {json.dumps(res)}\n\n"
  313. return
  314. except Exception as e:
  315. log.error(f"Error: {e}")
  316. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  317. return
  318. if isinstance(res, str):
  319. message = openai_chat_chunk_message_template(form_data["model"], res)
  320. yield f"data: {json.dumps(message)}\n\n"
  321. if isinstance(res, Iterator):
  322. for line in res:
  323. yield process_line(form_data, line)
  324. if isinstance(res, AsyncGenerator):
  325. async for line in res:
  326. yield process_line(form_data, line)
  327. if isinstance(res, str) or isinstance(res, Generator):
  328. finish_message = openai_chat_chunk_message_template(
  329. form_data["model"], ""
  330. )
  331. finish_message["choices"][0]["finish_reason"] = "stop"
  332. yield f"data: {json.dumps(finish_message)}\n\n"
  333. yield "data: [DONE]"
  334. return StreamingResponse(stream_content(), media_type="text/event-stream")
  335. else:
  336. try:
  337. res = await execute_pipe(pipe, params)
  338. except Exception as e:
  339. log.error(f"Error: {e}")
  340. return {"error": {"detail": str(e)}}
  341. if isinstance(res, StreamingResponse) or isinstance(res, dict):
  342. return res
  343. if isinstance(res, BaseModel):
  344. return res.model_dump()
  345. message = await get_message_content(res)
  346. return openai_chat_completion_message_template(form_data["model"], message)