Переглянути джерело

Merge pull request #391 from ollama-webui/async-proxy

Async proxy
Timothy Jaeryang Baek 1 рік тому
батько
коміт
ca2ff93c3c
1 змінених файлів з 50 додано та 25 видалено
  1. 50 25
      backend/apps/ollama/main.py

+ 50 - 25
backend/apps/ollama/main.py

@@ -11,6 +11,8 @@ from constants import ERROR_MESSAGES
 from utils.utils import decode_token, get_current_user
 from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
 
+import aiohttp
+
 app = FastAPI()
 app.add_middleware(
     CORSMiddleware,
@@ -30,8 +32,7 @@ async def get_ollama_api_url(user=Depends(get_current_user)):
     if user and user.role == "admin":
         return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
     else:
-        raise HTTPException(status_code=401,
-                            detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
 
 
 class UrlUpdateForm(BaseModel):
@@ -39,14 +40,29 @@ class UrlUpdateForm(BaseModel):
 
 
 @app.post("/url/update")
-async def update_ollama_api_url(form_data: UrlUpdateForm,
-                                user=Depends(get_current_user)):
+async def update_ollama_api_url(
+    form_data: UrlUpdateForm, user=Depends(get_current_user)
+):
     if user and user.role == "admin":
         app.state.OLLAMA_API_BASE_URL = form_data.url
         return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
     else:
-        raise HTTPException(status_code=401,
-                            detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+
+
+# async def fetch_sse(method, target_url, body, headers):
+#     async with aiohttp.ClientSession() as session:
+#         try:
+#             async with session.request(
+#                 method, target_url, data=body, headers=headers
+#             ) as response:
+#                 print(response.status)
+#                 async for line in response.content:
+#                     yield line
+#         except Exception as e:
+#             print(e)
+#             error_detail = "Ollama WebUI: Server Connection Error"
+#             yield json.dumps({"error": error_detail, "message": str(e)}).encode()
 
 
 @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
@@ -59,42 +75,51 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
     if user.role in ["user", "admin"]:
         if path in ["pull", "delete", "push", "copy", "create"]:
             if user.role != "admin":
-                raise HTTPException(status_code=401,
-                                    detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+                raise HTTPException(
+                    status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
+                )
     else:
-        raise HTTPException(status_code=401,
-                            detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
 
     headers.pop("Host", None)
     headers.pop("Authorization", None)
     headers.pop("Origin", None)
     headers.pop("Referer", None)
 
+    session = aiohttp.ClientSession()
+    response = None
     try:
-        r = requests.request(
-            method=request.method,
-            url=target_url,
-            data=body,
-            headers=headers,
-            stream=True,
+        response = await session.request(
+            request.method, target_url, data=body, headers=headers
         )
 
-        r.raise_for_status()
+        if not response.ok:
+            data = await response.json()
+            print(data)
+            response.raise_for_status()
+
+        async def generate():
+            async for line in response.content:
+                yield line
+            await session.close()
+
+        return StreamingResponse(generate(), response.status)
 
-        return StreamingResponse(
-            r.iter_content(chunk_size=8192),
-            status_code=r.status_code,
-            headers=dict(r.headers),
-        )
     except Exception as e:
         print(e)
         error_detail = "Ollama WebUI: Server Connection Error"
-        if r is not None:
+
+        if response is not None:
             try:
-                res = r.json()
+                res = await response.json()
                 if "error" in res:
                     error_detail = f"Ollama: {res['error']}"
             except:
                 error_detail = f"Ollama: {e}"
 
-        raise HTTPException(status_code=r.status_code, detail=error_detail)
+        await session.close()
+
+        raise HTTPException(
+            status_code=response.status if response else 500,
+            detail=error_detail,
+        )