Timothy Jaeryang Baek 4 months ago
parent
commit
4819199650
2 changed files with 333 additions and 138 deletions
  1. 162 17
      backend/open_webui/main.py
  2. 171 121
      backend/open_webui/routers/ollama.py

+ 162 - 17
backend/open_webui/main.py

@@ -46,6 +46,21 @@ from open_webui.routers import (
     retrieval,
     pipelines,
     tasks,
+    auths,
+    chats,
+    folders,
+    configs,
+    groups,
+    files,
+    functions,
+    memories,
+    models,
+    knowledge,
+    prompts,
+    evaluations,
+    tools,
+    users,
+    utils,
 )
 
 from open_webui.retrieval.utils import get_sources_from_files
@@ -117,6 +132,60 @@ from open_webui.config import (
     WHISPER_MODEL,
     WHISPER_MODEL_AUTO_UPDATE,
     WHISPER_MODEL_DIR,
+    # Retrieval
+    RAG_TEMPLATE,
+    DEFAULT_RAG_TEMPLATE,
+    RAG_EMBEDDING_MODEL,
+    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
+    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
+    RAG_RERANKING_MODEL,
+    RAG_RERANKING_MODEL_AUTO_UPDATE,
+    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
+    RAG_EMBEDDING_ENGINE,
+    RAG_EMBEDDING_BATCH_SIZE,
+    RAG_RELEVANCE_THRESHOLD,
+    RAG_FILE_MAX_COUNT,
+    RAG_FILE_MAX_SIZE,
+    RAG_OPENAI_API_BASE_URL,
+    RAG_OPENAI_API_KEY,
+    RAG_OLLAMA_BASE_URL,
+    RAG_OLLAMA_API_KEY,
+    CHUNK_OVERLAP,
+    CHUNK_SIZE,
+    CONTENT_EXTRACTION_ENGINE,
+    TIKA_SERVER_URL,
+    RAG_TOP_K,
+    RAG_TEXT_SPLITTER,
+    TIKTOKEN_ENCODING_NAME,
+    PDF_EXTRACT_IMAGES,
+    YOUTUBE_LOADER_LANGUAGE,
+    YOUTUBE_LOADER_PROXY_URL,
+    # Retrieval (Web Search)
+    RAG_WEB_SEARCH_ENGINE,
+    RAG_WEB_SEARCH_RESULT_COUNT,
+    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
+    RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
+    JINA_API_KEY,
+    SEARCHAPI_API_KEY,
+    SEARCHAPI_ENGINE,
+    SEARXNG_QUERY_URL,
+    SERPER_API_KEY,
+    SERPLY_API_KEY,
+    SERPSTACK_API_KEY,
+    SERPSTACK_HTTPS,
+    TAVILY_API_KEY,
+    BING_SEARCH_V7_ENDPOINT,
+    BING_SEARCH_V7_SUBSCRIPTION_KEY,
+    BRAVE_SEARCH_API_KEY,
+    KAGI_SEARCH_API_KEY,
+    MOJEEK_SEARCH_API_KEY,
+    GOOGLE_PSE_API_KEY,
+    GOOGLE_PSE_ENGINE_ID,
+    ENABLE_RAG_HYBRID_SEARCH,
+    ENABLE_RAG_LOCAL_WEB_FETCH,
+    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
+    ENABLE_RAG_WEB_SEARCH,
+    UPLOAD_DIR,
     # WebUI
     WEBUI_AUTH,
     WEBUI_NAME,
@@ -383,6 +452,72 @@ app.state.FUNCTIONS = {}
 #
 ########################################
 
+
+app.state.config.TOP_K = RAG_TOP_K
+app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
+app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
+app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
+
+app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
+app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
+    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
+)
+
+app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
+app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
+
+app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
+app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
+
+app.state.config.CHUNK_SIZE = CHUNK_SIZE
+app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
+
+app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
+app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
+app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
+app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
+app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
+
+app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
+app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
+
+app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
+app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
+
+app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
+
+app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
+app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
+
+
+app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
+app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
+app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
+
+app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
+app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
+app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
+app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
+app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY
+app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
+app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
+app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
+app.state.config.SERPER_API_KEY = SERPER_API_KEY
+app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
+app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
+app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
+app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
+app.state.config.JINA_API_KEY = JINA_API_KEY
+app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
+app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
+
+app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
+app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
+
+
+app.state.YOUTUBE_LOADER_TRANSLATION = None
+app.state.EMBEDDING_FUNCTION = None
+
 ########################################
 #
 # IMAGES
@@ -1083,8 +1218,8 @@ def filter_pipeline(payload, user, models):
         try:
             urlIdx = filter["urlIdx"]
 
-            url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
-            key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
+            url = app.state.config.OPENAI_API_BASE_URLS[urlIdx]
+            key = app.state.config.OPENAI_API_KEYS[urlIdx]
 
             if key == "":
                 continue
@@ -1230,14 +1365,6 @@ async def check_url(request: Request, call_next):
     return response
 
 
-# @app.middleware("http")
-# async def update_embedding_function(request: Request, call_next):
-#     response = await call_next(request)
-#     if "/embedding/update" in request.url.path:
-#         webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
-#     return response
-
-
 @app.middleware("http")
 async def inspect_websocket(request: Request, call_next):
     if (
@@ -1268,18 +1395,36 @@ app.add_middleware(
 app.mount("/ws", socket_app)
 
 
-app.mount("/ollama", ollama_app)
-app.mount("/openai", openai_app)
+app.include_router(ollama.router, prefix="/ollama")
+app.include_router(openai.router, prefix="/openai")
+
 
-app.mount("/images/api/v1", images_app)
-app.mount("/audio/api/v1", audio_app)
+app.include_router(images.router, prefix="/api/v1/images")
+app.include_router(audio.router, prefix="/api/v1/audio")
+app.include_router(retrieval.router, prefix="/api/v1/retrieval")
 
 
-app.mount("/retrieval/api/v1", retrieval_app)
+app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"])
 
-app.mount("/api/v1", webui_app)
+app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"])
+app.include_router(users.router, prefix="/api/v1/users", tags=["users"])
 
-app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
+app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"])
+
+app.include_router(models.router, prefix="/api/v1/models", tags=["models"])
+app.include_router(knowledge.router, prefix="/api/v1/knowledge", tags=["knowledge"])
+app.include_router(prompts.router, prefix="/api/v1/prompts", tags=["prompts"])
+app.include_router(tools.router, prefix="/api/v1/tools", tags=["tools"])
+
+app.include_router(memories.router, prefix="/api/v1/memories", tags=["memories"])
+app.include_router(folders.router, prefix="/api/v1/folders", tags=["folders"])
+app.include_router(groups.router, prefix="/api/v1/groups", tags=["groups"])
+app.include_router(files.router, prefix="/api/v1/files", tags=["files"])
+app.include_router(functions.router, prefix="/api/v1/functions", tags=["functions"])
+app.include_router(
+    evaluations.router, prefix="/api/v1/evaluations", tags=["evaluations"]
+)
+app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
 
 
 async def get_all_base_models():

+ 171 - 121
backend/open_webui/routers/ollama.py

@@ -13,7 +13,15 @@ from aiocache import cached
 
 import requests
 
-from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
+from fastapi import (
+    Depends,
+    FastAPI,
+    File,
+    HTTPException,
+    Request,
+    UploadFile,
+    APIRouter,
+)
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import StreamingResponse
 from pydantic import BaseModel, ConfigDict
@@ -26,18 +34,15 @@ from open_webui.models.models import Models
 from open_webui.config import (
     UPLOAD_DIR,
 )
-
-
 from open_webui.env import (
+    ENV,
+    SRC_LOG_LEVELS,
     AIOHTTP_CLIENT_TIMEOUT,
     AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
     BYPASS_MODEL_ACCESS_CONTROL,
 )
 
-
 from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import ENV, SRC_LOG_LEVELS
-
 
 from open_webui.utils.misc import (
     calculate_sha256,
@@ -54,13 +59,15 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
 
 
+router = APIRouter()
+
 # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
 # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
 # least connections, or least response time for better resource utilization and performance optimization.
 
 
-@app.head("/")
-@app.get("/")
+@router.head("/")
+@router.get("/")
 async def get_status():
     return {"status": True}
 
@@ -70,7 +77,7 @@ class ConnectionVerificationForm(BaseModel):
     key: Optional[str] = None
 
 
-@app.post("/verify")
+@router.post("/verify")
 async def verify_connection(
     form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
 ):
@@ -110,12 +117,12 @@ async def verify_connection(
             raise HTTPException(status_code=500, detail=error_detail)
 
 
-@app.get("/config")
-async def get_config(user=Depends(get_admin_user)):
+@router.get("/config")
+async def get_config(request: Request, user=Depends(get_admin_user)):
     return {
-        "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
-        "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
-        "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
+        "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
+        "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
+        "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
     }
 
 
@@ -125,23 +132,25 @@ class OllamaConfigForm(BaseModel):
     OLLAMA_API_CONFIGS: dict
 
 
-@app.post("/config/update")
-async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
-    app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
-    app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
+@router.post("/config/update")
+async def update_config(
+    request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user)
+):
+    request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
+    request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
 
-    app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
+    request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
 
     # Remove any extra configs
-    config_urls = app.state.config.OLLAMA_API_CONFIGS.keys()
-    for url in list(app.state.config.OLLAMA_BASE_URLS):
+    config_urls = request.app.state.config.OLLAMA_API_CONFIGS.keys()
+    for url in list(request.app.state.config.OLLAMA_BASE_URLS):
         if url not in config_urls:
-            app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
+            request.app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
 
     return {
-        "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
-        "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
-        "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
+        "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
+        "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
+        "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
     }
 
 
@@ -158,6 +167,12 @@ async def aiohttp_get(url, key=None):
         return None
 
 
+def get_api_key(url, configs):
+    parsed_url = urlparse(url)
+    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
+    return configs.get(base_url, {}).get("key", None)
+
+
 async def cleanup_response(
     response: Optional[aiohttp.ClientResponse],
     session: Optional[aiohttp.ClientSession],
@@ -169,7 +184,11 @@ async def cleanup_response(
 
 
 async def post_streaming_url(
-    url: str, payload: Union[str, bytes], stream: bool = True, content_type=None
+    url: str,
+    payload: Union[str, bytes],
+    stream: bool = True,
+    key: Optional[str] = None,
+    content_type=None,
 ):
     r = None
     try:
@@ -177,12 +196,6 @@ async def post_streaming_url(
             trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
         )
 
-        parsed_url = urlparse(url)
-        base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
-
-        api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
-        key = api_config.get("key", None)
-
         headers = {"Content-Type": "application/json"}
         if key:
             headers["Authorization"] = f"Bearer {key}"
@@ -246,13 +259,13 @@ def merge_models_lists(model_lists):
 @cached(ttl=3)
 async def get_all_models():
     log.info("get_all_models()")
-    if app.state.config.ENABLE_OLLAMA_API:
+    if request.app.state.config.ENABLE_OLLAMA_API:
         tasks = []
-        for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS):
-            if url not in app.state.config.OLLAMA_API_CONFIGS:
+        for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
+            if url not in request.app.state.config.OLLAMA_API_CONFIGS:
                 tasks.append(aiohttp_get(f"{url}/api/tags"))
             else:
-                api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+                api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
                 enable = api_config.get("enable", True)
                 key = api_config.get("key", None)
 
@@ -265,8 +278,8 @@ async def get_all_models():
 
         for idx, response in enumerate(responses):
             if response:
-                url = app.state.config.OLLAMA_BASE_URLS[idx]
-                api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+                url = request.app.state.config.OLLAMA_BASE_URLS[idx]
+                api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
 
                 prefix_id = api_config.get("prefix_id", None)
                 model_ids = api_config.get("model_ids", [])
@@ -298,21 +311,21 @@ async def get_all_models():
     return models
 
 
-@app.get("/api/tags")
-@app.get("/api/tags/{url_idx}")
+@router.get("/api/tags")
+@router.get("/api/tags/{url_idx}")
 async def get_ollama_tags(
-    url_idx: Optional[int] = None, user=Depends(get_verified_user)
+    request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
 ):
     models = []
     if url_idx is None:
         models = await get_all_models()
     else:
-        url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+        url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
 
         parsed_url = urlparse(url)
         base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
 
-        api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
+        api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
         key = api_config.get("key", None)
 
         headers = {}
@@ -356,18 +369,20 @@ async def get_ollama_tags(
     return models
 
 
-@app.get("/api/version")
-@app.get("/api/version/{url_idx}")
-async def get_ollama_versions(url_idx: Optional[int] = None):
-    if app.state.config.ENABLE_OLLAMA_API:
+@router.get("/api/version")
+@router.get("/api/version/{url_idx}")
+async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
+    if request.app.state.config.ENABLE_OLLAMA_API:
         if url_idx is None:
             # returns lowest version
             tasks = [
                 aiohttp_get(
                     f"{url}/api/version",
-                    app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
+                    request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
+                        "key", None
+                    ),
                 )
-                for url in app.state.config.OLLAMA_BASE_URLS
+                for url in request.app.state.config.OLLAMA_BASE_URLS
             ]
             responses = await asyncio.gather(*tasks)
             responses = list(filter(lambda x: x is not None, responses))
@@ -387,7 +402,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
                     detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
                 )
         else:
-            url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+            url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
 
             r = None
             try:
@@ -414,22 +429,24 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
         return {"version": False}
 
 
-@app.get("/api/ps")
-async def get_ollama_loaded_models(user=Depends(get_verified_user)):
+@router.get("/api/ps")
+async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
     """
     List models that are currently loaded into Ollama memory, and which node they are loaded on.
     """
-    if app.state.config.ENABLE_OLLAMA_API:
+    if request.app.state.config.ENABLE_OLLAMA_API:
         tasks = [
             aiohttp_get(
                 f"{url}/api/ps",
-                app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
+                request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
+                    "key", None
+                ),
             )
-            for url in app.state.config.OLLAMA_BASE_URLS
+            for url in request.app.state.config.OLLAMA_BASE_URLS
         ]
         responses = await asyncio.gather(*tasks)
 
-        return dict(zip(app.state.config.OLLAMA_BASE_URLS, responses))
+        return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
     else:
         return {}
 
@@ -438,18 +455,25 @@ class ModelNameForm(BaseModel):
     name: str
 
 
-@app.post("/api/pull")
-@app.post("/api/pull/{url_idx}")
+@router.post("/api/pull")
+@router.post("/api/pull/{url_idx}")
 async def pull_model(
-    form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
+    request: Request,
+    form_data: ModelNameForm,
+    url_idx: int = 0,
+    user=Depends(get_admin_user),
 ):
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+    url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
 
     # Admin should be able to pull models from any source
     payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
 
-    return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
+    return await post_streaming_url(
+        url=f"{url}/api/pull",
+        payload=json.dumps(payload),
+        key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
+    )
 
 
 class PushModelForm(BaseModel):
@@ -458,9 +482,10 @@ class PushModelForm(BaseModel):
     stream: Optional[bool] = None
 
 
-@app.delete("/api/push")
-@app.delete("/api/push/{url_idx}")
+@router.delete("/api/push")
+@router.delete("/api/push/{url_idx}")
 async def push_model(
+    request: Request,
     form_data: PushModelForm,
     url_idx: Optional[int] = None,
     user=Depends(get_admin_user),
@@ -477,11 +502,13 @@ async def push_model(
                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
             )
 
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+    url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.debug(f"url: {url}")
 
     return await post_streaming_url(
-        f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
+        url=f"{url}/api/push",
+        payload=form_data.model_dump_json(exclude_none=True).encode(),
+        key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
     )
 
 
@@ -492,17 +519,22 @@ class CreateModelForm(BaseModel):
     path: Optional[str] = None
 
 
-@app.post("/api/create")
-@app.post("/api/create/{url_idx}")
+@router.post("/api/create")
+@router.post("/api/create/{url_idx}")
 async def create_model(
-    form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
+    request: Request,
+    form_data: CreateModelForm,
+    url_idx: int = 0,
+    user=Depends(get_admin_user),
 ):
     log.debug(f"form_data: {form_data}")
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+    url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
 
     return await post_streaming_url(
-        f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
+        url=f"{url}/api/create",
+        payload=form_data.model_dump_json(exclude_none=True).encode(),
+        key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
     )
 
 
@@ -511,9 +543,10 @@ class CopyModelForm(BaseModel):
     destination: str
 
 
-@app.post("/api/copy")
-@app.post("/api/copy/{url_idx}")
+@router.post("/api/copy")
+@router.post("/api/copy/{url_idx}")
 async def copy_model(
+    request: Request,
     form_data: CopyModelForm,
     url_idx: Optional[int] = None,
     user=Depends(get_admin_user),
@@ -530,13 +563,13 @@ async def copy_model(
                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
             )
 
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+    url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
 
     parsed_url = urlparse(url)
     base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
 
-    api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
+    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
     key = api_config.get("key", None)
 
     headers = {"Content-Type": "application/json"}
@@ -573,9 +606,10 @@ async def copy_model(
         )
 
 
-@app.delete("/api/delete")
-@app.delete("/api/delete/{url_idx}")
+@router.delete("/api/delete")
+@router.delete("/api/delete/{url_idx}")
 async def delete_model(
+    request: Request,
     form_data: ModelNameForm,
     url_idx: Optional[int] = None,
     user=Depends(get_admin_user),
@@ -592,13 +626,13 @@ async def delete_model(
                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
             )
 
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+    url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
 
     parsed_url = urlparse(url)
     base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
 
-    api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
+    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
     key = api_config.get("key", None)
 
     headers = {"Content-Type": "application/json"}
@@ -634,8 +668,10 @@ async def delete_model(
         )
 
 
-@app.post("/api/show")
-async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
+@router.post("/api/show")
+async def show_model_info(
+    request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
+):
     model_list = await get_all_models()
     models = {model["model"]: model for model in model_list["models"]}
 
@@ -646,13 +682,13 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
         )
 
     url_idx = random.choice(models[form_data.name]["urls"])
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+    url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
 
     parsed_url = urlparse(url)
     base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
 
-    api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
+    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
     key = api_config.get("key", None)
 
     headers = {"Content-Type": "application/json"}
@@ -701,8 +737,8 @@ class GenerateEmbedForm(BaseModel):
     keep_alive: Optional[Union[int, str]] = None
 
 
-@app.post("/api/embed")
-@app.post("/api/embed/{url_idx}")
+@router.post("/api/embed")
+@router.post("/api/embed/{url_idx}")
 async def generate_embeddings(
     form_data: GenerateEmbedForm,
     url_idx: Optional[int] = None,
@@ -711,8 +747,8 @@ async def generate_embeddings(
     return await generate_ollama_batch_embeddings(form_data, url_idx)
 
 
-@app.post("/api/embeddings")
-@app.post("/api/embeddings/{url_idx}")
+@router.post("/api/embeddings")
+@router.post("/api/embeddings/{url_idx}")
 async def generate_embeddings(
     form_data: GenerateEmbeddingsForm,
     url_idx: Optional[int] = None,
@@ -744,13 +780,13 @@ async def generate_ollama_embeddings(
                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
             )
 
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+    url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
 
     parsed_url = urlparse(url)
     base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
 
-    api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
+    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
     key = api_config.get("key", None)
 
     headers = {"Content-Type": "application/json"}
@@ -814,13 +850,13 @@ async def generate_ollama_batch_embeddings(
                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
             )
 
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+    url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
 
     parsed_url = urlparse(url)
     base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
 
-    api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
+    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
     key = api_config.get("key", None)
 
     headers = {"Content-Type": "application/json"}
@@ -873,9 +909,10 @@ class GenerateCompletionForm(BaseModel):
     keep_alive: Optional[Union[int, str]] = None
 
 
-@app.post("/api/generate")
-@app.post("/api/generate/{url_idx}")
+@router.post("/api/generate")
+@router.post("/api/generate/{url_idx}")
 async def generate_completion(
+    request: Request,
     form_data: GenerateCompletionForm,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
@@ -897,15 +934,17 @@ async def generate_completion(
                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
             )
 
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-    api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+    url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
+    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
     prefix_id = api_config.get("prefix_id", None)
     if prefix_id:
         form_data.model = form_data.model.replace(f"{prefix_id}.", "")
     log.info(f"url: {url}")
 
     return await post_streaming_url(
-        f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
+        url=f"{url}/api/generate",
+        payload=form_data.model_dump_json(exclude_none=True).encode(),
+        key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
     )
 
 
@@ -936,13 +975,14 @@ async def get_ollama_url(url_idx: Optional[int], model: str):
                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
             )
         url_idx = random.choice(models[model]["urls"])
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+    url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     return url
 
 
-@app.post("/api/chat")
-@app.post("/api/chat/{url_idx}")
+@router.post("/api/chat")
+@router.post("/api/chat/{url_idx}")
 async def generate_chat_completion(
+    request: Request,
     form_data: GenerateChatCompletionForm,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
@@ -1003,15 +1043,16 @@ async def generate_chat_completion(
     parsed_url = urlparse(url)
     base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
 
-    api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
+    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
     prefix_id = api_config.get("prefix_id", None)
     if prefix_id:
         payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
 
     return await post_streaming_url(
-        f"{url}/api/chat",
-        json.dumps(payload),
+        url=f"{url}/api/chat",
+        payload=json.dumps(payload),
         stream=form_data.stream,
+        key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
         content_type="application/x-ndjson",
     )
 
@@ -1043,10 +1084,13 @@ class OpenAICompletionForm(BaseModel):
     model_config = ConfigDict(extra="allow")
 
 
-@app.post("/v1/completions")
-@app.post("/v1/completions/{url_idx}")
+@router.post("/v1/completions")
+@router.post("/v1/completions/{url_idx}")
 async def generate_openai_completion(
-    form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user)
+    request: Request,
+    form_data: dict,
+    url_idx: Optional[int] = None,
+    user=Depends(get_verified_user),
 ):
     try:
         form_data = OpenAICompletionForm(**form_data)
@@ -1099,22 +1143,24 @@ async def generate_openai_completion(
     url = await get_ollama_url(url_idx, payload["model"])
     log.info(f"url: {url}")
 
-    api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
     prefix_id = api_config.get("prefix_id", None)
 
     if prefix_id:
         payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
 
     return await post_streaming_url(
-        f"{url}/v1/completions",
-        json.dumps(payload),
+        url=f"{url}/v1/completions",
+        payload=json.dumps(payload),
         stream=payload.get("stream", False),
+        key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
     )
 
 
-@app.post("/v1/chat/completions")
-@app.post("/v1/chat/completions/{url_idx}")
+@router.post("/v1/chat/completions")
+@router.post("/v1/chat/completions/{url_idx}")
 async def generate_openai_chat_completion(
+    request: Request,
     form_data: dict,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
@@ -1172,21 +1218,23 @@ async def generate_openai_chat_completion(
     url = await get_ollama_url(url_idx, payload["model"])
     log.info(f"url: {url}")
 
-    api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
     prefix_id = api_config.get("prefix_id", None)
     if prefix_id:
         payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
 
     return await post_streaming_url(
-        f"{url}/v1/chat/completions",
-        json.dumps(payload),
+        url=f"{url}/v1/chat/completions",
+        payload=json.dumps(payload),
         stream=payload.get("stream", False),
+        key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
     )
 
 
-@app.get("/v1/models")
-@app.get("/v1/models/{url_idx}")
+@router.get("/v1/models")
+@router.get("/v1/models/{url_idx}")
 async def get_openai_models(
+    request: Request,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
 ):
@@ -1205,7 +1253,7 @@ async def get_openai_models(
         ]
 
     else:
-        url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+        url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
         try:
             r = requests.request(method="GET", url=f"{url}/api/tags")
             r.raise_for_status()
@@ -1329,9 +1377,10 @@ async def download_file_stream(
 
 
 # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
-@app.post("/models/download")
-@app.post("/models/download/{url_idx}")
+@router.post("/models/download")
+@router.post("/models/download/{url_idx}")
 async def download_model(
+    request: Request,
     form_data: UrlForm,
     url_idx: Optional[int] = None,
     user=Depends(get_admin_user),
@@ -1346,7 +1395,7 @@ async def download_model(
 
     if url_idx is None:
         url_idx = 0
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+    url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
 
     file_name = parse_huggingface_url(form_data.url)
 
@@ -1360,16 +1409,17 @@ async def download_model(
         return None
 
 
-@app.post("/models/upload")
-@app.post("/models/upload/{url_idx}")
+@router.post("/models/upload")
+@router.post("/models/upload/{url_idx}")
 def upload_model(
+    request: Request,
     file: UploadFile = File(...),
     url_idx: Optional[int] = None,
     user=Depends(get_admin_user),
 ):
     if url_idx is None:
         url_idx = 0
-    ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+    ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
 
     file_path = f"{UPLOAD_DIR}/{file.filename}"