code_interpreter.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import asyncio
  2. import json
  3. import logging
  4. import uuid
  5. from typing import Optional
  6. import aiohttp
  7. import websockets
  8. from pydantic import BaseModel
  9. from open_webui.env import SRC_LOG_LEVELS
  10. logger = logging.getLogger(__name__)
  11. logger.setLevel(SRC_LOG_LEVELS["MAIN"])
  12. class ResultModel(BaseModel):
  13. """
  14. Execute Code Result Model
  15. """
  16. stdout: Optional[str] = ""
  17. stderr: Optional[str] = ""
  18. result: Optional[str] = ""
  19. class JupyterCodeExecuter:
  20. """
  21. Execute code in jupyter notebook
  22. """
  23. def __init__(
  24. self,
  25. base_url: str,
  26. code: str,
  27. token: str = "",
  28. password: str = "",
  29. timeout: int = 60,
  30. ):
  31. """
  32. :param base_url: Jupyter server URL (e.g., "http://localhost:8888")
  33. :param code: Code to execute
  34. :param token: Jupyter authentication token (optional)
  35. :param password: Jupyter password (optional)
  36. :param timeout: WebSocket timeout in seconds (default: 60s)
  37. """
  38. self.base_url = base_url.rstrip("/")
  39. self.code = code
  40. self.token = token
  41. self.password = password
  42. self.timeout = timeout
  43. self.kernel_id = ""
  44. self.session = aiohttp.ClientSession(base_url=self.base_url)
  45. self.params = {}
  46. self.result = ResultModel()
  47. async def __aenter__(self):
  48. return self
  49. async def __aexit__(self, exc_type, exc_val, exc_tb):
  50. if self.kernel_id:
  51. try:
  52. async with self.session.delete(
  53. f"/api/kernels/{self.kernel_id}", params=self.params
  54. ) as response:
  55. response.raise_for_status()
  56. except Exception as err:
  57. logger.exception("close kernel failed, %s", err)
  58. await self.session.close()
  59. async def run(self) -> ResultModel:
  60. try:
  61. await self.sign_in()
  62. await self.init_kernel()
  63. await self.execute_code()
  64. except Exception as err:
  65. logger.exception("execute code failed, %s", err)
  66. self.result.stderr = f"Error: {err}"
  67. return self.result
  68. async def sign_in(self) -> None:
  69. # password authentication
  70. if self.password and not self.token:
  71. async with self.session.get("/login") as response:
  72. response.raise_for_status()
  73. xsrf_token = response.cookies["_xsrf"].value
  74. if not xsrf_token:
  75. raise ValueError("_xsrf token not found")
  76. self.session.cookie_jar.update_cookies(response.cookies)
  77. self.session.headers.update({"X-XSRFToken": xsrf_token})
  78. async with self.session.post(
  79. "/login",
  80. data={"_xsrf": xsrf_token, "password": self.password},
  81. allow_redirects=False,
  82. ) as response:
  83. response.raise_for_status()
  84. self.session.cookie_jar.update_cookies(response.cookies)
  85. # token authentication
  86. if self.token:
  87. self.params.update({"token": self.token})
  88. async def init_kernel(self) -> None:
  89. async with self.session.post(
  90. url="/api/kernels", params=self.params
  91. ) as response:
  92. response.raise_for_status()
  93. kernel_data = await response.json()
  94. self.kernel_id = kernel_data["id"]
  95. def init_ws(self) -> (str, dict):
  96. ws_base = self.base_url.replace("http", "ws")
  97. ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()])
  98. websocket_url = f"{ws_base}/api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}"
  99. ws_headers = {}
  100. if self.password and not self.token:
  101. ws_headers = {
  102. "Cookie": "; ".join(
  103. [
  104. f"{cookie.key}={cookie.value}"
  105. for cookie in self.session.cookie_jar
  106. ]
  107. ),
  108. **self.session.headers,
  109. }
  110. return websocket_url, ws_headers
  111. async def execute_code(self) -> None:
  112. # initialize ws
  113. websocket_url, ws_headers = self.init_ws()
  114. # execute
  115. async with websockets.connect(
  116. websocket_url, additional_headers=ws_headers
  117. ) as ws:
  118. await self.execute_in_jupyter(ws)
  119. async def execute_in_jupyter(self, ws) -> None:
  120. # send message
  121. msg_id = uuid.uuid4().hex
  122. await ws.send(
  123. json.dumps(
  124. {
  125. "header": {
  126. "msg_id": msg_id,
  127. "msg_type": "execute_request",
  128. "username": "user",
  129. "session": uuid.uuid4().hex,
  130. "date": "",
  131. "version": "5.3",
  132. },
  133. "parent_header": {},
  134. "metadata": {},
  135. "content": {
  136. "code": self.code,
  137. "silent": False,
  138. "store_history": True,
  139. "user_expressions": {},
  140. "allow_stdin": False,
  141. "stop_on_error": True,
  142. },
  143. "channel": "shell",
  144. }
  145. )
  146. )
  147. # parse message
  148. stdout, stderr, result = "", "", []
  149. while True:
  150. try:
  151. # wait for message
  152. message = await asyncio.wait_for(ws.recv(), self.timeout)
  153. message_data = json.loads(message)
  154. # msg id not match, skip
  155. if message_data.get("parent_header", {}).get("msg_id") != msg_id:
  156. continue
  157. # check message type
  158. msg_type = message_data.get("msg_type")
  159. match msg_type:
  160. case "stream":
  161. if message_data["content"]["name"] == "stdout":
  162. stdout += message_data["content"]["text"]
  163. elif message_data["content"]["name"] == "stderr":
  164. stderr += message_data["content"]["text"]
  165. case "execute_result" | "display_data":
  166. data = message_data["content"]["data"]
  167. if "image/png" in data:
  168. result.append(f"data:image/png;base64,{data['image/png']}")
  169. elif "text/plain" in data:
  170. result.append(data["text/plain"])
  171. case "error":
  172. stderr += "\n".join(message_data["content"]["traceback"])
  173. case "status":
  174. if message_data["content"]["execution_state"] == "idle":
  175. break
  176. except asyncio.TimeoutError:
  177. stderr += "\nExecution timed out."
  178. break
  179. self.result.stdout = stdout.strip()
  180. self.result.stderr = stderr.strip()
  181. self.result.result = "\n".join(result).strip() if result else ""
  182. async def execute_code_jupyter(
  183. base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
  184. ) -> dict:
  185. async with JupyterCodeExecuter(
  186. base_url, code, token, password, timeout
  187. ) as executor:
  188. result = await executor.run()
  189. return result.model_dump()