main.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. from fastapi import FastAPI, Depends
  2. from fastapi.routing import APIRoute
  3. from fastapi.responses import StreamingResponse
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from starlette.middleware.sessions import SessionMiddleware
  6. from sqlalchemy.orm import Session
  7. from apps.webui.routers import (
  8. auths,
  9. users,
  10. chats,
  11. documents,
  12. tools,
  13. models,
  14. prompts,
  15. configs,
  16. memories,
  17. utils,
  18. files,
  19. functions,
  20. )
  21. from apps.webui.models.functions import Functions
  22. from apps.webui.models.models import Models
  23. from apps.webui.utils import load_function_module_by_id
  24. from utils.misc import stream_message_template
  25. from utils.task import prompt_template
  26. from config import (
  27. WEBUI_BUILD_HASH,
  28. SHOW_ADMIN_DETAILS,
  29. ADMIN_EMAIL,
  30. WEBUI_AUTH,
  31. DEFAULT_MODELS,
  32. DEFAULT_PROMPT_SUGGESTIONS,
  33. DEFAULT_USER_ROLE,
  34. ENABLE_SIGNUP,
  35. USER_PERMISSIONS,
  36. WEBHOOK_URL,
  37. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  38. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  39. JWT_EXPIRES_IN,
  40. WEBUI_BANNERS,
  41. ENABLE_COMMUNITY_SHARING,
  42. AppConfig,
  43. OAUTH_USERNAME_CLAIM,
  44. OAUTH_PICTURE_CLAIM,
  45. )
  46. import inspect
  47. import uuid
  48. import time
  49. import json
  50. from typing import Iterator, Generator, Optional
  51. from pydantic import BaseModel
  52. app = FastAPI()
  53. origins = ["*"]
  54. app.state.config = AppConfig()
  55. app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
  56. app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
  57. app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  58. app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
  59. app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
  60. app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
  61. app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
  62. app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
  63. app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
  64. app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
  65. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  66. app.state.config.BANNERS = WEBUI_BANNERS
  67. app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
  68. app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
  69. app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
  70. app.state.MODELS = {}
  71. app.state.TOOLS = {}
  72. app.state.FUNCTIONS = {}
  73. app.add_middleware(
  74. CORSMiddleware,
  75. allow_origins=origins,
  76. allow_credentials=True,
  77. allow_methods=["*"],
  78. allow_headers=["*"],
  79. )
  80. app.include_router(configs.router, prefix="/configs", tags=["configs"])
  81. app.include_router(auths.router, prefix="/auths", tags=["auths"])
  82. app.include_router(users.router, prefix="/users", tags=["users"])
  83. app.include_router(chats.router, prefix="/chats", tags=["chats"])
  84. app.include_router(documents.router, prefix="/documents", tags=["documents"])
  85. app.include_router(models.router, prefix="/models", tags=["models"])
  86. app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
  87. app.include_router(memories.router, prefix="/memories", tags=["memories"])
  88. app.include_router(files.router, prefix="/files", tags=["files"])
  89. app.include_router(tools.router, prefix="/tools", tags=["tools"])
  90. app.include_router(functions.router, prefix="/functions", tags=["functions"])
  91. app.include_router(utils.router, prefix="/utils", tags=["utils"])
  92. @app.get("/")
  93. async def get_status():
  94. return {
  95. "status": True,
  96. "auth": WEBUI_AUTH,
  97. "default_models": app.state.config.DEFAULT_MODELS,
  98. "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  99. }
  100. async def get_pipe_models():
  101. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  102. pipe_models = []
  103. for pipe in pipes:
  104. # Check if function is already loaded
  105. if pipe.id not in app.state.FUNCTIONS:
  106. function_module, function_type, frontmatter = load_function_module_by_id(
  107. pipe.id
  108. )
  109. app.state.FUNCTIONS[pipe.id] = function_module
  110. else:
  111. function_module = app.state.FUNCTIONS[pipe.id]
  112. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  113. valves = Functions.get_function_valves_by_id(pipe.id)
  114. function_module.valves = function_module.Valves(
  115. **(valves if valves else {})
  116. )
  117. # Check if function is a manifold
  118. if hasattr(function_module, "type"):
  119. if function_module.type == "manifold":
  120. manifold_pipes = []
  121. # Check if pipes is a function or a list
  122. if callable(function_module.pipes):
  123. manifold_pipes = function_module.pipes()
  124. else:
  125. manifold_pipes = function_module.pipes
  126. for p in manifold_pipes:
  127. manifold_pipe_id = f'{pipe.id}.{p["id"]}'
  128. manifold_pipe_name = p["name"]
  129. if hasattr(function_module, "name"):
  130. manifold_pipe_name = (
  131. f"{function_module.name}{manifold_pipe_name}"
  132. )
  133. pipe_models.append(
  134. {
  135. "id": manifold_pipe_id,
  136. "name": manifold_pipe_name,
  137. "object": "model",
  138. "created": pipe.created_at,
  139. "owned_by": "openai",
  140. "pipe": {"type": pipe.type},
  141. }
  142. )
  143. else:
  144. pipe_models.append(
  145. {
  146. "id": pipe.id,
  147. "name": pipe.name,
  148. "object": "model",
  149. "created": pipe.created_at,
  150. "owned_by": "openai",
  151. "pipe": {"type": "pipe"},
  152. }
  153. )
  154. return pipe_models
  155. async def generate_function_chat_completion(form_data, user):
  156. model_id = form_data.get("model")
  157. model_info = Models.get_model_by_id(model_id)
  158. metadata = None
  159. if "metadata" in form_data:
  160. metadata = form_data["metadata"]
  161. del form_data["metadata"]
  162. if metadata:
  163. print(metadata)
  164. if model_info:
  165. if model_info.base_model_id:
  166. form_data["model"] = model_info.base_model_id
  167. model_info.params = model_info.params.model_dump()
  168. if model_info.params:
  169. if model_info.params.get("temperature", None) is not None:
  170. form_data["temperature"] = float(model_info.params.get("temperature"))
  171. if model_info.params.get("top_p", None):
  172. form_data["top_p"] = int(model_info.params.get("top_p", None))
  173. if model_info.params.get("max_tokens", None):
  174. form_data["max_tokens"] = int(model_info.params.get("max_tokens", None))
  175. if model_info.params.get("frequency_penalty", None):
  176. form_data["frequency_penalty"] = int(
  177. model_info.params.get("frequency_penalty", None)
  178. )
  179. if model_info.params.get("seed", None):
  180. form_data["seed"] = model_info.params.get("seed", None)
  181. if model_info.params.get("stop", None):
  182. form_data["stop"] = (
  183. [
  184. bytes(stop, "utf-8").decode("unicode_escape")
  185. for stop in model_info.params["stop"]
  186. ]
  187. if model_info.params.get("stop", None)
  188. else None
  189. )
  190. system = model_info.params.get("system", None)
  191. if system:
  192. system = prompt_template(
  193. system,
  194. **(
  195. {
  196. "user_name": user.name,
  197. "user_location": (
  198. user.info.get("location") if user.info else None
  199. ),
  200. }
  201. if user
  202. else {}
  203. ),
  204. )
  205. # Check if the payload already has a system message
  206. # If not, add a system message to the payload
  207. if form_data.get("messages"):
  208. for message in form_data["messages"]:
  209. if message.get("role") == "system":
  210. message["content"] = system + message["content"]
  211. break
  212. else:
  213. form_data["messages"].insert(
  214. 0,
  215. {
  216. "role": "system",
  217. "content": system,
  218. },
  219. )
  220. else:
  221. pass
  222. async def job():
  223. pipe_id = form_data["model"]
  224. if "." in pipe_id:
  225. pipe_id, sub_pipe_id = pipe_id.split(".", 1)
  226. print(pipe_id)
  227. # Check if function is already loaded
  228. if pipe_id not in app.state.FUNCTIONS:
  229. function_module, function_type, frontmatter = load_function_module_by_id(
  230. pipe_id
  231. )
  232. app.state.FUNCTIONS[pipe_id] = function_module
  233. else:
  234. function_module = app.state.FUNCTIONS[pipe_id]
  235. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  236. valves = Functions.get_function_valves_by_id(pipe_id)
  237. function_module.valves = function_module.Valves(
  238. **(valves if valves else {})
  239. )
  240. pipe = function_module.pipe
  241. # Get the signature of the function
  242. sig = inspect.signature(pipe)
  243. params = {"body": form_data}
  244. if "__user__" in sig.parameters:
  245. __user__ = {
  246. "id": user.id,
  247. "email": user.email,
  248. "name": user.name,
  249. "role": user.role,
  250. }
  251. try:
  252. if hasattr(function_module, "UserValves"):
  253. __user__["valves"] = function_module.UserValves(
  254. **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  255. )
  256. except Exception as e:
  257. print(e)
  258. params = {**params, "__user__": __user__}
  259. if form_data["stream"]:
  260. async def stream_content():
  261. try:
  262. if inspect.iscoroutinefunction(pipe):
  263. res = await pipe(**params)
  264. else:
  265. res = pipe(**params)
  266. # Directly return if the response is a StreamingResponse
  267. if isinstance(res, StreamingResponse):
  268. async for data in res.body_iterator:
  269. yield data
  270. return
  271. if isinstance(res, dict):
  272. yield f"data: {json.dumps(res)}\n\n"
  273. return
  274. except Exception as e:
  275. print(f"Error: {e}")
  276. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  277. return
  278. if isinstance(res, str):
  279. message = stream_message_template(form_data["model"], res)
  280. yield f"data: {json.dumps(message)}\n\n"
  281. if isinstance(res, Iterator):
  282. for line in res:
  283. if isinstance(line, BaseModel):
  284. line = line.model_dump_json()
  285. line = f"data: {line}"
  286. if isinstance(line, dict):
  287. line = f"data: {json.dumps(line)}"
  288. try:
  289. line = line.decode("utf-8")
  290. except:
  291. pass
  292. if line.startswith("data:"):
  293. yield f"{line}\n\n"
  294. else:
  295. line = stream_message_template(form_data["model"], line)
  296. yield f"data: {json.dumps(line)}\n\n"
  297. if isinstance(res, str) or isinstance(res, Generator):
  298. finish_message = {
  299. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  300. "object": "chat.completion.chunk",
  301. "created": int(time.time()),
  302. "model": form_data["model"],
  303. "choices": [
  304. {
  305. "index": 0,
  306. "delta": {},
  307. "logprobs": None,
  308. "finish_reason": "stop",
  309. }
  310. ],
  311. }
  312. yield f"data: {json.dumps(finish_message)}\n\n"
  313. yield f"data: [DONE]"
  314. return StreamingResponse(stream_content(), media_type="text/event-stream")
  315. else:
  316. try:
  317. if inspect.iscoroutinefunction(pipe):
  318. res = await pipe(**params)
  319. else:
  320. res = pipe(**params)
  321. if isinstance(res, StreamingResponse):
  322. return res
  323. except Exception as e:
  324. print(f"Error: {e}")
  325. return {"error": {"detail": str(e)}}
  326. if isinstance(res, dict):
  327. return res
  328. elif isinstance(res, BaseModel):
  329. return res.model_dump()
  330. else:
  331. message = ""
  332. if isinstance(res, str):
  333. message = res
  334. if isinstance(res, Generator):
  335. for stream in res:
  336. message = f"{message}{stream}"
  337. return {
  338. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  339. "object": "chat.completion",
  340. "created": int(time.time()),
  341. "model": form_data["model"],
  342. "choices": [
  343. {
  344. "index": 0,
  345. "message": {
  346. "role": "assistant",
  347. "content": message,
  348. },
  349. "logprobs": None,
  350. "finish_reason": "stop",
  351. }
  352. ],
  353. }
  354. return await job()