main.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. from fastapi import FastAPI
  2. from fastapi.responses import StreamingResponse
  3. from fastapi.middleware.cors import CORSMiddleware
  4. from apps.webui.routers import (
  5. auths,
  6. users,
  7. chats,
  8. documents,
  9. tools,
  10. models,
  11. prompts,
  12. configs,
  13. memories,
  14. utils,
  15. files,
  16. functions,
  17. )
  18. from apps.webui.models.functions import Functions
  19. from apps.webui.models.models import Models
  20. from apps.webui.utils import load_function_module_by_id
  21. from utils.misc import (
  22. openai_chat_chunk_message_template,
  23. openai_chat_completion_message_template,
  24. apply_model_params_to_body_openai,
  25. apply_model_system_prompt_to_body,
  26. )
  27. from config import (
  28. SHOW_ADMIN_DETAILS,
  29. ADMIN_EMAIL,
  30. WEBUI_AUTH,
  31. DEFAULT_MODELS,
  32. DEFAULT_PROMPT_SUGGESTIONS,
  33. DEFAULT_USER_ROLE,
  34. ENABLE_SIGNUP,
  35. ENABLE_LOGIN_FORM,
  36. USER_PERMISSIONS,
  37. WEBHOOK_URL,
  38. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  39. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  40. JWT_EXPIRES_IN,
  41. WEBUI_BANNERS,
  42. ENABLE_COMMUNITY_SHARING,
  43. AppConfig,
  44. OAUTH_USERNAME_CLAIM,
  45. OAUTH_PICTURE_CLAIM,
  46. OAUTH_EMAIL_CLAIM,
  47. CORS_ALLOW_ORIGIN,
  48. )
  49. from apps.socket.main import get_event_call, get_event_emitter
  50. import inspect
  51. import json
  52. from typing import Iterator, Generator, AsyncGenerator
  53. from pydantic import BaseModel
  54. app = FastAPI()
  55. app.state.config = AppConfig()
  56. app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
  57. app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
  58. app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
  59. app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  60. app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
  61. app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
  62. app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
  63. app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
  64. app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
  65. app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
  66. app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
  67. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  68. app.state.config.BANNERS = WEBUI_BANNERS
  69. app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
  70. app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
  71. app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
  72. app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
  73. app.state.MODELS = {}
  74. app.state.TOOLS = {}
  75. app.state.FUNCTIONS = {}
  76. app.add_middleware(
  77. CORSMiddleware,
  78. allow_origins=CORS_ALLOW_ORIGIN,
  79. allow_credentials=True,
  80. allow_methods=["*"],
  81. allow_headers=["*"],
  82. )
  83. app.include_router(configs.router, prefix="/configs", tags=["configs"])
  84. app.include_router(auths.router, prefix="/auths", tags=["auths"])
  85. app.include_router(users.router, prefix="/users", tags=["users"])
  86. app.include_router(chats.router, prefix="/chats", tags=["chats"])
  87. app.include_router(documents.router, prefix="/documents", tags=["documents"])
  88. app.include_router(models.router, prefix="/models", tags=["models"])
  89. app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
  90. app.include_router(memories.router, prefix="/memories", tags=["memories"])
  91. app.include_router(files.router, prefix="/files", tags=["files"])
  92. app.include_router(tools.router, prefix="/tools", tags=["tools"])
  93. app.include_router(functions.router, prefix="/functions", tags=["functions"])
  94. app.include_router(utils.router, prefix="/utils", tags=["utils"])
  95. @app.get("/")
  96. async def get_status():
  97. return {
  98. "status": True,
  99. "auth": WEBUI_AUTH,
  100. "default_models": app.state.config.DEFAULT_MODELS,
  101. "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  102. }
  103. def get_function_module(pipe_id: str):
  104. # Check if function is already loaded
  105. if pipe_id not in app.state.FUNCTIONS:
  106. function_module, _, _ = load_function_module_by_id(pipe_id)
  107. app.state.FUNCTIONS[pipe_id] = function_module
  108. else:
  109. function_module = app.state.FUNCTIONS[pipe_id]
  110. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  111. valves = Functions.get_function_valves_by_id(pipe_id)
  112. function_module.valves = function_module.Valves(**(valves if valves else {}))
  113. return function_module
  114. async def get_pipe_models():
  115. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  116. pipe_models = []
  117. for pipe in pipes:
  118. function_module = get_function_module(pipe.id)
  119. # Check if function is a manifold
  120. if hasattr(function_module, "pipes"):
  121. manifold_pipes = []
  122. # Check if pipes is a function or a list
  123. if callable(function_module.pipes):
  124. manifold_pipes = function_module.pipes()
  125. else:
  126. manifold_pipes = function_module.pipes
  127. for p in manifold_pipes:
  128. manifold_pipe_id = f'{pipe.id}.{p["id"]}'
  129. manifold_pipe_name = p["name"]
  130. if hasattr(function_module, "name"):
  131. manifold_pipe_name = f"{function_module.name}{manifold_pipe_name}"
  132. pipe_flag = {"type": pipe.type}
  133. if hasattr(function_module, "ChatValves"):
  134. pipe_flag["valves_spec"] = function_module.ChatValves.schema()
  135. pipe_models.append(
  136. {
  137. "id": manifold_pipe_id,
  138. "name": manifold_pipe_name,
  139. "object": "model",
  140. "created": pipe.created_at,
  141. "owned_by": "openai",
  142. "pipe": pipe_flag,
  143. }
  144. )
  145. else:
  146. pipe_flag = {"type": "pipe"}
  147. if hasattr(function_module, "ChatValves"):
  148. pipe_flag["valves_spec"] = function_module.ChatValves.schema()
  149. pipe_models.append(
  150. {
  151. "id": pipe.id,
  152. "name": pipe.name,
  153. "object": "model",
  154. "created": pipe.created_at,
  155. "owned_by": "openai",
  156. "pipe": pipe_flag,
  157. }
  158. )
  159. return pipe_models
  160. async def execute_pipe(pipe, params):
  161. if inspect.iscoroutinefunction(pipe):
  162. return await pipe(**params)
  163. else:
  164. return pipe(**params)
  165. async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
  166. if isinstance(res, str):
  167. return res
  168. if isinstance(res, Generator):
  169. return "".join(map(str, res))
  170. if isinstance(res, AsyncGenerator):
  171. return "".join([str(stream) async for stream in res])
  172. def process_line(form_data: dict, line):
  173. if isinstance(line, BaseModel):
  174. line = line.model_dump_json()
  175. line = f"data: {line}"
  176. if isinstance(line, dict):
  177. line = f"data: {json.dumps(line)}"
  178. try:
  179. line = line.decode("utf-8")
  180. except Exception:
  181. pass
  182. if line.startswith("data:"):
  183. return f"{line}\n\n"
  184. else:
  185. line = openai_chat_chunk_message_template(form_data["model"], line)
  186. return f"data: {json.dumps(line)}\n\n"
  187. def get_pipe_id(form_data: dict) -> str:
  188. pipe_id = form_data["model"]
  189. if "." in pipe_id:
  190. pipe_id, _ = pipe_id.split(".", 1)
  191. print(pipe_id)
  192. return pipe_id
  193. def get_function_params(function_module, form_data, user, extra_params={}):
  194. pipe_id = get_pipe_id(form_data)
  195. # Get the signature of the function
  196. sig = inspect.signature(function_module.pipe)
  197. params = {"body": form_data}
  198. for key, value in extra_params.items():
  199. if key in sig.parameters:
  200. params[key] = value
  201. if "__user__" in sig.parameters:
  202. __user__ = {
  203. "id": user.id,
  204. "email": user.email,
  205. "name": user.name,
  206. "role": user.role,
  207. }
  208. try:
  209. if hasattr(function_module, "UserValves"):
  210. __user__["valves"] = function_module.UserValves(
  211. **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  212. )
  213. except Exception as e:
  214. print(e)
  215. params["__user__"] = __user__
  216. return params
  217. async def generate_function_chat_completion(form_data, user):
  218. model_id = form_data.get("model")
  219. model_info = Models.get_model_by_id(model_id)
  220. metadata = form_data.pop("metadata", None)
  221. __event_emitter__ = None
  222. __event_call__ = None
  223. __task__ = None
  224. if metadata:
  225. if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
  226. __event_emitter__ = get_event_emitter(metadata)
  227. __event_call__ = get_event_call(metadata)
  228. __task__ = metadata.get("task", None)
  229. if model_info:
  230. if model_info.base_model_id:
  231. form_data["model"] = model_info.base_model_id
  232. params = model_info.params.model_dump()
  233. form_data = apply_model_params_to_body_openai(params, form_data)
  234. form_data = apply_model_system_prompt_to_body(params, form_data, user)
  235. pipe_id = get_pipe_id(form_data)
  236. function_module = get_function_module(pipe_id)
  237. pipe = function_module.pipe
  238. params = get_function_params(
  239. function_module,
  240. form_data,
  241. user,
  242. {
  243. "__event_emitter__": __event_emitter__,
  244. "__event_call__": __event_call__,
  245. "__task__": __task__,
  246. },
  247. )
  248. if form_data["stream"]:
  249. async def stream_content():
  250. try:
  251. res = await execute_pipe(pipe, params)
  252. # Directly return if the response is a StreamingResponse
  253. if isinstance(res, StreamingResponse):
  254. async for data in res.body_iterator:
  255. yield data
  256. return
  257. if isinstance(res, dict):
  258. yield f"data: {json.dumps(res)}\n\n"
  259. return
  260. except Exception as e:
  261. print(f"Error: {e}")
  262. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  263. return
  264. if isinstance(res, str):
  265. message = openai_chat_chunk_message_template(form_data["model"], res)
  266. yield f"data: {json.dumps(message)}\n\n"
  267. if isinstance(res, Iterator):
  268. for line in res:
  269. yield process_line(form_data, line)
  270. if isinstance(res, AsyncGenerator):
  271. async for line in res:
  272. yield process_line(form_data, line)
  273. if isinstance(res, str) or isinstance(res, Generator):
  274. finish_message = openai_chat_chunk_message_template(
  275. form_data["model"], ""
  276. )
  277. finish_message["choices"][0]["finish_reason"] = "stop"
  278. yield f"data: {json.dumps(finish_message)}\n\n"
  279. yield "data: [DONE]"
  280. return StreamingResponse(stream_content(), media_type="text/event-stream")
  281. else:
  282. try:
  283. res = await execute_pipe(pipe, params)
  284. except Exception as e:
  285. print(f"Error: {e}")
  286. return {"error": {"detail": str(e)}}
  287. if isinstance(res, StreamingResponse) or isinstance(res, dict):
  288. return res
  289. if isinstance(res, BaseModel):
  290. return res.model_dump()
  291. message = await get_message_content(res)
  292. return openai_chat_completion_message_template(form_data["model"], message)