code_interpreter.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import asyncio
  2. import json
  3. import logging
  4. import uuid
  5. from typing import Optional
  6. import aiohttp
  7. import websockets
  8. from pydantic import BaseModel
  9. from websockets import ClientConnection
  10. from open_webui.env import SRC_LOG_LEVELS
  11. logger = logging.getLogger(__name__)
  12. logger.setLevel(SRC_LOG_LEVELS["MAIN"])
  13. class ResultModel(BaseModel):
  14. """
  15. Execute Code Result Model
  16. """
  17. stdout: Optional[str] = ""
  18. stderr: Optional[str] = ""
  19. result: Optional[str] = ""
  20. class JupyterCodeExecuter:
  21. """
  22. Execute code in jupyter notebook
  23. """
  24. def __init__(self, base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60):
  25. """
  26. :param base_url: Jupyter server URL (e.g., "http://localhost:8888")
  27. :param code: Code to execute
  28. :param token: Jupyter authentication token (optional)
  29. :param password: Jupyter password (optional)
  30. :param timeout: WebSocket timeout in seconds (default: 60s)
  31. """
  32. self.base_url = base_url.rstrip("/")
  33. self.code = code
  34. self.token = token
  35. self.password = password
  36. self.timeout = timeout
  37. self.kernel_id = ""
  38. self.session = aiohttp.ClientSession(base_url=self.base_url)
  39. self.params = {}
  40. self.result = ResultModel()
  41. async def __aenter__(self):
  42. return self
  43. async def __aexit__(self, exc_type, exc_val, exc_tb):
  44. if self.kernel_id:
  45. try:
  46. async with self.session.delete(f"/api/kernels/{self.kernel_id}", params=self.params) as response:
  47. response.raise_for_status()
  48. except Exception as err:
  49. logger.exception("close kernel failed, %s", err)
  50. await self.session.close()
  51. async def run(self) -> ResultModel:
  52. try:
  53. await self.sign_in()
  54. await self.init_kernel()
  55. await self.execute_code()
  56. except Exception as err:
  57. logger.exception("execute code failed, %s", err)
  58. self.result.stderr = f"Error: {err}"
  59. return self.result
  60. async def sign_in(self) -> None:
  61. # password authentication
  62. if self.password and not self.token:
  63. async with self.session.get("/login") as response:
  64. response.raise_for_status()
  65. xsrf_token = response.cookies["_xsrf"].value
  66. if not xsrf_token:
  67. raise ValueError("_xsrf token not found")
  68. self.session.cookie_jar.update_cookies(response.cookies)
  69. self.session.headers.update({"X-XSRFToken": xsrf_token})
  70. async with self.session.post(
  71. "/login", data={"_xsrf": xsrf_token, "password": self.password}, allow_redirects=False
  72. ) as response:
  73. response.raise_for_status()
  74. self.session.cookie_jar.update_cookies(response.cookies)
  75. # token authentication
  76. if self.token:
  77. self.params.update({"token": self.token})
  78. async def init_kernel(self) -> None:
  79. async with self.session.post(url="/api/kernels", params=self.params) as response:
  80. response.raise_for_status()
  81. kernel_data = await response.json()
  82. self.kernel_id = kernel_data["id"]
  83. def init_ws(self) -> (str, dict):
  84. ws_base = self.base_url.replace("http", "ws")
  85. ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()])
  86. websocket_url = f"{ws_base}/api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}"
  87. ws_headers = {}
  88. if self.password and not self.token:
  89. ws_headers = {
  90. "Cookie": "; ".join([f"{cookie.key}={cookie.value}" for cookie in self.session.cookie_jar]),
  91. **self.session.headers,
  92. }
  93. return websocket_url, ws_headers
  94. async def execute_code(self) -> None:
  95. # initialize ws
  96. websocket_url, ws_headers = self.init_ws()
  97. # execute
  98. async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws:
  99. await self.execute_in_jupyter(ws)
  100. async def execute_in_jupyter(self, ws: ClientConnection) -> None:
  101. # send message
  102. msg_id = uuid.uuid4().hex
  103. await ws.send(
  104. json.dumps(
  105. {
  106. "header": {
  107. "msg_id": msg_id,
  108. "msg_type": "execute_request",
  109. "username": "user",
  110. "session": uuid.uuid4().hex,
  111. "date": "",
  112. "version": "5.3",
  113. },
  114. "parent_header": {},
  115. "metadata": {},
  116. "content": {
  117. "code": self.code,
  118. "silent": False,
  119. "store_history": True,
  120. "user_expressions": {},
  121. "allow_stdin": False,
  122. "stop_on_error": True,
  123. },
  124. "channel": "shell",
  125. }
  126. )
  127. )
  128. # parse message
  129. stdout, stderr, result = "", "", []
  130. while True:
  131. try:
  132. # wait for message
  133. message = await asyncio.wait_for(ws.recv(), self.timeout)
  134. message_data = json.loads(message)
  135. # msg id not match, skip
  136. if message_data.get("parent_header", {}).get("msg_id") != msg_id:
  137. continue
  138. # check message type
  139. msg_type = message_data.get("msg_type")
  140. match msg_type:
  141. case "stream":
  142. if message_data["content"]["name"] == "stdout":
  143. stdout += message_data["content"]["text"]
  144. elif message_data["content"]["name"] == "stderr":
  145. stderr += message_data["content"]["text"]
  146. case "execute_result" | "display_data":
  147. data = message_data["content"]["data"]
  148. if "image/png" in data:
  149. result.append(f"data:image/png;base64,{data['image/png']}")
  150. elif "text/plain" in data:
  151. result.append(data["text/plain"])
  152. case "error":
  153. stderr += "\n".join(message_data["content"]["traceback"])
  154. case "status":
  155. if message_data["content"]["execution_state"] == "idle":
  156. break
  157. except asyncio.TimeoutError:
  158. stderr += "\nExecution timed out."
  159. break
  160. self.result.stdout = stdout.strip()
  161. self.result.stderr = stderr.strip()
  162. self.result.result = "\n".join(result).strip() if result else ""
  163. async def execute_code_jupyter(
  164. base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
  165. ) -> dict:
  166. async with JupyterCodeExecuter(base_url, code, token, password, timeout) as executor:
  167. result = await executor.run()
  168. return result.model_dump()