code_interpreter.py 7.3 KB


  1. import asyncio
  2. import json
  3. import logging
  4. import uuid
  5. from typing import Optional
  6. import aiohttp
  7. import websockets
  8. from pydantic import BaseModel
  9. from open_webui.env import SRC_LOG_LEVELS
  10. logger = logging.getLogger(__name__)
  11. logger.setLevel(SRC_LOG_LEVELS["MAIN"])
  12. class ResultModel(BaseModel):
  13. """
  14. Execute Code Result Model
  15. """
  16. stdout: Optional[str] = ""
  17. stderr: Optional[str] = ""
  18. result: Optional[str] = ""
  19. class JupyterCodeExecuter:
  20. """
  21. Execute code in jupyter notebook
  22. """
  23. def __init__(
  24. self,
  25. base_url: str,
  26. code: str,
  27. token: str = "",
  28. password: str = "",
  29. timeout: int = 60,
  30. ):
  31. """
  32. :param base_url: Jupyter server URL (e.g., "http://localhost:8888")
  33. :param code: Code to execute
  34. :param token: Jupyter authentication token (optional)
  35. :param password: Jupyter password (optional)
  36. :param timeout: WebSocket timeout in seconds (default: 60s)
  37. """
  38. self.base_url = base_url.rstrip("/")
  39. self.code = code
  40. self.token = token
  41. self.password = password
  42. self.timeout = timeout
  43. self.kernel_id = ""
  44. self.session = aiohttp.ClientSession(base_url=self.base_url)
  45. self.params = {}
  46. self.result = ResultModel()
  47. async def __aenter__(self):
  48. return self
  49. async def __aexit__(self, exc_type, exc_val, exc_tb):
  50. if self.kernel_id:
  51. try:
  52. async with self.session.delete(
  53. f"/api/kernels/{self.kernel_id}", params=self.params
  54. ) as response:
  55. response.raise_for_status()
  56. except Exception as err:
  57. logger.exception("close kernel failed, %s", err)
  58. await self.session.close()
  59. async def run(self) -> ResultModel:
  60. try:
  61. await self.sign_in()
  62. await self.init_kernel()
  63. await self.execute_code()
  64. except Exception as err:
  65. logger.exception("execute code failed, %s", err)
  66. self.result.stderr = f"Error: {err}"
  67. return self.result
  68. async def sign_in(self) -> None:
  69. # password authentication
  70. if self.password and not self.token:
  71. async with self.session.get("/login") as response:
  72. response.raise_for_status()
  73. xsrf_token = response.cookies["_xsrf"].value
  74. if not xsrf_token:
  75. raise ValueError("_xsrf token not found")
  76. self.session.cookie_jar.update_cookies(response.cookies)
  77. self.session.headers.update({"X-XSRFToken": xsrf_token})
  78. async with self.session.post(
  79. "/login",
  80. data={"_xsrf": xsrf_token, "password": self.password},
  81. allow_redirects=False,
  82. ) as response:
  83. response.raise_for_status()
  84. self.session.cookie_jar.update_cookies(response.cookies)
  85. # token authentication
  86. if self.token:
  87. self.params.update({"token": self.token})
  88. async def init_kernel(self) -> None:
  89. async with self.session.post(
  90. url="/api/kernels", params=self.params
  91. ) as response:
  92. response.raise_for_status()
  93. kernel_data = await response.json()
  94. self.kernel_id = kernel_data["id"]
  95. def init_ws(self) -> (str, dict):
  96. ws_base = self.base_url.replace("http", "ws")
  97. ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()])
  98. websocket_url = f"{ws_base}/api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}"
  99. ws_headers = {}
  100. if self.password and not self.token:
  101. ws_headers = {
  102. "Cookie": "; ".join(
  103. [
  104. f"{cookie.key}={cookie.value}"
  105. for cookie in self.session.cookie_jar
  106. ]
  107. ),
  108. **self.session.headers,
  109. }
  110. return websocket_url, ws_headers
  111. async def execute_code(self) -> None:
  112. # initialize ws
  113. websocket_url, ws_headers = self.init_ws()
  114. # execute
  115. async with websockets.connect(
  116. websocket_url, additional_headers=ws_headers
  117. ) as ws:
  118. await self.execute_in_jupyter(ws)
  119. async def execute_in_jupyter(self, ws) -> None:
  120. # send message
  121. msg_id = uuid.uuid4().hex
  122. await ws.send(
  123. json.dumps(
  124. {
  125. "header": {
  126. "msg_id": msg_id,
  127. "msg_type": "execute_request",
  128. "username": "user",
  129. "session": uuid.uuid4().hex,
  130. "date": "",
  131. "version": "5.3",
  132. },
  133. "parent_header": {},
  134. "metadata": {},
  135. "content": {
  136. "code": self.code,
  137. "silent": False,
  138. "store_history": True,
  139. "user_expressions": {},
  140. "allow_stdin": False,
  141. "stop_on_error": True,
  142. },
  143. "channel": "shell",
  144. }
  145. )
  146. )
  147. # parse message
  148. stdout, stderr, result = "", "", []
  149. while True:
  150. try:
  151. # wait for message
  152. message = await asyncio.wait_for(ws.recv(), self.timeout)
  153. message_data = json.loads(message)
  154. # msg id not match, skip
  155. if message_data.get("parent_header", {}).get("msg_id") != msg_id:
  156. continue
  157. # check message type
  158. msg_type = message_data.get("msg_type")
  159. match msg_type:
  160. case "stream":
  161. if message_data["content"]["name"] == "stdout":
  162. stdout += message_data["content"]["text"]
  163. elif message_data["content"]["name"] == "stderr":
  164. stderr += message_data["content"]["text"]
  165. case "execute_result" | "display_data":
  166. data = message_data["content"]["data"]
  167. if "image/png" in data:
  168. result.append(f"data:image/png;base64,{data['image/png']}")
  169. elif "text/plain" in data:
  170. result.append(data["text/plain"])
  171. case "error":
  172. stderr += "\n".join(message_data["content"]["traceback"])
  173. case "status":
  174. if message_data["content"]["execution_state"] == "idle":
  175. break
  176. except asyncio.TimeoutError:
  177. stderr += "\nExecution timed out."
  178. break
  179. self.result.stdout = stdout.strip()
  180. self.result.stderr = stderr.strip()
  181. self.result.result = "\n".join(result).strip() if result else ""
  182. async def execute_code_jupyter(
  183. base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
  184. ) -> dict:
  185. async with JupyterCodeExecuter(
  186. base_url, code, token, password, timeout
  187. ) as executor:
  188. result = await executor.run()
  189. return result.model_dump()