images.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. import asyncio
  2. import base64
  3. import json
  4. import logging
  5. import mimetypes
  6. import re
  7. import uuid
  8. from pathlib import Path
  9. from typing import Optional
  10. import requests
  11. from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
  12. from fastapi.middleware.cors import CORSMiddleware
  13. from pydantic import BaseModel
  14. from open_webui.config import CACHE_DIR
  15. from open_webui.constants import ERROR_MESSAGES
  16. from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
  17. from open_webui.utils.auth import get_admin_user, get_verified_user
  18. from open_webui.utils.images.comfyui import (
  19. ComfyUIGenerateImageForm,
  20. ComfyUIWorkflow,
  21. comfyui_generate_image,
  22. )
  23. log = logging.getLogger(__name__)
  24. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  25. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  26. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  27. router = APIRouter()
  28. @router.get("/config")
  29. async def get_config(request: Request, user=Depends(get_admin_user)):
  30. return {
  31. "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
  32. "engine": request.app.state.config.ENGINE,
  33. "openai": {
  34. "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
  35. "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY,
  36. },
  37. "automatic1111": {
  38. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  39. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  40. "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
  41. "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
  42. "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
  43. },
  44. "comfyui": {
  45. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  46. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  47. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  48. },
  49. }
  50. class OpenAIConfigForm(BaseModel):
  51. OPENAI_API_BASE_URL: str
  52. OPENAI_API_KEY: str
  53. class Automatic1111ConfigForm(BaseModel):
  54. AUTOMATIC1111_BASE_URL: str
  55. AUTOMATIC1111_API_AUTH: str
  56. AUTOMATIC1111_CFG_SCALE: Optional[str | float | int]
  57. AUTOMATIC1111_SAMPLER: Optional[str]
  58. AUTOMATIC1111_SCHEDULER: Optional[str]
  59. class ComfyUIConfigForm(BaseModel):
  60. COMFYUI_BASE_URL: str
  61. COMFYUI_WORKFLOW: str
  62. COMFYUI_WORKFLOW_NODES: list[dict]
  63. class ConfigForm(BaseModel):
  64. enabled: bool
  65. engine: str
  66. openai: OpenAIConfigForm
  67. automatic1111: Automatic1111ConfigForm
  68. comfyui: ComfyUIConfigForm
  69. @router.post("/config/update")
  70. async def update_config(
  71. request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
  72. ):
  73. request.app.state.config.ENGINE = form_data.engine
  74. request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
  75. request.app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
  76. request.app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
  77. request.app.state.config.AUTOMATIC1111_BASE_URL = (
  78. form_data.automatic1111.AUTOMATIC1111_BASE_URL
  79. )
  80. request.app.state.config.AUTOMATIC1111_API_AUTH = (
  81. form_data.automatic1111.AUTOMATIC1111_API_AUTH
  82. )
  83. request.app.state.config.AUTOMATIC1111_CFG_SCALE = (
  84. float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
  85. if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
  86. else None
  87. )
  88. request.app.state.config.AUTOMATIC1111_SAMPLER = (
  89. form_data.automatic1111.AUTOMATIC1111_SAMPLER
  90. if form_data.automatic1111.AUTOMATIC1111_SAMPLER
  91. else None
  92. )
  93. request.app.state.config.AUTOMATIC1111_SCHEDULER = (
  94. form_data.automatic1111.AUTOMATIC1111_SCHEDULER
  95. if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
  96. else None
  97. )
  98. request.app.state.config.COMFYUI_BASE_URL = (
  99. form_data.comfyui.COMFYUI_BASE_URL.strip("/")
  100. )
  101. request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
  102. request.app.state.config.COMFYUI_WORKFLOW_NODES = (
  103. form_data.comfyui.COMFYUI_WORKFLOW_NODES
  104. )
  105. return {
  106. "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
  107. "engine": request.app.state.config.ENGINE,
  108. "openai": {
  109. "OPENAI_API_BASE_URL": request.app.state.config.OPENAI_API_BASE_URL,
  110. "OPENAI_API_KEY": request.app.state.config.OPENAI_API_KEY,
  111. },
  112. "automatic1111": {
  113. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  114. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  115. "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
  116. "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
  117. "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
  118. },
  119. "comfyui": {
  120. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  121. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  122. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  123. },
  124. }
  125. def get_automatic1111_api_auth(request: Request):
  126. if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
  127. return ""
  128. else:
  129. auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
  130. "utf-8"
  131. )
  132. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  133. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  134. return f"Basic {auth1111_base64_encoded_string}"
  135. @router.get("/config/url/verify")
  136. async def verify_url(request: Request, user=Depends(get_admin_user)):
  137. if request.app.state.config.ENGINE == "automatic1111":
  138. try:
  139. r = requests.get(
  140. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  141. headers={"authorization": get_automatic1111_api_auth(request)},
  142. )
  143. r.raise_for_status()
  144. return True
  145. except Exception:
  146. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  147. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  148. elif request.app.state.config.ENGINE == "comfyui":
  149. try:
  150. r = requests.get(
  151. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
  152. )
  153. r.raise_for_status()
  154. return True
  155. except Exception:
  156. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  157. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  158. else:
  159. return True
  160. def set_image_model(request: Request, model: str):
  161. log.info(f"Setting image model to {model}")
  162. request.app.state.config.MODEL = model
  163. if request.app.state.config.ENGINE in ["", "automatic1111"]:
  164. api_auth = get_automatic1111_api_auth()
  165. r = requests.get(
  166. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  167. headers={"authorization": api_auth},
  168. )
  169. options = r.json()
  170. if model != options["sd_model_checkpoint"]:
  171. options["sd_model_checkpoint"] = model
  172. r = requests.post(
  173. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  174. json=options,
  175. headers={"authorization": api_auth},
  176. )
  177. return request.app.state.config.MODEL
  178. def get_image_model():
  179. if request.app.state.config.ENGINE == "openai":
  180. return (
  181. request.app.state.config.MODEL
  182. if request.app.state.config.MODEL
  183. else "dall-e-2"
  184. )
  185. elif request.app.state.config.ENGINE == "comfyui":
  186. return request.app.state.config.MODEL if request.app.state.config.MODEL else ""
  187. elif (
  188. request.app.state.config.ENGINE == "automatic1111"
  189. or request.app.state.config.ENGINE == ""
  190. ):
  191. try:
  192. r = requests.get(
  193. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  194. headers={"authorization": get_automatic1111_api_auth()},
  195. )
  196. options = r.json()
  197. return options["sd_model_checkpoint"]
  198. except Exception as e:
  199. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  200. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  201. class ImageConfigForm(BaseModel):
  202. MODEL: str
  203. IMAGE_SIZE: str
  204. IMAGE_STEPS: int
  205. @router.get("/image/config")
  206. async def get_image_config(request: Request, user=Depends(get_admin_user)):
  207. return {
  208. "MODEL": request.app.state.config.MODEL,
  209. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  210. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  211. }
  212. @router.post("/image/config/update")
  213. async def update_image_config(
  214. request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
  215. ):
  216. set_image_model(request, form_data.MODEL)
  217. pattern = r"^\d+x\d+$"
  218. if re.match(pattern, form_data.IMAGE_SIZE):
  219. request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
  220. else:
  221. raise HTTPException(
  222. status_code=400,
  223. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  224. )
  225. if form_data.IMAGE_STEPS >= 0:
  226. request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
  227. else:
  228. raise HTTPException(
  229. status_code=400,
  230. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  231. )
  232. return {
  233. "MODEL": request.app.state.config.MODEL,
  234. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  235. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  236. }
  237. @router.get("/models")
  238. def get_models(request: Request, user=Depends(get_verified_user)):
  239. try:
  240. if request.app.state.config.ENGINE == "openai":
  241. return [
  242. {"id": "dall-e-2", "name": "DALL·E 2"},
  243. {"id": "dall-e-3", "name": "DALL·E 3"},
  244. ]
  245. elif request.app.state.config.ENGINE == "comfyui":
  246. # TODO - get models from comfyui
  247. r = requests.get(
  248. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
  249. )
  250. info = r.json()
  251. workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
  252. model_node_id = None
  253. for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
  254. if node["type"] == "model":
  255. if node["node_ids"]:
  256. model_node_id = node["node_ids"][0]
  257. break
  258. if model_node_id:
  259. model_list_key = None
  260. print(workflow[model_node_id]["class_type"])
  261. for key in info[workflow[model_node_id]["class_type"]]["input"][
  262. "required"
  263. ]:
  264. if "_name" in key:
  265. model_list_key = key
  266. break
  267. if model_list_key:
  268. return list(
  269. map(
  270. lambda model: {"id": model, "name": model},
  271. info[workflow[model_node_id]["class_type"]]["input"][
  272. "required"
  273. ][model_list_key][0],
  274. )
  275. )
  276. else:
  277. return list(
  278. map(
  279. lambda model: {"id": model, "name": model},
  280. info["CheckpointLoaderSimple"]["input"]["required"][
  281. "ckpt_name"
  282. ][0],
  283. )
  284. )
  285. elif (
  286. request.app.state.config.ENGINE == "automatic1111"
  287. or request.app.state.config.ENGINE == ""
  288. ):
  289. r = requests.get(
  290. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  291. headers={"authorization": get_automatic1111_api_auth()},
  292. )
  293. models = r.json()
  294. return list(
  295. map(
  296. lambda model: {"id": model["title"], "name": model["model_name"]},
  297. models,
  298. )
  299. )
  300. except Exception as e:
  301. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  302. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  303. class GenerateImageForm(BaseModel):
  304. model: Optional[str] = None
  305. prompt: str
  306. size: Optional[str] = None
  307. n: int = 1
  308. negative_prompt: Optional[str] = None
  309. def save_b64_image(b64_str):
  310. try:
  311. image_id = str(uuid.uuid4())
  312. if "," in b64_str:
  313. header, encoded = b64_str.split(",", 1)
  314. mime_type = header.split(";")[0]
  315. img_data = base64.b64decode(encoded)
  316. image_format = mimetypes.guess_extension(mime_type)
  317. image_filename = f"{image_id}{image_format}"
  318. file_path = IMAGE_CACHE_DIR / f"{image_filename}"
  319. with open(file_path, "wb") as f:
  320. f.write(img_data)
  321. return image_filename
  322. else:
  323. image_filename = f"{image_id}.png"
  324. file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
  325. img_data = base64.b64decode(b64_str)
  326. # Write the image data to a file
  327. with open(file_path, "wb") as f:
  328. f.write(img_data)
  329. return image_filename
  330. except Exception as e:
  331. log.exception(f"Error saving image: {e}")
  332. return None
  333. def save_url_image(url):
  334. image_id = str(uuid.uuid4())
  335. try:
  336. r = requests.get(url)
  337. r.raise_for_status()
  338. if r.headers["content-type"].split("/")[0] == "image":
  339. mime_type = r.headers["content-type"]
  340. image_format = mimetypes.guess_extension(mime_type)
  341. if not image_format:
  342. raise ValueError("Could not determine image type from MIME type")
  343. image_filename = f"{image_id}{image_format}"
  344. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
  345. with open(file_path, "wb") as image_file:
  346. for chunk in r.iter_content(chunk_size=8192):
  347. image_file.write(chunk)
  348. return image_filename
  349. else:
  350. log.error("Url does not point to an image.")
  351. return None
  352. except Exception as e:
  353. log.exception(f"Error saving image: {e}")
  354. return None
  355. @router.post("/generations")
  356. async def image_generations(
  357. request: Request,
  358. form_data: GenerateImageForm,
  359. user=Depends(get_verified_user),
  360. ):
  361. width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
  362. r = None
  363. try:
  364. if request.app.state.config.ENGINE == "openai":
  365. headers = {}
  366. headers["Authorization"] = (
  367. f"Bearer {request.app.state.config.OPENAI_API_KEY}"
  368. )
  369. headers["Content-Type"] = "application/json"
  370. if ENABLE_FORWARD_USER_INFO_HEADERS:
  371. headers["X-OpenWebUI-User-Name"] = user.name
  372. headers["X-OpenWebUI-User-Id"] = user.id
  373. headers["X-OpenWebUI-User-Email"] = user.email
  374. headers["X-OpenWebUI-User-Role"] = user.role
  375. data = {
  376. "model": (
  377. request.app.state.config.MODEL
  378. if request.app.state.config.MODEL != ""
  379. else "dall-e-2"
  380. ),
  381. "prompt": form_data.prompt,
  382. "n": form_data.n,
  383. "size": (
  384. form_data.size
  385. if form_data.size
  386. else request.app.state.config.IMAGE_SIZE
  387. ),
  388. "response_format": "b64_json",
  389. }
  390. # Use asyncio.to_thread for the requests.post call
  391. r = await asyncio.to_thread(
  392. requests.post,
  393. url=f"{request.app.state.config.OPENAI_API_BASE_URL}/images/generations",
  394. json=data,
  395. headers=headers,
  396. )
  397. r.raise_for_status()
  398. res = r.json()
  399. images = []
  400. for image in res["data"]:
  401. image_filename = save_b64_image(image["b64_json"])
  402. images.append({"url": f"/cache/image/generations/{image_filename}"})
  403. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  404. with open(file_body_path, "w") as f:
  405. json.dump(data, f)
  406. return images
  407. elif request.app.state.config.ENGINE == "comfyui":
  408. data = {
  409. "prompt": form_data.prompt,
  410. "width": width,
  411. "height": height,
  412. "n": form_data.n,
  413. }
  414. if request.app.state.config.IMAGE_STEPS is not None:
  415. data["steps"] = request.app.state.config.IMAGE_STEPS
  416. if form_data.negative_prompt is not None:
  417. data["negative_prompt"] = form_data.negative_prompt
  418. form_data = ComfyUIGenerateImageForm(
  419. **{
  420. "workflow": ComfyUIWorkflow(
  421. **{
  422. "workflow": request.app.state.config.COMFYUI_WORKFLOW,
  423. "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  424. }
  425. ),
  426. **data,
  427. }
  428. )
  429. res = await comfyui_generate_image(
  430. request.app.state.config.MODEL,
  431. form_data,
  432. user.id,
  433. request.app.state.config.COMFYUI_BASE_URL,
  434. )
  435. log.debug(f"res: {res}")
  436. images = []
  437. for image in res["data"]:
  438. image_filename = save_url_image(image["url"])
  439. images.append({"url": f"/cache/image/generations/{image_filename}"})
  440. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  441. with open(file_body_path, "w") as f:
  442. json.dump(form_data.model_dump(exclude_none=True), f)
  443. log.debug(f"images: {images}")
  444. return images
  445. elif (
  446. request.app.state.config.ENGINE == "automatic1111"
  447. or request.app.state.config.ENGINE == ""
  448. ):
  449. if form_data.model:
  450. set_image_model(form_data.model)
  451. data = {
  452. "prompt": form_data.prompt,
  453. "batch_size": form_data.n,
  454. "width": width,
  455. "height": height,
  456. }
  457. if request.app.state.config.IMAGE_STEPS is not None:
  458. data["steps"] = request.app.state.config.IMAGE_STEPS
  459. if form_data.negative_prompt is not None:
  460. data["negative_prompt"] = form_data.negative_prompt
  461. if request.app.state.config.AUTOMATIC1111_CFG_SCALE:
  462. data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE
  463. if request.app.state.config.AUTOMATIC1111_SAMPLER:
  464. data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER
  465. if request.app.state.config.AUTOMATIC1111_SCHEDULER:
  466. data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER
  467. # Use asyncio.to_thread for the requests.post call
  468. r = await asyncio.to_thread(
  469. requests.post,
  470. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  471. json=data,
  472. headers={"authorization": get_automatic1111_api_auth()},
  473. )
  474. res = r.json()
  475. log.debug(f"res: {res}")
  476. images = []
  477. for image in res["images"]:
  478. image_filename = save_b64_image(image)
  479. images.append({"url": f"/cache/image/generations/{image_filename}"})
  480. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  481. with open(file_body_path, "w") as f:
  482. json.dump({**data, "info": res["info"]}, f)
  483. return images
  484. except Exception as e:
  485. error = e
  486. if r != None:
  487. data = r.json()
  488. if "error" in data:
  489. error = data["error"]["message"]
  490. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))