main.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. import asyncio
  2. import base64
  3. import json
  4. import logging
  5. import mimetypes
  6. import re
  7. import uuid
  8. from pathlib import Path
  9. from typing import Optional
  10. import requests
  11. from open_webui.apps.images.utils.comfyui import (
  12. ComfyUIGenerateImageForm,
  13. ComfyUIWorkflow,
  14. comfyui_generate_image,
  15. )
  16. from open_webui.config import (
  17. AUTOMATIC1111_API_AUTH,
  18. AUTOMATIC1111_BASE_URL,
  19. CACHE_DIR,
  20. COMFYUI_BASE_URL,
  21. COMFYUI_WORKFLOW,
  22. COMFYUI_WORKFLOW_NODES,
  23. CORS_ALLOW_ORIGIN,
  24. ENABLE_IMAGE_GENERATION,
  25. IMAGE_GENERATION_ENGINE,
  26. IMAGE_GENERATION_MODEL,
  27. IMAGE_SIZE,
  28. IMAGE_STEPS,
  29. IMAGES_OPENAI_API_BASE_URL,
  30. IMAGES_OPENAI_API_KEY,
  31. AppConfig,
  32. )
  33. from open_webui.constants import ERROR_MESSAGES
  34. from open_webui.env import SRC_LOG_LEVELS
  35. from fastapi import Depends, FastAPI, HTTPException, Request
  36. from fastapi.middleware.cors import CORSMiddleware
  37. from pydantic import BaseModel
  38. from open_webui.utils.utils import get_admin_user, get_verified_user
  39. log = logging.getLogger(__name__)
  40. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  41. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  42. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  43. app = FastAPI()
  44. app.add_middleware(
  45. CORSMiddleware,
  46. allow_origins=CORS_ALLOW_ORIGIN,
  47. allow_credentials=True,
  48. allow_methods=["*"],
  49. allow_headers=["*"],
  50. )
  51. app.state.config = AppConfig()
  52. app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
  53. app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
  54. app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
  55. app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
  56. app.state.config.MODEL = IMAGE_GENERATION_MODEL
  57. app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  58. app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
  59. app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  60. app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
  61. app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
  62. app.state.config.IMAGE_SIZE = IMAGE_SIZE
  63. app.state.config.IMAGE_STEPS = IMAGE_STEPS
  64. @app.get("/config")
  65. async def get_config(request: Request, user=Depends(get_admin_user)):
  66. return {
  67. "enabled": app.state.config.ENABLED,
  68. "engine": app.state.config.ENGINE,
  69. "openai": {
  70. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  71. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  72. },
  73. "automatic1111": {
  74. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  75. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  76. },
  77. "comfyui": {
  78. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  79. "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
  80. "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
  81. },
  82. }
  83. class OpenAIConfigForm(BaseModel):
  84. OPENAI_API_BASE_URL: str
  85. OPENAI_API_KEY: str
  86. class Automatic1111ConfigForm(BaseModel):
  87. AUTOMATIC1111_BASE_URL: str
  88. AUTOMATIC1111_API_AUTH: str
  89. class ComfyUIConfigForm(BaseModel):
  90. COMFYUI_BASE_URL: str
  91. COMFYUI_WORKFLOW: str
  92. COMFYUI_WORKFLOW_NODES: list[dict]
  93. class ConfigForm(BaseModel):
  94. enabled: bool
  95. engine: str
  96. openai: OpenAIConfigForm
  97. automatic1111: Automatic1111ConfigForm
  98. comfyui: ComfyUIConfigForm
  99. @app.post("/config/update")
  100. async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
  101. app.state.config.ENGINE = form_data.engine
  102. app.state.config.ENABLED = form_data.enabled
  103. app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
  104. app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
  105. app.state.config.AUTOMATIC1111_BASE_URL = (
  106. form_data.automatic1111.AUTOMATIC1111_BASE_URL
  107. )
  108. app.state.config.AUTOMATIC1111_API_AUTH = (
  109. form_data.automatic1111.AUTOMATIC1111_API_AUTH
  110. )
  111. app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/")
  112. app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
  113. app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES
  114. return {
  115. "enabled": app.state.config.ENABLED,
  116. "engine": app.state.config.ENGINE,
  117. "openai": {
  118. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  119. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  120. },
  121. "automatic1111": {
  122. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  123. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  124. },
  125. "comfyui": {
  126. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  127. "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
  128. "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
  129. },
  130. }
  131. def get_automatic1111_api_auth():
  132. if app.state.config.AUTOMATIC1111_API_AUTH is None:
  133. return ""
  134. else:
  135. auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
  136. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  137. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  138. return f"Basic {auth1111_base64_encoded_string}"
  139. @app.get("/config/url/verify")
  140. async def verify_url(user=Depends(get_admin_user)):
  141. if app.state.config.ENGINE == "automatic1111":
  142. try:
  143. r = requests.get(
  144. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  145. headers={"authorization": get_automatic1111_api_auth()},
  146. )
  147. r.raise_for_status()
  148. return True
  149. except Exception:
  150. app.state.config.ENABLED = False
  151. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  152. elif app.state.config.ENGINE == "comfyui":
  153. try:
  154. r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
  155. r.raise_for_status()
  156. return True
  157. except Exception:
  158. app.state.config.ENABLED = False
  159. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  160. else:
  161. return True
  162. def set_image_model(model: str):
  163. log.info(f"Setting image model to {model}")
  164. app.state.config.MODEL = model
  165. if app.state.config.ENGINE in ["", "automatic1111"]:
  166. api_auth = get_automatic1111_api_auth()
  167. r = requests.get(
  168. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  169. headers={"authorization": api_auth},
  170. )
  171. options = r.json()
  172. if model != options["sd_model_checkpoint"]:
  173. options["sd_model_checkpoint"] = model
  174. r = requests.post(
  175. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  176. json=options,
  177. headers={"authorization": api_auth},
  178. )
  179. return app.state.config.MODEL
  180. def get_image_model():
  181. if app.state.config.ENGINE == "openai":
  182. return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
  183. elif app.state.config.ENGINE == "comfyui":
  184. return app.state.config.MODEL if app.state.config.MODEL else ""
  185. elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "":
  186. try:
  187. r = requests.get(
  188. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  189. headers={"authorization": get_automatic1111_api_auth()},
  190. )
  191. options = r.json()
  192. return options["sd_model_checkpoint"]
  193. except Exception as e:
  194. app.state.config.ENABLED = False
  195. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  196. class ImageConfigForm(BaseModel):
  197. MODEL: str
  198. IMAGE_SIZE: str
  199. IMAGE_STEPS: int
  200. @app.get("/image/config")
  201. async def get_image_config(user=Depends(get_admin_user)):
  202. return {
  203. "MODEL": app.state.config.MODEL,
  204. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  205. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  206. }
  207. @app.post("/image/config/update")
  208. async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)):
  209. set_image_model(form_data.MODEL)
  210. pattern = r"^\d+x\d+$"
  211. if re.match(pattern, form_data.IMAGE_SIZE):
  212. app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
  213. else:
  214. raise HTTPException(
  215. status_code=400,
  216. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  217. )
  218. if form_data.IMAGE_STEPS >= 0:
  219. app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
  220. else:
  221. raise HTTPException(
  222. status_code=400,
  223. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  224. )
  225. return {
  226. "MODEL": app.state.config.MODEL,
  227. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  228. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  229. }
  230. @app.get("/models")
  231. def get_models(user=Depends(get_verified_user)):
  232. try:
  233. if app.state.config.ENGINE == "openai":
  234. return [
  235. {"id": "dall-e-2", "name": "DALL·E 2"},
  236. {"id": "dall-e-3", "name": "DALL·E 3"},
  237. ]
  238. elif app.state.config.ENGINE == "comfyui":
  239. # TODO - get models from comfyui
  240. r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
  241. info = r.json()
  242. workflow = json.loads(app.state.config.COMFYUI_WORKFLOW)
  243. model_node_id = None
  244. for node in app.state.config.COMFYUI_WORKFLOW_NODES:
  245. if node["type"] == "model":
  246. if node["node_ids"]:
  247. model_node_id = node["node_ids"][0]
  248. break
  249. if model_node_id:
  250. model_list_key = None
  251. print(workflow[model_node_id]["class_type"])
  252. for key in info[workflow[model_node_id]["class_type"]]["input"][
  253. "required"
  254. ]:
  255. if "_name" in key:
  256. model_list_key = key
  257. break
  258. if model_list_key:
  259. return list(
  260. map(
  261. lambda model: {"id": model, "name": model},
  262. info[workflow[model_node_id]["class_type"]]["input"][
  263. "required"
  264. ][model_list_key][0],
  265. )
  266. )
  267. else:
  268. return list(
  269. map(
  270. lambda model: {"id": model, "name": model},
  271. info["CheckpointLoaderSimple"]["input"]["required"][
  272. "ckpt_name"
  273. ][0],
  274. )
  275. )
  276. elif (
  277. app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
  278. ):
  279. r = requests.get(
  280. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  281. headers={"authorization": get_automatic1111_api_auth()},
  282. )
  283. models = r.json()
  284. return list(
  285. map(
  286. lambda model: {"id": model["title"], "name": model["model_name"]},
  287. models,
  288. )
  289. )
  290. except Exception as e:
  291. app.state.config.ENABLED = False
  292. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  293. class GenerateImageForm(BaseModel):
  294. model: Optional[str] = None
  295. prompt: str
  296. size: Optional[str] = None
  297. n: int = 1
  298. negative_prompt: Optional[str] = None
  299. def save_b64_image(b64_str):
  300. try:
  301. image_id = str(uuid.uuid4())
  302. if "," in b64_str:
  303. header, encoded = b64_str.split(",", 1)
  304. mime_type = header.split(";")[0]
  305. img_data = base64.b64decode(encoded)
  306. image_format = mimetypes.guess_extension(mime_type)
  307. image_filename = f"{image_id}{image_format}"
  308. file_path = IMAGE_CACHE_DIR / f"{image_filename}"
  309. with open(file_path, "wb") as f:
  310. f.write(img_data)
  311. return image_filename
  312. else:
  313. image_filename = f"{image_id}.png"
  314. file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
  315. img_data = base64.b64decode(b64_str)
  316. # Write the image data to a file
  317. with open(file_path, "wb") as f:
  318. f.write(img_data)
  319. return image_filename
  320. except Exception as e:
  321. log.exception(f"Error saving image: {e}")
  322. return None
  323. def save_url_image(url):
  324. image_id = str(uuid.uuid4())
  325. try:
  326. r = requests.get(url)
  327. r.raise_for_status()
  328. if r.headers["content-type"].split("/")[0] == "image":
  329. mime_type = r.headers["content-type"]
  330. image_format = mimetypes.guess_extension(mime_type)
  331. if not image_format:
  332. raise ValueError("Could not determine image type from MIME type")
  333. image_filename = f"{image_id}{image_format}"
  334. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
  335. with open(file_path, "wb") as image_file:
  336. for chunk in r.iter_content(chunk_size=8192):
  337. image_file.write(chunk)
  338. return image_filename
  339. else:
  340. log.error("Url does not point to an image.")
  341. return None
  342. except Exception as e:
  343. log.exception(f"Error saving image: {e}")
  344. return None
  345. @app.post("/generations")
  346. async def image_generations(
  347. form_data: GenerateImageForm,
  348. user=Depends(get_verified_user),
  349. ):
  350. width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
  351. r = None
  352. try:
  353. if app.state.config.ENGINE == "openai":
  354. headers = {}
  355. headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
  356. headers["Content-Type"] = "application/json"
  357. data = {
  358. "model": (
  359. app.state.config.MODEL
  360. if app.state.config.MODEL != ""
  361. else "dall-e-2"
  362. ),
  363. "prompt": form_data.prompt,
  364. "n": form_data.n,
  365. "size": (
  366. form_data.size if form_data.size else app.state.config.IMAGE_SIZE
  367. ),
  368. "response_format": "b64_json",
  369. }
  370. r = requests.post(
  371. url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
  372. json=data,
  373. headers=headers,
  374. )
  375. r.raise_for_status()
  376. res = r.json()
  377. images = []
  378. for image in res["data"]:
  379. image_filename = save_b64_image(image["b64_json"])
  380. images.append({"url": f"/cache/image/generations/{image_filename}"})
  381. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  382. with open(file_body_path, "w") as f:
  383. json.dump(data, f)
  384. return images
  385. elif app.state.config.ENGINE == "comfyui":
  386. data = {
  387. "prompt": form_data.prompt,
  388. "width": width,
  389. "height": height,
  390. "n": form_data.n,
  391. }
  392. if app.state.config.IMAGE_STEPS is not None:
  393. data["steps"] = app.state.config.IMAGE_STEPS
  394. if form_data.negative_prompt is not None:
  395. data["negative_prompt"] = form_data.negative_prompt
  396. form_data = ComfyUIGenerateImageForm(
  397. **{
  398. "workflow": ComfyUIWorkflow(
  399. **{
  400. "workflow": app.state.config.COMFYUI_WORKFLOW,
  401. "nodes": app.state.config.COMFYUI_WORKFLOW_NODES,
  402. }
  403. ),
  404. **data,
  405. }
  406. )
  407. res = await comfyui_generate_image(
  408. app.state.config.MODEL,
  409. form_data,
  410. user.id,
  411. app.state.config.COMFYUI_BASE_URL,
  412. )
  413. log.debug(f"res: {res}")
  414. images = []
  415. for image in res["data"]:
  416. image_filename = save_url_image(image["url"])
  417. images.append({"url": f"/cache/image/generations/{image_filename}"})
  418. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  419. with open(file_body_path, "w") as f:
  420. json.dump(form_data.model_dump(exclude_none=True), f)
  421. log.debug(f"images: {images}")
  422. return images
  423. elif (
  424. app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
  425. ):
  426. if form_data.model:
  427. set_image_model(form_data.model)
  428. data = {
  429. "prompt": form_data.prompt,
  430. "batch_size": form_data.n,
  431. "width": width,
  432. "height": height,
  433. }
  434. if app.state.config.IMAGE_STEPS is not None:
  435. data["steps"] = app.state.config.IMAGE_STEPS
  436. if form_data.negative_prompt is not None:
  437. data["negative_prompt"] = form_data.negative_prompt
  438. # Use asyncio.to_thread for the requests.post call
  439. r = await asyncio.to_thread(
  440. requests.post,
  441. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  442. json=data,
  443. headers={"authorization": get_automatic1111_api_auth()},
  444. )
  445. res = r.json()
  446. log.debug(f"res: {res}")
  447. images = []
  448. for image in res["images"]:
  449. image_filename = save_b64_image(image)
  450. images.append({"url": f"/cache/image/generations/{image_filename}"})
  451. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  452. with open(file_body_path, "w") as f:
  453. json.dump({**data, "info": res["info"]}, f)
  454. return images
  455. except Exception as e:
  456. error = e
  457. if r != None:
  458. data = r.json()
  459. if "error" in data:
  460. error = data["error"]["message"]
  461. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))