functions.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. import logging
  2. import sys
  3. import inspect
  4. import json
  5. import asyncio
  6. from pydantic import BaseModel
  7. from typing import AsyncGenerator, Generator, Iterator
  8. from fastapi import (
  9. Depends,
  10. FastAPI,
  11. File,
  12. Form,
  13. HTTPException,
  14. Request,
  15. UploadFile,
  16. status,
  17. )
  18. from starlette.responses import Response, StreamingResponse
  19. from open_webui.socket.main import (
  20. get_event_call,
  21. get_event_emitter,
  22. )
  23. from open_webui.models.functions import Functions
  24. from open_webui.models.models import Models
  25. from open_webui.utils.plugin import load_function_module_by_id
  26. from open_webui.utils.tools import get_tools
  27. from open_webui.utils.access_control import has_access
  28. from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
  29. from open_webui.utils.misc import (
  30. add_or_update_system_message,
  31. get_last_user_message,
  32. prepend_to_first_user_message_content,
  33. openai_chat_chunk_message_template,
  34. openai_chat_completion_message_template,
  35. )
  36. from open_webui.utils.payload import (
  37. apply_model_params_to_body_openai,
  38. apply_model_system_prompt_to_body,
  39. )
  40. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  41. log = logging.getLogger(__name__)
  42. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  43. def get_function_module_by_id(request: Request, pipe_id: str):
  44. # Check if function is already loaded
  45. if pipe_id not in request.app.state.FUNCTIONS:
  46. function_module, _, _ = load_function_module_by_id(pipe_id)
  47. request.app.state.FUNCTIONS[pipe_id] = function_module
  48. else:
  49. function_module = request.app.state.FUNCTIONS[pipe_id]
  50. if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
  51. valves = Functions.get_function_valves_by_id(pipe_id)
  52. function_module.valves = function_module.Valves(**(valves if valves else {}))
  53. return function_module
  54. async def get_function_models(request):
  55. pipes = Functions.get_functions_by_type("pipe", active_only=True)
  56. pipe_models = []
  57. for pipe in pipes:
  58. function_module = get_function_module_by_id(request, pipe.id)
  59. # Check if function is a manifold
  60. if hasattr(function_module, "pipes"):
  61. sub_pipes = []
  62. # Handle pipes being a list, sync function, or async function
  63. try:
  64. if callable(function_module.pipes):
  65. if asyncio.iscoroutinefunction(function_module.pipes):
  66. sub_pipes = await function_module.pipes()
  67. else:
  68. sub_pipes = function_module.pipes()
  69. else:
  70. sub_pipes = function_module.pipes
  71. except Exception as e:
  72. log.exception(e)
  73. sub_pipes = []
  74. log.debug(
  75. f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
  76. )
  77. for p in sub_pipes:
  78. sub_pipe_id = f'{pipe.id}.{p["id"]}'
  79. sub_pipe_name = p["name"]
  80. if hasattr(function_module, "name"):
  81. sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
  82. pipe_flag = {"type": pipe.type}
  83. pipe_models.append(
  84. {
  85. "id": sub_pipe_id,
  86. "name": sub_pipe_name,
  87. "object": "model",
  88. "created": pipe.created_at,
  89. "owned_by": "openai",
  90. "pipe": pipe_flag,
  91. }
  92. )
  93. else:
  94. pipe_flag = {"type": "pipe"}
  95. log.debug(
  96. f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
  97. )
  98. pipe_models.append(
  99. {
  100. "id": pipe.id,
  101. "name": pipe.name,
  102. "object": "model",
  103. "created": pipe.created_at,
  104. "owned_by": "openai",
  105. "pipe": pipe_flag,
  106. }
  107. )
  108. return pipe_models
  109. async def generate_function_chat_completion(
  110. request, form_data, user, models: dict = {}
  111. ):
  112. async def execute_pipe(pipe, params):
  113. if inspect.iscoroutinefunction(pipe):
  114. return await pipe(**params)
  115. else:
  116. return pipe(**params)
  117. async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
  118. if isinstance(res, str):
  119. return res
  120. if isinstance(res, Generator):
  121. return "".join(map(str, res))
  122. if isinstance(res, AsyncGenerator):
  123. return "".join([str(stream) async for stream in res])
  124. def process_line(form_data: dict, line):
  125. if isinstance(line, BaseModel):
  126. line = line.model_dump_json()
  127. line = f"data: {line}"
  128. if isinstance(line, dict):
  129. line = f"data: {json.dumps(line)}"
  130. try:
  131. line = line.decode("utf-8")
  132. except Exception:
  133. pass
  134. if line.startswith("data:"):
  135. return f"{line}\n\n"
  136. else:
  137. line = openai_chat_chunk_message_template(form_data["model"], line)
  138. return f"data: {json.dumps(line)}\n\n"
  139. def get_pipe_id(form_data: dict) -> str:
  140. pipe_id = form_data["model"]
  141. if "." in pipe_id:
  142. pipe_id, _ = pipe_id.split(".", 1)
  143. return pipe_id
  144. def get_function_params(function_module, form_data, user, extra_params=None):
  145. if extra_params is None:
  146. extra_params = {}
  147. pipe_id = get_pipe_id(form_data)
  148. # Get the signature of the function
  149. sig = inspect.signature(function_module.pipe)
  150. params = {"body": form_data} | {
  151. k: v for k, v in extra_params.items() if k in sig.parameters
  152. }
  153. if "__user__" in params and hasattr(function_module, "UserValves"):
  154. user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
  155. try:
  156. params["__user__"]["valves"] = function_module.UserValves(**user_valves)
  157. except Exception as e:
  158. log.exception(e)
  159. params["__user__"]["valves"] = function_module.UserValves()
  160. return params
  161. model_id = form_data.get("model")
  162. model_info = Models.get_model_by_id(model_id)
  163. metadata = form_data.pop("metadata", {})
  164. files = metadata.get("files", [])
  165. tool_ids = metadata.get("tool_ids", [])
  166. # Check if tool_ids is None
  167. if tool_ids is None:
  168. tool_ids = []
  169. __event_emitter__ = None
  170. __event_call__ = None
  171. __task__ = None
  172. __task_body__ = None
  173. if metadata:
  174. if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
  175. __event_emitter__ = get_event_emitter(metadata)
  176. __event_call__ = get_event_call(metadata)
  177. __task__ = metadata.get("task", None)
  178. __task_body__ = metadata.get("task_body", None)
  179. extra_params = {
  180. "__event_emitter__": __event_emitter__,
  181. "__event_call__": __event_call__,
  182. "__task__": __task__,
  183. "__task_body__": __task_body__,
  184. "__files__": files,
  185. "__user__": {
  186. "id": user.id,
  187. "email": user.email,
  188. "name": user.name,
  189. "role": user.role,
  190. },
  191. "__metadata__": metadata,
  192. "__request__": request,
  193. }
  194. extra_params["__tools__"] = get_tools(
  195. request,
  196. tool_ids,
  197. user,
  198. {
  199. **extra_params,
  200. "__model__": models.get(form_data["model"], None),
  201. "__messages__": form_data["messages"],
  202. "__files__": files,
  203. },
  204. )
  205. if model_info:
  206. if model_info.base_model_id:
  207. form_data["model"] = model_info.base_model_id
  208. params = model_info.params.model_dump()
  209. form_data = apply_model_params_to_body_openai(params, form_data)
  210. form_data = apply_model_system_prompt_to_body(params, form_data, metadata, user)
  211. pipe_id = get_pipe_id(form_data)
  212. function_module = get_function_module_by_id(request, pipe_id)
  213. pipe = function_module.pipe
  214. params = get_function_params(function_module, form_data, user, extra_params)
  215. if form_data.get("stream", False):
  216. async def stream_content():
  217. try:
  218. res = await execute_pipe(pipe, params)
  219. # Directly return if the response is a StreamingResponse
  220. if isinstance(res, StreamingResponse):
  221. async for data in res.body_iterator:
  222. yield data
  223. return
  224. if isinstance(res, dict):
  225. yield f"data: {json.dumps(res)}\n\n"
  226. return
  227. except Exception as e:
  228. log.error(f"Error: {e}")
  229. yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
  230. return
  231. if isinstance(res, str):
  232. message = openai_chat_chunk_message_template(form_data["model"], res)
  233. yield f"data: {json.dumps(message)}\n\n"
  234. if isinstance(res, Iterator):
  235. for line in res:
  236. yield process_line(form_data, line)
  237. if isinstance(res, AsyncGenerator):
  238. async for line in res:
  239. yield process_line(form_data, line)
  240. if isinstance(res, str) or isinstance(res, Generator):
  241. finish_message = openai_chat_chunk_message_template(
  242. form_data["model"], ""
  243. )
  244. finish_message["choices"][0]["finish_reason"] = "stop"
  245. yield f"data: {json.dumps(finish_message)}\n\n"
  246. yield "data: [DONE]"
  247. return StreamingResponse(stream_content(), media_type="text/event-stream")
  248. else:
  249. try:
  250. res = await execute_pipe(pipe, params)
  251. except Exception as e:
  252. log.error(f"Error: {e}")
  253. return {"error": {"detail": str(e)}}
  254. if isinstance(res, StreamingResponse) or isinstance(res, dict):
  255. return res
  256. if isinstance(res, BaseModel):
  257. return res.model_dump()
  258. message = await get_message_content(res)
  259. return openai_chat_completion_message_template(form_data["model"], message)