main.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. from fastapi import FastAPI, Depends
  2. from fastapi.routing import APIRoute
  3. from fastapi.responses import StreamingResponse
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from starlette.middleware.sessions import SessionMiddleware
  6. from sqlalchemy.orm import Session
  7. from apps.webui.routers import (
  8. auths,
  9. users,
  10. chats,
  11. documents,
  12. tools,
  13. models,
  14. prompts,
  15. configs,
  16. memories,
  17. utils,
  18. files,
  19. functions,
  20. )
  21. from apps.webui.models.functions import Functions
  22. from apps.webui.utils import load_function_module_by_id
  23. from utils.misc import stream_message_template
  24. from config import (
  25. WEBUI_BUILD_HASH,
  26. SHOW_ADMIN_DETAILS,
  27. ADMIN_EMAIL,
  28. WEBUI_AUTH,
  29. DEFAULT_MODELS,
  30. DEFAULT_PROMPT_SUGGESTIONS,
  31. DEFAULT_USER_ROLE,
  32. ENABLE_SIGNUP,
  33. USER_PERMISSIONS,
  34. WEBHOOK_URL,
  35. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  36. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  37. JWT_EXPIRES_IN,
  38. WEBUI_BANNERS,
  39. ENABLE_COMMUNITY_SHARING,
  40. AppConfig,
  41. )
  42. import inspect
  43. import uuid
  44. import time
  45. import json
  46. from typing import Iterator, Generator
  47. from pydantic import BaseModel
  48. app = FastAPI()
  49. origins = ["*"]
  50. app.state.config = AppConfig()
  51. app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
  52. app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
  53. app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  54. app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
  55. app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
  56. app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
  57. app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
  58. app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
  59. app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
  60. app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
  61. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  62. app.state.config.BANNERS = WEBUI_BANNERS
  63. app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
  64. app.state.MODELS = {}
  65. app.state.TOOLS = {}
  66. app.state.FUNCTIONS = {}
  67. app.add_middleware(
  68. CORSMiddleware,
  69. allow_origins=origins,
  70. allow_credentials=True,
  71. allow_methods=["*"],
  72. allow_headers=["*"],
  73. )
  74. app.include_router(configs.router, prefix="/configs", tags=["configs"])
  75. app.include_router(auths.router, prefix="/auths", tags=["auths"])
  76. app.include_router(users.router, prefix="/users", tags=["users"])
  77. app.include_router(chats.router, prefix="/chats", tags=["chats"])
  78. app.include_router(documents.router, prefix="/documents", tags=["documents"])
  79. app.include_router(models.router, prefix="/models", tags=["models"])
  80. app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
  81. app.include_router(memories.router, prefix="/memories", tags=["memories"])
  82. app.include_router(files.router, prefix="/files", tags=["files"])
  83. app.include_router(tools.router, prefix="/tools", tags=["tools"])
  84. app.include_router(functions.router, prefix="/functions", tags=["functions"])
  85. app.include_router(utils.router, prefix="/utils", tags=["utils"])
  86. @app.get("/")
  87. async def get_status():
  88. return {
  89. "status": True,
  90. "auth": WEBUI_AUTH,
  91. "default_models": app.state.config.DEFAULT_MODELS,
  92. "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  93. }
  94. async def get_pipe_models():
  95. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  96. pipe_models = []
  97. for pipe in pipes:
  98. # Check if function is already loaded
  99. if pipe.id not in app.state.FUNCTIONS:
  100. function_module, function_type, frontmatter = load_function_module_by_id(
  101. pipe.id
  102. )
  103. app.state.FUNCTIONS[pipe.id] = function_module
  104. else:
  105. function_module = app.state.FUNCTIONS[pipe.id]
  106. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  107. print(f"Getting valves for {pipe.id}")
  108. valves = Functions.get_function_valves_by_id(pipe.id)
  109. function_module.valves = function_module.Valves(
  110. **(valves if valves else {})
  111. )
  112. # Check if function is a manifold
  113. if hasattr(function_module, "type"):
  114. if function_module.type == "manifold":
  115. manifold_pipes = []
  116. # Check if pipes is a function or a list
  117. if callable(function_module.pipes):
  118. manifold_pipes = function_module.pipes()
  119. else:
  120. manifold_pipes = function_module.pipes
  121. for p in manifold_pipes:
  122. manifold_pipe_id = f'{pipe.id}.{p["id"]}'
  123. manifold_pipe_name = p["name"]
  124. if hasattr(function_module, "name"):
  125. manifold_pipe_name = (
  126. f"{function_module.name}{manifold_pipe_name}"
  127. )
  128. pipe_models.append(
  129. {
  130. "id": manifold_pipe_id,
  131. "name": manifold_pipe_name,
  132. "object": "model",
  133. "created": pipe.created_at,
  134. "owned_by": "openai",
  135. "pipe": {"type": pipe.type},
  136. }
  137. )
  138. else:
  139. pipe_models.append(
  140. {
  141. "id": pipe.id,
  142. "name": pipe.name,
  143. "object": "model",
  144. "created": pipe.created_at,
  145. "owned_by": "openai",
  146. "pipe": {"type": "pipe"},
  147. }
  148. )
  149. return pipe_models
  150. async def generate_function_chat_completion(form_data, user):
  151. async def job():
  152. pipe_id = form_data["model"]
  153. if "." in pipe_id:
  154. pipe_id, sub_pipe_id = pipe_id.split(".", 1)
  155. print(pipe_id)
  156. # Check if function is already loaded
  157. if pipe_id not in app.state.FUNCTIONS:
  158. function_module, function_type, frontmatter = load_function_module_by_id(
  159. pipe_id
  160. )
  161. app.state.FUNCTIONS[pipe_id] = function_module
  162. else:
  163. function_module = app.state.FUNCTIONS[pipe_id]
  164. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  165. valves = Functions.get_function_valves_by_id(pipe_id)
  166. function_module.valves = function_module.Valves(
  167. **(valves if valves else {})
  168. )
  169. pipe = function_module.pipe
  170. # Get the signature of the function
  171. sig = inspect.signature(pipe)
  172. params = {"body": form_data}
  173. if "__user__" in sig.parameters:
  174. __user__ = {
  175. "id": user.id,
  176. "email": user.email,
  177. "name": user.name,
  178. "role": user.role,
  179. }
  180. try:
  181. if hasattr(function_module, "UserValves"):
  182. __user__["valves"] = function_module.UserValves(
  183. **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  184. )
  185. except Exception as e:
  186. print(e)
  187. params = {**params, "__user__": __user__}
  188. if form_data["stream"]:
  189. async def stream_content():
  190. try:
  191. if inspect.iscoroutinefunction(pipe):
  192. res = await pipe(**params)
  193. else:
  194. res = pipe(**params)
  195. # Directly return if the response is a StreamingResponse
  196. if isinstance(res, StreamingResponse):
  197. async for data in res.body_iterator:
  198. yield data
  199. return
  200. if isinstance(res, dict):
  201. yield f"data: {json.dumps(res)}\n\n"
  202. return
  203. except Exception as e:
  204. print(f"Error: {e}")
  205. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  206. return
  207. if isinstance(res, str):
  208. message = stream_message_template(form_data["model"], res)
  209. yield f"data: {json.dumps(message)}\n\n"
  210. if isinstance(res, Iterator):
  211. for line in res:
  212. if isinstance(line, BaseModel):
  213. line = line.model_dump_json()
  214. line = f"data: {line}"
  215. try:
  216. line = line.decode("utf-8")
  217. except:
  218. pass
  219. if line.startswith("data:"):
  220. yield f"{line}\n\n"
  221. else:
  222. line = stream_message_template(form_data["model"], line)
  223. yield f"data: {json.dumps(line)}\n\n"
  224. if isinstance(res, str) or isinstance(res, Generator):
  225. finish_message = {
  226. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  227. "object": "chat.completion.chunk",
  228. "created": int(time.time()),
  229. "model": form_data["model"],
  230. "choices": [
  231. {
  232. "index": 0,
  233. "delta": {},
  234. "logprobs": None,
  235. "finish_reason": "stop",
  236. }
  237. ],
  238. }
  239. yield f"data: {json.dumps(finish_message)}\n\n"
  240. yield f"data: [DONE]"
  241. return StreamingResponse(stream_content(), media_type="text/event-stream")
  242. else:
  243. try:
  244. if inspect.iscoroutinefunction(pipe):
  245. res = await pipe(**params)
  246. else:
  247. res = pipe(**params)
  248. if isinstance(res, StreamingResponse):
  249. return res
  250. except Exception as e:
  251. print(f"Error: {e}")
  252. return {"error": {"detail": str(e)}}
  253. if isinstance(res, dict):
  254. return res
  255. elif isinstance(res, BaseModel):
  256. return res.model_dump()
  257. else:
  258. message = ""
  259. if isinstance(res, str):
  260. message = res
  261. if isinstance(res, Generator):
  262. for stream in res:
  263. message = f"{message}{stream}"
  264. return {
  265. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  266. "object": "chat.completion",
  267. "created": int(time.time()),
  268. "model": form_data["model"],
  269. "choices": [
  270. {
  271. "index": 0,
  272. "message": {
  273. "role": "assistant",
  274. "content": message,
  275. },
  276. "logprobs": None,
  277. "finish_reason": "stop",
  278. }
  279. ],
  280. }
  281. return await job()