main.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. from fastapi import FastAPI, Depends
  2. from fastapi.routing import APIRoute
  3. from fastapi.responses import StreamingResponse
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from starlette.middleware.sessions import SessionMiddleware
  6. from sqlalchemy.orm import Session
  7. from apps.webui.routers import (
  8. auths,
  9. users,
  10. chats,
  11. documents,
  12. tools,
  13. models,
  14. prompts,
  15. configs,
  16. memories,
  17. utils,
  18. files,
  19. functions,
  20. )
  21. from apps.webui.models.functions import Functions
  22. from apps.webui.utils import load_function_module_by_id
  23. from utils.misc import stream_message_template
  24. from config import (
  25. WEBUI_BUILD_HASH,
  26. SHOW_ADMIN_DETAILS,
  27. ADMIN_EMAIL,
  28. WEBUI_AUTH,
  29. DEFAULT_MODELS,
  30. DEFAULT_PROMPT_SUGGESTIONS,
  31. DEFAULT_USER_ROLE,
  32. ENABLE_SIGNUP,
  33. USER_PERMISSIONS,
  34. WEBHOOK_URL,
  35. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  36. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  37. JWT_EXPIRES_IN,
  38. WEBUI_BANNERS,
  39. ENABLE_COMMUNITY_SHARING,
  40. AppConfig,
  41. OAUTH_USERNAME_CLAIM,
  42. OAUTH_PICTURE_CLAIM,
  43. )
  44. import inspect
  45. import uuid
  46. import time
  47. import json
  48. from typing import Iterator, Generator
  49. from pydantic import BaseModel
  50. app = FastAPI()
  51. origins = ["*"]
  52. app.state.config = AppConfig()
  53. app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
  54. app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
  55. app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  56. app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
  57. app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
  58. app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
  59. app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
  60. app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
  61. app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
  62. app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
  63. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  64. app.state.config.BANNERS = WEBUI_BANNERS
  65. app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
  66. app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
  67. app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
  68. app.state.MODELS = {}
  69. app.state.TOOLS = {}
  70. app.state.FUNCTIONS = {}
  71. app.add_middleware(
  72. CORSMiddleware,
  73. allow_origins=origins,
  74. allow_credentials=True,
  75. allow_methods=["*"],
  76. allow_headers=["*"],
  77. )
  78. app.include_router(configs.router, prefix="/configs", tags=["configs"])
  79. app.include_router(auths.router, prefix="/auths", tags=["auths"])
  80. app.include_router(users.router, prefix="/users", tags=["users"])
  81. app.include_router(chats.router, prefix="/chats", tags=["chats"])
  82. app.include_router(documents.router, prefix="/documents", tags=["documents"])
  83. app.include_router(models.router, prefix="/models", tags=["models"])
  84. app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
  85. app.include_router(memories.router, prefix="/memories", tags=["memories"])
  86. app.include_router(files.router, prefix="/files", tags=["files"])
  87. app.include_router(tools.router, prefix="/tools", tags=["tools"])
  88. app.include_router(functions.router, prefix="/functions", tags=["functions"])
  89. app.include_router(utils.router, prefix="/utils", tags=["utils"])
  90. @app.get("/")
  91. async def get_status():
  92. return {
  93. "status": True,
  94. "auth": WEBUI_AUTH,
  95. "default_models": app.state.config.DEFAULT_MODELS,
  96. "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  97. }
  98. async def get_pipe_models():
  99. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  100. pipe_models = []
  101. for pipe in pipes:
  102. # Check if function is already loaded
  103. if pipe.id not in app.state.FUNCTIONS:
  104. function_module, function_type, frontmatter = load_function_module_by_id(
  105. pipe.id
  106. )
  107. app.state.FUNCTIONS[pipe.id] = function_module
  108. else:
  109. function_module = app.state.FUNCTIONS[pipe.id]
  110. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  111. print(f"Getting valves for {pipe.id}")
  112. valves = Functions.get_function_valves_by_id(pipe.id)
  113. function_module.valves = function_module.Valves(
  114. **(valves if valves else {})
  115. )
  116. # Check if function is a manifold
  117. if hasattr(function_module, "type"):
  118. if function_module.type == "manifold":
  119. manifold_pipes = []
  120. # Check if pipes is a function or a list
  121. if callable(function_module.pipes):
  122. manifold_pipes = function_module.pipes()
  123. else:
  124. manifold_pipes = function_module.pipes
  125. for p in manifold_pipes:
  126. manifold_pipe_id = f'{pipe.id}.{p["id"]}'
  127. manifold_pipe_name = p["name"]
  128. if hasattr(function_module, "name"):
  129. manifold_pipe_name = (
  130. f"{function_module.name}{manifold_pipe_name}"
  131. )
  132. pipe_models.append(
  133. {
  134. "id": manifold_pipe_id,
  135. "name": manifold_pipe_name,
  136. "object": "model",
  137. "created": pipe.created_at,
  138. "owned_by": "openai",
  139. "pipe": {"type": pipe.type},
  140. }
  141. )
  142. else:
  143. pipe_models.append(
  144. {
  145. "id": pipe.id,
  146. "name": pipe.name,
  147. "object": "model",
  148. "created": pipe.created_at,
  149. "owned_by": "openai",
  150. "pipe": {"type": "pipe"},
  151. }
  152. )
  153. return pipe_models
  154. async def generate_function_chat_completion(form_data, user):
  155. async def job():
  156. pipe_id = form_data["model"]
  157. if "." in pipe_id:
  158. pipe_id, sub_pipe_id = pipe_id.split(".", 1)
  159. print(pipe_id)
  160. # Check if function is already loaded
  161. if pipe_id not in app.state.FUNCTIONS:
  162. function_module, function_type, frontmatter = load_function_module_by_id(
  163. pipe_id
  164. )
  165. app.state.FUNCTIONS[pipe_id] = function_module
  166. else:
  167. function_module = app.state.FUNCTIONS[pipe_id]
  168. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  169. valves = Functions.get_function_valves_by_id(pipe_id)
  170. function_module.valves = function_module.Valves(
  171. **(valves if valves else {})
  172. )
  173. pipe = function_module.pipe
  174. # Get the signature of the function
  175. sig = inspect.signature(pipe)
  176. params = {"body": form_data}
  177. if "__user__" in sig.parameters:
  178. __user__ = {
  179. "id": user.id,
  180. "email": user.email,
  181. "name": user.name,
  182. "role": user.role,
  183. }
  184. try:
  185. if hasattr(function_module, "UserValves"):
  186. __user__["valves"] = function_module.UserValves(
  187. **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  188. )
  189. except Exception as e:
  190. print(e)
  191. params = {**params, "__user__": __user__}
  192. if form_data["stream"]:
  193. async def stream_content():
  194. try:
  195. if inspect.iscoroutinefunction(pipe):
  196. res = await pipe(**params)
  197. else:
  198. res = pipe(**params)
  199. # Directly return if the response is a StreamingResponse
  200. if isinstance(res, StreamingResponse):
  201. async for data in res.body_iterator:
  202. yield data
  203. return
  204. if isinstance(res, dict):
  205. yield f"data: {json.dumps(res)}\n\n"
  206. return
  207. except Exception as e:
  208. print(f"Error: {e}")
  209. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  210. return
  211. if isinstance(res, str):
  212. message = stream_message_template(form_data["model"], res)
  213. yield f"data: {json.dumps(message)}\n\n"
  214. if isinstance(res, Iterator):
  215. for line in res:
  216. if isinstance(line, BaseModel):
  217. line = line.model_dump_json()
  218. line = f"data: {line}"
  219. if isinstance(line, dict):
  220. line = f"data: {json.dumps(line)}"
  221. try:
  222. line = line.decode("utf-8")
  223. except:
  224. pass
  225. if line.startswith("data:"):
  226. yield f"{line}\n\n"
  227. else:
  228. line = stream_message_template(form_data["model"], line)
  229. yield f"data: {json.dumps(line)}\n\n"
  230. if isinstance(res, str) or isinstance(res, Generator):
  231. finish_message = {
  232. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  233. "object": "chat.completion.chunk",
  234. "created": int(time.time()),
  235. "model": form_data["model"],
  236. "choices": [
  237. {
  238. "index": 0,
  239. "delta": {},
  240. "logprobs": None,
  241. "finish_reason": "stop",
  242. }
  243. ],
  244. }
  245. yield f"data: {json.dumps(finish_message)}\n\n"
  246. yield f"data: [DONE]"
  247. return StreamingResponse(stream_content(), media_type="text/event-stream")
  248. else:
  249. try:
  250. if inspect.iscoroutinefunction(pipe):
  251. res = await pipe(**params)
  252. else:
  253. res = pipe(**params)
  254. if isinstance(res, StreamingResponse):
  255. return res
  256. except Exception as e:
  257. print(f"Error: {e}")
  258. return {"error": {"detail": str(e)}}
  259. if isinstance(res, dict):
  260. return res
  261. elif isinstance(res, BaseModel):
  262. return res.model_dump()
  263. else:
  264. message = ""
  265. if isinstance(res, str):
  266. message = res
  267. if isinstance(res, Generator):
  268. for stream in res:
  269. message = f"{message}{stream}"
  270. return {
  271. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  272. "object": "chat.completion",
  273. "created": int(time.time()),
  274. "model": form_data["model"],
  275. "choices": [
  276. {
  277. "index": 0,
  278. "message": {
  279. "role": "assistant",
  280. "content": message,
  281. },
  282. "logprobs": None,
  283. "finish_reason": "stop",
  284. }
  285. ],
  286. }
  287. return await job()