Explorar o código

refac: access control

Timothy J. Baek hai 5 meses
pai
achega
4a34ca35f0
Modificáronse 2 ficheiros con 75 adicións e 60 borrados
  1. 74 59
      backend/open_webui/apps/ollama/main.py
  2. 1 1
      backend/open_webui/apps/openai/main.py

+ 74 - 59
backend/open_webui/apps/ollama/main.py

@@ -43,6 +43,7 @@ from open_webui.utils.payload import (
     apply_model_system_prompt_to_body,
 )
 from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.utils.access_control import has_access
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
@@ -316,22 +317,9 @@ async def get_all_models():
 async def get_ollama_tags(
     url_idx: Optional[int] = None, user=Depends(get_verified_user)
 ):
+    models = []
     if url_idx is None:
         models = await get_all_models()
-
-        # TODO: Check User Group and Filter Models
-        # if app.state.config.ENABLE_MODEL_FILTER:
-        #     if user.role == "user":
-        #         models["models"] = list(
-        #             filter(
-        #                 lambda model: model["name"]
-        #                 in app.state.config.MODEL_FILTER_LIST,
-        #                 models["models"],
-        #             )
-        #         )
-        #         return models
-
-        return models
     else:
         url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 
@@ -347,7 +335,7 @@ async def get_ollama_tags(
             r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers)
             r.raise_for_status()
 
-            return r.json()
+            models = r.json()
         except Exception as e:
             log.exception(e)
             error_detail = "Open WebUI: Server Connection Error"
@@ -363,6 +351,23 @@ async def get_ollama_tags(
                 status_code=r.status_code if r else 500,
                 detail=error_detail,
             )
+        
+    if user.role == "user":
+        # Filter models based on user access control
+        filtered_models = []
+        for model in models.get("models", []):
+            model_info = Models.get_model_by_id(model["model"])
+            if model_info:
+                if has_access(
+                    user.id, type="read", access_control=model_info.access_control
+                ):
+                    filtered_models.append(model)
+            else:
+                filtered_models.append(model)
+        models["models"] = filtered_models
+    
+        
+    return models
 
 
 @app.get("/api/version")
@@ -926,16 +931,9 @@ async def generate_chat_completion(
     if "metadata" in payload:
         del payload["metadata"]
 
-    model_id = form_data.model
-
-    # TODO: Check User Group and Filter Models
-    # if not bypass_filter:
-    #     if app.state.config.ENABLE_MODEL_FILTER:
-    #         if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
-    #             raise HTTPException(
-    #                 status_code=403,
-    #                 detail="Model not found",
-    #             )
+    model_id = payload["model"]
+    if ":" not in model_id:
+        model_id = f"{model_id}:latest"
 
     model_info = Models.get_model_by_id(model_id)
 
@@ -954,9 +952,19 @@ async def generate_chat_completion(
             )
             payload = apply_model_system_prompt_to_body(params, payload, user)
 
+        # Check if user has access to the model
+        if not bypass_filter and user.role == "user" and not has_access(
+            user.id, type="read", access_control=model_info.access_control
+        ):
+            raise HTTPException(
+                status_code=403,
+                detail="Model not found",
+            )
+    
     if ":" not in payload["model"]:
         payload["model"] = f"{payload['model']}:latest"
 
+
     url = await get_ollama_url(url_idx, payload["model"])
     log.info(f"url: {url}")
     log.debug(f"generate_chat_completion() - 2.payload = {payload}")
@@ -1015,17 +1023,11 @@ async def generate_openai_chat_completion(
         del payload["metadata"]
 
     model_id = completion_form.model
+    if ":" not in model_id:
+        model_id = f"{model_id}:latest"
 
-    # TODO: Check User Group and Filter Models
-    # if app.state.config.ENABLE_MODEL_FILTER:
-    #     if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
-    #         raise HTTPException(
-    #             status_code=403,
-    #             detail="Model not found",
-    #         )
 
     model_info = Models.get_model_by_id(model_id)
-
     if model_info:
         if model_info.base_model_id:
             payload["model"] = model_info.base_model_id
@@ -1036,6 +1038,15 @@ async def generate_openai_chat_completion(
             payload = apply_model_params_to_body_openai(params, payload)
             payload = apply_model_system_prompt_to_body(params, payload, user)
 
+        # Check if user has access to the model
+        if user.role == "user" and not has_access(
+            user.id, type="read", access_control=model_info.access_control
+        ):
+            raise HTTPException(
+                status_code=403,
+                detail="Model not found",
+            )
+
     if ":" not in payload["model"]:
         payload["model"] = f"{payload['model']}:latest"
 
@@ -1060,32 +1071,19 @@ async def get_openai_models(
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
 ):
+    
+    models = []
     if url_idx is None:
-        models = await get_all_models()
-
-        # TODO: Check User Group and Filter Models
-        # if app.state.config.ENABLE_MODEL_FILTER:
-        #     if user.role == "user":
-        #         models["models"] = list(
-        #             filter(
-        #                 lambda model: model["name"]
-        #                 in app.state.config.MODEL_FILTER_LIST,
-        #                 models["models"],
-        #             )
-        #         )
-
-        return {
-            "data": [
+        model_list = await get_all_models()
+        models = [
                 {
                     "id": model["model"],
                     "object": "model",
                     "created": int(time.time()),
                     "owned_by": "openai",
                 }
-                for model in models["models"]
-            ],
-            "object": "list",
-        }
+                for model in model_list["models"]
+            ]
 
     else:
         url = app.state.config.OLLAMA_BASE_URLS[url_idx]
@@ -1093,10 +1091,9 @@ async def get_openai_models(
             r = requests.request(method="GET", url=f"{url}/api/tags")
             r.raise_for_status()
 
-            models = r.json()
+            model_list = r.json()
 
-            return {
-                "data": [
+            models = [
                     {
                         "id": model["model"],
                         "object": "model",
@@ -1104,10 +1101,7 @@ async def get_openai_models(
                         "owned_by": "openai",
                     }
                     for model in models["models"]
-                ],
-                "object": "list",
-            }
-
+                ]
         except Exception as e:
             log.exception(e)
             error_detail = "Open WebUI: Server Connection Error"
@@ -1123,6 +1117,27 @@ async def get_openai_models(
                 status_code=r.status_code if r else 500,
                 detail=error_detail,
             )
+        
+
+    if user.role == "user":
+        # Filter models based on user access control
+        filtered_models = []
+        for model in models:
+            model_info = Models.get_model_by_id(model["id"])
+            if model_info:
+                if has_access(
+                    user.id, type="read", access_control=model_info.access_control
+                ):
+                    filtered_models.append(model)
+            else:
+                filtered_models.append(model)
+        models = filtered_models
+        
+
+    return {
+            "data": models,
+            "object": "list",
+        }
 
 
 class UrlForm(BaseModel):

+ 1 - 1
backend/open_webui/apps/openai/main.py

@@ -501,7 +501,7 @@ async def generate_chat_completion(
         payload = apply_model_system_prompt_to_body(params, payload, user)
 
         # Check if user has access to the model
-        if user.role == "user" and not has_access(
+        if not bypass_filter and user.role == "user" and not has_access(
             user.id, type="read", access_control=model_info.access_control
         ):
             raise HTTPException(