main.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. import asyncio
  2. import base64
  3. import json
  4. import logging
  5. import mimetypes
  6. import re
  7. import uuid
  8. from pathlib import Path
  9. from typing import Optional
  10. import requests
  11. from open_webui.apps.images.utils.comfyui import (
  12. ComfyUIGenerateImageForm,
  13. ComfyUIWorkflow,
  14. comfyui_generate_image,
  15. )
  16. from open_webui.config import (
  17. AUTOMATIC1111_API_AUTH,
  18. AUTOMATIC1111_BASE_URL,
  19. AUTOMATIC1111_CFG_SCALE,
  20. AUTOMATIC1111_SAMPLER,
  21. AUTOMATIC1111_SCHEDULER,
  22. CACHE_DIR,
  23. COMFYUI_BASE_URL,
  24. COMFYUI_WORKFLOW,
  25. COMFYUI_WORKFLOW_NODES,
  26. CORS_ALLOW_ORIGIN,
  27. ENABLE_IMAGE_GENERATION,
  28. IMAGE_GENERATION_ENGINE,
  29. IMAGE_GENERATION_MODEL,
  30. IMAGE_SIZE,
  31. IMAGE_STEPS,
  32. IMAGES_OPENAI_API_BASE_URL,
  33. IMAGES_OPENAI_API_KEY,
  34. AppConfig,
  35. )
  36. from open_webui.constants import ERROR_MESSAGES
  37. from open_webui.env import SRC_LOG_LEVELS
  38. from fastapi import Depends, FastAPI, HTTPException, Request
  39. from fastapi.middleware.cors import CORSMiddleware
  40. from pydantic import BaseModel
  41. from open_webui.utils.utils import get_admin_user, get_verified_user
  42. log = logging.getLogger(__name__)
  43. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  44. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  45. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  46. app = FastAPI()
  47. app.add_middleware(
  48. CORSMiddleware,
  49. allow_origins=CORS_ALLOW_ORIGIN,
  50. allow_credentials=True,
  51. allow_methods=["*"],
  52. allow_headers=["*"],
  53. )
  54. app.state.config = AppConfig()
  55. app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
  56. app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
  57. app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
  58. app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
  59. app.state.config.MODEL = IMAGE_GENERATION_MODEL
  60. app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  61. app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
  62. app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
  63. app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
  64. app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
  65. app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  66. app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
  67. app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
  68. app.state.config.IMAGE_SIZE = IMAGE_SIZE
  69. app.state.config.IMAGE_STEPS = IMAGE_STEPS
  70. @app.get("/config")
  71. async def get_config(request: Request, user=Depends(get_admin_user)):
  72. return {
  73. "enabled": app.state.config.ENABLED,
  74. "engine": app.state.config.ENGINE,
  75. "openai": {
  76. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  77. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  78. },
  79. "automatic1111": {
  80. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  81. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  82. "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
  83. "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
  84. "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
  85. },
  86. "comfyui": {
  87. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  88. "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
  89. "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
  90. },
  91. }
  92. class OpenAIConfigForm(BaseModel):
  93. OPENAI_API_BASE_URL: str
  94. OPENAI_API_KEY: str
  95. class Automatic1111ConfigForm(BaseModel):
  96. AUTOMATIC1111_BASE_URL: str
  97. AUTOMATIC1111_API_AUTH: str
  98. AUTOMATIC1111_CFG_SCALE: Optional[str]
  99. AUTOMATIC1111_SAMPLER: Optional[str]
  100. AUTOMATIC1111_SCHEDULER: Optional[str]
  101. class ComfyUIConfigForm(BaseModel):
  102. COMFYUI_BASE_URL: str
  103. COMFYUI_WORKFLOW: str
  104. COMFYUI_WORKFLOW_NODES: list[dict]
  105. class ConfigForm(BaseModel):
  106. enabled: bool
  107. engine: str
  108. openai: OpenAIConfigForm
  109. automatic1111: Automatic1111ConfigForm
  110. comfyui: ComfyUIConfigForm
  111. @app.post("/config/update")
  112. async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
  113. app.state.config.ENGINE = form_data.engine
  114. app.state.config.ENABLED = form_data.enabled
  115. app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
  116. app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
  117. app.state.config.AUTOMATIC1111_BASE_URL = (
  118. form_data.automatic1111.AUTOMATIC1111_BASE_URL
  119. )
  120. app.state.config.AUTOMATIC1111_API_AUTH = (
  121. form_data.automatic1111.AUTOMATIC1111_API_AUTH
  122. )
  123. app.state.config.AUTOMATIC1111_CFG_SCALE = float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE) if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE != "" else None
  124. app.state.config.AUTOMATIC1111_SAMPLER = form_data.automatic1111.AUTOMATIC1111_SAMPLER if form_data.automatic1111.AUTOMATIC1111_SAMPLER != "" else None
  125. app.state.config.AUTOMATIC1111_SCHEDULER = form_data.automatic1111.AUTOMATIC1111_SCHEDULER if form_data.automatic1111.AUTOMATIC1111_SCHEDULER != "" else None
  126. app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/")
  127. app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
  128. app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES
  129. return {
  130. "enabled": app.state.config.ENABLED,
  131. "engine": app.state.config.ENGINE,
  132. "openai": {
  133. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  134. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  135. },
  136. "automatic1111": {
  137. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  138. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  139. "AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
  140. "AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
  141. "AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
  142. },
  143. "comfyui": {
  144. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  145. "COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
  146. "COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
  147. },
  148. }
  149. def get_automatic1111_api_auth():
  150. if app.state.config.AUTOMATIC1111_API_AUTH is None:
  151. return ""
  152. else:
  153. auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
  154. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  155. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  156. return f"Basic {auth1111_base64_encoded_string}"
  157. @app.get("/config/url/verify")
  158. async def verify_url(user=Depends(get_admin_user)):
  159. if app.state.config.ENGINE == "automatic1111":
  160. try:
  161. r = requests.get(
  162. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  163. headers={"authorization": get_automatic1111_api_auth()},
  164. )
  165. r.raise_for_status()
  166. return True
  167. except Exception:
  168. app.state.config.ENABLED = False
  169. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  170. elif app.state.config.ENGINE == "comfyui":
  171. try:
  172. r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
  173. r.raise_for_status()
  174. return True
  175. except Exception:
  176. app.state.config.ENABLED = False
  177. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  178. else:
  179. return True
  180. def set_image_model(model: str):
  181. log.info(f"Setting image model to {model}")
  182. app.state.config.MODEL = model
  183. if app.state.config.ENGINE in ["", "automatic1111"]:
  184. api_auth = get_automatic1111_api_auth()
  185. r = requests.get(
  186. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  187. headers={"authorization": api_auth},
  188. )
  189. options = r.json()
  190. if model != options["sd_model_checkpoint"]:
  191. options["sd_model_checkpoint"] = model
  192. r = requests.post(
  193. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  194. json=options,
  195. headers={"authorization": api_auth},
  196. )
  197. return app.state.config.MODEL
  198. def get_image_model():
  199. if app.state.config.ENGINE == "openai":
  200. return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
  201. elif app.state.config.ENGINE == "comfyui":
  202. return app.state.config.MODEL if app.state.config.MODEL else ""
  203. elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "":
  204. try:
  205. r = requests.get(
  206. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  207. headers={"authorization": get_automatic1111_api_auth()},
  208. )
  209. options = r.json()
  210. return options["sd_model_checkpoint"]
  211. except Exception as e:
  212. app.state.config.ENABLED = False
  213. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  214. class ImageConfigForm(BaseModel):
  215. MODEL: str
  216. IMAGE_SIZE: str
  217. IMAGE_STEPS: int
  218. @app.get("/image/config")
  219. async def get_image_config(user=Depends(get_admin_user)):
  220. return {
  221. "MODEL": app.state.config.MODEL,
  222. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  223. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  224. }
  225. @app.post("/image/config/update")
  226. async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)):
  227. set_image_model(form_data.MODEL)
  228. pattern = r"^\d+x\d+$"
  229. if re.match(pattern, form_data.IMAGE_SIZE):
  230. app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
  231. else:
  232. raise HTTPException(
  233. status_code=400,
  234. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  235. )
  236. if form_data.IMAGE_STEPS >= 0:
  237. app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
  238. else:
  239. raise HTTPException(
  240. status_code=400,
  241. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  242. )
  243. return {
  244. "MODEL": app.state.config.MODEL,
  245. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  246. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  247. }
  248. @app.get("/models")
  249. def get_models(user=Depends(get_verified_user)):
  250. try:
  251. if app.state.config.ENGINE == "openai":
  252. return [
  253. {"id": "dall-e-2", "name": "DALL·E 2"},
  254. {"id": "dall-e-3", "name": "DALL·E 3"},
  255. ]
  256. elif app.state.config.ENGINE == "comfyui":
  257. # TODO - get models from comfyui
  258. r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
  259. info = r.json()
  260. workflow = json.loads(app.state.config.COMFYUI_WORKFLOW)
  261. model_node_id = None
  262. for node in app.state.config.COMFYUI_WORKFLOW_NODES:
  263. if node["type"] == "model":
  264. if node["node_ids"]:
  265. model_node_id = node["node_ids"][0]
  266. break
  267. if model_node_id:
  268. model_list_key = None
  269. print(workflow[model_node_id]["class_type"])
  270. for key in info[workflow[model_node_id]["class_type"]]["input"][
  271. "required"
  272. ]:
  273. if "_name" in key:
  274. model_list_key = key
  275. break
  276. if model_list_key:
  277. return list(
  278. map(
  279. lambda model: {"id": model, "name": model},
  280. info[workflow[model_node_id]["class_type"]]["input"][
  281. "required"
  282. ][model_list_key][0],
  283. )
  284. )
  285. else:
  286. return list(
  287. map(
  288. lambda model: {"id": model, "name": model},
  289. info["CheckpointLoaderSimple"]["input"]["required"][
  290. "ckpt_name"
  291. ][0],
  292. )
  293. )
  294. elif (
  295. app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
  296. ):
  297. r = requests.get(
  298. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  299. headers={"authorization": get_automatic1111_api_auth()},
  300. )
  301. models = r.json()
  302. return list(
  303. map(
  304. lambda model: {"id": model["title"], "name": model["model_name"]},
  305. models,
  306. )
  307. )
  308. except Exception as e:
  309. app.state.config.ENABLED = False
  310. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  311. class GenerateImageForm(BaseModel):
  312. model: Optional[str] = None
  313. prompt: str
  314. size: Optional[str] = None
  315. n: int = 1
  316. negative_prompt: Optional[str] = None
  317. def save_b64_image(b64_str):
  318. try:
  319. image_id = str(uuid.uuid4())
  320. if "," in b64_str:
  321. header, encoded = b64_str.split(",", 1)
  322. mime_type = header.split(";")[0]
  323. img_data = base64.b64decode(encoded)
  324. image_format = mimetypes.guess_extension(mime_type)
  325. image_filename = f"{image_id}{image_format}"
  326. file_path = IMAGE_CACHE_DIR / f"{image_filename}"
  327. with open(file_path, "wb") as f:
  328. f.write(img_data)
  329. return image_filename
  330. else:
  331. image_filename = f"{image_id}.png"
  332. file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
  333. img_data = base64.b64decode(b64_str)
  334. # Write the image data to a file
  335. with open(file_path, "wb") as f:
  336. f.write(img_data)
  337. return image_filename
  338. except Exception as e:
  339. log.exception(f"Error saving image: {e}")
  340. return None
  341. def save_url_image(url):
  342. image_id = str(uuid.uuid4())
  343. try:
  344. r = requests.get(url)
  345. r.raise_for_status()
  346. if r.headers["content-type"].split("/")[0] == "image":
  347. mime_type = r.headers["content-type"]
  348. image_format = mimetypes.guess_extension(mime_type)
  349. if not image_format:
  350. raise ValueError("Could not determine image type from MIME type")
  351. image_filename = f"{image_id}{image_format}"
  352. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
  353. with open(file_path, "wb") as image_file:
  354. for chunk in r.iter_content(chunk_size=8192):
  355. image_file.write(chunk)
  356. return image_filename
  357. else:
  358. log.error("Url does not point to an image.")
  359. return None
  360. except Exception as e:
  361. log.exception(f"Error saving image: {e}")
  362. return None
  363. @app.post("/generations")
  364. async def image_generations(
  365. form_data: GenerateImageForm,
  366. user=Depends(get_verified_user),
  367. ):
  368. width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
  369. r = None
  370. try:
  371. if app.state.config.ENGINE == "openai":
  372. headers = {}
  373. headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
  374. headers["Content-Type"] = "application/json"
  375. data = {
  376. "model": (
  377. app.state.config.MODEL
  378. if app.state.config.MODEL != ""
  379. else "dall-e-2"
  380. ),
  381. "prompt": form_data.prompt,
  382. "n": form_data.n,
  383. "size": (
  384. form_data.size if form_data.size else app.state.config.IMAGE_SIZE
  385. ),
  386. "response_format": "b64_json",
  387. }
  388. r = requests.post(
  389. url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
  390. json=data,
  391. headers=headers,
  392. )
  393. r.raise_for_status()
  394. res = r.json()
  395. images = []
  396. for image in res["data"]:
  397. image_filename = save_b64_image(image["b64_json"])
  398. images.append({"url": f"/cache/image/generations/{image_filename}"})
  399. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  400. with open(file_body_path, "w") as f:
  401. json.dump(data, f)
  402. return images
  403. elif app.state.config.ENGINE == "comfyui":
  404. data = {
  405. "prompt": form_data.prompt,
  406. "width": width,
  407. "height": height,
  408. "n": form_data.n,
  409. }
  410. if app.state.config.IMAGE_STEPS is not None:
  411. data["steps"] = app.state.config.IMAGE_STEPS
  412. if form_data.negative_prompt is not None:
  413. data["negative_prompt"] = form_data.negative_prompt
  414. form_data = ComfyUIGenerateImageForm(
  415. **{
  416. "workflow": ComfyUIWorkflow(
  417. **{
  418. "workflow": app.state.config.COMFYUI_WORKFLOW,
  419. "nodes": app.state.config.COMFYUI_WORKFLOW_NODES,
  420. }
  421. ),
  422. **data,
  423. }
  424. )
  425. res = await comfyui_generate_image(
  426. app.state.config.MODEL,
  427. form_data,
  428. user.id,
  429. app.state.config.COMFYUI_BASE_URL,
  430. )
  431. log.debug(f"res: {res}")
  432. images = []
  433. for image in res["data"]:
  434. image_filename = save_url_image(image["url"])
  435. images.append({"url": f"/cache/image/generations/{image_filename}"})
  436. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  437. with open(file_body_path, "w") as f:
  438. json.dump(form_data.model_dump(exclude_none=True), f)
  439. log.debug(f"images: {images}")
  440. return images
  441. elif (
  442. app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
  443. ):
  444. if form_data.model:
  445. set_image_model(form_data.model)
  446. data = {
  447. "prompt": form_data.prompt,
  448. "batch_size": form_data.n,
  449. "width": width,
  450. "height": height,
  451. }
  452. if app.state.config.IMAGE_STEPS is not None:
  453. data["steps"] = app.state.config.IMAGE_STEPS
  454. if form_data.negative_prompt is not None:
  455. data["negative_prompt"] = form_data.negative_prompt
  456. if app.state.config.AUTOMATIC1111_CFG_SCALE:
  457. data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE
  458. if app.state.config.AUTOMATIC1111_SAMPLER:
  459. data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER
  460. if app.state.config.AUTOMATIC1111_SCHEDULER:
  461. data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER
  462. # Use asyncio.to_thread for the requests.post call
  463. r = await asyncio.to_thread(
  464. requests.post,
  465. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  466. json=data,
  467. headers={"authorization": get_automatic1111_api_auth()},
  468. )
  469. res = r.json()
  470. log.debug(f"res: {res}")
  471. images = []
  472. for image in res["images"]:
  473. image_filename = save_b64_image(image)
  474. images.append({"url": f"/cache/image/generations/{image_filename}"})
  475. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  476. with open(file_body_path, "w") as f:
  477. json.dump({**data, "info": res["info"]}, f)
  478. return images
  479. except Exception as e:
  480. error = e
  481. if r != None:
  482. data = r.json()
  483. if "error" in data:
  484. error = data["error"]["message"]
  485. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))