functions.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. import logging
  2. import sys
  3. import inspect
  4. import json
  5. from pydantic import BaseModel
  6. from typing import AsyncGenerator, Generator, Iterator
  7. from fastapi import (
  8. Depends,
  9. FastAPI,
  10. File,
  11. Form,
  12. HTTPException,
  13. Request,
  14. UploadFile,
  15. status,
  16. )
  17. from starlette.responses import Response, StreamingResponse
  18. from open_webui.socket.main import (
  19. get_event_call,
  20. get_event_emitter,
  21. )
  22. from open_webui.models.functions import Functions
  23. from open_webui.models.models import Models
  24. from open_webui.utils.plugin import load_function_module_by_id
  25. from open_webui.utils.tools import get_tools
  26. from open_webui.utils.access_control import has_access
  27. from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
  28. from open_webui.utils.misc import (
  29. add_or_update_system_message,
  30. get_last_user_message,
  31. prepend_to_first_user_message_content,
  32. openai_chat_chunk_message_template,
  33. openai_chat_completion_message_template,
  34. )
  35. from open_webui.utils.payload import (
  36. apply_model_params_to_body_openai,
  37. apply_model_system_prompt_to_body,
  38. )
  39. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  40. log = logging.getLogger(__name__)
  41. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  42. def get_function_module_by_id(request: Request, pipe_id: str):
  43. # Check if function is already loaded
  44. if pipe_id not in request.app.state.FUNCTIONS:
  45. function_module, _, _ = load_function_module_by_id(pipe_id)
  46. request.app.state.FUNCTIONS[pipe_id] = function_module
  47. else:
  48. function_module = request.app.state.FUNCTIONS[pipe_id]
  49. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  50. valves = Functions.get_function_valves_by_id(pipe_id)
  51. function_module.valves = function_module.Valves(**(valves if valves else {}))
  52. return function_module
  53. async def get_function_models(request):
  54. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  55. pipe_models = []
  56. for pipe in pipes:
  57. function_module = get_function_module_by_id(request, pipe.id)
  58. # Check if function is a manifold
  59. if hasattr(function_module, "pipes"):
  60. sub_pipes = []
  61. # Check if pipes is a function or a list
  62. try:
  63. if callable(function_module.pipes):
  64. sub_pipes = function_module.pipes()
  65. else:
  66. sub_pipes = function_module.pipes
  67. except Exception as e:
  68. log.exception(e)
  69. sub_pipes = []
  70. log.debug(
  71. f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
  72. )
  73. for p in sub_pipes:
  74. sub_pipe_id = f'{pipe.id}.{p["id"]}'
  75. sub_pipe_name = p["name"]
  76. if hasattr(function_module, "name"):
  77. sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
  78. pipe_flag = {"type": pipe.type}
  79. pipe_models.append(
  80. {
  81. "id": sub_pipe_id,
  82. "name": sub_pipe_name,
  83. "object": "model",
  84. "created": pipe.created_at,
  85. "owned_by": "openai",
  86. "pipe": pipe_flag,
  87. }
  88. )
  89. else:
  90. pipe_flag = {"type": "pipe"}
  91. log.debug(
  92. f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
  93. )
  94. pipe_models.append(
  95. {
  96. "id": pipe.id,
  97. "name": pipe.name,
  98. "object": "model",
  99. "created": pipe.created_at,
  100. "owned_by": "openai",
  101. "pipe": pipe_flag,
  102. }
  103. )
  104. return pipe_models
  105. async def generate_function_chat_completion(
  106. request, form_data, user, models: dict = {}
  107. ):
  108. async def execute_pipe(pipe, params):
  109. if inspect.iscoroutinefunction(pipe):
  110. return await pipe(**params)
  111. else:
  112. return pipe(**params)
  113. async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
  114. if isinstance(res, str):
  115. return res
  116. if isinstance(res, Generator):
  117. return "".join(map(str, res))
  118. if isinstance(res, AsyncGenerator):
  119. return "".join([str(stream) async for stream in res])
  120. def process_line(form_data: dict, line):
  121. if isinstance(line, BaseModel):
  122. line = line.model_dump_json()
  123. line = f"data: {line}"
  124. if isinstance(line, dict):
  125. line = f"data: {json.dumps(line)}"
  126. try:
  127. line = line.decode("utf-8")
  128. except Exception:
  129. pass
  130. if line.startswith("data:"):
  131. return f"{line}\n\n"
  132. else:
  133. line = openai_chat_chunk_message_template(form_data["model"], line)
  134. return f"data: {json.dumps(line)}\n\n"
  135. def get_pipe_id(form_data: dict) -> str:
  136. pipe_id = form_data["model"]
  137. if "." in pipe_id:
  138. pipe_id, _ = pipe_id.split(".", 1)
  139. return pipe_id
  140. def get_function_params(function_module, form_data, user, extra_params=None):
  141. if extra_params is None:
  142. extra_params = {}
  143. pipe_id = get_pipe_id(form_data)
  144. # Get the signature of the function
  145. sig = inspect.signature(function_module.pipe)
  146. params = {"body": form_data} | {
  147. k: v for k, v in extra_params.items() if k in sig.parameters
  148. }
  149. if "__user__" in params and hasattr(function_module, "UserValves"):
  150. user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  151. try:
  152. params["__user__"]["valves"] = function_module.UserValves(**user_valves)
  153. except Exception as e:
  154. log.exception(e)
  155. params["__user__"]["valves"] = function_module.UserValves()
  156. return params
  157. model_id = form_data.get("model")
  158. model_info = Models.get_model_by_id(model_id)
  159. metadata = form_data.pop("metadata", {})
  160. files = metadata.get("files", [])
  161. tool_ids = metadata.get("tool_ids", [])
  162. # Check if tool_ids is None
  163. if tool_ids is None:
  164. tool_ids = []
  165. __event_emitter__ = None
  166. __event_call__ = None
  167. __task__ = None
  168. __task_body__ = None
  169. if metadata:
  170. if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
  171. __event_emitter__ = get_event_emitter(metadata)
  172. __event_call__ = get_event_call(metadata)
  173. __task__ = metadata.get("task", None)
  174. __task_body__ = metadata.get("task_body", None)
  175. extra_params = {
  176. "__event_emitter__": __event_emitter__,
  177. "__event_call__": __event_call__,
  178. "__task__": __task__,
  179. "__task_body__": __task_body__,
  180. "__files__": files,
  181. "__user__": {
  182. "id": user.id,
  183. "email": user.email,
  184. "name": user.name,
  185. "role": user.role,
  186. },
  187. "__metadata__": metadata,
  188. "__request__": request,
  189. }
  190. extra_params["__tools__"] = get_tools(
  191. request,
  192. tool_ids,
  193. user,
  194. {
  195. **extra_params,
  196. "__model__": models.get(form_data["model"], None),
  197. "__messages__": form_data["messages"],
  198. "__files__": files,
  199. },
  200. )
  201. if model_info:
  202. if model_info.base_model_id:
  203. form_data["model"] = model_info.base_model_id
  204. params = model_info.params.model_dump()
  205. form_data = apply_model_params_to_body_openai(params, form_data)
  206. form_data = apply_model_system_prompt_to_body(params, form_data, user)
  207. pipe_id = get_pipe_id(form_data)
  208. function_module = get_function_module_by_id(request, pipe_id)
  209. pipe = function_module.pipe
  210. params = get_function_params(function_module, form_data, user, extra_params)
  211. if form_data.get("stream", False):
  212. async def stream_content():
  213. try:
  214. res = await execute_pipe(pipe, params)
  215. # Directly return if the response is a StreamingResponse
  216. if isinstance(res, StreamingResponse):
  217. async for data in res.body_iterator:
  218. yield data
  219. return
  220. if isinstance(res, dict):
  221. yield f"data: {json.dumps(res)}\n\n"
  222. return
  223. except Exception as e:
  224. log.error(f"Error: {e}")
  225. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  226. return
  227. if isinstance(res, str):
  228. message = openai_chat_chunk_message_template(form_data["model"], res)
  229. yield f"data: {json.dumps(message)}\n\n"
  230. if isinstance(res, Iterator):
  231. for line in res:
  232. yield process_line(form_data, line)
  233. if isinstance(res, AsyncGenerator):
  234. async for line in res:
  235. yield process_line(form_data, line)
  236. if isinstance(res, str) or isinstance(res, Generator):
  237. finish_message = openai_chat_chunk_message_template(
  238. form_data["model"], ""
  239. )
  240. finish_message["choices"][0]["finish_reason"] = "stop"
  241. yield f"data: {json.dumps(finish_message)}\n\n"
  242. yield "data: [DONE]"
  243. return StreamingResponse(stream_content(), media_type="text/event-stream")
  244. else:
  245. try:
  246. res = await execute_pipe(pipe, params)
  247. except Exception as e:
  248. log.error(f"Error: {e}")
  249. return {"error": {"detail": str(e)}}
  250. if isinstance(res, StreamingResponse) or isinstance(res, dict):
  251. return res
  252. if isinstance(res, BaseModel):
  253. return res.model_dump()
  254. message = await get_message_content(res)
  255. return openai_chat_completion_message_template(form_data["model"], message)