|
@@ -1,129 +1,185 @@
|
|
|
import asyncio
|
|
|
import json
|
|
|
+import logging
|
|
|
import uuid
|
|
|
from typing import Optional
|
|
|
|
|
|
-import httpx
|
|
|
+import aiohttp
|
|
|
import websockets
|
|
|
+from pydantic import BaseModel
|
|
|
+from websockets import ClientConnection
|
|
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
|
|
-async def execute_code_jupyter(
|
|
|
- jupyter_url: str, code: str, token: str = None, password: str = None, timeout: int = 60
|
|
|
-) -> Optional[dict]:
|
|
|
+
|
|
|
+class ResultModel(BaseModel):
|
|
|
"""
|
|
|
- Executes Python code in a Jupyter kernel.
|
|
|
- Supports authentication with a token or password.
|
|
|
- :param jupyter_url: Jupyter server URL (e.g., "http://localhost:8888")
|
|
|
- :param code: Code to execute
|
|
|
- :param token: Jupyter authentication token (optional)
|
|
|
- :param password: Jupyter password (optional)
|
|
|
- :param timeout: WebSocket timeout in seconds (default: 10s)
|
|
|
- :return: Dictionary with stdout, stderr, and result
|
|
|
- - Images are prefixed with "base64:image/png," and separated by newlines if multiple.
|
|
|
+ Execute Code Result Model
|
|
|
"""
|
|
|
|
|
|
- jupyter_url = jupyter_url.rstrip("/")
|
|
|
- client = httpx.AsyncClient(base_url=jupyter_url, timeout=timeout, follow_redirects=True)
|
|
|
- headers = {}
|
|
|
+ stdout: Optional[str] = ""
|
|
|
+ stderr: Optional[str] = ""
|
|
|
+ result: Optional[str] = ""
|
|
|
|
|
|
- # password authentication
|
|
|
- if password and not token:
|
|
|
- try:
|
|
|
- response = await client.get("/login")
|
|
|
- response.raise_for_status()
|
|
|
- xsrf_token = response.cookies.get("_xsrf")
|
|
|
- if not xsrf_token:
|
|
|
- raise ValueError("_xsrf token not found")
|
|
|
- response = await client.post("/login", data={"_xsrf": xsrf_token, "password": password})
|
|
|
- response.raise_for_status()
|
|
|
- headers["X-XSRFToken"] = xsrf_token
|
|
|
- except Exception as e:
|
|
|
- return {"stdout": "", "stderr": f"Authentication Error: {str(e)}", "result": ""}
|
|
|
|
|
|
- # token authentication
|
|
|
- params = {"token": token} if token else {}
|
|
|
+class JupyterCodeExecuter:
|
|
|
+ """
|
|
|
+ Execute code in jupyter notebook
|
|
|
+ """
|
|
|
|
|
|
- kernel_id = ""
|
|
|
- try:
|
|
|
- response = await client.post(url="/api/kernels", params=params, headers=headers)
|
|
|
- response.raise_for_status()
|
|
|
- kernel_id = response.json()["id"]
|
|
|
+ def __init__(self, base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60):
|
|
|
+ """
|
|
|
+ :param base_url: Jupyter server URL (e.g., "http://localhost:8888")
|
|
|
+ :param code: Code to execute
|
|
|
+ :param token: Jupyter authentication token (optional)
|
|
|
+ :param password: Jupyter password (optional)
|
|
|
+ :param timeout: WebSocket timeout in seconds (default: 60s)
|
|
|
+ """
|
|
|
+ self.base_url = base_url.rstrip("/")
|
|
|
+ self.code = code
|
|
|
+ self.token = token
|
|
|
+ self.password = password
|
|
|
+ self.timeout = timeout
|
|
|
+ self.kernel_id = ""
|
|
|
+ self.session = aiohttp.ClientSession(base_url=self.base_url)
|
|
|
+ self.params = {}
|
|
|
+ self.result = ResultModel()
|
|
|
+
|
|
|
+ async def __aenter__(self):
|
|
|
+ return self
|
|
|
+
|
|
|
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
|
+ if self.kernel_id:
|
|
|
+ try:
|
|
|
+ await self.session.delete(f"/api/kernels/{self.kernel_id}", params=self.params)
|
|
|
+ except Exception as err:
|
|
|
+ logger.exception("close kernel failed, %s", err)
|
|
|
+ await self.session.close()
|
|
|
+
|
|
|
+ async def run(self) -> ResultModel:
|
|
|
+ try:
|
|
|
+ await self.sign_in()
|
|
|
+ await self.init_kernel()
|
|
|
+ await self.execute_code()
|
|
|
+ except Exception as err:
|
|
|
+ logger.error(err)
|
|
|
+ self.result.stderr = f"Error: {err}"
|
|
|
+ return self.result
|
|
|
+
|
|
|
+ async def sign_in(self) -> None:
|
|
|
+ # password authentication
|
|
|
+ if self.password and not self.token:
|
|
|
+ async with self.session.get("/login") as response:
|
|
|
+ response.raise_for_status()
|
|
|
+ xsrf_token = response.cookies["_xsrf"].value
|
|
|
+ if not xsrf_token:
|
|
|
+ raise ValueError("_xsrf token not found")
|
|
|
+ self.session.cookie_jar.update_cookies(response.cookies)
|
|
|
+ self.session.headers.update({"X-XSRFToken": xsrf_token})
|
|
|
+ async with self.session.post(
|
|
|
+ "/login", data={"_xsrf": xsrf_token, "password": self.password}, allow_redirects=False
|
|
|
+ ) as response:
|
|
|
+ response.raise_for_status()
|
|
|
+ self.session.cookie_jar.update_cookies(response.cookies)
|
|
|
+
|
|
|
+ # token authentication
|
|
|
+ if self.token:
|
|
|
+ self.params.update({"token": self.token})
|
|
|
+
|
|
|
+ async def init_kernel(self) -> None:
|
|
|
+ async with self.session.post(url="/api/kernels", params=self.params) as response:
|
|
|
+ response.raise_for_status()
|
|
|
+ kernel_data = await response.json()
|
|
|
+ self.kernel_id = kernel_data["id"]
|
|
|
|
|
|
- ws_base = jupyter_url.replace("http", "ws")
|
|
|
- websocket_url = f"{ws_base}/api/kernels/{kernel_id}/channels" + (f"?token={token}" if token else "")
|
|
|
+ def init_ws(self) -> (str, dict):
|
|
|
+ ws_base = self.base_url.replace("http", "ws")
|
|
|
+ ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()])
|
|
|
+ websocket_url = f"{ws_base}/api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}"
|
|
|
ws_headers = {}
|
|
|
- if password and not token:
|
|
|
+ if self.password and not self.token:
|
|
|
ws_headers = {
|
|
|
- "X-XSRFToken": client.cookies.get("_xsrf"),
|
|
|
- "Cookie": "; ".join([f"{name}={value}" for name, value in client.cookies.items()]),
|
|
|
+ "Cookie": "; ".join([f"{cookie.key}={cookie.value}" for cookie in self.session.cookie_jar]),
|
|
|
+ **self.session.headers,
|
|
|
}
|
|
|
+ return websocket_url, ws_headers
|
|
|
|
|
|
+ async def execute_code(self) -> None:
|
|
|
+ # initialize ws
|
|
|
+ websocket_url, ws_headers = self.init_ws()
|
|
|
+ # execute
|
|
|
async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws:
|
|
|
- msg_id = str(uuid.uuid4())
|
|
|
- await ws.send(
|
|
|
- json.dumps(
|
|
|
- {
|
|
|
- "header": {
|
|
|
- "msg_id": msg_id,
|
|
|
- "msg_type": "execute_request",
|
|
|
- "username": "user",
|
|
|
- "session": str(uuid.uuid4()),
|
|
|
- "date": "",
|
|
|
- "version": "5.3",
|
|
|
- },
|
|
|
- "parent_header": {},
|
|
|
- "metadata": {},
|
|
|
- "content": {
|
|
|
- "code": code,
|
|
|
- "silent": False,
|
|
|
- "store_history": True,
|
|
|
- "user_expressions": {},
|
|
|
- "allow_stdin": False,
|
|
|
- "stop_on_error": True,
|
|
|
- },
|
|
|
- "channel": "shell",
|
|
|
- }
|
|
|
- )
|
|
|
+ await self.execute_in_jupyter(ws)
|
|
|
+
|
|
|
+ async def execute_in_jupyter(self, ws: ClientConnection) -> None:
|
|
|
+ # send message
|
|
|
+ msg_id = uuid.uuid4().hex
|
|
|
+ await ws.send(
|
|
|
+ json.dumps(
|
|
|
+ {
|
|
|
+ "header": {
|
|
|
+ "msg_id": msg_id,
|
|
|
+ "msg_type": "execute_request",
|
|
|
+ "username": "user",
|
|
|
+ "session": uuid.uuid4().hex,
|
|
|
+ "date": "",
|
|
|
+ "version": "5.3",
|
|
|
+ },
|
|
|
+ "parent_header": {},
|
|
|
+ "metadata": {},
|
|
|
+ "content": {
|
|
|
+ "code": self.code,
|
|
|
+ "silent": False,
|
|
|
+ "store_history": True,
|
|
|
+ "user_expressions": {},
|
|
|
+ "allow_stdin": False,
|
|
|
+ "stop_on_error": True,
|
|
|
+ },
|
|
|
+ "channel": "shell",
|
|
|
+ }
|
|
|
)
|
|
|
+ )
|
|
|
+ # parse message
|
|
|
+ stdout, stderr, result = "", "", []
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ # wait for message
|
|
|
+ message = await asyncio.wait_for(ws.recv(), self.timeout)
|
|
|
+ message_data = json.loads(message)
|
|
|
+ # msg id not match, skip
|
|
|
+ if message_data.get("parent_header", {}).get("msg_id") != msg_id:
|
|
|
+ continue
|
|
|
+ # check message type
|
|
|
+ msg_type = message_data.get("msg_type")
|
|
|
+ match msg_type:
|
|
|
+ case "stream":
|
|
|
+ if message_data["content"]["name"] == "stdout":
|
|
|
+ stdout += message_data["content"]["text"]
|
|
|
+ elif message_data["content"]["name"] == "stderr":
|
|
|
+ stderr += message_data["content"]["text"]
|
|
|
+ case "execute_result" | "display_data":
|
|
|
+ data = message_data["content"]["data"]
|
|
|
+ if "image/png" in data:
|
|
|
+ result.append(f"data:image/png;base64,{data['image/png']}")
|
|
|
+ elif "text/plain" in data:
|
|
|
+ result.append(data["text/plain"])
|
|
|
+ case "error":
|
|
|
+ stderr += "\n".join(message_data["content"]["traceback"])
|
|
|
+ case "status":
|
|
|
+ if message_data["content"]["execution_state"] == "idle":
|
|
|
+ break
|
|
|
+
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ stderr += "\nExecution timed out."
|
|
|
+ break
|
|
|
+ self.result.stdout = stdout.strip()
|
|
|
+ self.result.stderr = stderr.strip()
|
|
|
+ self.result.result = "\n".join(result).strip() if result else ""
|
|
|
|
|
|
- stdout, stderr, result = "", "", []
|
|
|
-
|
|
|
- while True:
|
|
|
- try:
|
|
|
- message = await asyncio.wait_for(ws.recv(), timeout)
|
|
|
- message_data = json.loads(message)
|
|
|
- if message_data.get("parent_header", {}).get("msg_id") != msg_id:
|
|
|
- continue
|
|
|
-
|
|
|
- msg_type = message_data.get("msg_type")
|
|
|
- match msg_type:
|
|
|
- case "stream":
|
|
|
- if message_data["content"]["name"] == "stdout":
|
|
|
- stdout += message_data["content"]["text"]
|
|
|
- elif message_data["content"]["name"] == "stderr":
|
|
|
- stderr += message_data["content"]["text"]
|
|
|
- case "execute_result" | "display_data":
|
|
|
- data = message_data["content"]["data"]
|
|
|
- if "image/png" in data:
|
|
|
- result.append(f"data:image/png;base64,{data['image/png']}")
|
|
|
- elif "text/plain" in data:
|
|
|
- result.append(data["text/plain"])
|
|
|
- case "error":
|
|
|
- stderr += "\n".join(message_data["content"]["traceback"])
|
|
|
- case "status":
|
|
|
- if message_data["content"]["execution_state"] == "idle":
|
|
|
- break
|
|
|
-
|
|
|
- except asyncio.TimeoutError:
|
|
|
- stderr += "\nExecution timed out."
|
|
|
- break
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""}
|
|
|
-
|
|
|
- finally:
|
|
|
- if kernel_id:
|
|
|
- await client.delete(f"/api/kernels/{kernel_id}", headers=headers, params=params)
|
|
|
- await client.aclose()
|
|
|
-
|
|
|
- return {"stdout": stdout.strip(), "stderr": stderr.strip(), "result": "\n".join(result).strip() if result else ""}
|
|
|
+
|
|
|
+async def execute_code_jupyter(
|
|
|
+ base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
|
|
|
+) -> dict:
|
|
|
+ async with JupyterCodeExecuter(base_url, code, token, password, timeout) as executor:
|
|
|
+ result = await executor.run()
|
|
|
+ return result.model_dump()
|