|
@@ -55,6 +55,10 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
|
|
|
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
|
|
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
|
|
},
|
|
|
+ "gemini": {
|
|
|
+ "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
|
|
+ "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
|
|
+ },
|
|
|
}
|
|
|
|
|
|
|
|
@@ -78,6 +82,11 @@ class ComfyUIConfigForm(BaseModel):
|
|
|
COMFYUI_WORKFLOW_NODES: list[dict]
|
|
|
|
|
|
|
|
|
+class GeminiConfigForm(BaseModel):
|
|
|
+ GEMINI_API_BASE_URL: str
|
|
|
+ GEMINI_API_KEY: str
|
|
|
+
|
|
|
+
|
|
|
class ConfigForm(BaseModel):
|
|
|
enabled: bool
|
|
|
engine: str
|
|
@@ -85,6 +94,7 @@ class ConfigForm(BaseModel):
|
|
|
openai: OpenAIConfigForm
|
|
|
automatic1111: Automatic1111ConfigForm
|
|
|
comfyui: ComfyUIConfigForm
|
|
|
+ gemini: GeminiConfigForm
|
|
|
|
|
|
|
|
|
@router.post("/config/update")
|
|
@@ -103,6 +113,11 @@ async def update_config(
|
|
|
)
|
|
|
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
|
|
|
|
|
+ request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
|
|
|
+ form_data.gemini.GEMINI_API_BASE_URL
|
|
|
+ )
|
|
|
+ request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
|
|
|
+
|
|
|
request.app.state.config.AUTOMATIC1111_BASE_URL = (
|
|
|
form_data.automatic1111.AUTOMATIC1111_BASE_URL
|
|
|
)
|
|
@@ -155,6 +170,10 @@ async def update_config(
|
|
|
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
|
|
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
|
|
},
|
|
|
+ "gemini": {
|
|
|
+ "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
|
|
|
+ "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
|
|
|
+ },
|
|
|
}
|
|
|
|
|
|
|
|
@@ -224,6 +243,12 @@ def get_image_model(request):
|
|
|
if request.app.state.config.IMAGE_GENERATION_MODEL
|
|
|
else "dall-e-2"
|
|
|
)
|
|
|
+ elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
|
|
+ return (
|
|
|
+ request.app.state.config.IMAGE_GENERATION_MODEL
|
|
|
+ if request.app.state.config.IMAGE_GENERATION_MODEL
|
|
|
+ else "imagen-3.0-generate-002"
|
|
|
+ )
|
|
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
|
|
return (
|
|
|
request.app.state.config.IMAGE_GENERATION_MODEL
|
|
@@ -299,6 +324,10 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
|
|
{"id": "dall-e-2", "name": "DALL·E 2"},
|
|
|
{"id": "dall-e-3", "name": "DALL·E 3"},
|
|
|
]
|
|
|
+ elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
|
|
+ return [
|
|
|
+ {"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"},
|
|
|
+ ]
|
|
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
|
|
# TODO - get models from comfyui
|
|
|
headers = {
|
|
@@ -483,6 +512,40 @@ async def image_generations(
|
|
|
images.append({"url": url})
|
|
|
return images
|
|
|
|
|
|
+ elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
|
|
+ headers = {}
|
|
|
+ headers["Content-Type"] = "application/json"
|
|
|
+ api_key = request.app.state.config.IMAGES_GEMINI_API_KEY
|
|
|
+ model = get_image_model(request)
|
|
|
+ data = {
|
|
|
+ "instances": {"prompt": form_data.prompt},
|
|
|
+ "parameters": {
|
|
|
+ "sampleCount": form_data.n,
|
|
|
+ "outputOptions": {"mimeType": "image/png"},
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ # Use asyncio.to_thread for the requests.post call
|
|
|
+ r = await asyncio.to_thread(
|
|
|
+ requests.post,
|
|
|
+ url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict?key={api_key}",
|
|
|
+ json=data,
|
|
|
+ headers=headers,
|
|
|
+ )
|
|
|
+
|
|
|
+ r.raise_for_status()
|
|
|
+ res = r.json()
|
|
|
+
|
|
|
+ images = []
|
|
|
+ for image in res["predictions"]:
|
|
|
+ image_data, content_type = load_b64_image_data(
|
|
|
+ image["bytesBase64Encoded"]
|
|
|
+ )
|
|
|
+ url = upload_image(request, data, image_data, content_type, user)
|
|
|
+ images.append({"url": url})
|
|
|
+
|
|
|
+ return images
|
|
|
+
|
|
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
|
|
data = {
|
|
|
"prompt": form_data.prompt,
|