main.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. from fastapi import FastAPI, Depends
  2. from fastapi.routing import APIRoute
  3. from fastapi.responses import StreamingResponse
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from starlette.middleware.sessions import SessionMiddleware
  6. from sqlalchemy.orm import Session
  7. from apps.webui.routers import (
  8. auths,
  9. users,
  10. chats,
  11. documents,
  12. tools,
  13. models,
  14. prompts,
  15. configs,
  16. memories,
  17. utils,
  18. files,
  19. functions,
  20. )
  21. from apps.webui.models.functions import Functions
  22. from apps.webui.models.models import Models
  23. from apps.webui.utils import load_function_module_by_id
  24. from utils.misc import stream_message_template
  25. from utils.task import prompt_template
  26. from config import (
  27. WEBUI_BUILD_HASH,
  28. SHOW_ADMIN_DETAILS,
  29. ADMIN_EMAIL,
  30. WEBUI_AUTH,
  31. DEFAULT_MODELS,
  32. DEFAULT_PROMPT_SUGGESTIONS,
  33. DEFAULT_USER_ROLE,
  34. ENABLE_SIGNUP,
  35. ENABLE_LOGIN_FORM,
  36. USER_PERMISSIONS,
  37. WEBHOOK_URL,
  38. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  39. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  40. JWT_EXPIRES_IN,
  41. WEBUI_BANNERS,
  42. ENABLE_COMMUNITY_SHARING,
  43. AppConfig,
  44. OAUTH_USERNAME_CLAIM,
  45. OAUTH_PICTURE_CLAIM,
  46. )
  47. from apps.socket.main import get_event_call, get_event_emitter
  48. import inspect
  49. import uuid
  50. import time
  51. import json
  52. from typing import Iterator, Generator, AsyncGenerator, Optional
  53. from pydantic import BaseModel
  54. app = FastAPI()
  55. origins = ["*"]
  56. app.state.config = AppConfig()
  57. app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
  58. app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
  59. app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
  60. app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  61. app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
  62. app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
  63. app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
  64. app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
  65. app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
  66. app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
  67. app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
  68. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  69. app.state.config.BANNERS = WEBUI_BANNERS
  70. app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
  71. app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
  72. app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
  73. app.state.MODELS = {}
  74. app.state.TOOLS = {}
  75. app.state.FUNCTIONS = {}
  76. app.add_middleware(
  77. CORSMiddleware,
  78. allow_origins=origins,
  79. allow_credentials=True,
  80. allow_methods=["*"],
  81. allow_headers=["*"],
  82. )
  83. app.include_router(configs.router, prefix="/configs", tags=["configs"])
  84. app.include_router(auths.router, prefix="/auths", tags=["auths"])
  85. app.include_router(users.router, prefix="/users", tags=["users"])
  86. app.include_router(chats.router, prefix="/chats", tags=["chats"])
  87. app.include_router(documents.router, prefix="/documents", tags=["documents"])
  88. app.include_router(models.router, prefix="/models", tags=["models"])
  89. app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
  90. app.include_router(memories.router, prefix="/memories", tags=["memories"])
  91. app.include_router(files.router, prefix="/files", tags=["files"])
  92. app.include_router(tools.router, prefix="/tools", tags=["tools"])
  93. app.include_router(functions.router, prefix="/functions", tags=["functions"])
  94. app.include_router(utils.router, prefix="/utils", tags=["utils"])
  95. @app.get("/")
  96. async def get_status():
  97. return {
  98. "status": True,
  99. "auth": WEBUI_AUTH,
  100. "default_models": app.state.config.DEFAULT_MODELS,
  101. "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  102. }
  103. async def get_pipe_models():
  104. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  105. pipe_models = []
  106. for pipe in pipes:
  107. # Check if function is already loaded
  108. if pipe.id not in app.state.FUNCTIONS:
  109. function_module, function_type, frontmatter = load_function_module_by_id(
  110. pipe.id
  111. )
  112. app.state.FUNCTIONS[pipe.id] = function_module
  113. else:
  114. function_module = app.state.FUNCTIONS[pipe.id]
  115. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  116. valves = Functions.get_function_valves_by_id(pipe.id)
  117. function_module.valves = function_module.Valves(
  118. **(valves if valves else {})
  119. )
  120. # Check if function is a manifold
  121. if hasattr(function_module, "type"):
  122. if function_module.type == "manifold":
  123. manifold_pipes = []
  124. # Check if pipes is a function or a list
  125. if callable(function_module.pipes):
  126. manifold_pipes = function_module.pipes()
  127. else:
  128. manifold_pipes = function_module.pipes
  129. for p in manifold_pipes:
  130. manifold_pipe_id = f'{pipe.id}.{p["id"]}'
  131. manifold_pipe_name = p["name"]
  132. if hasattr(function_module, "name"):
  133. manifold_pipe_name = (
  134. f"{function_module.name}{manifold_pipe_name}"
  135. )
  136. pipe_flag = {"type": pipe.type}
  137. if hasattr(function_module, "ChatValves"):
  138. pipe_flag["valves_spec"] = function_module.ChatValves.schema()
  139. pipe_models.append(
  140. {
  141. "id": manifold_pipe_id,
  142. "name": manifold_pipe_name,
  143. "object": "model",
  144. "created": pipe.created_at,
  145. "owned_by": "openai",
  146. "pipe": pipe_flag,
  147. }
  148. )
  149. else:
  150. pipe_flag = {"type": "pipe"}
  151. if hasattr(function_module, "ChatValves"):
  152. pipe_flag["valves_spec"] = function_module.ChatValves.schema()
  153. pipe_models.append(
  154. {
  155. "id": pipe.id,
  156. "name": pipe.name,
  157. "object": "model",
  158. "created": pipe.created_at,
  159. "owned_by": "openai",
  160. "pipe": pipe_flag,
  161. }
  162. )
  163. return pipe_models
  164. async def generate_function_chat_completion(form_data, user):
  165. model_id = form_data.get("model")
  166. model_info = Models.get_model_by_id(model_id)
  167. metadata = None
  168. if "metadata" in form_data:
  169. metadata = form_data["metadata"]
  170. del form_data["metadata"]
  171. __event_emitter__ = None
  172. __event_call__ = None
  173. __task__ = None
  174. if metadata:
  175. if (
  176. metadata.get("session_id")
  177. and metadata.get("chat_id")
  178. and metadata.get("message_id")
  179. ):
  180. __event_emitter__ = await get_event_emitter(metadata)
  181. __event_call__ = await get_event_call(metadata)
  182. if metadata.get("task"):
  183. __task__ = metadata.get("task")
  184. if model_info:
  185. if model_info.base_model_id:
  186. form_data["model"] = model_info.base_model_id
  187. model_info.params = model_info.params.model_dump()
  188. if model_info.params:
  189. if model_info.params.get("temperature", None) is not None:
  190. form_data["temperature"] = float(model_info.params.get("temperature"))
  191. if model_info.params.get("top_p", None):
  192. form_data["top_p"] = int(model_info.params.get("top_p", None))
  193. if model_info.params.get("max_tokens", None):
  194. form_data["max_tokens"] = int(model_info.params.get("max_tokens", None))
  195. if model_info.params.get("frequency_penalty", None):
  196. form_data["frequency_penalty"] = int(
  197. model_info.params.get("frequency_penalty", None)
  198. )
  199. if model_info.params.get("seed", None):
  200. form_data["seed"] = model_info.params.get("seed", None)
  201. if model_info.params.get("stop", None):
  202. form_data["stop"] = (
  203. [
  204. bytes(stop, "utf-8").decode("unicode_escape")
  205. for stop in model_info.params["stop"]
  206. ]
  207. if model_info.params.get("stop", None)
  208. else None
  209. )
  210. system = model_info.params.get("system", None)
  211. if system:
  212. system = prompt_template(
  213. system,
  214. **(
  215. {
  216. "user_name": user.name,
  217. "user_location": (
  218. user.info.get("location") if user.info else None
  219. ),
  220. }
  221. if user
  222. else {}
  223. ),
  224. )
  225. # Check if the payload already has a system message
  226. # If not, add a system message to the payload
  227. if form_data.get("messages"):
  228. for message in form_data["messages"]:
  229. if message.get("role") == "system":
  230. message["content"] = system + message["content"]
  231. break
  232. else:
  233. form_data["messages"].insert(
  234. 0,
  235. {
  236. "role": "system",
  237. "content": system,
  238. },
  239. )
  240. else:
  241. pass
  242. async def job():
  243. pipe_id = form_data["model"]
  244. if "." in pipe_id:
  245. pipe_id, sub_pipe_id = pipe_id.split(".", 1)
  246. print(pipe_id)
  247. # Check if function is already loaded
  248. if pipe_id not in app.state.FUNCTIONS:
  249. function_module, function_type, frontmatter = load_function_module_by_id(
  250. pipe_id
  251. )
  252. app.state.FUNCTIONS[pipe_id] = function_module
  253. else:
  254. function_module = app.state.FUNCTIONS[pipe_id]
  255. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  256. valves = Functions.get_function_valves_by_id(pipe_id)
  257. function_module.valves = function_module.Valves(
  258. **(valves if valves else {})
  259. )
  260. pipe = function_module.pipe
  261. # Get the signature of the function
  262. sig = inspect.signature(pipe)
  263. params = {"body": form_data}
  264. if "__user__" in sig.parameters:
  265. __user__ = {
  266. "id": user.id,
  267. "email": user.email,
  268. "name": user.name,
  269. "role": user.role,
  270. }
  271. try:
  272. if hasattr(function_module, "UserValves"):
  273. __user__["valves"] = function_module.UserValves(
  274. **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  275. )
  276. except Exception as e:
  277. print(e)
  278. params = {**params, "__user__": __user__}
  279. if "__event_emitter__" in sig.parameters:
  280. params = {**params, "__event_emitter__": __event_emitter__}
  281. if "__event_call__" in sig.parameters:
  282. params = {**params, "__event_call__": __event_call__}
  283. if "__task__" in sig.parameters:
  284. params = {**params, "__task__": __task__}
  285. if form_data["stream"]:
  286. async def stream_content():
  287. try:
  288. if inspect.iscoroutinefunction(pipe):
  289. res = await pipe(**params)
  290. else:
  291. res = pipe(**params)
  292. # Directly return if the response is a StreamingResponse
  293. if isinstance(res, StreamingResponse):
  294. async for data in res.body_iterator:
  295. yield data
  296. return
  297. if isinstance(res, dict):
  298. yield f"data: {json.dumps(res)}\n\n"
  299. return
  300. except Exception as e:
  301. print(f"Error: {e}")
  302. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  303. return
  304. if isinstance(res, str):
  305. message = stream_message_template(form_data["model"], res)
  306. yield f"data: {json.dumps(message)}\n\n"
  307. if isinstance(res, Iterator):
  308. for line in res:
  309. if isinstance(line, BaseModel):
  310. line = line.model_dump_json()
  311. line = f"data: {line}"
  312. if isinstance(line, dict):
  313. line = f"data: {json.dumps(line)}"
  314. try:
  315. line = line.decode("utf-8")
  316. except:
  317. pass
  318. if line.startswith("data:"):
  319. yield f"{line}\n\n"
  320. else:
  321. line = stream_message_template(form_data["model"], line)
  322. yield f"data: {json.dumps(line)}\n\n"
  323. if isinstance(res, str) or isinstance(res, Generator):
  324. finish_message = {
  325. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  326. "object": "chat.completion.chunk",
  327. "created": int(time.time()),
  328. "model": form_data["model"],
  329. "choices": [
  330. {
  331. "index": 0,
  332. "delta": {},
  333. "logprobs": None,
  334. "finish_reason": "stop",
  335. }
  336. ],
  337. }
  338. yield f"data: {json.dumps(finish_message)}\n\n"
  339. yield f"data: [DONE]"
  340. if isinstance(res, AsyncGenerator):
  341. async for line in res:
  342. if isinstance(line, BaseModel):
  343. line = line.model_dump_json()
  344. line = f"data: {line}"
  345. if isinstance(line, dict):
  346. line = f"data: {json.dumps(line)}"
  347. try:
  348. line = line.decode("utf-8")
  349. except:
  350. pass
  351. if line.startswith("data:"):
  352. yield f"{line}\n\n"
  353. else:
  354. line = stream_message_template(form_data["model"], line)
  355. yield f"data: {json.dumps(line)}\n\n"
  356. return StreamingResponse(stream_content(), media_type="text/event-stream")
  357. else:
  358. try:
  359. if inspect.iscoroutinefunction(pipe):
  360. res = await pipe(**params)
  361. else:
  362. res = pipe(**params)
  363. if isinstance(res, StreamingResponse):
  364. return res
  365. except Exception as e:
  366. print(f"Error: {e}")
  367. return {"error": {"detail": str(e)}}
  368. if isinstance(res, dict):
  369. return res
  370. elif isinstance(res, BaseModel):
  371. return res.model_dump()
  372. else:
  373. message = ""
  374. if isinstance(res, str):
  375. message = res
  376. elif isinstance(res, Generator):
  377. for stream in res:
  378. message = f"{message}{stream}"
  379. elif isinstance(res, AsyncGenerator):
  380. async for stream in res:
  381. message = f"{message}{stream}"
  382. return {
  383. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  384. "object": "chat.completion",
  385. "created": int(time.time()),
  386. "model": form_data["model"],
  387. "choices": [
  388. {
  389. "index": 0,
  390. "message": {
  391. "role": "assistant",
  392. "content": message,
  393. },
  394. "logprobs": None,
  395. "finish_reason": "stop",
  396. }
  397. ],
  398. }
  399. return await job()