Timothy Jaeryang Baek hace 4 meses
padre
commit
d8a01cb911
Se han modificado 1 ficheros con 28 adiciones y 26 borrados
  1. 28 26
      backend/open_webui/routers/images.py

+ 28 - 26
backend/open_webui/routers/images.py

@@ -42,10 +42,10 @@ router = APIRouter()
 async def get_config(request: Request, user=Depends(get_admin_user)):
     return {
         "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
-        "engine": request.app.state.config.ENGINE,
+        "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
         "openai": {
-            "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY,
+            "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
         },
         "automatic1111": {
             "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
@@ -93,11 +93,13 @@ class ConfigForm(BaseModel):
 async def update_config(
     request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
 ):
-    request.app.state.config.ENGINE = form_data.engine
+    request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
     request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
 
-    request.app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
-    request.app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
+    request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
+        form_data.openai.OPENAI_API_BASE_URL
+    )
+    request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
 
     request.app.state.config.AUTOMATIC1111_BASE_URL = (
         form_data.automatic1111.AUTOMATIC1111_BASE_URL
@@ -132,10 +134,10 @@ async def update_config(
 
     return {
         "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
-        "engine": request.app.state.config.ENGINE,
+        "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
         "openai": {
-            "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY,
+            "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
         },
         "automatic1111": {
             "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
@@ -166,7 +168,7 @@ def get_automatic1111_api_auth(request: Request):
 
 @router.get("/config/url/verify")
 async def verify_url(request: Request, user=Depends(get_admin_user)):
-    if request.app.state.config.ENGINE == "automatic1111":
+    if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
         try:
             r = requests.get(
                 url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
@@ -177,7 +179,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
         except Exception:
             request.app.state.config.ENABLE_IMAGE_GENERATION = False
             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
-    elif request.app.state.config.ENGINE == "comfyui":
+    elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
         try:
             r = requests.get(
                 url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
@@ -194,7 +196,7 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
 def set_image_model(request: Request, model: str):
     log.info(f"Setting image model to {model}")
     request.app.state.config.MODEL = model
-    if request.app.state.config.ENGINE in ["", "automatic1111"]:
+    if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
         api_auth = get_automatic1111_api_auth()
         r = requests.get(
             url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
@@ -212,17 +214,17 @@ def set_image_model(request: Request, model: str):
 
 
 def get_image_model():
-    if request.app.state.config.ENGINE == "openai":
+    if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
         return (
             request.app.state.config.MODEL
             if request.app.state.config.MODEL
             else "dall-e-2"
         )
-    elif request.app.state.config.ENGINE == "comfyui":
+    elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
         return request.app.state.config.MODEL if request.app.state.config.MODEL else ""
     elif (
-        request.app.state.config.ENGINE == "automatic1111"
-        or request.app.state.config.ENGINE == ""
+        request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
+        or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
     ):
         try:
             r = requests.get(
@@ -285,12 +287,12 @@ async def update_image_config(
 @router.get("/models")
 def get_models(request: Request, user=Depends(get_verified_user)):
     try:
-        if request.app.state.config.ENGINE == "openai":
+        if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
             return [
                 {"id": "dall-e-2", "name": "DALL·E 2"},
                 {"id": "dall-e-3", "name": "DALL·E 3"},
             ]
-        elif request.app.state.config.ENGINE == "comfyui":
+        elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
             # TODO - get models from comfyui
             r = requests.get(
                 url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
@@ -336,8 +338,8 @@ def get_models(request: Request, user=Depends(get_verified_user)):
                     )
                 )
         elif (
-            request.app.state.config.ENGINE == "automatic1111"
-            or request.app.state.config.ENGINE == ""
+            request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
+            or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
         ):
             r = requests.get(
                 url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
@@ -433,10 +435,10 @@ async def image_generations(
 
     r = None
     try:
-        if request.app.state.config.ENGINE == "openai":
+        if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
             headers = {}
             headers["Authorization"] = (
-                f"Bearer {request.app.state.config.OPENAI_API_KEY}"
+                f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}"
             )
             headers["Content-Type"] = "application/json"
 
@@ -465,7 +467,7 @@ async def image_generations(
             # Use asyncio.to_thread for the requests.post call
             r = await asyncio.to_thread(
                 requests.post,
-                url=f"{request.app.state.config.OPENAI_API_BASE_URL}/images/generations",
+                url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations",
                 json=data,
                 headers=headers,
             )
@@ -485,7 +487,7 @@ async def image_generations(
 
             return images
 
-        elif request.app.state.config.ENGINE == "comfyui":
+        elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
             data = {
                 "prompt": form_data.prompt,
                 "width": width,
@@ -531,8 +533,8 @@ async def image_generations(
             log.debug(f"images: {images}")
             return images
         elif (
-            request.app.state.config.ENGINE == "automatic1111"
-            or request.app.state.config.ENGINE == ""
+            request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
+            or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
         ):
             if form_data.model:
                 set_image_model(form_data.model)