code_interpreter.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import asyncio
  2. import json
  3. import logging
  4. import uuid
  5. from typing import Optional
  6. import aiohttp
  7. import websockets
  8. from pydantic import BaseModel
  9. from websockets import ClientConnection
  10. from open_webui.env import SRC_LOG_LEVELS
  11. logger = logging.getLogger(__name__)
  12. logger.setLevel(SRC_LOG_LEVELS["MAIN"])
  13. class ResultModel(BaseModel):
  14. """
  15. Execute Code Result Model
  16. """
  17. stdout: Optional[str] = ""
  18. stderr: Optional[str] = ""
  19. result: Optional[str] = ""
  20. class JupyterCodeExecuter:
  21. """
  22. Execute code in jupyter notebook
  23. """
  24. def __init__(self, base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60):
  25. """
  26. :param base_url: Jupyter server URL (e.g., "http://localhost:8888")
  27. :param code: Code to execute
  28. :param token: Jupyter authentication token (optional)
  29. :param password: Jupyter password (optional)
  30. :param timeout: WebSocket timeout in seconds (default: 60s)
  31. """
  32. self.base_url = base_url.rstrip("/")
  33. self.code = code
  34. self.token = token
  35. self.password = password
  36. self.timeout = timeout
  37. self.kernel_id = ""
  38. self.session = aiohttp.ClientSession(base_url=self.base_url)
  39. self.params = {}
  40. self.result = ResultModel()
  41. async def __aenter__(self):
  42. return self
  43. async def __aexit__(self, exc_type, exc_val, exc_tb):
  44. if self.kernel_id:
  45. try:
  46. await self.session.delete(f"/api/kernels/{self.kernel_id}", params=self.params)
  47. except Exception as err:
  48. logger.exception("close kernel failed, %s", err)
  49. await self.session.close()
  50. async def run(self) -> ResultModel:
  51. try:
  52. await self.sign_in()
  53. await self.init_kernel()
  54. await self.execute_code()
  55. except Exception as err:
  56. logger.error(err)
  57. self.result.stderr = f"Error: {err}"
  58. return self.result
  59. async def sign_in(self) -> None:
  60. # password authentication
  61. if self.password and not self.token:
  62. async with self.session.get("/login") as response:
  63. response.raise_for_status()
  64. xsrf_token = response.cookies["_xsrf"].value
  65. if not xsrf_token:
  66. raise ValueError("_xsrf token not found")
  67. self.session.cookie_jar.update_cookies(response.cookies)
  68. self.session.headers.update({"X-XSRFToken": xsrf_token})
  69. async with self.session.post(
  70. "/login", data={"_xsrf": xsrf_token, "password": self.password}, allow_redirects=False
  71. ) as response:
  72. response.raise_for_status()
  73. self.session.cookie_jar.update_cookies(response.cookies)
  74. # token authentication
  75. if self.token:
  76. self.params.update({"token": self.token})
  77. async def init_kernel(self) -> None:
  78. async with self.session.post(url="/api/kernels", params=self.params) as response:
  79. response.raise_for_status()
  80. kernel_data = await response.json()
  81. self.kernel_id = kernel_data["id"]
  82. def init_ws(self) -> (str, dict):
  83. ws_base = self.base_url.replace("http", "ws")
  84. ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()])
  85. websocket_url = f"{ws_base}/api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}"
  86. ws_headers = {}
  87. if self.password and not self.token:
  88. ws_headers = {
  89. "Cookie": "; ".join([f"{cookie.key}={cookie.value}" for cookie in self.session.cookie_jar]),
  90. **self.session.headers,
  91. }
  92. return websocket_url, ws_headers
  93. async def execute_code(self) -> None:
  94. # initialize ws
  95. websocket_url, ws_headers = self.init_ws()
  96. # execute
  97. async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws:
  98. await self.execute_in_jupyter(ws)
  99. async def execute_in_jupyter(self, ws: ClientConnection) -> None:
  100. # send message
  101. msg_id = uuid.uuid4().hex
  102. await ws.send(
  103. json.dumps(
  104. {
  105. "header": {
  106. "msg_id": msg_id,
  107. "msg_type": "execute_request",
  108. "username": "user",
  109. "session": uuid.uuid4().hex,
  110. "date": "",
  111. "version": "5.3",
  112. },
  113. "parent_header": {},
  114. "metadata": {},
  115. "content": {
  116. "code": self.code,
  117. "silent": False,
  118. "store_history": True,
  119. "user_expressions": {},
  120. "allow_stdin": False,
  121. "stop_on_error": True,
  122. },
  123. "channel": "shell",
  124. }
  125. )
  126. )
  127. # parse message
  128. stdout, stderr, result = "", "", []
  129. while True:
  130. try:
  131. # wait for message
  132. message = await asyncio.wait_for(ws.recv(), self.timeout)
  133. message_data = json.loads(message)
  134. # msg id not match, skip
  135. if message_data.get("parent_header", {}).get("msg_id") != msg_id:
  136. continue
  137. # check message type
  138. msg_type = message_data.get("msg_type")
  139. match msg_type:
  140. case "stream":
  141. if message_data["content"]["name"] == "stdout":
  142. stdout += message_data["content"]["text"]
  143. elif message_data["content"]["name"] == "stderr":
  144. stderr += message_data["content"]["text"]
  145. case "execute_result" | "display_data":
  146. data = message_data["content"]["data"]
  147. if "image/png" in data:
  148. result.append(f"data:image/png;base64,{data['image/png']}")
  149. elif "text/plain" in data:
  150. result.append(data["text/plain"])
  151. case "error":
  152. stderr += "\n".join(message_data["content"]["traceback"])
  153. case "status":
  154. if message_data["content"]["execution_state"] == "idle":
  155. break
  156. except asyncio.TimeoutError:
  157. stderr += "\nExecution timed out."
  158. break
  159. self.result.stdout = stdout.strip()
  160. self.result.stderr = stderr.strip()
  161. self.result.result = "\n".join(result).strip() if result else ""
  162. async def execute_code_jupyter(
  163. base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
  164. ) -> dict:
  165. async with JupyterCodeExecuter(base_url, code, token, password, timeout) as executor:
  166. result = await executor.run()
  167. return result.model_dump()