Ver código fonte

Allow seting CORS origin

Craig Quiter 8 meses atrás
pai
commit
d2f10d50bf

+ 2 - 1
backend/apps/audio/main.py

@@ -38,6 +38,7 @@ from config import (
     AUDIO_TTS_MODEL,
     AUDIO_TTS_MODEL,
     AUDIO_TTS_VOICE,
     AUDIO_TTS_VOICE,
     AppConfig,
     AppConfig,
+    CORS_ALLOW_ORIGIN,
 )
 )
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 from utils.utils import (
 from utils.utils import (
@@ -52,7 +53,7 @@ log.setLevel(SRC_LOG_LEVELS["AUDIO"])
 app = FastAPI()
 app = FastAPI()
 app.add_middleware(
 app.add_middleware(
     CORSMiddleware,
     CORSMiddleware,
-    allow_origins=["*"],
+    allow_origins=CORS_ALLOW_ORIGIN,
     allow_credentials=True,
     allow_credentials=True,
     allow_methods=["*"],
     allow_methods=["*"],
     allow_headers=["*"],
     allow_headers=["*"],

+ 2 - 1
backend/apps/images/main.py

@@ -51,6 +51,7 @@ from config import (
     IMAGE_SIZE,
     IMAGE_SIZE,
     IMAGE_STEPS,
     IMAGE_STEPS,
     AppConfig,
     AppConfig,
+    CORS_ALLOW_ORIGIN,
 )
 )
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
@@ -62,7 +63,7 @@ IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
 app = FastAPI()
 app = FastAPI()
 app.add_middleware(
 app.add_middleware(
     CORSMiddleware,
     CORSMiddleware,
-    allow_origins=["*"],
+    allow_origins=CORS_ALLOW_ORIGIN,
     allow_credentials=True,
     allow_credentials=True,
     allow_methods=["*"],
     allow_methods=["*"],
     allow_headers=["*"],
     allow_headers=["*"],

+ 2 - 1
backend/apps/ollama/main.py

@@ -41,6 +41,7 @@ from config import (
     MODEL_FILTER_LIST,
     MODEL_FILTER_LIST,
     UPLOAD_DIR,
     UPLOAD_DIR,
     AppConfig,
     AppConfig,
+    CORS_ALLOW_ORIGIN,
 )
 )
 from utils.misc import (
 from utils.misc import (
     calculate_sha256,
     calculate_sha256,
@@ -55,7 +56,7 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
 app = FastAPI()
 app = FastAPI()
 app.add_middleware(
 app.add_middleware(
     CORSMiddleware,
     CORSMiddleware,
-    allow_origins=["*"],
+    allow_origins=CORS_ALLOW_ORIGIN,
     allow_credentials=True,
     allow_credentials=True,
     allow_methods=["*"],
     allow_methods=["*"],
     allow_headers=["*"],
     allow_headers=["*"],

+ 2 - 1
backend/apps/openai/main.py

@@ -32,6 +32,7 @@ from config import (
     ENABLE_MODEL_FILTER,
     ENABLE_MODEL_FILTER,
     MODEL_FILTER_LIST,
     MODEL_FILTER_LIST,
     AppConfig,
     AppConfig,
+    CORS_ALLOW_ORIGIN,
 )
 )
 from typing import Optional, Literal, overload
 from typing import Optional, Literal, overload
 
 
@@ -45,7 +46,7 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
 app = FastAPI()
 app = FastAPI()
 app.add_middleware(
 app.add_middleware(
     CORSMiddleware,
     CORSMiddleware,
-    allow_origins=["*"],
+    allow_origins=CORS_ALLOW_ORIGIN,
     allow_credentials=True,
     allow_credentials=True,
     allow_methods=["*"],
     allow_methods=["*"],
     allow_headers=["*"],
     allow_headers=["*"],

+ 2 - 4
backend/apps/rag/main.py

@@ -129,6 +129,7 @@ from config import (
     RAG_WEB_SEARCH_RESULT_COUNT,
     RAG_WEB_SEARCH_RESULT_COUNT,
     RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
     RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
     RAG_EMBEDDING_OPENAI_BATCH_SIZE,
     RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+    CORS_ALLOW_ORIGIN,
 )
 )
 
 
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
@@ -240,12 +241,9 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
     app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
     app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
 )
 )
 
 
-origins = ["*"]
-
-
 app.add_middleware(
 app.add_middleware(
     CORSMiddleware,
     CORSMiddleware,
-    allow_origins=origins,
+    allow_origins=CORS_ALLOW_ORIGIN,
     allow_credentials=True,
     allow_credentials=True,
     allow_methods=["*"],
     allow_methods=["*"],
     allow_headers=["*"],
     allow_headers=["*"],

+ 2 - 3
backend/apps/webui/main.py

@@ -47,6 +47,7 @@ from config import (
     OAUTH_USERNAME_CLAIM,
     OAUTH_USERNAME_CLAIM,
     OAUTH_PICTURE_CLAIM,
     OAUTH_PICTURE_CLAIM,
     OAUTH_EMAIL_CLAIM,
     OAUTH_EMAIL_CLAIM,
+    CORS_ALLOW_ORIGIN,
 )
 )
 
 
 from apps.socket.main import get_event_call, get_event_emitter
 from apps.socket.main import get_event_call, get_event_emitter
@@ -59,8 +60,6 @@ from pydantic import BaseModel
 
 
 app = FastAPI()
 app = FastAPI()
 
 
-origins = ["*"]
-
 app.state.config = AppConfig()
 app.state.config = AppConfig()
 
 
 app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
 app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
@@ -93,7 +92,7 @@ app.state.FUNCTIONS = {}
 
 
 app.add_middleware(
 app.add_middleware(
     CORSMiddleware,
     CORSMiddleware,
-    allow_origins=origins,
+    allow_origins=CORS_ALLOW_ORIGIN,
     allow_credentials=True,
     allow_credentials=True,
     allow_methods=["*"],
     allow_methods=["*"],
     allow_headers=["*"],
     allow_headers=["*"],

+ 31 - 0
backend/config.py

@@ -3,6 +3,8 @@ import sys
 import logging
 import logging
 import importlib.metadata
 import importlib.metadata
 import pkgutil
 import pkgutil
+from urllib.parse import urlparse
+
 import chromadb
 import chromadb
 from chromadb import Settings
 from chromadb import Settings
 from bs4 import BeautifulSoup
 from bs4 import BeautifulSoup
@@ -840,6 +842,35 @@ ENABLE_COMMUNITY_SHARING = PersistentConfig(
     os.environ.get("ENABLE_COMMUNITY_SHARING", "True").lower() == "true",
     os.environ.get("ENABLE_COMMUNITY_SHARING", "True").lower() == "true",
 )
 )
 
 
+def validate_cors_origins(origins):
+    for origin in origins:
+        if origin != "*":
+            validate_cors_origin(origin)
+
+
+def validate_cors_origin(origin):
+    parsed_url = urlparse(origin)
+
+    # Check if the scheme is either http or https
+    if parsed_url.scheme not in ["http", "https"]:
+        raise ValueError(f"Invalid scheme in CORS_ALLOW_ORIGIN: '{origin}'. Only 'http' and 'https' are allowed.")
+
+    # Ensure that the netloc (domain + port) is present, indicating it's a valid URL
+    if not parsed_url.netloc:
+        raise ValueError(f"Invalid URL structure in CORS_ALLOW_ORIGIN: '{origin}'.")
+
+
+# For production, you should only need one host as
+# fastapi serves the svelte-kit built frontend and backend from the same host and port.
+# To test CORS_ALLOW_ORIGIN locally, you can set something like
+# CORS_ALLOW_ORIGIN=http://localhost:5173;http://localhost:8080
+# in your .env file depending on your frontend port, 5173 in this case.
+CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";")
+
+if "*" in CORS_ALLOW_ORIGIN:
+    log.warning("\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n")
+
+validate_cors_origins(CORS_ALLOW_ORIGIN)
 
 
 class BannerModel(BaseModel):
 class BannerModel(BaseModel):
     id: str
     id: str

+ 2 - 3
backend/main.py

@@ -119,6 +119,7 @@ from config import (
     WEBUI_SESSION_COOKIE_SECURE,
     WEBUI_SESSION_COOKIE_SECURE,
     ENABLE_ADMIN_CHAT_ACCESS,
     ENABLE_ADMIN_CHAT_ACCESS,
     AppConfig,
     AppConfig,
+    CORS_ALLOW_ORIGIN,
 )
 )
 
 
 from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
 from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
@@ -209,8 +210,6 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
 
 
 app.state.MODELS = {}
 app.state.MODELS = {}
 
 
-origins = ["*"]
-
 
 
 ##################################
 ##################################
 #
 #
@@ -833,7 +832,7 @@ app.add_middleware(PipelineMiddleware)
 
 
 app.add_middleware(
 app.add_middleware(
     CORSMiddleware,
     CORSMiddleware,
-    allow_origins=origins,
+    allow_origins=CORS_ALLOW_ORIGIN,
     allow_credentials=True,
     allow_credentials=True,
     allow_methods=["*"],
     allow_methods=["*"],
     allow_headers=["*"],
     allow_headers=["*"],