main.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. from fastapi import FastAPI, Depends
  2. from fastapi.routing import APIRoute
  3. from fastapi.responses import StreamingResponse
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from apps.webui.routers import (
  6. auths,
  7. users,
  8. chats,
  9. documents,
  10. tools,
  11. models,
  12. prompts,
  13. configs,
  14. memories,
  15. utils,
  16. files,
  17. functions,
  18. )
  19. from apps.webui.models.functions import Functions
  20. from apps.webui.utils import load_function_module_by_id
  21. from utils.misc import stream_message_template
  22. from config import (
  23. WEBUI_BUILD_HASH,
  24. SHOW_ADMIN_DETAILS,
  25. ADMIN_EMAIL,
  26. WEBUI_AUTH,
  27. DEFAULT_MODELS,
  28. DEFAULT_PROMPT_SUGGESTIONS,
  29. DEFAULT_USER_ROLE,
  30. ENABLE_SIGNUP,
  31. USER_PERMISSIONS,
  32. WEBHOOK_URL,
  33. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  34. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  35. JWT_EXPIRES_IN,
  36. WEBUI_BANNERS,
  37. ENABLE_COMMUNITY_SHARING,
  38. AppConfig,
  39. )
  40. import inspect
  41. import uuid
  42. import time
  43. import json
  44. from typing import Iterator, Generator
  45. from pydantic import BaseModel
  46. app = FastAPI()
  47. origins = ["*"]
  48. app.state.config = AppConfig()
  49. app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
  50. app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
  51. app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  52. app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
  53. app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
  54. app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
  55. app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
  56. app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
  57. app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
  58. app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
  59. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  60. app.state.config.BANNERS = WEBUI_BANNERS
  61. app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
  62. app.state.MODELS = {}
  63. app.state.TOOLS = {}
  64. app.state.FUNCTIONS = {}
  65. app.add_middleware(
  66. CORSMiddleware,
  67. allow_origins=origins,
  68. allow_credentials=True,
  69. allow_methods=["*"],
  70. allow_headers=["*"],
  71. )
  72. app.include_router(configs.router, prefix="/configs", tags=["configs"])
  73. app.include_router(auths.router, prefix="/auths", tags=["auths"])
  74. app.include_router(users.router, prefix="/users", tags=["users"])
  75. app.include_router(chats.router, prefix="/chats", tags=["chats"])
  76. app.include_router(documents.router, prefix="/documents", tags=["documents"])
  77. app.include_router(models.router, prefix="/models", tags=["models"])
  78. app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
  79. app.include_router(memories.router, prefix="/memories", tags=["memories"])
  80. app.include_router(files.router, prefix="/files", tags=["files"])
  81. app.include_router(tools.router, prefix="/tools", tags=["tools"])
  82. app.include_router(functions.router, prefix="/functions", tags=["functions"])
  83. app.include_router(utils.router, prefix="/utils", tags=["utils"])
  84. @app.get("/")
  85. async def get_status():
  86. return {
  87. "status": True,
  88. "auth": WEBUI_AUTH,
  89. "default_models": app.state.config.DEFAULT_MODELS,
  90. "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  91. }
  92. async def get_pipe_models():
  93. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  94. pipe_models = []
  95. for pipe in pipes:
  96. # Check if function is already loaded
  97. if pipe.id not in app.state.FUNCTIONS:
  98. function_module, function_type, frontmatter = load_function_module_by_id(
  99. pipe.id
  100. )
  101. app.state.FUNCTIONS[pipe.id] = function_module
  102. else:
  103. function_module = app.state.FUNCTIONS[pipe.id]
  104. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  105. print(f"Getting valves for {pipe.id}")
  106. valves = Functions.get_function_valves_by_id(pipe.id)
  107. function_module.valves = function_module.Valves(
  108. **(valves if valves else {})
  109. )
  110. # Check if function is a manifold
  111. if hasattr(function_module, "type"):
  112. if function_module.type == "manifold":
  113. manifold_pipes = []
  114. # Check if pipes is a function or a list
  115. if callable(function_module.pipes):
  116. manifold_pipes = function_module.pipes()
  117. else:
  118. manifold_pipes = function_module.pipes
  119. for p in manifold_pipes:
  120. manifold_pipe_id = f'{pipe.id}.{p["id"]}'
  121. manifold_pipe_name = p["name"]
  122. if hasattr(function_module, "name"):
  123. manifold_pipe_name = (
  124. f"{function_module.name}{manifold_pipe_name}"
  125. )
  126. pipe_models.append(
  127. {
  128. "id": manifold_pipe_id,
  129. "name": manifold_pipe_name,
  130. "object": "model",
  131. "created": pipe.created_at,
  132. "owned_by": "openai",
  133. "pipe": {"type": pipe.type},
  134. }
  135. )
  136. else:
  137. pipe_models.append(
  138. {
  139. "id": pipe.id,
  140. "name": pipe.name,
  141. "object": "model",
  142. "created": pipe.created_at,
  143. "owned_by": "openai",
  144. "pipe": {"type": "pipe"},
  145. }
  146. )
  147. return pipe_models
  148. async def generate_function_chat_completion(form_data, user):
  149. async def job():
  150. pipe_id = form_data["model"]
  151. if "." in pipe_id:
  152. pipe_id, sub_pipe_id = pipe_id.split(".", 1)
  153. print(pipe_id)
  154. # Check if function is already loaded
  155. if pipe_id not in app.state.FUNCTIONS:
  156. function_module, function_type, frontmatter = load_function_module_by_id(
  157. pipe_id
  158. )
  159. app.state.FUNCTIONS[pipe_id] = function_module
  160. else:
  161. function_module = app.state.FUNCTIONS[pipe_id]
  162. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  163. valves = Functions.get_function_valves_by_id(pipe_id)
  164. function_module.valves = function_module.Valves(
  165. **(valves if valves else {})
  166. )
  167. pipe = function_module.pipe
  168. # Get the signature of the function
  169. sig = inspect.signature(pipe)
  170. params = {"body": form_data}
  171. if "__user__" in sig.parameters:
  172. __user__ = {
  173. "id": user.id,
  174. "email": user.email,
  175. "name": user.name,
  176. "role": user.role,
  177. }
  178. try:
  179. if hasattr(function_module, "UserValves"):
  180. __user__["valves"] = function_module.UserValves(
  181. **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  182. )
  183. except Exception as e:
  184. print(e)
  185. params = {**params, "__user__": __user__}
  186. if form_data["stream"]:
  187. async def stream_content():
  188. try:
  189. if inspect.iscoroutinefunction(pipe):
  190. res = await pipe(**params)
  191. else:
  192. res = pipe(**params)
  193. except Exception as e:
  194. print(f"Error: {e}")
  195. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  196. return
  197. if isinstance(res, str):
  198. message = stream_message_template(form_data["model"], res)
  199. yield f"data: {json.dumps(message)}\n\n"
  200. if isinstance(res, Iterator):
  201. for line in res:
  202. if isinstance(line, BaseModel):
  203. line = line.model_dump_json()
  204. line = f"data: {line}"
  205. try:
  206. line = line.decode("utf-8")
  207. except:
  208. pass
  209. if line.startswith("data:"):
  210. yield f"{line}\n\n"
  211. else:
  212. line = stream_message_template(form_data["model"], line)
  213. yield f"data: {json.dumps(line)}\n\n"
  214. if isinstance(res, str) or isinstance(res, Generator):
  215. finish_message = {
  216. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  217. "object": "chat.completion.chunk",
  218. "created": int(time.time()),
  219. "model": form_data["model"],
  220. "choices": [
  221. {
  222. "index": 0,
  223. "delta": {},
  224. "logprobs": None,
  225. "finish_reason": "stop",
  226. }
  227. ],
  228. }
  229. yield f"data: {json.dumps(finish_message)}\n\n"
  230. yield f"data: [DONE]"
  231. return StreamingResponse(stream_content(), media_type="text/event-stream")
  232. else:
  233. try:
  234. if inspect.iscoroutinefunction(pipe):
  235. res = await pipe(**params)
  236. else:
  237. res = pipe(**params)
  238. except Exception as e:
  239. print(f"Error: {e}")
  240. return {"error": {"detail": str(e)}}
  241. if inspect.iscoroutinefunction(pipe):
  242. res = await pipe(**params)
  243. else:
  244. res = pipe(**params)
  245. if isinstance(res, dict):
  246. return res
  247. elif isinstance(res, BaseModel):
  248. return res.model_dump()
  249. else:
  250. message = ""
  251. if isinstance(res, str):
  252. message = res
  253. if isinstance(res, Generator):
  254. for stream in res:
  255. message = f"{message}{stream}"
  256. return {
  257. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  258. "object": "chat.completion",
  259. "created": int(time.time()),
  260. "model": form_data["model"],
  261. "choices": [
  262. {
  263. "index": 0,
  264. "message": {
  265. "role": "assistant",
  266. "content": message,
  267. },
  268. "logprobs": None,
  269. "finish_reason": "stop",
  270. }
  271. ],
  272. }
  273. return await job()