chat.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. from fastapi import APIRouter, Depends, HTTPException, Response, status
  2. from pydantic import BaseModel
  3. router = APIRouter()
  4. @app.post("/api/chat/completions")
  5. async def generate_chat_completions(
  6. request: Request,
  7. form_data: dict,
  8. user=Depends(get_verified_user),
  9. bypass_filter: bool = False,
  10. ):
  11. if BYPASS_MODEL_ACCESS_CONTROL:
  12. bypass_filter = True
  13. model_list = request.state.models
  14. models = {model["id"]: model for model in model_list}
  15. model_id = form_data["model"]
  16. if model_id not in models:
  17. raise HTTPException(
  18. status_code=status.HTTP_404_NOT_FOUND,
  19. detail="Model not found",
  20. )
  21. model = models[model_id]
  22. # Check if user has access to the model
  23. if not bypass_filter and user.role == "user":
  24. if model.get("arena"):
  25. if not has_access(
  26. user.id,
  27. type="read",
  28. access_control=model.get("info", {})
  29. .get("meta", {})
  30. .get("access_control", {}),
  31. ):
  32. raise HTTPException(
  33. status_code=403,
  34. detail="Model not found",
  35. )
  36. else:
  37. model_info = Models.get_model_by_id(model_id)
  38. if not model_info:
  39. raise HTTPException(
  40. status_code=404,
  41. detail="Model not found",
  42. )
  43. elif not (
  44. user.id == model_info.user_id
  45. or has_access(
  46. user.id, type="read", access_control=model_info.access_control
  47. )
  48. ):
  49. raise HTTPException(
  50. status_code=403,
  51. detail="Model not found",
  52. )
  53. if model["owned_by"] == "arena":
  54. model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
  55. filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
  56. if model_ids and filter_mode == "exclude":
  57. model_ids = [
  58. model["id"]
  59. for model in await get_all_models()
  60. if model.get("owned_by") != "arena" and model["id"] not in model_ids
  61. ]
  62. selected_model_id = None
  63. if isinstance(model_ids, list) and model_ids:
  64. selected_model_id = random.choice(model_ids)
  65. else:
  66. model_ids = [
  67. model["id"]
  68. for model in await get_all_models()
  69. if model.get("owned_by") != "arena"
  70. ]
  71. selected_model_id = random.choice(model_ids)
  72. form_data["model"] = selected_model_id
  73. if form_data.get("stream") == True:
  74. async def stream_wrapper(stream):
  75. yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
  76. async for chunk in stream:
  77. yield chunk
  78. response = await generate_chat_completions(
  79. form_data, user, bypass_filter=True
  80. )
  81. return StreamingResponse(
  82. stream_wrapper(response.body_iterator), media_type="text/event-stream"
  83. )
  84. else:
  85. return {
  86. **(
  87. await generate_chat_completions(form_data, user, bypass_filter=True)
  88. ),
  89. "selected_model_id": selected_model_id,
  90. }
  91. if model.get("pipe"):
  92. # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
  93. return await generate_function_chat_completion(
  94. form_data, user=user, models=models
  95. )
  96. if model["owned_by"] == "ollama":
  97. # Using /ollama/api/chat endpoint
  98. form_data = convert_payload_openai_to_ollama(form_data)
  99. form_data = GenerateChatCompletionForm(**form_data)
  100. response = await generate_ollama_chat_completion(
  101. form_data=form_data, user=user, bypass_filter=bypass_filter
  102. )
  103. if form_data.stream:
  104. response.headers["content-type"] = "text/event-stream"
  105. return StreamingResponse(
  106. convert_streaming_response_ollama_to_openai(response),
  107. headers=dict(response.headers),
  108. )
  109. else:
  110. return convert_response_ollama_to_openai(response)
  111. else:
  112. return await generate_openai_chat_completion(
  113. form_data, user=user, bypass_filter=bypass_filter
  114. )
  115. @app.post("/api/chat/completed")
  116. async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
  117. model_list = await get_all_models()
  118. models = {model["id"]: model for model in model_list}
  119. data = form_data
  120. model_id = data["model"]
  121. if model_id not in models:
  122. raise HTTPException(
  123. status_code=status.HTTP_404_NOT_FOUND,
  124. detail="Model not found",
  125. )
  126. model = models[model_id]
  127. sorted_filters = get_sorted_filters(model_id, models)
  128. if "pipeline" in model:
  129. sorted_filters = [model] + sorted_filters
  130. for filter in sorted_filters:
  131. r = None
  132. try:
  133. urlIdx = filter["urlIdx"]
  134. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  135. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  136. if key != "":
  137. headers = {"Authorization": f"Bearer {key}"}
  138. r = requests.post(
  139. f"{url}/{filter['id']}/filter/outlet",
  140. headers=headers,
  141. json={
  142. "user": {
  143. "id": user.id,
  144. "name": user.name,
  145. "email": user.email,
  146. "role": user.role,
  147. },
  148. "body": data,
  149. },
  150. )
  151. r.raise_for_status()
  152. data = r.json()
  153. except Exception as e:
  154. # Handle connection error here
  155. print(f"Connection error: {e}")
  156. if r is not None:
  157. try:
  158. res = r.json()
  159. if "detail" in res:
  160. return JSONResponse(
  161. status_code=r.status_code,
  162. content=res,
  163. )
  164. except Exception:
  165. pass
  166. else:
  167. pass
  168. __event_emitter__ = get_event_emitter(
  169. {
  170. "chat_id": data["chat_id"],
  171. "message_id": data["id"],
  172. "session_id": data["session_id"],
  173. }
  174. )
  175. __event_call__ = get_event_call(
  176. {
  177. "chat_id": data["chat_id"],
  178. "message_id": data["id"],
  179. "session_id": data["session_id"],
  180. }
  181. )
  182. def get_priority(function_id):
  183. function = Functions.get_function_by_id(function_id)
  184. if function is not None and hasattr(function, "valves"):
  185. # TODO: Fix FunctionModel to include vavles
  186. return (function.valves if function.valves else {}).get("priority", 0)
  187. return 0
  188. filter_ids = [function.id for function in Functions.get_global_filter_functions()]
  189. if "info" in model and "meta" in model["info"]:
  190. filter_ids.extend(model["info"]["meta"].get("filterIds", []))
  191. filter_ids = list(set(filter_ids))
  192. enabled_filter_ids = [
  193. function.id
  194. for function in Functions.get_functions_by_type("filter", active_only=True)
  195. ]
  196. filter_ids = [
  197. filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
  198. ]
  199. # Sort filter_ids by priority, using the get_priority function
  200. filter_ids.sort(key=get_priority)
  201. for filter_id in filter_ids:
  202. filter = Functions.get_function_by_id(filter_id)
  203. if not filter:
  204. continue
  205. if filter_id in webui_app.state.FUNCTIONS:
  206. function_module = webui_app.state.FUNCTIONS[filter_id]
  207. else:
  208. function_module, _, _ = load_function_module_by_id(filter_id)
  209. webui_app.state.FUNCTIONS[filter_id] = function_module
  210. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  211. valves = Functions.get_function_valves_by_id(filter_id)
  212. function_module.valves = function_module.Valves(
  213. **(valves if valves else {})
  214. )
  215. if not hasattr(function_module, "outlet"):
  216. continue
  217. try:
  218. outlet = function_module.outlet
  219. # Get the signature of the function
  220. sig = inspect.signature(outlet)
  221. params = {"body": data}
  222. # Extra parameters to be passed to the function
  223. extra_params = {
  224. "__model__": model,
  225. "__id__": filter_id,
  226. "__event_emitter__": __event_emitter__,
  227. "__event_call__": __event_call__,
  228. }
  229. # Add extra params in contained in function signature
  230. for key, value in extra_params.items():
  231. if key in sig.parameters:
  232. params[key] = value
  233. if "__user__" in sig.parameters:
  234. __user__ = {
  235. "id": user.id,
  236. "email": user.email,
  237. "name": user.name,
  238. "role": user.role,
  239. }
  240. try:
  241. if hasattr(function_module, "UserValves"):
  242. __user__["valves"] = function_module.UserValves(
  243. **Functions.get_user_valves_by_id_and_user_id(
  244. filter_id, user.id
  245. )
  246. )
  247. except Exception as e:
  248. print(e)
  249. params = {**params, "__user__": __user__}
  250. if inspect.iscoroutinefunction(outlet):
  251. data = await outlet(**params)
  252. else:
  253. data = outlet(**params)
  254. except Exception as e:
  255. print(f"Error: {e}")
  256. return JSONResponse(
  257. status_code=status.HTTP_400_BAD_REQUEST,
  258. content={"detail": str(e)},
  259. )
  260. return data
  261. @app.post("/api/chat/actions/{action_id}")
  262. async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)):
  263. if "." in action_id:
  264. action_id, sub_action_id = action_id.split(".")
  265. else:
  266. sub_action_id = None
  267. action = Functions.get_function_by_id(action_id)
  268. if not action:
  269. raise HTTPException(
  270. status_code=status.HTTP_404_NOT_FOUND,
  271. detail="Action not found",
  272. )
  273. model_list = await get_all_models()
  274. models = {model["id"]: model for model in model_list}
  275. data = form_data
  276. model_id = data["model"]
  277. if model_id not in models:
  278. raise HTTPException(
  279. status_code=status.HTTP_404_NOT_FOUND,
  280. detail="Model not found",
  281. )
  282. model = models[model_id]
  283. __event_emitter__ = get_event_emitter(
  284. {
  285. "chat_id": data["chat_id"],
  286. "message_id": data["id"],
  287. "session_id": data["session_id"],
  288. }
  289. )
  290. __event_call__ = get_event_call(
  291. {
  292. "chat_id": data["chat_id"],
  293. "message_id": data["id"],
  294. "session_id": data["session_id"],
  295. }
  296. )
  297. if action_id in webui_app.state.FUNCTIONS:
  298. function_module = webui_app.state.FUNCTIONS[action_id]
  299. else:
  300. function_module, _, _ = load_function_module_by_id(action_id)
  301. webui_app.state.FUNCTIONS[action_id] = function_module
  302. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  303. valves = Functions.get_function_valves_by_id(action_id)
  304. function_module.valves = function_module.Valves(**(valves if valves else {}))
  305. if hasattr(function_module, "action"):
  306. try:
  307. action = function_module.action
  308. # Get the signature of the function
  309. sig = inspect.signature(action)
  310. params = {"body": data}
  311. # Extra parameters to be passed to the function
  312. extra_params = {
  313. "__model__": model,
  314. "__id__": sub_action_id if sub_action_id is not None else action_id,
  315. "__event_emitter__": __event_emitter__,
  316. "__event_call__": __event_call__,
  317. }
  318. # Add extra params in contained in function signature
  319. for key, value in extra_params.items():
  320. if key in sig.parameters:
  321. params[key] = value
  322. if "__user__" in sig.parameters:
  323. __user__ = {
  324. "id": user.id,
  325. "email": user.email,
  326. "name": user.name,
  327. "role": user.role,
  328. }
  329. try:
  330. if hasattr(function_module, "UserValves"):
  331. __user__["valves"] = function_module.UserValves(
  332. **Functions.get_user_valves_by_id_and_user_id(
  333. action_id, user.id
  334. )
  335. )
  336. except Exception as e:
  337. print(e)
  338. params = {**params, "__user__": __user__}
  339. if inspect.iscoroutinefunction(action):
  340. data = await action(**params)
  341. else:
  342. data = action(**params)
  343. except Exception as e:
  344. print(f"Error: {e}")
  345. return JSONResponse(
  346. status_code=status.HTTP_400_BAD_REQUEST,
  347. content={"detail": str(e)},
  348. )
  349. return data