main.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. from fastapi import FastAPI, Depends
  2. from fastapi.routing import APIRoute
  3. from fastapi.responses import StreamingResponse
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from starlette.middleware.sessions import SessionMiddleware
  6. from sqlalchemy.orm import Session
  7. from apps.webui.routers import (
  8. auths,
  9. users,
  10. chats,
  11. documents,
  12. tools,
  13. models,
  14. prompts,
  15. configs,
  16. memories,
  17. utils,
  18. files,
  19. functions,
  20. )
  21. from apps.webui.models.functions import Functions
  22. from apps.webui.models.models import Models
  23. from apps.webui.utils import load_function_module_by_id
  24. from utils.misc import stream_message_template
  25. from utils.task import prompt_template
  26. from config import (
  27. WEBUI_BUILD_HASH,
  28. SHOW_ADMIN_DETAILS,
  29. ADMIN_EMAIL,
  30. WEBUI_AUTH,
  31. DEFAULT_MODELS,
  32. DEFAULT_PROMPT_SUGGESTIONS,
  33. DEFAULT_USER_ROLE,
  34. ENABLE_SIGNUP,
  35. USER_PERMISSIONS,
  36. WEBHOOK_URL,
  37. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  38. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  39. JWT_EXPIRES_IN,
  40. WEBUI_BANNERS,
  41. ENABLE_COMMUNITY_SHARING,
  42. AppConfig,
  43. OAUTH_USERNAME_CLAIM,
  44. OAUTH_PICTURE_CLAIM,
  45. )
  46. import inspect
  47. import uuid
  48. import time
  49. import json
  50. from typing import Iterator, Generator
  51. from pydantic import BaseModel
  52. app = FastAPI()
  53. origins = ["*"]
  54. app.state.config = AppConfig()
  55. app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
  56. app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
  57. app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  58. app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
  59. app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
  60. app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
  61. app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
  62. app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
  63. app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
  64. app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
  65. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  66. app.state.config.BANNERS = WEBUI_BANNERS
  67. app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
  68. app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
  69. app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
  70. app.state.MODELS = {}
  71. app.state.TOOLS = {}
  72. app.state.FUNCTIONS = {}
  73. app.add_middleware(
  74. CORSMiddleware,
  75. allow_origins=origins,
  76. allow_credentials=True,
  77. allow_methods=["*"],
  78. allow_headers=["*"],
  79. )
  80. app.include_router(configs.router, prefix="/configs", tags=["configs"])
  81. app.include_router(auths.router, prefix="/auths", tags=["auths"])
  82. app.include_router(users.router, prefix="/users", tags=["users"])
  83. app.include_router(chats.router, prefix="/chats", tags=["chats"])
  84. app.include_router(documents.router, prefix="/documents", tags=["documents"])
  85. app.include_router(models.router, prefix="/models", tags=["models"])
  86. app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
  87. app.include_router(memories.router, prefix="/memories", tags=["memories"])
  88. app.include_router(files.router, prefix="/files", tags=["files"])
  89. app.include_router(tools.router, prefix="/tools", tags=["tools"])
  90. app.include_router(functions.router, prefix="/functions", tags=["functions"])
  91. app.include_router(utils.router, prefix="/utils", tags=["utils"])
  92. @app.get("/")
  93. async def get_status():
  94. return {
  95. "status": True,
  96. "auth": WEBUI_AUTH,
  97. "default_models": app.state.config.DEFAULT_MODELS,
  98. "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  99. }
  100. async def get_pipe_models():
  101. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  102. pipe_models = []
  103. for pipe in pipes:
  104. # Check if function is already loaded
  105. if pipe.id not in app.state.FUNCTIONS:
  106. function_module, function_type, frontmatter = load_function_module_by_id(
  107. pipe.id
  108. )
  109. app.state.FUNCTIONS[pipe.id] = function_module
  110. else:
  111. function_module = app.state.FUNCTIONS[pipe.id]
  112. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  113. valves = Functions.get_function_valves_by_id(pipe.id)
  114. function_module.valves = function_module.Valves(
  115. **(valves if valves else {})
  116. )
  117. # Check if function is a manifold
  118. if hasattr(function_module, "type"):
  119. if function_module.type == "manifold":
  120. manifold_pipes = []
  121. # Check if pipes is a function or a list
  122. if callable(function_module.pipes):
  123. manifold_pipes = function_module.pipes()
  124. else:
  125. manifold_pipes = function_module.pipes
  126. for p in manifold_pipes:
  127. manifold_pipe_id = f'{pipe.id}.{p["id"]}'
  128. manifold_pipe_name = p["name"]
  129. if hasattr(function_module, "name"):
  130. manifold_pipe_name = (
  131. f"{function_module.name}{manifold_pipe_name}"
  132. )
  133. pipe_models.append(
  134. {
  135. "id": manifold_pipe_id,
  136. "name": manifold_pipe_name,
  137. "object": "model",
  138. "created": pipe.created_at,
  139. "owned_by": "openai",
  140. "pipe": {"type": pipe.type},
  141. }
  142. )
  143. else:
  144. pipe_models.append(
  145. {
  146. "id": pipe.id,
  147. "name": pipe.name,
  148. "object": "model",
  149. "created": pipe.created_at,
  150. "owned_by": "openai",
  151. "pipe": {"type": "pipe"},
  152. }
  153. )
  154. return pipe_models
  155. async def generate_function_chat_completion(form_data, user):
  156. model_id = form_data.get("model")
  157. model_info = Models.get_model_by_id(model_id)
  158. if model_info:
  159. if model_info.base_model_id:
  160. form_data["model"] = model_info.base_model_id
  161. model_info.params = model_info.params.model_dump()
  162. if model_info.params:
  163. if model_info.params.get("temperature", None) is not None:
  164. form_data["temperature"] = float(model_info.params.get("temperature"))
  165. if model_info.params.get("top_p", None):
  166. form_data["top_p"] = int(model_info.params.get("top_p", None))
  167. if model_info.params.get("max_tokens", None):
  168. form_data["max_tokens"] = int(model_info.params.get("max_tokens", None))
  169. if model_info.params.get("frequency_penalty", None):
  170. form_data["frequency_penalty"] = int(
  171. model_info.params.get("frequency_penalty", None)
  172. )
  173. if model_info.params.get("seed", None):
  174. form_data["seed"] = model_info.params.get("seed", None)
  175. if model_info.params.get("stop", None):
  176. form_data["stop"] = (
  177. [
  178. bytes(stop, "utf-8").decode("unicode_escape")
  179. for stop in model_info.params["stop"]
  180. ]
  181. if model_info.params.get("stop", None)
  182. else None
  183. )
  184. system = model_info.params.get("system", None)
  185. if system:
  186. system = prompt_template(
  187. system,
  188. **(
  189. {
  190. "user_name": user.name,
  191. "user_location": (
  192. user.info.get("location") if user.info else None
  193. ),
  194. }
  195. if user
  196. else {}
  197. ),
  198. )
  199. # Check if the payload already has a system message
  200. # If not, add a system message to the payload
  201. if form_data.get("messages"):
  202. for message in form_data["messages"]:
  203. if message.get("role") == "system":
  204. message["content"] = system + message["content"]
  205. break
  206. else:
  207. form_data["messages"].insert(
  208. 0,
  209. {
  210. "role": "system",
  211. "content": system,
  212. },
  213. )
  214. else:
  215. pass
  216. async def job():
  217. pipe_id = form_data["model"]
  218. if "." in pipe_id:
  219. pipe_id, sub_pipe_id = pipe_id.split(".", 1)
  220. print(pipe_id)
  221. # Check if function is already loaded
  222. if pipe_id not in app.state.FUNCTIONS:
  223. function_module, function_type, frontmatter = load_function_module_by_id(
  224. pipe_id
  225. )
  226. app.state.FUNCTIONS[pipe_id] = function_module
  227. else:
  228. function_module = app.state.FUNCTIONS[pipe_id]
  229. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  230. valves = Functions.get_function_valves_by_id(pipe_id)
  231. function_module.valves = function_module.Valves(
  232. **(valves if valves else {})
  233. )
  234. pipe = function_module.pipe
  235. # Get the signature of the function
  236. sig = inspect.signature(pipe)
  237. params = {"body": form_data}
  238. if "__user__" in sig.parameters:
  239. __user__ = {
  240. "id": user.id,
  241. "email": user.email,
  242. "name": user.name,
  243. "role": user.role,
  244. }
  245. try:
  246. if hasattr(function_module, "UserValves"):
  247. __user__["valves"] = function_module.UserValves(
  248. **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  249. )
  250. except Exception as e:
  251. print(e)
  252. params = {**params, "__user__": __user__}
  253. if form_data["stream"]:
  254. async def stream_content():
  255. try:
  256. if inspect.iscoroutinefunction(pipe):
  257. res = await pipe(**params)
  258. else:
  259. res = pipe(**params)
  260. # Directly return if the response is a StreamingResponse
  261. if isinstance(res, StreamingResponse):
  262. async for data in res.body_iterator:
  263. yield data
  264. return
  265. if isinstance(res, dict):
  266. yield f"data: {json.dumps(res)}\n\n"
  267. return
  268. except Exception as e:
  269. print(f"Error: {e}")
  270. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  271. return
  272. if isinstance(res, str):
  273. message = stream_message_template(form_data["model"], res)
  274. yield f"data: {json.dumps(message)}\n\n"
  275. if isinstance(res, Iterator):
  276. for line in res:
  277. if isinstance(line, BaseModel):
  278. line = line.model_dump_json()
  279. line = f"data: {line}"
  280. if isinstance(line, dict):
  281. line = f"data: {json.dumps(line)}"
  282. try:
  283. line = line.decode("utf-8")
  284. except:
  285. pass
  286. if line.startswith("data:"):
  287. yield f"{line}\n\n"
  288. else:
  289. line = stream_message_template(form_data["model"], line)
  290. yield f"data: {json.dumps(line)}\n\n"
  291. if isinstance(res, str) or isinstance(res, Generator):
  292. finish_message = {
  293. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  294. "object": "chat.completion.chunk",
  295. "created": int(time.time()),
  296. "model": form_data["model"],
  297. "choices": [
  298. {
  299. "index": 0,
  300. "delta": {},
  301. "logprobs": None,
  302. "finish_reason": "stop",
  303. }
  304. ],
  305. }
  306. yield f"data: {json.dumps(finish_message)}\n\n"
  307. yield f"data: [DONE]"
  308. return StreamingResponse(stream_content(), media_type="text/event-stream")
  309. else:
  310. try:
  311. if inspect.iscoroutinefunction(pipe):
  312. res = await pipe(**params)
  313. else:
  314. res = pipe(**params)
  315. if isinstance(res, StreamingResponse):
  316. return res
  317. except Exception as e:
  318. print(f"Error: {e}")
  319. return {"error": {"detail": str(e)}}
  320. if isinstance(res, dict):
  321. return res
  322. elif isinstance(res, BaseModel):
  323. return res.model_dump()
  324. else:
  325. message = ""
  326. if isinstance(res, str):
  327. message = res
  328. if isinstance(res, Generator):
  329. for stream in res:
  330. message = f"{message}{stream}"
  331. return {
  332. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  333. "object": "chat.completion",
  334. "created": int(time.time()),
  335. "model": form_data["model"],
  336. "choices": [
  337. {
  338. "index": 0,
  339. "message": {
  340. "role": "assistant",
  341. "content": message,
  342. },
  343. "logprobs": None,
  344. "finish_reason": "stop",
  345. }
  346. ],
  347. }
  348. return await job()