main.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. from fastapi import FastAPI
  2. from fastapi.responses import StreamingResponse
  3. from fastapi.middleware.cors import CORSMiddleware
  4. from apps.webui.routers import (
  5. auths,
  6. users,
  7. chats,
  8. documents,
  9. tools,
  10. models,
  11. prompts,
  12. configs,
  13. memories,
  14. utils,
  15. files,
  16. functions,
  17. )
  18. from apps.webui.models.functions import Functions
  19. from apps.webui.models.models import Models
  20. from apps.webui.utils import load_function_module_by_id
  21. from utils.misc import stream_message_template
  22. from utils.task import prompt_template
  23. from config import (
  24. SHOW_ADMIN_DETAILS,
  25. ADMIN_EMAIL,
  26. WEBUI_AUTH,
  27. DEFAULT_MODELS,
  28. DEFAULT_PROMPT_SUGGESTIONS,
  29. DEFAULT_USER_ROLE,
  30. ENABLE_SIGNUP,
  31. ENABLE_LOGIN_FORM,
  32. USER_PERMISSIONS,
  33. WEBHOOK_URL,
  34. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  35. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  36. JWT_EXPIRES_IN,
  37. WEBUI_BANNERS,
  38. ENABLE_COMMUNITY_SHARING,
  39. AppConfig,
  40. OAUTH_USERNAME_CLAIM,
  41. OAUTH_PICTURE_CLAIM,
  42. )
  43. from apps.socket.main import get_event_call, get_event_emitter
  44. import inspect
  45. import uuid
  46. import time
  47. import json
  48. from typing import Iterator, Generator, AsyncGenerator
  49. from pydantic import BaseModel
  50. app = FastAPI()
  51. origins = ["*"]
  52. app.state.config = AppConfig()
  53. app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
  54. app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
  55. app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
  56. app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  57. app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
  58. app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
  59. app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
  60. app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
  61. app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
  62. app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
  63. app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
  64. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  65. app.state.config.BANNERS = WEBUI_BANNERS
  66. app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
  67. app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
  68. app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
  69. app.state.MODELS = {}
  70. app.state.TOOLS = {}
  71. app.state.FUNCTIONS = {}
  72. app.add_middleware(
  73. CORSMiddleware,
  74. allow_origins=origins,
  75. allow_credentials=True,
  76. allow_methods=["*"],
  77. allow_headers=["*"],
  78. )
  79. app.include_router(configs.router, prefix="/configs", tags=["configs"])
  80. app.include_router(auths.router, prefix="/auths", tags=["auths"])
  81. app.include_router(users.router, prefix="/users", tags=["users"])
  82. app.include_router(chats.router, prefix="/chats", tags=["chats"])
  83. app.include_router(documents.router, prefix="/documents", tags=["documents"])
  84. app.include_router(models.router, prefix="/models", tags=["models"])
  85. app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
  86. app.include_router(memories.router, prefix="/memories", tags=["memories"])
  87. app.include_router(files.router, prefix="/files", tags=["files"])
  88. app.include_router(tools.router, prefix="/tools", tags=["tools"])
  89. app.include_router(functions.router, prefix="/functions", tags=["functions"])
  90. app.include_router(utils.router, prefix="/utils", tags=["utils"])
  91. @app.get("/")
  92. async def get_status():
  93. return {
  94. "status": True,
  95. "auth": WEBUI_AUTH,
  96. "default_models": app.state.config.DEFAULT_MODELS,
  97. "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  98. }
  99. def get_function_module(pipe_id: str):
  100. # Check if function is already loaded
  101. if pipe_id not in app.state.FUNCTIONS:
  102. function_module, _, _ = load_function_module_by_id(pipe_id)
  103. app.state.FUNCTIONS[pipe_id] = function_module
  104. else:
  105. function_module = app.state.FUNCTIONS[pipe_id]
  106. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  107. valves = Functions.get_function_valves_by_id(pipe_id)
  108. function_module.valves = function_module.Valves(**(valves if valves else {}))
  109. return function_module
  110. async def get_pipe_models():
  111. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  112. pipe_models = []
  113. for pipe in pipes:
  114. function_module = get_function_module(pipe.id)
  115. # Check if function is a manifold
  116. if hasattr(function_module, "type"):
  117. if not function_module.type == "manifold":
  118. continue
  119. manifold_pipes = []
  120. # Check if pipes is a function or a list
  121. if callable(function_module.pipes):
  122. manifold_pipes = function_module.pipes()
  123. else:
  124. manifold_pipes = function_module.pipes
  125. for p in manifold_pipes:
  126. manifold_pipe_id = f'{pipe.id}.{p["id"]}'
  127. manifold_pipe_name = p["name"]
  128. if hasattr(function_module, "name"):
  129. manifold_pipe_name = f"{function_module.name}{manifold_pipe_name}"
  130. pipe_flag = {"type": pipe.type}
  131. if hasattr(function_module, "ChatValves"):
  132. pipe_flag["valves_spec"] = function_module.ChatValves.schema()
  133. pipe_models.append(
  134. {
  135. "id": manifold_pipe_id,
  136. "name": manifold_pipe_name,
  137. "object": "model",
  138. "created": pipe.created_at,
  139. "owned_by": "openai",
  140. "pipe": pipe_flag,
  141. }
  142. )
  143. else:
  144. pipe_flag = {"type": "pipe"}
  145. if hasattr(function_module, "ChatValves"):
  146. pipe_flag["valves_spec"] = function_module.ChatValves.schema()
  147. pipe_models.append(
  148. {
  149. "id": pipe.id,
  150. "name": pipe.name,
  151. "object": "model",
  152. "created": pipe.created_at,
  153. "owned_by": "openai",
  154. "pipe": pipe_flag,
  155. }
  156. )
  157. return pipe_models
  158. async def execute_pipe(pipe, params):
  159. if inspect.iscoroutinefunction(pipe):
  160. return await pipe(**params)
  161. else:
  162. return pipe(**params)
  163. async def get_message(res: str | Generator | AsyncGenerator) -> str:
  164. if isinstance(res, str):
  165. return res
  166. if isinstance(res, Generator):
  167. return "".join(map(str, res))
  168. if isinstance(res, AsyncGenerator):
  169. return "".join([str(stream) async for stream in res])
  170. def get_final_message(form_data: dict, message: str | None = None) -> dict:
  171. choice = {
  172. "index": 0,
  173. "logprobs": None,
  174. "finish_reason": "stop",
  175. }
  176. # If message is None, we're dealing with a chunk
  177. if not message:
  178. choice["delta"] = {}
  179. else:
  180. choice["message"] = {"role": "assistant", "content": message}
  181. return {
  182. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  183. "created": int(time.time()),
  184. "model": form_data["model"],
  185. "object": "chat.completion" if message is not None else "chat.completion.chunk",
  186. "choices": [choice],
  187. }
  188. def process_line(form_data: dict, line):
  189. if isinstance(line, BaseModel):
  190. line = line.model_dump_json()
  191. line = f"data: {line}"
  192. if isinstance(line, dict):
  193. line = f"data: {json.dumps(line)}"
  194. try:
  195. line = line.decode("utf-8")
  196. except Exception:
  197. pass
  198. if line.startswith("data:"):
  199. return f"{line}\n\n"
  200. else:
  201. line = stream_message_template(form_data["model"], line)
  202. return f"data: {json.dumps(line)}\n\n"
  203. def get_pipe_id(form_data: dict) -> str:
  204. pipe_id = form_data["model"]
  205. if "." in pipe_id:
  206. pipe_id, _ = pipe_id.split(".", 1)
  207. print(pipe_id)
  208. return pipe_id
  209. def get_params_dict(pipe, form_data, user, extra_params, function_module):
  210. pipe_id = get_pipe_id(form_data)
  211. # Get the signature of the function
  212. sig = inspect.signature(pipe)
  213. params = {"body": form_data}
  214. for key, value in extra_params.items():
  215. if key in sig.parameters:
  216. params[key] = value
  217. if "__user__" in sig.parameters:
  218. __user__ = {
  219. "id": user.id,
  220. "email": user.email,
  221. "name": user.name,
  222. "role": user.role,
  223. }
  224. try:
  225. if hasattr(function_module, "UserValves"):
  226. __user__["valves"] = function_module.UserValves(
  227. **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  228. )
  229. except Exception as e:
  230. print(e)
  231. params["__user__"] = __user__
  232. return params
  233. async def generate_function_chat_completion(form_data, user):
  234. model_id = form_data.get("model")
  235. model_info = Models.get_model_by_id(model_id)
  236. metadata = form_data.pop("metadata", None)
  237. __event_emitter__ = __event_call__ = __task__ = None
  238. if metadata:
  239. if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
  240. __event_emitter__ = get_event_emitter(metadata)
  241. __event_call__ = get_event_call(metadata)
  242. __task__ = metadata.get("task", None)
  243. if not model_info:
  244. return
  245. if model_info.base_model_id:
  246. form_data["model"] = model_info.base_model_id
  247. params = model_info.params.model_dump()
  248. if params:
  249. mappings = {
  250. "temperature": float,
  251. "top_p": int,
  252. "max_tokens": int,
  253. "frequency_penalty": int,
  254. "seed": lambda x: x,
  255. "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
  256. }
  257. for key, cast_func in mappings.items():
  258. if (value := params.get(key)) is not None:
  259. form_data[key] = cast_func(value)
  260. system = params.get("system", None)
  261. if not system:
  262. return
  263. if user:
  264. template_params = {
  265. "user_name": user.name,
  266. "user_location": user.info.get("location") if user.info else None,
  267. }
  268. else:
  269. template_params = {}
  270. system = prompt_template(system, **template_params)
  271. # Check if the payload already has a system message
  272. # If not, add a system message to the payload
  273. for message in form_data.get("messages", []):
  274. if message.get("role") == "system":
  275. message["content"] = system + message["content"]
  276. break
  277. else:
  278. if form_data.get("messages"):
  279. form_data["messages"].insert(0, {"role": "system", "content": system})
  280. extra_params = {
  281. "__event_emitter__": __event_emitter__,
  282. "__event_call__": __event_call__,
  283. "__task__": __task__,
  284. }
  285. async def job():
  286. pipe_id = get_pipe_id(form_data)
  287. function_module = get_function_module(pipe_id)
  288. pipe = function_module.pipe
  289. params = get_params_dict(pipe, form_data, user, extra_params, function_module)
  290. if form_data["stream"]:
  291. async def stream_content():
  292. try:
  293. res = await execute_pipe(pipe, params)
  294. # Directly return if the response is a StreamingResponse
  295. if isinstance(res, StreamingResponse):
  296. async for data in res.body_iterator:
  297. yield data
  298. return
  299. if isinstance(res, dict):
  300. yield f"data: {json.dumps(res)}\n\n"
  301. return
  302. except Exception as e:
  303. print(f"Error: {e}")
  304. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  305. return
  306. if isinstance(res, str):
  307. message = stream_message_template(form_data["model"], res)
  308. yield f"data: {json.dumps(message)}\n\n"
  309. if isinstance(res, Iterator):
  310. for line in res:
  311. yield process_line(form_data, line)
  312. if isinstance(res, AsyncGenerator):
  313. async for line in res:
  314. yield process_line(form_data, line)
  315. if isinstance(res, str) or isinstance(res, Generator):
  316. finish_message = get_final_message(form_data)
  317. yield f"data: {json.dumps(finish_message)}\n\n"
  318. yield "data: [DONE]"
  319. return StreamingResponse(stream_content(), media_type="text/event-stream")
  320. else:
  321. try:
  322. res = await execute_pipe(pipe, params)
  323. except Exception as e:
  324. print(f"Error: {e}")
  325. return {"error": {"detail": str(e)}}
  326. if isinstance(res, StreamingResponse) or isinstance(res, dict):
  327. return res
  328. if isinstance(res, BaseModel):
  329. return res.model_dump()
  330. message = await get_message(res)
  331. return get_final_message(form_data, message)
  332. return await job()