main.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. from fastapi import FastAPI, Depends
  2. from fastapi.routing import APIRoute
  3. from fastapi.responses import StreamingResponse
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from starlette.middleware.sessions import SessionMiddleware
  6. from apps.webui.routers import (
  7. auths,
  8. users,
  9. chats,
  10. documents,
  11. tools,
  12. models,
  13. prompts,
  14. configs,
  15. memories,
  16. utils,
  17. files,
  18. functions,
  19. )
  20. from apps.webui.models.functions import Functions
  21. from apps.webui.utils import load_function_module_by_id
  22. from utils.misc import stream_message_template
  23. from config import (
  24. WEBUI_BUILD_HASH,
  25. SHOW_ADMIN_DETAILS,
  26. ADMIN_EMAIL,
  27. WEBUI_AUTH,
  28. DEFAULT_MODELS,
  29. DEFAULT_PROMPT_SUGGESTIONS,
  30. DEFAULT_USER_ROLE,
  31. ENABLE_SIGNUP,
  32. USER_PERMISSIONS,
  33. WEBHOOK_URL,
  34. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  35. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  36. JWT_EXPIRES_IN,
  37. WEBUI_BANNERS,
  38. ENABLE_COMMUNITY_SHARING,
  39. AppConfig,
  40. )
  41. import inspect
  42. import uuid
  43. import time
  44. import json
  45. from typing import Iterator, Generator
  46. from pydantic import BaseModel
  47. app = FastAPI()
  48. origins = ["*"]
  49. app.state.config = AppConfig()
  50. app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
  51. app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
  52. app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  53. app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
  54. app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
  55. app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
  56. app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
  57. app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
  58. app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
  59. app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
  60. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  61. app.state.config.BANNERS = WEBUI_BANNERS
  62. app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
  63. app.state.MODELS = {}
  64. app.state.TOOLS = {}
  65. app.state.FUNCTIONS = {}
  66. app.add_middleware(
  67. CORSMiddleware,
  68. allow_origins=origins,
  69. allow_credentials=True,
  70. allow_methods=["*"],
  71. allow_headers=["*"],
  72. )
  73. app.include_router(configs.router, prefix="/configs", tags=["configs"])
  74. app.include_router(auths.router, prefix="/auths", tags=["auths"])
  75. app.include_router(users.router, prefix="/users", tags=["users"])
  76. app.include_router(chats.router, prefix="/chats", tags=["chats"])
  77. app.include_router(documents.router, prefix="/documents", tags=["documents"])
  78. app.include_router(models.router, prefix="/models", tags=["models"])
  79. app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
  80. app.include_router(memories.router, prefix="/memories", tags=["memories"])
  81. app.include_router(files.router, prefix="/files", tags=["files"])
  82. app.include_router(tools.router, prefix="/tools", tags=["tools"])
  83. app.include_router(functions.router, prefix="/functions", tags=["functions"])
  84. app.include_router(utils.router, prefix="/utils", tags=["utils"])
  85. @app.get("/")
  86. async def get_status():
  87. return {
  88. "status": True,
  89. "auth": WEBUI_AUTH,
  90. "default_models": app.state.config.DEFAULT_MODELS,
  91. "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  92. }
  93. async def get_pipe_models():
  94. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  95. pipe_models = []
  96. for pipe in pipes:
  97. # Check if function is already loaded
  98. if pipe.id not in app.state.FUNCTIONS:
  99. function_module, function_type, frontmatter = load_function_module_by_id(
  100. pipe.id
  101. )
  102. app.state.FUNCTIONS[pipe.id] = function_module
  103. else:
  104. function_module = app.state.FUNCTIONS[pipe.id]
  105. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  106. print(f"Getting valves for {pipe.id}")
  107. valves = Functions.get_function_valves_by_id(pipe.id)
  108. function_module.valves = function_module.Valves(
  109. **(valves if valves else {})
  110. )
  111. # Check if function is a manifold
  112. if hasattr(function_module, "type"):
  113. if function_module.type == "manifold":
  114. manifold_pipes = []
  115. # Check if pipes is a function or a list
  116. if callable(function_module.pipes):
  117. manifold_pipes = function_module.pipes()
  118. else:
  119. manifold_pipes = function_module.pipes
  120. for p in manifold_pipes:
  121. manifold_pipe_id = f'{pipe.id}.{p["id"]}'
  122. manifold_pipe_name = p["name"]
  123. if hasattr(function_module, "name"):
  124. manifold_pipe_name = (
  125. f"{function_module.name}{manifold_pipe_name}"
  126. )
  127. pipe_models.append(
  128. {
  129. "id": manifold_pipe_id,
  130. "name": manifold_pipe_name,
  131. "object": "model",
  132. "created": pipe.created_at,
  133. "owned_by": "openai",
  134. "pipe": {"type": pipe.type},
  135. }
  136. )
  137. else:
  138. pipe_models.append(
  139. {
  140. "id": pipe.id,
  141. "name": pipe.name,
  142. "object": "model",
  143. "created": pipe.created_at,
  144. "owned_by": "openai",
  145. "pipe": {"type": "pipe"},
  146. }
  147. )
  148. return pipe_models
  149. async def generate_function_chat_completion(form_data, user):
  150. async def job():
  151. pipe_id = form_data["model"]
  152. if "." in pipe_id:
  153. pipe_id, sub_pipe_id = pipe_id.split(".", 1)
  154. print(pipe_id)
  155. # Check if function is already loaded
  156. if pipe_id not in app.state.FUNCTIONS:
  157. function_module, function_type, frontmatter = load_function_module_by_id(
  158. pipe_id
  159. )
  160. app.state.FUNCTIONS[pipe_id] = function_module
  161. else:
  162. function_module = app.state.FUNCTIONS[pipe_id]
  163. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  164. valves = Functions.get_function_valves_by_id(pipe_id)
  165. function_module.valves = function_module.Valves(
  166. **(valves if valves else {})
  167. )
  168. pipe = function_module.pipe
  169. # Get the signature of the function
  170. sig = inspect.signature(pipe)
  171. params = {"body": form_data}
  172. if "__user__" in sig.parameters:
  173. __user__ = {
  174. "id": user.id,
  175. "email": user.email,
  176. "name": user.name,
  177. "role": user.role,
  178. }
  179. try:
  180. if hasattr(function_module, "UserValves"):
  181. __user__["valves"] = function_module.UserValves(
  182. **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  183. )
  184. except Exception as e:
  185. print(e)
  186. params = {**params, "__user__": __user__}
  187. if form_data["stream"]:
  188. async def stream_content():
  189. try:
  190. if inspect.iscoroutinefunction(pipe):
  191. res = await pipe(**params)
  192. else:
  193. res = pipe(**params)
  194. # Directly return if the response is a StreamingResponse
  195. if isinstance(res, StreamingResponse):
  196. async for data in res.body_iterator:
  197. yield data
  198. return
  199. if isinstance(res, dict):
  200. yield f"data: {json.dumps(res)}\n\n"
  201. return
  202. except Exception as e:
  203. print(f"Error: {e}")
  204. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  205. return
  206. if isinstance(res, str):
  207. message = stream_message_template(form_data["model"], res)
  208. yield f"data: {json.dumps(message)}\n\n"
  209. if isinstance(res, Iterator):
  210. for line in res:
  211. if isinstance(line, BaseModel):
  212. line = line.model_dump_json()
  213. line = f"data: {line}"
  214. if isinstance(line, dict):
  215. line = f"data: {json.dumps(line)}"
  216. try:
  217. line = line.decode("utf-8")
  218. except:
  219. pass
  220. if line.startswith("data:"):
  221. yield f"{line}\n\n"
  222. else:
  223. line = stream_message_template(form_data["model"], line)
  224. yield f"data: {json.dumps(line)}\n\n"
  225. if isinstance(res, str) or isinstance(res, Generator):
  226. finish_message = {
  227. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  228. "object": "chat.completion.chunk",
  229. "created": int(time.time()),
  230. "model": form_data["model"],
  231. "choices": [
  232. {
  233. "index": 0,
  234. "delta": {},
  235. "logprobs": None,
  236. "finish_reason": "stop",
  237. }
  238. ],
  239. }
  240. yield f"data: {json.dumps(finish_message)}\n\n"
  241. yield f"data: [DONE]"
  242. return StreamingResponse(stream_content(), media_type="text/event-stream")
  243. else:
  244. try:
  245. if inspect.iscoroutinefunction(pipe):
  246. res = await pipe(**params)
  247. else:
  248. res = pipe(**params)
  249. if isinstance(res, StreamingResponse):
  250. return res
  251. except Exception as e:
  252. print(f"Error: {e}")
  253. return {"error": {"detail": str(e)}}
  254. if isinstance(res, dict):
  255. return res
  256. elif isinstance(res, BaseModel):
  257. return res.model_dump()
  258. else:
  259. message = ""
  260. if isinstance(res, str):
  261. message = res
  262. if isinstance(res, Generator):
  263. for stream in res:
  264. message = f"{message}{stream}"
  265. return {
  266. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  267. "object": "chat.completion",
  268. "created": int(time.time()),
  269. "model": form_data["model"],
  270. "choices": [
  271. {
  272. "index": 0,
  273. "message": {
  274. "role": "assistant",
  275. "content": message,
  276. },
  277. "logprobs": None,
  278. "finish_reason": "stop",
  279. }
  280. ],
  281. }
  282. return await job()