main.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. from fastapi import FastAPI, Depends
  2. from fastapi.routing import APIRoute
  3. from fastapi.responses import StreamingResponse
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from starlette.middleware.sessions import SessionMiddleware
  6. from sqlalchemy.orm import Session
  7. from apps.webui.routers import (
  8. auths,
  9. users,
  10. chats,
  11. documents,
  12. tools,
  13. models,
  14. prompts,
  15. configs,
  16. memories,
  17. utils,
  18. files,
  19. functions,
  20. )
  21. from apps.webui.models.functions import Functions
  22. from apps.webui.models.models import Models
  23. from apps.webui.utils import load_function_module_by_id
  24. from utils.misc import stream_message_template
  25. from utils.task import prompt_template
  26. from config import (
  27. WEBUI_BUILD_HASH,
  28. SHOW_ADMIN_DETAILS,
  29. ADMIN_EMAIL,
  30. WEBUI_AUTH,
  31. DEFAULT_MODELS,
  32. DEFAULT_PROMPT_SUGGESTIONS,
  33. DEFAULT_USER_ROLE,
  34. ENABLE_SIGNUP,
  35. USER_PERMISSIONS,
  36. WEBHOOK_URL,
  37. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  38. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  39. JWT_EXPIRES_IN,
  40. WEBUI_BANNERS,
  41. ENABLE_COMMUNITY_SHARING,
  42. AppConfig,
  43. OAUTH_USERNAME_CLAIM,
  44. OAUTH_PICTURE_CLAIM,
  45. )
  46. import inspect
  47. import uuid
  48. import time
  49. import json
  50. from typing import Iterator, Generator
  51. from pydantic import BaseModel
  52. app = FastAPI()
  53. origins = ["*"]
  54. app.state.config = AppConfig()
  55. app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
  56. app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
  57. app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  58. app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
  59. app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
  60. app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
  61. app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
  62. app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
  63. app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
  64. app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
  65. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  66. app.state.config.BANNERS = WEBUI_BANNERS
  67. app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
  68. app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
  69. app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
  70. app.state.MODELS = {}
  71. app.state.TOOLS = {}
  72. app.state.FUNCTIONS = {}
  73. app.add_middleware(
  74. CORSMiddleware,
  75. allow_origins=origins,
  76. allow_credentials=True,
  77. allow_methods=["*"],
  78. allow_headers=["*"],
  79. )
  80. app.include_router(configs.router, prefix="/configs", tags=["configs"])
  81. app.include_router(auths.router, prefix="/auths", tags=["auths"])
  82. app.include_router(users.router, prefix="/users", tags=["users"])
  83. app.include_router(chats.router, prefix="/chats", tags=["chats"])
  84. app.include_router(documents.router, prefix="/documents", tags=["documents"])
  85. app.include_router(models.router, prefix="/models", tags=["models"])
  86. app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
  87. app.include_router(memories.router, prefix="/memories", tags=["memories"])
  88. app.include_router(files.router, prefix="/files", tags=["files"])
  89. app.include_router(tools.router, prefix="/tools", tags=["tools"])
  90. app.include_router(functions.router, prefix="/functions", tags=["functions"])
  91. app.include_router(utils.router, prefix="/utils", tags=["utils"])
  92. @app.get("/")
  93. async def get_status():
  94. return {
  95. "status": True,
  96. "auth": WEBUI_AUTH,
  97. "default_models": app.state.config.DEFAULT_MODELS,
  98. "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  99. }
  100. async def get_pipe_models():
  101. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  102. pipe_models = []
  103. for pipe in pipes:
  104. # Check if function is already loaded
  105. if pipe.id not in app.state.FUNCTIONS:
  106. function_module, function_type, frontmatter = load_function_module_by_id(
  107. pipe.id
  108. )
  109. app.state.FUNCTIONS[pipe.id] = function_module
  110. else:
  111. function_module = app.state.FUNCTIONS[pipe.id]
  112. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  113. print(f"Getting valves for {pipe.id}")
  114. valves = Functions.get_function_valves_by_id(pipe.id)
  115. function_module.valves = function_module.Valves(
  116. **(valves if valves else {})
  117. )
  118. # Check if function is a manifold
  119. if hasattr(function_module, "type"):
  120. if function_module.type == "manifold":
  121. manifold_pipes = []
  122. # Check if pipes is a function or a list
  123. if callable(function_module.pipes):
  124. manifold_pipes = function_module.pipes()
  125. else:
  126. manifold_pipes = function_module.pipes
  127. for p in manifold_pipes:
  128. manifold_pipe_id = f'{pipe.id}.{p["id"]}'
  129. manifold_pipe_name = p["name"]
  130. if hasattr(function_module, "name"):
  131. manifold_pipe_name = (
  132. f"{function_module.name}{manifold_pipe_name}"
  133. )
  134. pipe_models.append(
  135. {
  136. "id": manifold_pipe_id,
  137. "name": manifold_pipe_name,
  138. "object": "model",
  139. "created": pipe.created_at,
  140. "owned_by": "openai",
  141. "pipe": {"type": pipe.type},
  142. }
  143. )
  144. else:
  145. pipe_models.append(
  146. {
  147. "id": pipe.id,
  148. "name": pipe.name,
  149. "object": "model",
  150. "created": pipe.created_at,
  151. "owned_by": "openai",
  152. "pipe": {"type": "pipe"},
  153. }
  154. )
  155. return pipe_models
  156. async def generate_function_chat_completion(form_data, user):
  157. model_id = form_data.get("model")
  158. model_info = Models.get_model_by_id(model_id)
  159. if model_info:
  160. if model_info.base_model_id:
  161. form_data["model"] = model_info.base_model_id
  162. model_info.params = model_info.params.model_dump()
  163. if model_info.params:
  164. if model_info.params.get("temperature", None) is not None:
  165. form_data["temperature"] = float(model_info.params.get("temperature"))
  166. if model_info.params.get("top_p", None):
  167. form_data["top_p"] = int(model_info.params.get("top_p", None))
  168. if model_info.params.get("max_tokens", None):
  169. form_data["max_tokens"] = int(model_info.params.get("max_tokens", None))
  170. if model_info.params.get("frequency_penalty", None):
  171. form_data["frequency_penalty"] = int(
  172. model_info.params.get("frequency_penalty", None)
  173. )
  174. if model_info.params.get("seed", None):
  175. form_data["seed"] = model_info.params.get("seed", None)
  176. if model_info.params.get("stop", None):
  177. form_data["stop"] = (
  178. [
  179. bytes(stop, "utf-8").decode("unicode_escape")
  180. for stop in model_info.params["stop"]
  181. ]
  182. if model_info.params.get("stop", None)
  183. else None
  184. )
  185. system = model_info.params.get("system", None)
  186. if system:
  187. system = prompt_template(
  188. system,
  189. **(
  190. {
  191. "user_name": user.name,
  192. "user_location": (
  193. user.info.get("location") if user.info else None
  194. ),
  195. }
  196. if user
  197. else {}
  198. ),
  199. )
  200. # Check if the payload already has a system message
  201. # If not, add a system message to the payload
  202. if form_data.get("messages"):
  203. for message in form_data["messages"]:
  204. if message.get("role") == "system":
  205. message["content"] = system + message["content"]
  206. break
  207. else:
  208. form_data["messages"].insert(
  209. 0,
  210. {
  211. "role": "system",
  212. "content": system,
  213. },
  214. )
  215. else:
  216. pass
  217. async def job():
  218. pipe_id = form_data["model"]
  219. if "." in pipe_id:
  220. pipe_id, sub_pipe_id = pipe_id.split(".", 1)
  221. print(pipe_id)
  222. # Check if function is already loaded
  223. if pipe_id not in app.state.FUNCTIONS:
  224. function_module, function_type, frontmatter = load_function_module_by_id(
  225. pipe_id
  226. )
  227. app.state.FUNCTIONS[pipe_id] = function_module
  228. else:
  229. function_module = app.state.FUNCTIONS[pipe_id]
  230. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  231. valves = Functions.get_function_valves_by_id(pipe_id)
  232. function_module.valves = function_module.Valves(
  233. **(valves if valves else {})
  234. )
  235. pipe = function_module.pipe
  236. # Get the signature of the function
  237. sig = inspect.signature(pipe)
  238. params = {"body": form_data}
  239. if "__user__" in sig.parameters:
  240. __user__ = {
  241. "id": user.id,
  242. "email": user.email,
  243. "name": user.name,
  244. "role": user.role,
  245. }
  246. try:
  247. if hasattr(function_module, "UserValves"):
  248. __user__["valves"] = function_module.UserValves(
  249. **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  250. )
  251. except Exception as e:
  252. print(e)
  253. params = {**params, "__user__": __user__}
  254. if form_data["stream"]:
  255. async def stream_content():
  256. try:
  257. if inspect.iscoroutinefunction(pipe):
  258. res = await pipe(**params)
  259. else:
  260. res = pipe(**params)
  261. # Directly return if the response is a StreamingResponse
  262. if isinstance(res, StreamingResponse):
  263. async for data in res.body_iterator:
  264. yield data
  265. return
  266. if isinstance(res, dict):
  267. yield f"data: {json.dumps(res)}\n\n"
  268. return
  269. except Exception as e:
  270. print(f"Error: {e}")
  271. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  272. return
  273. if isinstance(res, str):
  274. message = stream_message_template(form_data["model"], res)
  275. yield f"data: {json.dumps(message)}\n\n"
  276. if isinstance(res, Iterator):
  277. for line in res:
  278. if isinstance(line, BaseModel):
  279. line = line.model_dump_json()
  280. line = f"data: {line}"
  281. if isinstance(line, dict):
  282. line = f"data: {json.dumps(line)}"
  283. try:
  284. line = line.decode("utf-8")
  285. except:
  286. pass
  287. if line.startswith("data:"):
  288. yield f"{line}\n\n"
  289. else:
  290. line = stream_message_template(form_data["model"], line)
  291. yield f"data: {json.dumps(line)}\n\n"
  292. if isinstance(res, str) or isinstance(res, Generator):
  293. finish_message = {
  294. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  295. "object": "chat.completion.chunk",
  296. "created": int(time.time()),
  297. "model": form_data["model"],
  298. "choices": [
  299. {
  300. "index": 0,
  301. "delta": {},
  302. "logprobs": None,
  303. "finish_reason": "stop",
  304. }
  305. ],
  306. }
  307. yield f"data: {json.dumps(finish_message)}\n\n"
  308. yield f"data: [DONE]"
  309. return StreamingResponse(stream_content(), media_type="text/event-stream")
  310. else:
  311. try:
  312. if inspect.iscoroutinefunction(pipe):
  313. res = await pipe(**params)
  314. else:
  315. res = pipe(**params)
  316. if isinstance(res, StreamingResponse):
  317. return res
  318. except Exception as e:
  319. print(f"Error: {e}")
  320. return {"error": {"detail": str(e)}}
  321. if isinstance(res, dict):
  322. return res
  323. elif isinstance(res, BaseModel):
  324. return res.model_dump()
  325. else:
  326. message = ""
  327. if isinstance(res, str):
  328. message = res
  329. if isinstance(res, Generator):
  330. for stream in res:
  331. message = f"{message}{stream}"
  332. return {
  333. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  334. "object": "chat.completion",
  335. "created": int(time.time()),
  336. "model": form_data["model"],
  337. "choices": [
  338. {
  339. "index": 0,
  340. "message": {
  341. "role": "assistant",
  342. "content": message,
  343. },
  344. "logprobs": None,
  345. "finish_reason": "stop",
  346. }
  347. ],
  348. }
  349. return await job()