Selaa lähdekoodia

Merge pull request #10309 from JoaoCostaIFG/gemini_image_gen

feat: add Google Imagen/Gemini API image generation
Timothy Jaeryang Baek 2 kuukautta sitten
vanhempi
commit
32a90deeaf

+ 14 - 0
backend/open_webui/config.py

@@ -783,6 +783,9 @@ ENABLE_OPENAI_API = PersistentConfig(
 OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
 OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
 
+GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
+GEMINI_API_BASE_URL = os.environ.get("GEMINI_API_BASE_URL", "")
+
 
 if OPENAI_API_BASE_URL == "":
     OPENAI_API_BASE_URL = "https://api.openai.com/v1"
@@ -2135,6 +2138,17 @@ IMAGES_OPENAI_API_KEY = PersistentConfig(
     os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
 )
 
+IMAGES_GEMINI_API_BASE_URL = PersistentConfig(
+    "IMAGES_GEMINI_API_BASE_URL",
+    "image_generation.gemini.api_base_url",
+    os.getenv("IMAGES_GEMINI_API_BASE_URL", GEMINI_API_BASE_URL),
+)
+IMAGES_GEMINI_API_KEY = PersistentConfig(
+    "IMAGES_GEMINI_API_KEY",
+    "image_generation.gemini.api_key",
+    os.getenv("IMAGES_GEMINI_API_KEY", GEMINI_API_KEY),
+)
+
 IMAGE_SIZE = PersistentConfig(
     "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
 )

+ 5 - 0
backend/open_webui/main.py

@@ -131,6 +131,8 @@ from open_webui.config import (
     IMAGE_STEPS,
     IMAGES_OPENAI_API_BASE_URL,
     IMAGES_OPENAI_API_KEY,
+    IMAGES_GEMINI_API_BASE_URL,
+    IMAGES_GEMINI_API_KEY,
     # Audio
     AUDIO_STT_ENGINE,
     AUDIO_STT_MODEL,
@@ -658,6 +660,9 @@ app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION
 app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
 app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
 
+app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL
+app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
+
 app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL
 
 app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL

+ 64 - 0
backend/open_webui/routers/images.py

@@ -55,6 +55,10 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
             "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
             "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
         },
+        "gemini": {
+            "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
+            "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
+        },
     }
 
 
@@ -78,6 +82,11 @@ class ComfyUIConfigForm(BaseModel):
     COMFYUI_WORKFLOW_NODES: list[dict]
 
 
+class GeminiConfigForm(BaseModel):
+    GEMINI_API_BASE_URL: str
+    GEMINI_API_KEY: str
+
+
 class ConfigForm(BaseModel):
     enabled: bool
     engine: str
@@ -85,6 +94,7 @@ class ConfigForm(BaseModel):
     openai: OpenAIConfigForm
     automatic1111: Automatic1111ConfigForm
     comfyui: ComfyUIConfigForm
+    gemini: GeminiConfigForm
 
 
 @router.post("/config/update")
@@ -103,6 +113,11 @@ async def update_config(
     )
     request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
 
+    request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
+        form_data.gemini.GEMINI_API_BASE_URL
+    )
+    request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
+
     request.app.state.config.AUTOMATIC1111_BASE_URL = (
         form_data.automatic1111.AUTOMATIC1111_BASE_URL
     )
@@ -155,6 +170,10 @@ async def update_config(
             "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
             "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
         },
+        "gemini": {
+            "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
+            "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
+        },
     }
 
 
@@ -224,6 +243,12 @@ def get_image_model(request):
             if request.app.state.config.IMAGE_GENERATION_MODEL
             else "dall-e-2"
         )
+    elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
+        return (
+            request.app.state.config.IMAGE_GENERATION_MODEL
+            if request.app.state.config.IMAGE_GENERATION_MODEL
+            else "imagen-3.0-generate-002"
+        )
     elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
         return (
             request.app.state.config.IMAGE_GENERATION_MODEL
@@ -299,6 +324,10 @@ def get_models(request: Request, user=Depends(get_verified_user)):
                 {"id": "dall-e-2", "name": "DALL·E 2"},
                 {"id": "dall-e-3", "name": "DALL·E 3"},
             ]
+        elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
+            return [
+                {"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"},
+            ]
         elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
             # TODO - get models from comfyui
             headers = {
@@ -483,6 +512,41 @@ async def image_generations(
                 images.append({"url": url})
             return images
 
+        elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
+            headers = {}
+            headers["Content-Type"] = "application/json"
+            headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
+
+            model = get_image_model(request)
+            data = {
+                "instances": {"prompt": form_data.prompt},
+                "parameters": {
+                    "sampleCount": form_data.n,
+                    "outputOptions": {"mimeType": "image/png"},
+                },
+            }
+
+            # Use asyncio.to_thread for the requests.post call
+            r = await asyncio.to_thread(
+                requests.post,
+                url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
+                json=data,
+                headers=headers,
+            )
+
+            r.raise_for_status()
+            res = r.json()
+
+            images = []
+            for image in res["predictions"]:
+                image_data, content_type = load_b64_image_data(
+                    image["bytesBase64Encoded"]
+                )
+                url = upload_image(request, data, image_data, content_type, user)
+                images.append({"url": url})
+
+            return images
+
         elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
             data = {
                 "prompt": form_data.prompt,

+ 22 - 0
src/lib/components/admin/Settings/Images.svelte

@@ -261,6 +261,9 @@
 										} else if (config.engine === 'openai' && config.openai.OPENAI_API_KEY === '') {
 											toast.error($i18n.t('OpenAI API Key is required.'));
 											config.enabled = false;
+										} else if (config.engine === 'gemini' && config.gemini.GEMINI_API_KEY === '') {
+											toast.error($i18n.t('Gemini API Key is required.'));
+											config.enabled = false;
 										}
 									}
 
@@ -294,6 +297,7 @@
 							<option value="openai">{$i18n.t('Default (Open AI)')}</option>
 							<option value="comfyui">{$i18n.t('ComfyUI')}</option>
 							<option value="automatic1111">{$i18n.t('Automatic1111')}</option>
+							<option value="gemini">{$i18n.t('Gemini')}</option>
 						</select>
 					</div>
 				</div>
@@ -605,6 +609,24 @@
 							/>
 						</div>
 					</div>
+				{:else if config?.engine === 'gemini'}
+					<div>
+						<div class=" mb-1.5 text-sm font-medium">{$i18n.t('Gemini API Config')}</div>
+
+						<div class="flex gap-2 mb-1">
+							<input
+								class="flex-1 w-full text-sm bg-transparent outline-none"
+								placeholder={$i18n.t('API Base URL')}
+								bind:value={config.gemini.GEMINI_API_BASE_URL}
+								required
+							/>
+
+							<SensitiveInput
+								placeholder={$i18n.t('API Key')}
+								bind:value={config.gemini.GEMINI_API_KEY}
+							/>
+						</div>
+					</div>
 				{/if}
 			</div>