Przeglądaj źródła

Merge pull request #9918 from df-cgdm/main

feat: Add  X-OpenWebUI when forwarding to ollama servers
Timothy Jaeryang Baek 2 miesięcy temu
rodzic
commit
d8bc3098db

+ 2 - 2
backend/open_webui/main.py

@@ -919,7 +919,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
 
         return filtered_models
 
-    models = await get_all_models(request)
+    models = await get_all_models(request, user=user)
 
     # Filter out filter pipelines
     models = [
@@ -959,7 +959,7 @@ async def chat_completion(
     user=Depends(get_verified_user),
 ):
     if not request.app.state.MODELS:
-        await get_all_models(request)
+        await get_all_models(request, user=user)
 
     model_item = form_data.pop("model_item", {})
     tasks = form_data.pop("background_tasks", None)

+ 130 - 16
backend/open_webui/routers/ollama.py

@@ -14,6 +14,11 @@ from urllib.parse import urlparse
 import aiohttp
 from aiocache import cached
 import requests
+from open_webui.models.users import UserModel
+
+from open_webui.env import (
+    ENABLE_FORWARD_USER_INFO_HEADERS,
+)
 
 from fastapi import (
     Depends,
@@ -66,12 +71,26 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
 ##########################################
 
 
-async def send_get_request(url, key=None):
+async def send_get_request(url, key=None, user: UserModel = None):
     timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
     try:
         async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
             async with session.get(
-                url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
+                url,
+                headers={
+                    "Content-Type": "application/json",
+                    **({"Authorization": f"Bearer {key}"} if key else {}),
+                    **(
+                        {
+                            "X-OpenWebUI-User-Name": user.name,
+                            "X-OpenWebUI-User-Id": user.id,
+                            "X-OpenWebUI-User-Email": user.email,
+                            "X-OpenWebUI-User-Role": user.role,
+                        }
+                        if ENABLE_FORWARD_USER_INFO_HEADERS and user
+                        else {}
+                    ),
+                },
             ) as response:
                 return await response.json()
     except Exception as e:
@@ -96,6 +115,7 @@ async def send_post_request(
     stream: bool = True,
     key: Optional[str] = None,
     content_type: Optional[str] = None,
+    user: UserModel = None,
 ):
 
     r = None
@@ -110,6 +130,16 @@ async def send_post_request(
             headers={
                 "Content-Type": "application/json",
                 **({"Authorization": f"Bearer {key}"} if key else {}),
+                **(
+                    {
+                        "X-OpenWebUI-User-Name": user.name,
+                        "X-OpenWebUI-User-Id": user.id,
+                        "X-OpenWebUI-User-Email": user.email,
+                        "X-OpenWebUI-User-Role": user.role,
+                    }
+                    if ENABLE_FORWARD_USER_INFO_HEADERS and user
+                    else {}
+                ),
             },
         )
         r.raise_for_status()
@@ -191,7 +221,19 @@ async def verify_connection(
         try:
             async with session.get(
                 f"{url}/api/version",
-                headers={**({"Authorization": f"Bearer {key}"} if key else {})},
+                headers={
+                    **({"Authorization": f"Bearer {key}"} if key else {}),
+                    **(
+                        {
+                            "X-OpenWebUI-User-Name": user.name,
+                            "X-OpenWebUI-User-Id": user.id,
+                            "X-OpenWebUI-User-Email": user.email,
+                            "X-OpenWebUI-User-Role": user.role,
+                        }
+                        if ENABLE_FORWARD_USER_INFO_HEADERS and user
+                        else {}
+                    ),
+                },
             ) as r:
                 if r.status != 200:
                     detail = f"HTTP Error: {r.status}"
@@ -254,7 +296,7 @@ async def update_config(
 
 
 @cached(ttl=3)
-async def get_all_models(request: Request):
+async def get_all_models(request: Request, user: UserModel = None):
     log.info("get_all_models()")
     if request.app.state.config.ENABLE_OLLAMA_API:
         request_tasks = []
@@ -262,7 +304,7 @@ async def get_all_models(request: Request):
             if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
                 url not in request.app.state.config.OLLAMA_API_CONFIGS  # Legacy support
             ):
-                request_tasks.append(send_get_request(f"{url}/api/tags"))
+                request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
             else:
                 api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
                     str(idx),
@@ -275,7 +317,9 @@ async def get_all_models(request: Request):
                 key = api_config.get("key", None)
 
                 if enable:
-                    request_tasks.append(send_get_request(f"{url}/api/tags", key))
+                    request_tasks.append(
+                        send_get_request(f"{url}/api/tags", key, user=user)
+                    )
                 else:
                     request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
 
@@ -360,7 +404,7 @@ async def get_ollama_tags(
     models = []
 
     if url_idx is None:
-        models = await get_all_models(request)
+        models = await get_all_models(request, user=user)
     else:
         url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
         key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
@@ -370,7 +414,19 @@ async def get_ollama_tags(
             r = requests.request(
                 method="GET",
                 url=f"{url}/api/tags",
-                headers={**({"Authorization": f"Bearer {key}"} if key else {})},
+                headers={
+                    **({"Authorization": f"Bearer {key}"} if key else {}),
+                    **(
+                        {
+                            "X-OpenWebUI-User-Name": user.name,
+                            "X-OpenWebUI-User-Id": user.id,
+                            "X-OpenWebUI-User-Email": user.email,
+                            "X-OpenWebUI-User-Role": user.role,
+                        }
+                        if ENABLE_FORWARD_USER_INFO_HEADERS and user
+                        else {}
+                    ),
+                },
             )
             r.raise_for_status()
 
@@ -477,6 +533,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u
                         url, {}
                     ),  # Legacy support
                 ).get("key", None),
+                user=user,
             )
             for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
         ]
@@ -509,6 +566,7 @@ async def pull_model(
         url=f"{url}/api/pull",
         payload=json.dumps(payload),
         key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
+        user=user,
     )
 
 
@@ -527,7 +585,7 @@ async def push_model(
     user=Depends(get_admin_user),
 ):
     if url_idx is None:
-        await get_all_models(request)
+        await get_all_models(request, user=user)
         models = request.app.state.OLLAMA_MODELS
 
         if form_data.name in models:
@@ -545,6 +603,7 @@ async def push_model(
         url=f"{url}/api/push",
         payload=form_data.model_dump_json(exclude_none=True).encode(),
         key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
+        user=user,
     )
 
 
@@ -571,6 +630,7 @@ async def create_model(
         url=f"{url}/api/create",
         payload=form_data.model_dump_json(exclude_none=True).encode(),
         key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
+        user=user,
     )
 
 
@@ -588,7 +648,7 @@ async def copy_model(
     user=Depends(get_admin_user),
 ):
     if url_idx is None:
-        await get_all_models(request)
+        await get_all_models(request, user=user)
         models = request.app.state.OLLAMA_MODELS
 
         if form_data.source in models:
@@ -609,6 +669,16 @@ async def copy_model(
             headers={
                 "Content-Type": "application/json",
                 **({"Authorization": f"Bearer {key}"} if key else {}),
+                **(
+                    {
+                        "X-OpenWebUI-User-Name": user.name,
+                        "X-OpenWebUI-User-Id": user.id,
+                        "X-OpenWebUI-User-Email": user.email,
+                        "X-OpenWebUI-User-Role": user.role,
+                    }
+                    if ENABLE_FORWARD_USER_INFO_HEADERS and user
+                    else {}
+                ),
             },
             data=form_data.model_dump_json(exclude_none=True).encode(),
         )
@@ -643,7 +713,7 @@ async def delete_model(
     user=Depends(get_admin_user),
 ):
     if url_idx is None:
-        await get_all_models(request)
+        await get_all_models(request, user=user)
         models = request.app.state.OLLAMA_MODELS
 
         if form_data.name in models:
@@ -665,6 +735,16 @@ async def delete_model(
             headers={
                 "Content-Type": "application/json",
                 **({"Authorization": f"Bearer {key}"} if key else {}),
+                **(
+                    {
+                        "X-OpenWebUI-User-Name": user.name,
+                        "X-OpenWebUI-User-Id": user.id,
+                        "X-OpenWebUI-User-Email": user.email,
+                        "X-OpenWebUI-User-Role": user.role,
+                    }
+                    if ENABLE_FORWARD_USER_INFO_HEADERS and user
+                    else {}
+                ),
             },
         )
         r.raise_for_status()
@@ -693,7 +773,7 @@ async def delete_model(
 async def show_model_info(
     request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
 ):
-    await get_all_models(request)
+    await get_all_models(request, user=user)
     models = request.app.state.OLLAMA_MODELS
 
     if form_data.name not in models:
@@ -714,6 +794,16 @@ async def show_model_info(
             headers={
                 "Content-Type": "application/json",
                 **({"Authorization": f"Bearer {key}"} if key else {}),
+                **(
+                    {
+                        "X-OpenWebUI-User-Name": user.name,
+                        "X-OpenWebUI-User-Id": user.id,
+                        "X-OpenWebUI-User-Email": user.email,
+                        "X-OpenWebUI-User-Role": user.role,
+                    }
+                    if ENABLE_FORWARD_USER_INFO_HEADERS and user
+                    else {}
+                ),
             },
             data=form_data.model_dump_json(exclude_none=True).encode(),
         )
@@ -757,7 +847,7 @@ async def embed(
     log.info(f"generate_ollama_batch_embeddings {form_data}")
 
     if url_idx is None:
-        await get_all_models(request)
+        await get_all_models(request, user=user)
         models = request.app.state.OLLAMA_MODELS
 
         model = form_data.model
@@ -783,6 +873,16 @@ async def embed(
             headers={
                 "Content-Type": "application/json",
                 **({"Authorization": f"Bearer {key}"} if key else {}),
+                **(
+                    {
+                        "X-OpenWebUI-User-Name": user.name,
+                        "X-OpenWebUI-User-Id": user.id,
+                        "X-OpenWebUI-User-Email": user.email,
+                        "X-OpenWebUI-User-Role": user.role,
+                    }
+                    if ENABLE_FORWARD_USER_INFO_HEADERS and user
+                    else {}
+                ),
             },
             data=form_data.model_dump_json(exclude_none=True).encode(),
         )
@@ -826,7 +926,7 @@ async def embeddings(
     log.info(f"generate_ollama_embeddings {form_data}")
 
     if url_idx is None:
-        await get_all_models(request)
+        await get_all_models(request, user=user)
         models = request.app.state.OLLAMA_MODELS
 
         model = form_data.model
@@ -852,6 +952,16 @@ async def embeddings(
             headers={
                 "Content-Type": "application/json",
                 **({"Authorization": f"Bearer {key}"} if key else {}),
+                **(
+                    {
+                        "X-OpenWebUI-User-Name": user.name,
+                        "X-OpenWebUI-User-Id": user.id,
+                        "X-OpenWebUI-User-Email": user.email,
+                        "X-OpenWebUI-User-Role": user.role,
+                    }
+                    if ENABLE_FORWARD_USER_INFO_HEADERS and user
+                    else {}
+                ),
             },
             data=form_data.model_dump_json(exclude_none=True).encode(),
         )
@@ -901,7 +1011,7 @@ async def generate_completion(
     user=Depends(get_verified_user),
 ):
     if url_idx is None:
-        await get_all_models(request)
+        await get_all_models(request, user=user)
         models = request.app.state.OLLAMA_MODELS
 
         model = form_data.model
@@ -931,6 +1041,7 @@ async def generate_completion(
         url=f"{url}/api/generate",
         payload=form_data.model_dump_json(exclude_none=True).encode(),
         key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
+        user=user,
     )
 
 
@@ -1060,6 +1171,7 @@ async def generate_chat_completion(
         stream=form_data.stream,
         key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
         content_type="application/x-ndjson",
+        user=user,
     )
 
 
@@ -1162,6 +1274,7 @@ async def generate_openai_completion(
         payload=json.dumps(payload),
         stream=payload.get("stream", False),
         key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
+        user=user,
     )
 
 
@@ -1240,6 +1353,7 @@ async def generate_openai_chat_completion(
         payload=json.dumps(payload),
         stream=payload.get("stream", False),
         key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
+        user=user,
     )
 
 
@@ -1253,7 +1367,7 @@ async def get_openai_models(
 
     models = []
     if url_idx is None:
-        model_list = await get_all_models(request)
+        model_list = await get_all_models(request, user=user)
         models = [
             {
                 "id": model["model"],

+ 35 - 8
backend/open_webui/routers/openai.py

@@ -26,6 +26,7 @@ from open_webui.env import (
     ENABLE_FORWARD_USER_INFO_HEADERS,
     BYPASS_MODEL_ACCESS_CONTROL,
 )
+from open_webui.models.users import UserModel
 
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.env import ENV, SRC_LOG_LEVELS
@@ -51,12 +52,25 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
 ##########################################
 
 
-async def send_get_request(url, key=None):
+async def send_get_request(url, key=None, user: UserModel = None):
     timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
     try:
         async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
             async with session.get(
-                url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
+                url,
+                headers={
+                    **({"Authorization": f"Bearer {key}"} if key else {}),
+                    **(
+                        {
+                            "X-OpenWebUI-User-Name": user.name,
+                            "X-OpenWebUI-User-Id": user.id,
+                            "X-OpenWebUI-User-Email": user.email,
+                            "X-OpenWebUI-User-Role": user.role,
+                        }
+                        if ENABLE_FORWARD_USER_INFO_HEADERS
+                        else {}
+                    ),
+                },
             ) as response:
                 return await response.json()
     except Exception as e:
@@ -247,7 +261,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
 
 
-async def get_all_models_responses(request: Request) -> list:
+async def get_all_models_responses(request: Request, user: UserModel) -> list:
     if not request.app.state.config.ENABLE_OPENAI_API:
         return []
 
@@ -271,7 +285,9 @@ async def get_all_models_responses(request: Request) -> list:
         ):
             request_tasks.append(
                 send_get_request(
-                    f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
+                    f"{url}/models",
+                    request.app.state.config.OPENAI_API_KEYS[idx],
+                    user=user,
                 )
             )
         else:
@@ -291,6 +307,7 @@ async def get_all_models_responses(request: Request) -> list:
                         send_get_request(
                             f"{url}/models",
                             request.app.state.config.OPENAI_API_KEYS[idx],
+                            user=user,
                         )
                     )
                 else:
@@ -352,13 +369,13 @@ async def get_filtered_models(models, user):
 
 
 @cached(ttl=3)
-async def get_all_models(request: Request) -> dict[str, list]:
+async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
     log.info("get_all_models()")
 
     if not request.app.state.config.ENABLE_OPENAI_API:
         return {"data": []}
 
-    responses = await get_all_models_responses(request)
+    responses = await get_all_models_responses(request, user=user)
 
     def extract_data(response):
         if response and "data" in response:
@@ -418,7 +435,7 @@ async def get_models(
     }
 
     if url_idx is None:
-        models = await get_all_models(request)
+        models = await get_all_models(request, user=user)
     else:
         url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
         key = request.app.state.config.OPENAI_API_KEYS[url_idx]
@@ -515,6 +532,16 @@ async def verify_connection(
                 headers={
                     "Authorization": f"Bearer {key}",
                     "Content-Type": "application/json",
+                    **(
+                        {
+                            "X-OpenWebUI-User-Name": user.name,
+                            "X-OpenWebUI-User-Id": user.id,
+                            "X-OpenWebUI-User-Email": user.email,
+                            "X-OpenWebUI-User-Role": user.role,
+                        }
+                        if ENABLE_FORWARD_USER_INFO_HEADERS
+                        else {}
+                    ),
                 },
             ) as r:
                 if r.status != 200:
@@ -587,7 +614,7 @@ async def generate_chat_completion(
                 detail="Model not found",
             )
 
-    await get_all_models(request)
+    await get_all_models(request, user=user)
     model = request.app.state.OPENAI_MODELS.get(model_id)
     if model:
         idx = model["urlIdx"]

+ 2 - 2
backend/open_webui/utils/chat.py

@@ -285,7 +285,7 @@ chat_completion = generate_chat_completion
 
 async def chat_completed(request: Request, form_data: dict, user: Any):
     if not request.app.state.MODELS:
-        await get_all_models(request)
+        await get_all_models(request, user=user)
 
     if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
         models = {
@@ -351,7 +351,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
         raise Exception(f"Action not found: {action_id}")
 
     if not request.app.state.MODELS:
-        await get_all_models(request)
+        await get_all_models(request, user=user)
 
     if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
         models = {

+ 6 - 5
backend/open_webui/utils/models.py

@@ -22,6 +22,7 @@ from open_webui.config import (
 )
 
 from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
+from open_webui.models.users import UserModel
 
 
 logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
@@ -29,17 +30,17 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MAIN"])
 
 
-async def get_all_base_models(request: Request):
+async def get_all_base_models(request: Request, user: UserModel = None):
     function_models = []
     openai_models = []
     ollama_models = []
 
     if request.app.state.config.ENABLE_OPENAI_API:
-        openai_models = await openai.get_all_models(request)
+        openai_models = await openai.get_all_models(request, user=user)
         openai_models = openai_models["data"]
 
     if request.app.state.config.ENABLE_OLLAMA_API:
-        ollama_models = await ollama.get_all_models(request)
+        ollama_models = await ollama.get_all_models(request, user=user)
         ollama_models = [
             {
                 "id": model["model"],
@@ -58,8 +59,8 @@ async def get_all_base_models(request: Request):
     return models
 
 
-async def get_all_models(request):
-    models = await get_all_base_models(request)
+async def get_all_models(request, user: UserModel = None):
+    models = await get_all_base_models(request, user=user)
 
     # If there are no models, return an empty list
     if len(models) == 0: