main.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. import inspect
  2. import json
  3. import logging
  4. from typing import AsyncGenerator, Generator, Iterator
  5. from open_webui.apps.socket.main import get_event_call, get_event_emitter
  6. from open_webui.apps.webui.models.functions import Functions
  7. from open_webui.apps.webui.models.models import Models
  8. from open_webui.apps.webui.routers import (
  9. auths,
  10. chats,
  11. folders,
  12. configs,
  13. files,
  14. functions,
  15. memories,
  16. models,
  17. knowledge,
  18. prompts,
  19. tools,
  20. users,
  21. utils,
  22. )
  23. from open_webui.apps.webui.utils import load_function_module_by_id
  24. from open_webui.config import (
  25. ADMIN_EMAIL,
  26. CORS_ALLOW_ORIGIN,
  27. DEFAULT_MODELS,
  28. DEFAULT_PROMPT_SUGGESTIONS,
  29. DEFAULT_USER_ROLE,
  30. ENABLE_COMMUNITY_SHARING,
  31. ENABLE_LOGIN_FORM,
  32. ENABLE_MESSAGE_RATING,
  33. ENABLE_SIGNUP,
  34. JWT_EXPIRES_IN,
  35. ENABLE_OAUTH_ROLE_MANAGEMENT,
  36. OAUTH_ROLES_CLAIM,
  37. OAUTH_EMAIL_CLAIM,
  38. OAUTH_PICTURE_CLAIM,
  39. OAUTH_USERNAME_CLAIM,
  40. OAUTH_ALLOWED_ROLES,
  41. OAUTH_ADMIN_ROLES,
  42. SHOW_ADMIN_DETAILS,
  43. USER_PERMISSIONS,
  44. WEBHOOK_URL,
  45. WEBUI_AUTH,
  46. WEBUI_BANNERS,
  47. AppConfig,
  48. )
  49. from open_webui.env import (
  50. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  51. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  52. )
  53. from fastapi import FastAPI
  54. from fastapi.middleware.cors import CORSMiddleware
  55. from fastapi.responses import StreamingResponse
  56. from pydantic import BaseModel
  57. from open_webui.utils.misc import (
  58. openai_chat_chunk_message_template,
  59. openai_chat_completion_message_template,
  60. )
  61. from open_webui.utils.payload import (
  62. apply_model_params_to_body_openai,
  63. apply_model_system_prompt_to_body,
  64. )
  65. from open_webui.utils.tools import get_tools
  66. app = FastAPI()
  67. log = logging.getLogger(__name__)
  68. app.state.config = AppConfig()
  69. app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
  70. app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
  71. app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
  72. app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  73. app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
  74. app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
  75. app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
  76. app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
  77. app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
  78. app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
  79. app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
  80. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  81. app.state.config.BANNERS = WEBUI_BANNERS
  82. app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
  83. app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
  84. app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
  85. app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
  86. app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
  87. app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
  88. app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
  89. app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
  90. app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
  91. app.state.MODELS = {}
  92. app.state.TOOLS = {}
  93. app.state.FUNCTIONS = {}
  94. app.add_middleware(
  95. CORSMiddleware,
  96. allow_origins=CORS_ALLOW_ORIGIN,
  97. allow_credentials=True,
  98. allow_methods=["*"],
  99. allow_headers=["*"],
  100. )
  101. app.include_router(configs.router, prefix="/configs", tags=["configs"])
  102. app.include_router(auths.router, prefix="/auths", tags=["auths"])
  103. app.include_router(users.router, prefix="/users", tags=["users"])
  104. app.include_router(chats.router, prefix="/chats", tags=["chats"])
  105. app.include_router(folders.router, prefix="/folders", tags=["folders"])
  106. app.include_router(models.router, prefix="/models", tags=["models"])
  107. app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
  108. app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
  109. app.include_router(files.router, prefix="/files", tags=["files"])
  110. app.include_router(tools.router, prefix="/tools", tags=["tools"])
  111. app.include_router(functions.router, prefix="/functions", tags=["functions"])
  112. app.include_router(memories.router, prefix="/memories", tags=["memories"])
  113. app.include_router(utils.router, prefix="/utils", tags=["utils"])
  114. @app.get("/")
  115. async def get_status():
  116. return {
  117. "status": True,
  118. "auth": WEBUI_AUTH,
  119. "default_models": app.state.config.DEFAULT_MODELS,
  120. "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  121. }
  122. def get_function_module(pipe_id: str):
  123. # Check if function is already loaded
  124. if pipe_id not in app.state.FUNCTIONS:
  125. function_module, _, _ = load_function_module_by_id(pipe_id)
  126. app.state.FUNCTIONS[pipe_id] = function_module
  127. else:
  128. function_module = app.state.FUNCTIONS[pipe_id]
  129. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  130. valves = Functions.get_function_valves_by_id(pipe_id)
  131. function_module.valves = function_module.Valves(**(valves if valves else {}))
  132. return function_module
  133. async def get_pipe_models():
  134. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  135. pipe_models = []
  136. for pipe in pipes:
  137. function_module = get_function_module(pipe.id)
  138. # Check if function is a manifold
  139. if hasattr(function_module, "pipes"):
  140. sub_pipes = []
  141. # Check if pipes is a function or a list
  142. try:
  143. if callable(function_module.pipes):
  144. sub_pipes = function_module.pipes()
  145. else:
  146. sub_pipes = function_module.pipes
  147. except Exception as e:
  148. log.exception(e)
  149. sub_pipes = []
  150. print(sub_pipes)
  151. for p in sub_pipes:
  152. sub_pipe_id = f'{pipe.id}.{p["id"]}'
  153. sub_pipe_name = p["name"]
  154. if hasattr(function_module, "name"):
  155. sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
  156. pipe_flag = {"type": pipe.type}
  157. pipe_models.append(
  158. {
  159. "id": sub_pipe_id,
  160. "name": sub_pipe_name,
  161. "object": "model",
  162. "created": pipe.created_at,
  163. "owned_by": "openai",
  164. "pipe": pipe_flag,
  165. }
  166. )
  167. else:
  168. pipe_flag = {"type": "pipe"}
  169. pipe_models.append(
  170. {
  171. "id": pipe.id,
  172. "name": pipe.name,
  173. "object": "model",
  174. "created": pipe.created_at,
  175. "owned_by": "openai",
  176. "pipe": pipe_flag,
  177. }
  178. )
  179. return pipe_models
  180. async def execute_pipe(pipe, params):
  181. if inspect.iscoroutinefunction(pipe):
  182. return await pipe(**params)
  183. else:
  184. return pipe(**params)
  185. async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
  186. if isinstance(res, str):
  187. return res
  188. if isinstance(res, Generator):
  189. return "".join(map(str, res))
  190. if isinstance(res, AsyncGenerator):
  191. return "".join([str(stream) async for stream in res])
  192. def process_line(form_data: dict, line):
  193. if isinstance(line, BaseModel):
  194. line = line.model_dump_json()
  195. line = f"data: {line}"
  196. if isinstance(line, dict):
  197. line = f"data: {json.dumps(line)}"
  198. try:
  199. line = line.decode("utf-8")
  200. except Exception:
  201. pass
  202. if line.startswith("data:"):
  203. return f"{line}\n\n"
  204. else:
  205. line = openai_chat_chunk_message_template(form_data["model"], line)
  206. return f"data: {json.dumps(line)}\n\n"
  207. def get_pipe_id(form_data: dict) -> str:
  208. pipe_id = form_data["model"]
  209. if "." in pipe_id:
  210. pipe_id, _ = pipe_id.split(".", 1)
  211. print(pipe_id)
  212. return pipe_id
  213. def get_function_params(function_module, form_data, user, extra_params=None):
  214. if extra_params is None:
  215. extra_params = {}
  216. pipe_id = get_pipe_id(form_data)
  217. # Get the signature of the function
  218. sig = inspect.signature(function_module.pipe)
  219. params = {"body": form_data} | {
  220. k: v for k, v in extra_params.items() if k in sig.parameters
  221. }
  222. if "__user__" in params and hasattr(function_module, "UserValves"):
  223. user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  224. try:
  225. params["__user__"]["valves"] = function_module.UserValves(**user_valves)
  226. except Exception as e:
  227. log.exception(e)
  228. params["__user__"]["valves"] = function_module.UserValves()
  229. return params
  230. async def generate_function_chat_completion(form_data, user):
  231. model_id = form_data.get("model")
  232. model_info = Models.get_model_by_id(model_id)
  233. metadata = form_data.pop("metadata", {})
  234. files = metadata.get("files", [])
  235. tool_ids = metadata.get("tool_ids", [])
  236. # Check if tool_ids is None
  237. if tool_ids is None:
  238. tool_ids = []
  239. __event_emitter__ = None
  240. __event_call__ = None
  241. __task__ = None
  242. __task_body__ = None
  243. if metadata:
  244. if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
  245. __event_emitter__ = get_event_emitter(metadata)
  246. __event_call__ = get_event_call(metadata)
  247. __task__ = metadata.get("task", None)
  248. __task_body__ = metadata.get("task_body", None)
  249. extra_params = {
  250. "__event_emitter__": __event_emitter__,
  251. "__event_call__": __event_call__,
  252. "__task__": __task__,
  253. "__task_body__": __task_body__,
  254. "__files__": files,
  255. "__user__": {
  256. "id": user.id,
  257. "email": user.email,
  258. "name": user.name,
  259. "role": user.role,
  260. },
  261. }
  262. extra_params["__tools__"] = get_tools(
  263. app,
  264. tool_ids,
  265. user,
  266. {
  267. **extra_params,
  268. "__model__": app.state.MODELS[form_data["model"]],
  269. "__messages__": form_data["messages"],
  270. "__files__": files,
  271. },
  272. )
  273. if model_info:
  274. if model_info.base_model_id:
  275. form_data["model"] = model_info.base_model_id
  276. params = model_info.params.model_dump()
  277. form_data = apply_model_params_to_body_openai(params, form_data)
  278. form_data = apply_model_system_prompt_to_body(params, form_data, user)
  279. pipe_id = get_pipe_id(form_data)
  280. function_module = get_function_module(pipe_id)
  281. pipe = function_module.pipe
  282. params = get_function_params(function_module, form_data, user, extra_params)
  283. if form_data.get("stream", False):
  284. async def stream_content():
  285. try:
  286. res = await execute_pipe(pipe, params)
  287. # Directly return if the response is a StreamingResponse
  288. if isinstance(res, StreamingResponse):
  289. async for data in res.body_iterator:
  290. yield data
  291. return
  292. if isinstance(res, dict):
  293. yield f"data: {json.dumps(res)}\n\n"
  294. return
  295. except Exception as e:
  296. print(f"Error: {e}")
  297. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  298. return
  299. if isinstance(res, str):
  300. message = openai_chat_chunk_message_template(form_data["model"], res)
  301. yield f"data: {json.dumps(message)}\n\n"
  302. if isinstance(res, Iterator):
  303. for line in res:
  304. yield process_line(form_data, line)
  305. if isinstance(res, AsyncGenerator):
  306. async for line in res:
  307. yield process_line(form_data, line)
  308. if isinstance(res, str) or isinstance(res, Generator):
  309. finish_message = openai_chat_chunk_message_template(
  310. form_data["model"], ""
  311. )
  312. finish_message["choices"][0]["finish_reason"] = "stop"
  313. yield f"data: {json.dumps(finish_message)}\n\n"
  314. yield "data: [DONE]"
  315. return StreamingResponse(stream_content(), media_type="text/event-stream")
  316. else:
  317. try:
  318. res = await execute_pipe(pipe, params)
  319. except Exception as e:
  320. print(f"Error: {e}")
  321. return {"error": {"detail": str(e)}}
  322. if isinstance(res, StreamingResponse) or isinstance(res, dict):
  323. return res
  324. if isinstance(res, BaseModel):
  325. return res.model_dump()
  326. message = await get_message_content(res)
  327. return openai_chat_completion_message_template(form_data["model"], message)