main.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. from fastapi import FastAPI, Depends
  2. from fastapi.routing import APIRoute
  3. from fastapi.responses import StreamingResponse
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from starlette.middleware.sessions import SessionMiddleware
  6. from apps.webui.routers import (
  7. auths,
  8. users,
  9. chats,
  10. documents,
  11. tools,
  12. models,
  13. prompts,
  14. configs,
  15. memories,
  16. utils,
  17. files,
  18. functions,
  19. )
  20. from apps.webui.models.functions import Functions
  21. from apps.webui.utils import load_function_module_by_id
  22. from utils.misc import stream_message_template
  23. from config import (
  24. WEBUI_BUILD_HASH,
  25. SHOW_ADMIN_DETAILS,
  26. ADMIN_EMAIL,
  27. WEBUI_AUTH,
  28. DEFAULT_MODELS,
  29. DEFAULT_PROMPT_SUGGESTIONS,
  30. DEFAULT_USER_ROLE,
  31. ENABLE_SIGNUP,
  32. USER_PERMISSIONS,
  33. WEBHOOK_URL,
  34. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  35. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  36. JWT_EXPIRES_IN,
  37. WEBUI_BANNERS,
  38. ENABLE_COMMUNITY_SHARING,
  39. AppConfig,
  40. OAUTH_USERNAME_CLAIM,
  41. OAUTH_PICTURE_CLAIM,
  42. )
  43. import inspect
  44. import uuid
  45. import time
  46. import json
  47. from typing import Iterator, Generator
  48. from pydantic import BaseModel
  49. app = FastAPI()
  50. origins = ["*"]
  51. app.state.config = AppConfig()
  52. app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
  53. app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
  54. app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  55. app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
  56. app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
  57. app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
  58. app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
  59. app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
  60. app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
  61. app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
  62. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  63. app.state.config.BANNERS = WEBUI_BANNERS
  64. app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
  65. app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
  66. app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
  67. app.state.MODELS = {}
  68. app.state.TOOLS = {}
  69. app.state.FUNCTIONS = {}
  70. app.add_middleware(
  71. CORSMiddleware,
  72. allow_origins=origins,
  73. allow_credentials=True,
  74. allow_methods=["*"],
  75. allow_headers=["*"],
  76. )
  77. app.include_router(configs.router, prefix="/configs", tags=["configs"])
  78. app.include_router(auths.router, prefix="/auths", tags=["auths"])
  79. app.include_router(users.router, prefix="/users", tags=["users"])
  80. app.include_router(chats.router, prefix="/chats", tags=["chats"])
  81. app.include_router(documents.router, prefix="/documents", tags=["documents"])
  82. app.include_router(models.router, prefix="/models", tags=["models"])
  83. app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
  84. app.include_router(memories.router, prefix="/memories", tags=["memories"])
  85. app.include_router(files.router, prefix="/files", tags=["files"])
  86. app.include_router(tools.router, prefix="/tools", tags=["tools"])
  87. app.include_router(functions.router, prefix="/functions", tags=["functions"])
  88. app.include_router(utils.router, prefix="/utils", tags=["utils"])
  89. @app.get("/")
  90. async def get_status():
  91. return {
  92. "status": True,
  93. "auth": WEBUI_AUTH,
  94. "default_models": app.state.config.DEFAULT_MODELS,
  95. "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  96. }
  97. async def get_pipe_models():
  98. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  99. pipe_models = []
  100. for pipe in pipes:
  101. # Check if function is already loaded
  102. if pipe.id not in app.state.FUNCTIONS:
  103. function_module, function_type, frontmatter = load_function_module_by_id(
  104. pipe.id
  105. )
  106. app.state.FUNCTIONS[pipe.id] = function_module
  107. else:
  108. function_module = app.state.FUNCTIONS[pipe.id]
  109. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  110. print(f"Getting valves for {pipe.id}")
  111. valves = Functions.get_function_valves_by_id(pipe.id)
  112. function_module.valves = function_module.Valves(
  113. **(valves if valves else {})
  114. )
  115. # Check if function is a manifold
  116. if hasattr(function_module, "type"):
  117. if function_module.type == "manifold":
  118. manifold_pipes = []
  119. # Check if pipes is a function or a list
  120. if callable(function_module.pipes):
  121. manifold_pipes = function_module.pipes()
  122. else:
  123. manifold_pipes = function_module.pipes
  124. for p in manifold_pipes:
  125. manifold_pipe_id = f'{pipe.id}.{p["id"]}'
  126. manifold_pipe_name = p["name"]
  127. if hasattr(function_module, "name"):
  128. manifold_pipe_name = (
  129. f"{function_module.name}{manifold_pipe_name}"
  130. )
  131. pipe_models.append(
  132. {
  133. "id": manifold_pipe_id,
  134. "name": manifold_pipe_name,
  135. "object": "model",
  136. "created": pipe.created_at,
  137. "owned_by": "openai",
  138. "pipe": {"type": pipe.type},
  139. }
  140. )
  141. else:
  142. pipe_models.append(
  143. {
  144. "id": pipe.id,
  145. "name": pipe.name,
  146. "object": "model",
  147. "created": pipe.created_at,
  148. "owned_by": "openai",
  149. "pipe": {"type": "pipe"},
  150. }
  151. )
  152. return pipe_models
  153. async def generate_function_chat_completion(form_data, user):
  154. async def job():
  155. pipe_id = form_data["model"]
  156. if "." in pipe_id:
  157. pipe_id, sub_pipe_id = pipe_id.split(".", 1)
  158. print(pipe_id)
  159. # Check if function is already loaded
  160. if pipe_id not in app.state.FUNCTIONS:
  161. function_module, function_type, frontmatter = load_function_module_by_id(
  162. pipe_id
  163. )
  164. app.state.FUNCTIONS[pipe_id] = function_module
  165. else:
  166. function_module = app.state.FUNCTIONS[pipe_id]
  167. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  168. valves = Functions.get_function_valves_by_id(pipe_id)
  169. function_module.valves = function_module.Valves(
  170. **(valves if valves else {})
  171. )
  172. pipe = function_module.pipe
  173. # Get the signature of the function
  174. sig = inspect.signature(pipe)
  175. params = {"body": form_data}
  176. if "__user__" in sig.parameters:
  177. __user__ = {
  178. "id": user.id,
  179. "email": user.email,
  180. "name": user.name,
  181. "role": user.role,
  182. }
  183. try:
  184. if hasattr(function_module, "UserValves"):
  185. __user__["valves"] = function_module.UserValves(
  186. **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  187. )
  188. except Exception as e:
  189. print(e)
  190. params = {**params, "__user__": __user__}
  191. if form_data["stream"]:
  192. async def stream_content():
  193. try:
  194. if inspect.iscoroutinefunction(pipe):
  195. res = await pipe(**params)
  196. else:
  197. res = pipe(**params)
  198. # Directly return if the response is a StreamingResponse
  199. if isinstance(res, StreamingResponse):
  200. async for data in res.body_iterator:
  201. yield data
  202. return
  203. if isinstance(res, dict):
  204. yield f"data: {json.dumps(res)}\n\n"
  205. return
  206. except Exception as e:
  207. print(f"Error: {e}")
  208. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  209. return
  210. if isinstance(res, str):
  211. message = stream_message_template(form_data["model"], res)
  212. yield f"data: {json.dumps(message)}\n\n"
  213. if isinstance(res, Iterator):
  214. for line in res:
  215. if isinstance(line, BaseModel):
  216. line = line.model_dump_json()
  217. line = f"data: {line}"
  218. if isinstance(line, dict):
  219. line = f"data: {json.dumps(line)}"
  220. try:
  221. line = line.decode("utf-8")
  222. except:
  223. pass
  224. if line.startswith("data:"):
  225. yield f"{line}\n\n"
  226. else:
  227. line = stream_message_template(form_data["model"], line)
  228. yield f"data: {json.dumps(line)}\n\n"
  229. if isinstance(res, str) or isinstance(res, Generator):
  230. finish_message = {
  231. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  232. "object": "chat.completion.chunk",
  233. "created": int(time.time()),
  234. "model": form_data["model"],
  235. "choices": [
  236. {
  237. "index": 0,
  238. "delta": {},
  239. "logprobs": None,
  240. "finish_reason": "stop",
  241. }
  242. ],
  243. }
  244. yield f"data: {json.dumps(finish_message)}\n\n"
  245. yield f"data: [DONE]"
  246. return StreamingResponse(stream_content(), media_type="text/event-stream")
  247. else:
  248. try:
  249. if inspect.iscoroutinefunction(pipe):
  250. res = await pipe(**params)
  251. else:
  252. res = pipe(**params)
  253. if isinstance(res, StreamingResponse):
  254. return res
  255. except Exception as e:
  256. print(f"Error: {e}")
  257. return {"error": {"detail": str(e)}}
  258. if isinstance(res, dict):
  259. return res
  260. elif isinstance(res, BaseModel):
  261. return res.model_dump()
  262. else:
  263. message = ""
  264. if isinstance(res, str):
  265. message = res
  266. if isinstance(res, Generator):
  267. for stream in res:
  268. message = f"{message}{stream}"
  269. return {
  270. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  271. "object": "chat.completion",
  272. "created": int(time.time()),
  273. "model": form_data["model"],
  274. "choices": [
  275. {
  276. "index": 0,
  277. "message": {
  278. "role": "assistant",
  279. "content": message,
  280. },
  281. "logprobs": None,
  282. "finish_reason": "stop",
  283. }
  284. ],
  285. }
  286. return await job()