Browse Source

fix: arena access control

Timothy Jaeryang Baek 5 months ago
parent
commit
f37d847521

+ 1 - 1
backend/open_webui/apps/ollama/main.py

@@ -958,7 +958,7 @@ async def generate_chat_completion(
                     status_code=403,
                     status_code=403,
                     detail="Model not found",
                     detail="Model not found",
                 )
                 )
-    else:
+    elif not bypass_filter:
         if user.role != "admin":
         if user.role != "admin":
             raise HTTPException(
             raise HTTPException(
                 status_code=403,
                 status_code=403,

+ 1 - 1
backend/open_webui/apps/openai/main.py

@@ -510,7 +510,7 @@ async def generate_chat_completion(
                     status_code=403,
                     status_code=403,
                     detail="Model not found",
                     detail="Model not found",
                 )
                 )
-    else:
+    elif not bypass_filter:
         if user.role != "admin":
         if user.role != "admin":
             raise HTTPException(
             raise HTTPException(
                 status_code=403,
                 status_code=403,

+ 60 - 36
backend/open_webui/main.py

@@ -557,21 +557,34 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
 
         model_info = Models.get_model_by_id(model["id"])
         model_info = Models.get_model_by_id(model["id"])
         if user.role == "user":
         if user.role == "user":
-            if not model_info:
-                return JSONResponse(
-                    status_code=status.HTTP_404_NOT_FOUND,
-                    content={"detail": "Model not found"},
-                )
-            elif not (
-                user.id == model_info.user_id
-                or has_access(
-                    user.id, type="read", access_control=model_info.access_control
-                )
-            ):
-                return JSONResponse(
-                    status_code=status.HTTP_403_FORBIDDEN,
-                    content={"detail": "User does not have access to the model"},
-                )
+            if model.get("arena"):
+                if not has_access(
+                    user.id,
+                    type="read",
+                    access_control=model.get("info", {})
+                    .get("meta", {})
+                    .get("access_control", {}),
+                ):
+                    raise HTTPException(
+                        status_code=403,
+                        detail="Model not found",
+                    )
+            else:
+                if not model_info:
+                    return JSONResponse(
+                        status_code=status.HTTP_404_NOT_FOUND,
+                        content={"detail": "Model not found"},
+                    )
+                elif not (
+                    user.id == model_info.user_id
+                    or has_access(
+                        user.id, type="read", access_control=model_info.access_control
+                    )
+                ):
+                    return JSONResponse(
+                        status_code=status.HTTP_403_FORBIDDEN,
+                        content={"detail": "User does not have access to the model"},
+                    )
 
 
         metadata = {
         metadata = {
             "chat_id": body.pop("chat_id", None),
             "chat_id": body.pop("chat_id", None),
@@ -1160,24 +1173,38 @@ async def generate_chat_completions(
         )
         )
 
 
     model = models[model_id]
     model = models[model_id]
+
     # Check if user has access to the model
     # Check if user has access to the model
-    if user.role == "user":
-        model_info = Models.get_model_by_id(model_id)
-        if not model_info:
-            raise HTTPException(
-                status_code=404,
-                detail="Model not found",
-            )
-        elif not (
-            user.id == model_info.user_id
-            or has_access(
-                user.id, type="read", access_control=model_info.access_control
-            )
-        ):
-            raise HTTPException(
-                status_code=403,
-                detail="Model not found",
-            )
+    if not bypass_filter and user.role == "user":
+        if model.get("arena"):
+            if not has_access(
+                user.id,
+                type="read",
+                access_control=model.get("info", {})
+                .get("meta", {})
+                .get("access_control", {}),
+            ):
+                raise HTTPException(
+                    status_code=403,
+                    detail="Model not found",
+                )
+        else:
+            model_info = Models.get_model_by_id(model_id)
+            if not model_info:
+                raise HTTPException(
+                    status_code=404,
+                    detail="Model not found",
+                )
+            elif not (
+                user.id == model_info.user_id
+                or has_access(
+                    user.id, type="read", access_control=model_info.access_control
+                )
+            ):
+                raise HTTPException(
+                    status_code=403,
+                    detail="Model not found",
+                )
 
 
     if model["owned_by"] == "arena":
     if model["owned_by"] == "arena":
         model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
         model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
@@ -1186,9 +1213,7 @@ async def generate_chat_completions(
             model_ids = [
             model_ids = [
                 model["id"]
                 model["id"]
                 for model in await get_all_models()
                 for model in await get_all_models()
-                if model.get("owned_by") != "arena"
-                and not model.get("info", {}).get("meta", {}).get("hidden", False)
-                and model["id"] not in model_ids
+                if model.get("owned_by") != "arena" and model["id"] not in model_ids
             ]
             ]
 
 
         selected_model_id = None
         selected_model_id = None
@@ -1199,7 +1224,6 @@ async def generate_chat_completions(
                 model["id"]
                 model["id"]
                 for model in await get_all_models()
                 for model in await get_all_models()
                 if model.get("owned_by") != "arena"
                 if model.get("owned_by") != "arena"
-                and not model.get("info", {}).get("meta", {}).get("hidden", False)
             ]
             ]
             selected_model_id = random.choice(model_ids)
             selected_model_id = random.choice(model_ids)