chat.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. import time
  2. import logging
  3. import sys
  4. from aiocache import cached
  5. from typing import Any, Optional
  6. import random
  7. import json
  8. import inspect
  9. import uuid
  10. import asyncio
  11. from fastapi import Request, status
  12. from starlette.responses import Response, StreamingResponse, JSONResponse
  13. from open_webui.models.users import UserModel
  14. from open_webui.socket.main import (
  15. sio,
  16. get_event_call,
  17. get_event_emitter,
  18. )
  19. from open_webui.functions import generate_function_chat_completion
  20. from open_webui.routers.openai import (
  21. generate_chat_completion as generate_openai_chat_completion,
  22. )
  23. from open_webui.routers.ollama import (
  24. generate_chat_completion as generate_ollama_chat_completion,
  25. )
  26. from open_webui.routers.pipelines import (
  27. process_pipeline_inlet_filter,
  28. process_pipeline_outlet_filter,
  29. )
  30. from open_webui.models.functions import Functions
  31. from open_webui.models.models import Models
  32. from open_webui.utils.plugin import load_function_module_by_id
  33. from open_webui.utils.models import get_all_models, check_model_access
  34. from open_webui.utils.payload import convert_payload_openai_to_ollama
  35. from open_webui.utils.response import (
  36. convert_response_ollama_to_openai,
  37. convert_streaming_response_ollama_to_openai,
  38. )
  39. from open_webui.utils.filter import (
  40. get_sorted_filter_ids,
  41. process_filter_functions,
  42. )
  43. from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
  44. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  45. log = logging.getLogger(__name__)
  46. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  47. async def generate_direct_chat_completion(
  48. request: Request,
  49. form_data: dict,
  50. user: Any,
  51. models: dict,
  52. ):
  53. log.info("generate_direct_chat_completion")
  54. metadata = form_data.pop("metadata", {})
  55. user_id = metadata.get("user_id")
  56. session_id = metadata.get("session_id")
  57. request_id = str(uuid.uuid4()) # Generate a unique request ID
  58. event_caller = get_event_call(metadata)
  59. channel = f"{user_id}:{session_id}:{request_id}"
  60. if form_data.get("stream"):
  61. q = asyncio.Queue()
  62. async def message_listener(sid, data):
  63. """
  64. Handle received socket messages and push them into the queue.
  65. """
  66. await q.put(data)
  67. # Register the listener
  68. sio.on(channel, message_listener)
  69. # Start processing chat completion in background
  70. res = await event_caller(
  71. {
  72. "type": "request:chat:completion",
  73. "data": {
  74. "form_data": form_data,
  75. "model": models[form_data["model"]],
  76. "channel": channel,
  77. "session_id": session_id,
  78. },
  79. }
  80. )
  81. log.info(f"res: {res}")
  82. if res.get("status", False):
  83. # Define a generator to stream responses
  84. async def event_generator():
  85. nonlocal q
  86. try:
  87. while True:
  88. data = await q.get() # Wait for new messages
  89. if isinstance(data, dict):
  90. if "done" in data and data["done"]:
  91. break # Stop streaming when 'done' is received
  92. yield f"data: {json.dumps(data)}\n\n"
  93. elif isinstance(data, str):
  94. yield data
  95. except Exception as e:
  96. log.debug(f"Error in event generator: {e}")
  97. pass
  98. # Define a background task to run the event generator
  99. async def background():
  100. try:
  101. del sio.handlers["/"][channel]
  102. except Exception as e:
  103. pass
  104. # Return the streaming response
  105. return StreamingResponse(
  106. event_generator(), media_type="text/event-stream", background=background
  107. )
  108. else:
  109. raise Exception(str(res))
  110. else:
  111. res = await event_caller(
  112. {
  113. "type": "request:chat:completion",
  114. "data": {
  115. "form_data": form_data,
  116. "model": models[form_data["model"]],
  117. "channel": channel,
  118. "session_id": session_id,
  119. },
  120. }
  121. )
  122. if "error" in res and res["error"]:
  123. raise Exception(res["error"])
  124. return res
  125. async def generate_chat_completion(
  126. request: Request,
  127. form_data: dict,
  128. user: Any,
  129. bypass_filter: bool = False,
  130. ):
  131. log.debug(f"generate_chat_completion: {form_data}")
  132. if BYPASS_MODEL_ACCESS_CONTROL:
  133. bypass_filter = True
  134. if hasattr(request.state, "metadata"):
  135. if "metadata" not in form_data:
  136. form_data["metadata"] = request.state.metadata
  137. else:
  138. form_data["metadata"] = {
  139. **form_data["metadata"],
  140. **request.state.metadata,
  141. }
  142. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  143. models = {
  144. request.state.model["id"]: request.state.model,
  145. }
  146. log.debug(f"direct connection to model: {models}")
  147. else:
  148. models = request.app.state.MODELS
  149. model_id = form_data["model"]
  150. if model_id not in models:
  151. raise Exception("Model not found")
  152. model = models[model_id]
  153. if getattr(request.state, "direct", False):
  154. return await generate_direct_chat_completion(
  155. request, form_data, user=user, models=models
  156. )
  157. else:
  158. # Check if user has access to the model
  159. if not bypass_filter and user.role == "user":
  160. try:
  161. check_model_access(user, model)
  162. except Exception as e:
  163. raise e
  164. if model.get("owned_by") == "arena":
  165. model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
  166. filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
  167. if model_ids and filter_mode == "exclude":
  168. model_ids = [
  169. model["id"]
  170. for model in list(request.app.state.MODELS.values())
  171. if model.get("owned_by") != "arena" and model["id"] not in model_ids
  172. ]
  173. selected_model_id = None
  174. if isinstance(model_ids, list) and model_ids:
  175. selected_model_id = random.choice(model_ids)
  176. else:
  177. model_ids = [
  178. model["id"]
  179. for model in list(request.app.state.MODELS.values())
  180. if model.get("owned_by") != "arena"
  181. ]
  182. selected_model_id = random.choice(model_ids)
  183. form_data["model"] = selected_model_id
  184. if form_data.get("stream") == True:
  185. async def stream_wrapper(stream):
  186. yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
  187. async for chunk in stream:
  188. yield chunk
  189. response = await generate_chat_completion(
  190. request, form_data, user, bypass_filter=True
  191. )
  192. return StreamingResponse(
  193. stream_wrapper(response.body_iterator),
  194. media_type="text/event-stream",
  195. background=response.background,
  196. )
  197. else:
  198. return {
  199. **(
  200. await generate_chat_completion(
  201. request, form_data, user, bypass_filter=True
  202. )
  203. ),
  204. "selected_model_id": selected_model_id,
  205. }
  206. if model.get("pipe"):
  207. # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
  208. return await generate_function_chat_completion(
  209. request, form_data, user=user, models=models
  210. )
  211. if model.get("owned_by") == "ollama":
  212. # Using /ollama/api/chat endpoint
  213. form_data = convert_payload_openai_to_ollama(form_data)
  214. response = await generate_ollama_chat_completion(
  215. request=request,
  216. form_data=form_data,
  217. user=user,
  218. bypass_filter=bypass_filter,
  219. )
  220. if form_data.get("stream"):
  221. response.headers["content-type"] = "text/event-stream"
  222. return StreamingResponse(
  223. convert_streaming_response_ollama_to_openai(response),
  224. headers=dict(response.headers),
  225. background=response.background,
  226. )
  227. else:
  228. return convert_response_ollama_to_openai(response)
  229. else:
  230. return await generate_openai_chat_completion(
  231. request=request,
  232. form_data=form_data,
  233. user=user,
  234. bypass_filter=bypass_filter,
  235. )
  236. chat_completion = generate_chat_completion
  237. async def chat_completed(request: Request, form_data: dict, user: Any):
  238. if not request.app.state.MODELS:
  239. await get_all_models(request, user=user)
  240. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  241. models = {
  242. request.state.model["id"]: request.state.model,
  243. }
  244. else:
  245. models = request.app.state.MODELS
  246. data = form_data
  247. model_id = data["model"]
  248. if model_id not in models:
  249. raise Exception("Model not found")
  250. model = models[model_id]
  251. try:
  252. data = await process_pipeline_outlet_filter(request, data, user, models)
  253. except Exception as e:
  254. return Exception(f"Error: {e}")
  255. metadata = {
  256. "chat_id": data["chat_id"],
  257. "message_id": data["id"],
  258. "session_id": data["session_id"],
  259. "user_id": user.id,
  260. }
  261. extra_params = {
  262. "__event_emitter__": get_event_emitter(metadata),
  263. "__event_call__": get_event_call(metadata),
  264. "__user__": {
  265. "id": user.id,
  266. "email": user.email,
  267. "name": user.name,
  268. "role": user.role,
  269. },
  270. "__metadata__": metadata,
  271. "__request__": request,
  272. "__model__": model,
  273. }
  274. try:
  275. filter_functions = [
  276. Functions.get_function_by_id(filter_id)
  277. for filter_id in get_sorted_filter_ids(model)
  278. ]
  279. result, _ = await process_filter_functions(
  280. request=request,
  281. filter_functions=filter_functions,
  282. filter_type="outlet",
  283. form_data=data,
  284. extra_params=extra_params,
  285. )
  286. return result
  287. except Exception as e:
  288. return Exception(f"Error: {e}")
  289. async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
  290. if "." in action_id:
  291. action_id, sub_action_id = action_id.split(".")
  292. else:
  293. sub_action_id = None
  294. action = Functions.get_function_by_id(action_id)
  295. if not action:
  296. raise Exception(f"Action not found: {action_id}")
  297. if not request.app.state.MODELS:
  298. await get_all_models(request, user=user)
  299. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  300. models = {
  301. request.state.model["id"]: request.state.model,
  302. }
  303. else:
  304. models = request.app.state.MODELS
  305. data = form_data
  306. model_id = data["model"]
  307. if model_id not in models:
  308. raise Exception("Model not found")
  309. model = models[model_id]
  310. __event_emitter__ = get_event_emitter(
  311. {
  312. "chat_id": data["chat_id"],
  313. "message_id": data["id"],
  314. "session_id": data["session_id"],
  315. "user_id": user.id,
  316. }
  317. )
  318. __event_call__ = get_event_call(
  319. {
  320. "chat_id": data["chat_id"],
  321. "message_id": data["id"],
  322. "session_id": data["session_id"],
  323. "user_id": user.id,
  324. }
  325. )
  326. if action_id in request.app.state.FUNCTIONS:
  327. function_module = request.app.state.FUNCTIONS[action_id]
  328. else:
  329. function_module, _, _ = load_function_module_by_id(action_id)
  330. request.app.state.FUNCTIONS[action_id] = function_module
  331. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  332. valves = Functions.get_function_valves_by_id(action_id)
  333. function_module.valves = function_module.Valves(**(valves if valves else {}))
  334. if hasattr(function_module, "action"):
  335. try:
  336. action = function_module.action
  337. # Get the signature of the function
  338. sig = inspect.signature(action)
  339. params = {"body": data}
  340. # Extra parameters to be passed to the function
  341. extra_params = {
  342. "__model__": model,
  343. "__id__": sub_action_id if sub_action_id is not None else action_id,
  344. "__event_emitter__": __event_emitter__,
  345. "__event_call__": __event_call__,
  346. "__request__": request,
  347. }
  348. # Add extra params in contained in function signature
  349. for key, value in extra_params.items():
  350. if key in sig.parameters:
  351. params[key] = value
  352. if "__user__" in sig.parameters:
  353. __user__ = {
  354. "id": user.id,
  355. "email": user.email,
  356. "name": user.name,
  357. "role": user.role,
  358. }
  359. try:
  360. if hasattr(function_module, "UserValves"):
  361. __user__["valves"] = function_module.UserValves(
  362. **Functions.get_user_valves_by_id_and_user_id(
  363. action_id, user.id
  364. )
  365. )
  366. except Exception as e:
  367. log.exception(f"Failed to get user values: {e}")
  368. params = {**params, "__user__": __user__}
  369. if inspect.iscoroutinefunction(action):
  370. data = await action(**params)
  371. else:
  372. data = action(**params)
  373. except Exception as e:
  374. return Exception(f"Error: {e}")
  375. return data