main.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. from fastapi import FastAPI, Depends
  2. from fastapi.routing import APIRoute
  3. from fastapi.responses import StreamingResponse
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from starlette.middleware.sessions import SessionMiddleware
  6. from sqlalchemy.orm import Session
  7. from apps.webui.routers import (
  8. auths,
  9. users,
  10. chats,
  11. documents,
  12. tools,
  13. models,
  14. prompts,
  15. configs,
  16. memories,
  17. utils,
  18. files,
  19. functions,
  20. )
  21. from apps.webui.models.functions import Functions
  22. from apps.webui.models.models import Models
  23. from apps.webui.utils import load_function_module_by_id
  24. from utils.misc import stream_message_template
  25. from utils.task import prompt_template
  26. from config import (
  27. WEBUI_BUILD_HASH,
  28. SHOW_ADMIN_DETAILS,
  29. ADMIN_EMAIL,
  30. WEBUI_AUTH,
  31. DEFAULT_MODELS,
  32. DEFAULT_PROMPT_SUGGESTIONS,
  33. DEFAULT_USER_ROLE,
  34. ENABLE_SIGNUP,
  35. USER_PERMISSIONS,
  36. WEBHOOK_URL,
  37. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  38. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  39. JWT_EXPIRES_IN,
  40. WEBUI_BANNERS,
  41. ENABLE_COMMUNITY_SHARING,
  42. AppConfig,
  43. OAUTH_USERNAME_CLAIM,
  44. OAUTH_PICTURE_CLAIM,
  45. )
  46. from apps.socket.main import get_event_call, get_event_emitter
  47. import inspect
  48. import uuid
  49. import time
  50. import json
  51. from typing import Iterator, Generator, AsyncGenerator, Optional
  52. from pydantic import BaseModel
  53. app = FastAPI()
  54. origins = ["*"]
  55. app.state.config = AppConfig()
  56. app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
  57. app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
  58. app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  59. app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
  60. app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
  61. app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
  62. app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
  63. app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
  64. app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
  65. app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
  66. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  67. app.state.config.BANNERS = WEBUI_BANNERS
  68. app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
  69. app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
  70. app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
  71. app.state.MODELS = {}
  72. app.state.TOOLS = {}
  73. app.state.FUNCTIONS = {}
  74. app.add_middleware(
  75. CORSMiddleware,
  76. allow_origins=origins,
  77. allow_credentials=True,
  78. allow_methods=["*"],
  79. allow_headers=["*"],
  80. )
  81. app.include_router(configs.router, prefix="/configs", tags=["configs"])
  82. app.include_router(auths.router, prefix="/auths", tags=["auths"])
  83. app.include_router(users.router, prefix="/users", tags=["users"])
  84. app.include_router(chats.router, prefix="/chats", tags=["chats"])
  85. app.include_router(documents.router, prefix="/documents", tags=["documents"])
  86. app.include_router(models.router, prefix="/models", tags=["models"])
  87. app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
  88. app.include_router(memories.router, prefix="/memories", tags=["memories"])
  89. app.include_router(files.router, prefix="/files", tags=["files"])
  90. app.include_router(tools.router, prefix="/tools", tags=["tools"])
  91. app.include_router(functions.router, prefix="/functions", tags=["functions"])
  92. app.include_router(utils.router, prefix="/utils", tags=["utils"])
  93. @app.get("/")
  94. async def get_status():
  95. return {
  96. "status": True,
  97. "auth": WEBUI_AUTH,
  98. "default_models": app.state.config.DEFAULT_MODELS,
  99. "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  100. }
  101. async def get_pipe_models():
  102. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  103. pipe_models = []
  104. for pipe in pipes:
  105. # Check if function is already loaded
  106. if pipe.id not in app.state.FUNCTIONS:
  107. function_module, function_type, frontmatter = load_function_module_by_id(
  108. pipe.id
  109. )
  110. app.state.FUNCTIONS[pipe.id] = function_module
  111. else:
  112. function_module = app.state.FUNCTIONS[pipe.id]
  113. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  114. valves = Functions.get_function_valves_by_id(pipe.id)
  115. function_module.valves = function_module.Valves(
  116. **(valves if valves else {})
  117. )
  118. # Check if function is a manifold
  119. if hasattr(function_module, "type"):
  120. if function_module.type == "manifold":
  121. manifold_pipes = []
  122. # Check if pipes is a function or a list
  123. if callable(function_module.pipes):
  124. manifold_pipes = function_module.pipes()
  125. else:
  126. manifold_pipes = function_module.pipes
  127. for p in manifold_pipes:
  128. manifold_pipe_id = f'{pipe.id}.{p["id"]}'
  129. manifold_pipe_name = p["name"]
  130. if hasattr(function_module, "name"):
  131. manifold_pipe_name = (
  132. f"{function_module.name}{manifold_pipe_name}"
  133. )
  134. pipe_flag = {"type": pipe.type}
  135. if hasattr(function_module, "ChatValves"):
  136. pipe_flag["valves_spec"] = function_module.ChatValves.schema()
  137. pipe_models.append(
  138. {
  139. "id": manifold_pipe_id,
  140. "name": manifold_pipe_name,
  141. "object": "model",
  142. "created": pipe.created_at,
  143. "owned_by": "openai",
  144. "pipe": pipe_flag,
  145. }
  146. )
  147. else:
  148. pipe_flag = {"type": "pipe"}
  149. if hasattr(function_module, "ChatValves"):
  150. pipe_flag["valves_spec"] = function_module.ChatValves.schema()
  151. pipe_models.append(
  152. {
  153. "id": pipe.id,
  154. "name": pipe.name,
  155. "object": "model",
  156. "created": pipe.created_at,
  157. "owned_by": "openai",
  158. "pipe": pipe_flag,
  159. }
  160. )
  161. return pipe_models
  162. async def generate_function_chat_completion(form_data, user):
  163. model_id = form_data.get("model")
  164. model_info = Models.get_model_by_id(model_id)
  165. metadata = None
  166. if "metadata" in form_data:
  167. metadata = form_data["metadata"]
  168. del form_data["metadata"]
  169. __event_emitter__ = None
  170. __event_call__ = None
  171. __task__ = None
  172. if metadata:
  173. if (
  174. metadata.get("session_id")
  175. and metadata.get("chat_id")
  176. and metadata.get("message_id")
  177. ):
  178. __event_emitter__ = await get_event_emitter(metadata)
  179. __event_call__ = await get_event_call(metadata)
  180. if metadata.get("task"):
  181. __task__ = metadata.get("task")
  182. if model_info:
  183. if model_info.base_model_id:
  184. form_data["model"] = model_info.base_model_id
  185. model_info.params = model_info.params.model_dump()
  186. if model_info.params:
  187. if model_info.params.get("temperature", None) is not None:
  188. form_data["temperature"] = float(model_info.params.get("temperature"))
  189. if model_info.params.get("top_p", None):
  190. form_data["top_p"] = int(model_info.params.get("top_p", None))
  191. if model_info.params.get("max_tokens", None):
  192. form_data["max_tokens"] = int(model_info.params.get("max_tokens", None))
  193. if model_info.params.get("frequency_penalty", None):
  194. form_data["frequency_penalty"] = int(
  195. model_info.params.get("frequency_penalty", None)
  196. )
  197. if model_info.params.get("seed", None):
  198. form_data["seed"] = model_info.params.get("seed", None)
  199. if model_info.params.get("stop", None):
  200. form_data["stop"] = (
  201. [
  202. bytes(stop, "utf-8").decode("unicode_escape")
  203. for stop in model_info.params["stop"]
  204. ]
  205. if model_info.params.get("stop", None)
  206. else None
  207. )
  208. system = model_info.params.get("system", None)
  209. if system:
  210. system = prompt_template(
  211. system,
  212. **(
  213. {
  214. "user_name": user.name,
  215. "user_location": (
  216. user.info.get("location") if user.info else None
  217. ),
  218. }
  219. if user
  220. else {}
  221. ),
  222. )
  223. # Check if the payload already has a system message
  224. # If not, add a system message to the payload
  225. if form_data.get("messages"):
  226. for message in form_data["messages"]:
  227. if message.get("role") == "system":
  228. message["content"] = system + message["content"]
  229. break
  230. else:
  231. form_data["messages"].insert(
  232. 0,
  233. {
  234. "role": "system",
  235. "content": system,
  236. },
  237. )
  238. else:
  239. pass
  240. async def job():
  241. pipe_id = form_data["model"]
  242. if "." in pipe_id:
  243. pipe_id, sub_pipe_id = pipe_id.split(".", 1)
  244. print(pipe_id)
  245. # Check if function is already loaded
  246. if pipe_id not in app.state.FUNCTIONS:
  247. function_module, function_type, frontmatter = load_function_module_by_id(
  248. pipe_id
  249. )
  250. app.state.FUNCTIONS[pipe_id] = function_module
  251. else:
  252. function_module = app.state.FUNCTIONS[pipe_id]
  253. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  254. valves = Functions.get_function_valves_by_id(pipe_id)
  255. function_module.valves = function_module.Valves(
  256. **(valves if valves else {})
  257. )
  258. pipe = function_module.pipe
  259. # Get the signature of the function
  260. sig = inspect.signature(pipe)
  261. params = {"body": form_data}
  262. if "__user__" in sig.parameters:
  263. __user__ = {
  264. "id": user.id,
  265. "email": user.email,
  266. "name": user.name,
  267. "role": user.role,
  268. }
  269. try:
  270. if hasattr(function_module, "UserValves"):
  271. __user__["valves"] = function_module.UserValves(
  272. **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  273. )
  274. except Exception as e:
  275. print(e)
  276. params = {**params, "__user__": __user__}
  277. if "__event_emitter__" in sig.parameters:
  278. params = {**params, "__event_emitter__": __event_emitter__}
  279. if "__event_call__" in sig.parameters:
  280. params = {**params, "__event_call__": __event_call__}
  281. if "__task__" in sig.parameters:
  282. params = {**params, "__task__": __task__}
  283. if form_data["stream"]:
  284. async def stream_content():
  285. try:
  286. if inspect.iscoroutinefunction(pipe):
  287. res = await pipe(**params)
  288. else:
  289. res = pipe(**params)
  290. # Directly return if the response is a StreamingResponse
  291. if isinstance(res, StreamingResponse):
  292. async for data in res.body_iterator:
  293. yield data
  294. return
  295. if isinstance(res, dict):
  296. yield f"data: {json.dumps(res)}\n\n"
  297. return
  298. except Exception as e:
  299. print(f"Error: {e}")
  300. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  301. return
  302. if isinstance(res, str):
  303. message = stream_message_template(form_data["model"], res)
  304. yield f"data: {json.dumps(message)}\n\n"
  305. if isinstance(res, Iterator):
  306. for line in res:
  307. if isinstance(line, BaseModel):
  308. line = line.model_dump_json()
  309. line = f"data: {line}"
  310. if isinstance(line, dict):
  311. line = f"data: {json.dumps(line)}"
  312. try:
  313. line = line.decode("utf-8")
  314. except:
  315. pass
  316. if line.startswith("data:"):
  317. yield f"{line}\n\n"
  318. else:
  319. line = stream_message_template(form_data["model"], line)
  320. yield f"data: {json.dumps(line)}\n\n"
  321. if isinstance(res, str) or isinstance(res, Generator):
  322. finish_message = {
  323. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  324. "object": "chat.completion.chunk",
  325. "created": int(time.time()),
  326. "model": form_data["model"],
  327. "choices": [
  328. {
  329. "index": 0,
  330. "delta": {},
  331. "logprobs": None,
  332. "finish_reason": "stop",
  333. }
  334. ],
  335. }
  336. yield f"data: {json.dumps(finish_message)}\n\n"
  337. yield f"data: [DONE]"
  338. if isinstance(res, AsyncGenerator):
  339. async for line in res:
  340. if isinstance(line, BaseModel):
  341. line = line.model_dump_json()
  342. line = f"data: {line}"
  343. if isinstance(line, dict):
  344. line = f"data: {json.dumps(line)}"
  345. try:
  346. line = line.decode("utf-8")
  347. except:
  348. pass
  349. if line.startswith("data:"):
  350. yield f"{line}\n\n"
  351. else:
  352. line = stream_message_template(form_data["model"], line)
  353. yield f"data: {json.dumps(line)}\n\n"
  354. return StreamingResponse(stream_content(), media_type="text/event-stream")
  355. else:
  356. try:
  357. if inspect.iscoroutinefunction(pipe):
  358. res = await pipe(**params)
  359. else:
  360. res = pipe(**params)
  361. if isinstance(res, StreamingResponse):
  362. return res
  363. except Exception as e:
  364. print(f"Error: {e}")
  365. return {"error": {"detail": str(e)}}
  366. if isinstance(res, dict):
  367. return res
  368. elif isinstance(res, BaseModel):
  369. return res.model_dump()
  370. else:
  371. message = ""
  372. if isinstance(res, str):
  373. message = res
  374. elif isinstance(res, Generator):
  375. for stream in res:
  376. message = f"{message}{stream}"
  377. elif isinstance(res, AsyncGenerator):
  378. async for stream in res:
  379. message = f"{message}{stream}"
  380. return {
  381. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  382. "object": "chat.completion",
  383. "created": int(time.time()),
  384. "model": form_data["model"],
  385. "choices": [
  386. {
  387. "index": 0,
  388. "message": {
  389. "role": "assistant",
  390. "content": message,
  391. },
  392. "logprobs": None,
  393. "finish_reason": "stop",
  394. }
  395. ],
  396. }
  397. return await job()