Browse Source

fix: access control behaviour

Timothy Jaeryang Baek 5 months ago
parent
commit
1d4c3a8c58

+ 47 - 44
backend/open_webui/apps/ollama/main.py

@@ -351,22 +351,21 @@ async def get_ollama_tags(
                 status_code=r.status_code if r else 500,
                 detail=error_detail,
             )
-        
+
     if user.role == "user":
         # Filter models based on user access control
         filtered_models = []
         for model in models.get("models", []):
             model_info = Models.get_model_by_id(model["model"])
             if model_info:
-                if has_access(
+                if user.id == model_info.user_id or has_access(
                     user.id, type="read", access_control=model_info.access_control
                 ):
                     filtered_models.append(model)
             else:
                 filtered_models.append(model)
         models["models"] = filtered_models
-    
-        
+
     return models
 
 
@@ -953,18 +952,21 @@ async def generate_chat_completion(
             payload = apply_model_system_prompt_to_body(params, payload, user)
 
         # Check if user has access to the model
-        if not bypass_filter and user.role == "user" and not has_access(
-            user.id, type="read", access_control=model_info.access_control
-        ):
-            raise HTTPException(
-                status_code=403,
-                detail="Model not found",
-            )
-    
+        if not bypass_filter and user.role == "user":
+            if not (
+                user.id == model_info.user_id
+                or has_access(
+                    user.id, type="read", access_control=model_info.access_control
+                )
+            ):
+                raise HTTPException(
+                    status_code=403,
+                    detail="Model not found",
+                )
+
     if ":" not in payload["model"]:
         payload["model"] = f"{payload['model']}:latest"
 
-
     url = await get_ollama_url(url_idx, payload["model"])
     log.info(f"url: {url}")
     log.debug(f"generate_chat_completion() - 2.payload = {payload}")
@@ -1026,7 +1028,6 @@ async def generate_openai_chat_completion(
     if ":" not in model_id:
         model_id = f"{model_id}:latest"
 
-
     model_info = Models.get_model_by_id(model_id)
     if model_info:
         if model_info.base_model_id:
@@ -1039,13 +1040,17 @@ async def generate_openai_chat_completion(
             payload = apply_model_system_prompt_to_body(params, payload, user)
 
         # Check if user has access to the model
-        if user.role == "user" and not has_access(
-            user.id, type="read", access_control=model_info.access_control
-        ):
-            raise HTTPException(
-                status_code=403,
-                detail="Model not found",
-            )
+        if user.role == "user":
+            if not (
+                user.id == model_info.user_id
+                or has_access(
+                    user.id, type="read", access_control=model_info.access_control
+                )
+            ):
+                raise HTTPException(
+                    status_code=403,
+                    detail="Model not found",
+                )
 
     if ":" not in payload["model"]:
         payload["model"] = f"{payload['model']}:latest"
@@ -1071,19 +1076,19 @@ async def get_openai_models(
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
 ):
-    
+
     models = []
     if url_idx is None:
         model_list = await get_all_models()
         models = [
-                {
-                    "id": model["model"],
-                    "object": "model",
-                    "created": int(time.time()),
-                    "owned_by": "openai",
-                }
-                for model in model_list["models"]
-            ]
+            {
+                "id": model["model"],
+                "object": "model",
+                "created": int(time.time()),
+                "owned_by": "openai",
+            }
+            for model in model_list["models"]
+        ]
 
     else:
         url = app.state.config.OLLAMA_BASE_URLS[url_idx]
@@ -1094,14 +1099,14 @@ async def get_openai_models(
             model_list = r.json()
 
             models = [
-                    {
-                        "id": model["model"],
-                        "object": "model",
-                        "created": int(time.time()),
-                        "owned_by": "openai",
-                    }
-                    for model in models["models"]
-                ]
+                {
+                    "id": model["model"],
+                    "object": "model",
+                    "created": int(time.time()),
+                    "owned_by": "openai",
+                }
+                for model in models["models"]
+            ]
         except Exception as e:
             log.exception(e)
             error_detail = "Open WebUI: Server Connection Error"
@@ -1117,7 +1122,6 @@ async def get_openai_models(
                 status_code=r.status_code if r else 500,
                 detail=error_detail,
             )
-        
 
     if user.role == "user":
         # Filter models based on user access control
@@ -1125,19 +1129,18 @@ async def get_openai_models(
         for model in models:
             model_info = Models.get_model_by_id(model["id"])
             if model_info:
-                if has_access(
+                if user.id == model_info.user_id or has_access(
                     user.id, type="read", access_control=model_info.access_control
                 ):
                     filtered_models.append(model)
             else:
                 filtered_models.append(model)
         models = filtered_models
-        
 
     return {
-            "data": models,
-            "object": "list",
-        }
+        "data": models,
+        "object": "list",
+    }
 
 
 class UrlForm(BaseModel):

+ 12 - 8
backend/open_webui/apps/openai/main.py

@@ -420,7 +420,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
         for model in models.get("data", []):
             model_info = Models.get_model_by_id(model["id"])
             if model_info:
-                if has_access(
+                if user.id == model_info.user_id or has_access(
                     user.id, type="read", access_control=model_info.access_control
                 ):
                     filtered_models.append(model)
@@ -501,13 +501,17 @@ async def generate_chat_completion(
         payload = apply_model_system_prompt_to_body(params, payload, user)
 
         # Check if user has access to the model
-        if not bypass_filter and user.role == "user" and not has_access(
-            user.id, type="read", access_control=model_info.access_control
-        ):
-            raise HTTPException(
-                status_code=403,
-                detail="Model not found",
-            )
+        if not bypass_filter and user.role == "user":
+            if not (
+                user.id == model_info.user_id
+                or has_access(
+                    user.id, type="read", access_control=model_info.access_control
+                )
+            ):
+                raise HTTPException(
+                    status_code=403,
+                    detail="Model not found",
+                )
 
     # Attemp to get urlIdx from the model
     models = await get_all_models()

+ 11 - 5
backend/open_webui/main.py

@@ -557,8 +557,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
         model_info = Models.get_model_by_id(model["id"])
         if user.role == "user":
-            if model_info and not has_access(
-                user.id, type="read", access_control=model_info.access_control
+            if model_info and not (
+                user.id == model_info.user_id
+                or has_access(
+                    user.id, type="read", access_control=model_info.access_control
+                )
             ):
                 return JSONResponse(
                     status_code=status.HTTP_403_FORBIDDEN,
@@ -1106,7 +1109,7 @@ async def get_models(user=Depends(get_verified_user)):
         for model in models:
             model_info = Models.get_model_by_id(model["id"])
             if model_info:
-                if has_access(
+                if user.id == model_info.user_id or has_access(
                     user.id, type="read", access_control=model_info.access_control
                 ):
                     filtered_models.append(model)
@@ -1144,8 +1147,11 @@ async def generate_chat_completions(
     # Check if user has access to the model
     if user.role == "user":
         model_info = Models.get_model_by_id(model_id)
-        if not has_access(
-            user.id, type="read", access_control=model_info.access_control
+        if not (
+            user.id == model_info.user_id
+            or has_access(
+                user.id, type="read", access_control=model_info.access_control
+            )
         ):
             raise HTTPException(
                 status_code=403,