ollama.py 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507
  1. # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
  2. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
  3. # least connections, or least response time for better resource utilization and performance optimization.
  4. import asyncio
  5. import json
  6. import logging
  7. import os
  8. import random
  9. import re
  10. import time
  11. from typing import Optional, Union
  12. from urllib.parse import urlparse
  13. import aiohttp
  14. from aiocache import cached
  15. import requests
  16. from fastapi import (
  17. Depends,
  18. FastAPI,
  19. File,
  20. HTTPException,
  21. Request,
  22. UploadFile,
  23. APIRouter,
  24. )
  25. from fastapi.middleware.cors import CORSMiddleware
  26. from fastapi.responses import StreamingResponse
  27. from pydantic import BaseModel, ConfigDict
  28. from starlette.background import BackgroundTask
  29. from open_webui.models.models import Models
  30. from open_webui.utils.misc import (
  31. calculate_sha256,
  32. )
  33. from open_webui.utils.payload import (
  34. apply_model_params_to_body_ollama,
  35. apply_model_params_to_body_openai,
  36. apply_model_system_prompt_to_body,
  37. )
  38. from open_webui.utils.auth import get_admin_user, get_verified_user
  39. from open_webui.utils.access_control import has_access
  40. from open_webui.config import (
  41. UPLOAD_DIR,
  42. )
  43. from open_webui.env import (
  44. ENV,
  45. SRC_LOG_LEVELS,
  46. AIOHTTP_CLIENT_TIMEOUT,
  47. AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
  48. BYPASS_MODEL_ACCESS_CONTROL,
  49. )
  50. from open_webui.constants import ERROR_MESSAGES
  51. log = logging.getLogger(__name__)
  52. log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
  53. ##########################################
  54. #
  55. # Utility functions
  56. #
  57. ##########################################
  58. async def send_get_request(url, key=None):
  59. timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
  60. try:
  61. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  62. async with session.get(
  63. url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
  64. ) as response:
  65. return await response.json()
  66. except Exception as e:
  67. # Handle connection error here
  68. log.error(f"Connection error: {e}")
  69. return None
  70. async def cleanup_response(
  71. response: Optional[aiohttp.ClientResponse],
  72. session: Optional[aiohttp.ClientSession],
  73. ):
  74. if response:
  75. response.close()
  76. if session:
  77. await session.close()
  78. async def send_post_request(
  79. url: str,
  80. payload: Union[str, bytes],
  81. stream: bool = True,
  82. key: Optional[str] = None,
  83. content_type: Optional[str] = None,
  84. ):
  85. r = None
  86. try:
  87. session = aiohttp.ClientSession(
  88. trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
  89. )
  90. r = await session.post(
  91. url,
  92. data=payload,
  93. headers={
  94. "Content-Type": "application/json",
  95. **({"Authorization": f"Bearer {key}"} if key else {}),
  96. },
  97. )
  98. r.raise_for_status()
  99. if stream:
  100. response_headers = dict(r.headers)
  101. if content_type:
  102. response_headers["Content-Type"] = content_type
  103. return StreamingResponse(
  104. r.content,
  105. status_code=r.status,
  106. headers=response_headers,
  107. background=BackgroundTask(
  108. cleanup_response, response=r, session=session
  109. ),
  110. )
  111. else:
  112. res = await r.json()
  113. await cleanup_response(r, session)
  114. return res
  115. except Exception as e:
  116. detail = None
  117. if r is not None:
  118. try:
  119. res = await r.json()
  120. if "error" in res:
  121. detail = f"Ollama: {res.get('error', 'Unknown error')}"
  122. except Exception:
  123. detail = f"Ollama: {e}"
  124. raise HTTPException(
  125. status_code=r.status if r else 500,
  126. detail=detail if detail else "Open WebUI: Server Connection Error",
  127. )
  128. def get_api_key(idx, url, configs):
  129. parsed_url = urlparse(url)
  130. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  131. return configs.get(str(idx), configs.get(base_url, {})).get(
  132. "key", None
  133. ) # Legacy support
  134. ##########################################
  135. #
  136. # API routes
  137. #
  138. ##########################################
  139. router = APIRouter()
  140. @router.head("/")
  141. @router.get("/")
  142. async def get_status():
  143. return {"status": True}
  144. class ConnectionVerificationForm(BaseModel):
  145. url: str
  146. key: Optional[str] = None
  147. @router.post("/verify")
  148. async def verify_connection(
  149. form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
  150. ):
  151. url = form_data.url
  152. key = form_data.key
  153. async with aiohttp.ClientSession(
  154. timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
  155. ) as session:
  156. try:
  157. async with session.get(
  158. f"{url}/api/version",
  159. headers={**({"Authorization": f"Bearer {key}"} if key else {})},
  160. ) as r:
  161. if r.status != 200:
  162. detail = f"HTTP Error: {r.status}"
  163. res = await r.json()
  164. if "error" in res:
  165. detail = f"External Error: {res['error']}"
  166. raise Exception(detail)
  167. data = await r.json()
  168. return data
  169. except aiohttp.ClientError as e:
  170. log.exception(f"Client error: {str(e)}")
  171. raise HTTPException(
  172. status_code=500, detail="Open WebUI: Server Connection Error"
  173. )
  174. except Exception as e:
  175. log.exception(f"Unexpected error: {e}")
  176. error_detail = f"Unexpected error: {str(e)}"
  177. raise HTTPException(status_code=500, detail=error_detail)
  178. @router.get("/config")
  179. async def get_config(request: Request, user=Depends(get_admin_user)):
  180. return {
  181. "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
  182. "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
  183. "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
  184. }
  185. class OllamaConfigForm(BaseModel):
  186. ENABLE_OLLAMA_API: Optional[bool] = None
  187. OLLAMA_BASE_URLS: list[str]
  188. OLLAMA_API_CONFIGS: dict
  189. @router.post("/config/update")
  190. async def update_config(
  191. request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user)
  192. ):
  193. request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
  194. request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
  195. request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
  196. # Remove the API configs that are not in the API URLS
  197. keys = list(map(str, range(len(request.app.state.config.OLLAMA_BASE_URLS))))
  198. request.app.state.config.OLLAMA_API_CONFIGS = {
  199. key: value
  200. for key, value in request.app.state.config.OLLAMA_API_CONFIGS.items()
  201. if key in keys
  202. }
  203. return {
  204. "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
  205. "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
  206. "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
  207. }
  208. @cached(ttl=3)
  209. async def get_all_models(request: Request):
  210. log.info("get_all_models()")
  211. if request.app.state.config.ENABLE_OLLAMA_API:
  212. request_tasks = []
  213. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
  214. if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
  215. url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
  216. ):
  217. request_tasks.append(send_get_request(f"{url}/api/tags"))
  218. else:
  219. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  220. str(idx),
  221. request.app.state.config.OLLAMA_API_CONFIGS.get(
  222. url, {}
  223. ), # Legacy support
  224. )
  225. enable = api_config.get("enable", True)
  226. key = api_config.get("key", None)
  227. if enable:
  228. request_tasks.append(send_get_request(f"{url}/api/tags", key))
  229. else:
  230. request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
  231. responses = await asyncio.gather(*request_tasks)
  232. for idx, response in enumerate(responses):
  233. if response:
  234. url = request.app.state.config.OLLAMA_BASE_URLS[idx]
  235. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  236. str(idx),
  237. request.app.state.config.OLLAMA_API_CONFIGS.get(
  238. url, {}
  239. ), # Legacy support
  240. )
  241. prefix_id = api_config.get("prefix_id", None)
  242. model_ids = api_config.get("model_ids", [])
  243. if len(model_ids) != 0 and "models" in response:
  244. response["models"] = list(
  245. filter(
  246. lambda model: model["model"] in model_ids,
  247. response["models"],
  248. )
  249. )
  250. if prefix_id:
  251. for model in response.get("models", []):
  252. model["model"] = f"{prefix_id}.{model['model']}"
  253. def merge_models_lists(model_lists):
  254. merged_models = {}
  255. for idx, model_list in enumerate(model_lists):
  256. if model_list is not None:
  257. for model in model_list:
  258. id = model["model"]
  259. if id not in merged_models:
  260. model["urls"] = [idx]
  261. merged_models[id] = model
  262. else:
  263. merged_models[id]["urls"].append(idx)
  264. return list(merged_models.values())
  265. models = {
  266. "models": merge_models_lists(
  267. map(
  268. lambda response: response.get("models", []) if response else None,
  269. responses,
  270. )
  271. )
  272. }
  273. else:
  274. models = {"models": []}
  275. request.app.state.OLLAMA_MODELS = {
  276. model["model"]: model for model in models["models"]
  277. }
  278. return models
  279. async def get_filtered_models(models, user):
  280. # Filter models based on user access control
  281. filtered_models = []
  282. for model in models.get("models", []):
  283. model_info = Models.get_model_by_id(model["model"])
  284. if model_info:
  285. if user.id == model_info.user_id or has_access(
  286. user.id, type="read", access_control=model_info.access_control
  287. ):
  288. filtered_models.append(model)
  289. return filtered_models
  290. @router.get("/api/tags")
  291. @router.get("/api/tags/{url_idx}")
  292. async def get_ollama_tags(
  293. request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
  294. ):
  295. models = []
  296. if url_idx is None:
  297. models = await get_all_models(request)
  298. else:
  299. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  300. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  301. r = None
  302. try:
  303. r = requests.request(
  304. method="GET",
  305. url=f"{url}/api/tags",
  306. headers={**({"Authorization": f"Bearer {key}"} if key else {})},
  307. )
  308. r.raise_for_status()
  309. models = r.json()
  310. except Exception as e:
  311. log.exception(e)
  312. detail = None
  313. if r is not None:
  314. try:
  315. res = r.json()
  316. if "error" in res:
  317. detail = f"Ollama: {res['error']}"
  318. except Exception:
  319. detail = f"Ollama: {e}"
  320. raise HTTPException(
  321. status_code=r.status_code if r else 500,
  322. detail=detail if detail else "Open WebUI: Server Connection Error",
  323. )
  324. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  325. models["models"] = await get_filtered_models(models, user)
  326. return models
  327. @router.get("/api/version")
  328. @router.get("/api/version/{url_idx}")
  329. async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
  330. if request.app.state.config.ENABLE_OLLAMA_API:
  331. if url_idx is None:
  332. # returns lowest version
  333. request_tasks = [
  334. send_get_request(
  335. f"{url}/api/version",
  336. request.app.state.config.OLLAMA_API_CONFIGS.get(
  337. str(idx),
  338. request.app.state.config.OLLAMA_API_CONFIGS.get(
  339. url, {}
  340. ), # Legacy support
  341. ).get("key", None),
  342. )
  343. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
  344. ]
  345. responses = await asyncio.gather(*request_tasks)
  346. responses = list(filter(lambda x: x is not None, responses))
  347. if len(responses) > 0:
  348. lowest_version = min(
  349. responses,
  350. key=lambda x: tuple(
  351. map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
  352. ),
  353. )
  354. return {"version": lowest_version["version"]}
  355. else:
  356. raise HTTPException(
  357. status_code=500,
  358. detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
  359. )
  360. else:
  361. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  362. r = None
  363. try:
  364. r = requests.request(method="GET", url=f"{url}/api/version")
  365. r.raise_for_status()
  366. return r.json()
  367. except Exception as e:
  368. log.exception(e)
  369. detail = None
  370. if r is not None:
  371. try:
  372. res = r.json()
  373. if "error" in res:
  374. detail = f"Ollama: {res['error']}"
  375. except Exception:
  376. detail = f"Ollama: {e}"
  377. raise HTTPException(
  378. status_code=r.status_code if r else 500,
  379. detail=detail if detail else "Open WebUI: Server Connection Error",
  380. )
  381. else:
  382. return {"version": False}
  383. @router.get("/api/ps")
  384. async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
  385. """
  386. List models that are currently loaded into Ollama memory, and which node they are loaded on.
  387. """
  388. if request.app.state.config.ENABLE_OLLAMA_API:
  389. request_tasks = [
  390. send_get_request(
  391. f"{url}/api/ps",
  392. request.app.state.config.OLLAMA_API_CONFIGS.get(
  393. str(idx),
  394. request.app.state.config.OLLAMA_API_CONFIGS.get(
  395. url, {}
  396. ), # Legacy support
  397. ).get("key", None),
  398. )
  399. for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
  400. ]
  401. responses = await asyncio.gather(*request_tasks)
  402. return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
  403. else:
  404. return {}
  405. class ModelNameForm(BaseModel):
  406. name: str
  407. @router.post("/api/pull")
  408. @router.post("/api/pull/{url_idx}")
  409. async def pull_model(
  410. request: Request,
  411. form_data: ModelNameForm,
  412. url_idx: int = 0,
  413. user=Depends(get_admin_user),
  414. ):
  415. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  416. log.info(f"url: {url}")
  417. # Admin should be able to pull models from any source
  418. payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
  419. return await send_post_request(
  420. url=f"{url}/api/pull",
  421. payload=json.dumps(payload),
  422. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  423. )
  424. class PushModelForm(BaseModel):
  425. name: str
  426. insecure: Optional[bool] = None
  427. stream: Optional[bool] = None
  428. @router.delete("/api/push")
  429. @router.delete("/api/push/{url_idx}")
  430. async def push_model(
  431. request: Request,
  432. form_data: PushModelForm,
  433. url_idx: Optional[int] = None,
  434. user=Depends(get_admin_user),
  435. ):
  436. if url_idx is None:
  437. await get_all_models(request)
  438. models = request.app.state.OLLAMA_MODELS
  439. if form_data.name in models:
  440. url_idx = models[form_data.name]["urls"][0]
  441. else:
  442. raise HTTPException(
  443. status_code=400,
  444. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  445. )
  446. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  447. log.debug(f"url: {url}")
  448. return await send_post_request(
  449. url=f"{url}/api/push",
  450. payload=form_data.model_dump_json(exclude_none=True).encode(),
  451. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  452. )
  453. class CreateModelForm(BaseModel):
  454. model: Optional[str] = None
  455. stream: Optional[bool] = None
  456. path: Optional[str] = None
  457. model_config = ConfigDict(extra="allow")
  458. @router.post("/api/create")
  459. @router.post("/api/create/{url_idx}")
  460. async def create_model(
  461. request: Request,
  462. form_data: CreateModelForm,
  463. url_idx: int = 0,
  464. user=Depends(get_admin_user),
  465. ):
  466. log.debug(f"form_data: {form_data}")
  467. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  468. return await send_post_request(
  469. url=f"{url}/api/create",
  470. payload=form_data.model_dump_json(exclude_none=True).encode(),
  471. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  472. )
  473. class CopyModelForm(BaseModel):
  474. source: str
  475. destination: str
  476. @router.post("/api/copy")
  477. @router.post("/api/copy/{url_idx}")
  478. async def copy_model(
  479. request: Request,
  480. form_data: CopyModelForm,
  481. url_idx: Optional[int] = None,
  482. user=Depends(get_admin_user),
  483. ):
  484. if url_idx is None:
  485. await get_all_models(request)
  486. models = request.app.state.OLLAMA_MODELS
  487. if form_data.source in models:
  488. url_idx = models[form_data.source]["urls"][0]
  489. else:
  490. raise HTTPException(
  491. status_code=400,
  492. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
  493. )
  494. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  495. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  496. try:
  497. r = requests.request(
  498. method="POST",
  499. url=f"{url}/api/copy",
  500. headers={
  501. "Content-Type": "application/json",
  502. **({"Authorization": f"Bearer {key}"} if key else {}),
  503. },
  504. data=form_data.model_dump_json(exclude_none=True).encode(),
  505. )
  506. r.raise_for_status()
  507. log.debug(f"r.text: {r.text}")
  508. return True
  509. except Exception as e:
  510. log.exception(e)
  511. detail = None
  512. if r is not None:
  513. try:
  514. res = r.json()
  515. if "error" in res:
  516. detail = f"Ollama: {res['error']}"
  517. except Exception:
  518. detail = f"Ollama: {e}"
  519. raise HTTPException(
  520. status_code=r.status_code if r else 500,
  521. detail=detail if detail else "Open WebUI: Server Connection Error",
  522. )
  523. @router.delete("/api/delete")
  524. @router.delete("/api/delete/{url_idx}")
  525. async def delete_model(
  526. request: Request,
  527. form_data: ModelNameForm,
  528. url_idx: Optional[int] = None,
  529. user=Depends(get_admin_user),
  530. ):
  531. if url_idx is None:
  532. await get_all_models(request)
  533. models = request.app.state.OLLAMA_MODELS
  534. if form_data.name in models:
  535. url_idx = models[form_data.name]["urls"][0]
  536. else:
  537. raise HTTPException(
  538. status_code=400,
  539. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  540. )
  541. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  542. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  543. try:
  544. r = requests.request(
  545. method="DELETE",
  546. url=f"{url}/api/delete",
  547. data=form_data.model_dump_json(exclude_none=True).encode(),
  548. headers={
  549. "Content-Type": "application/json",
  550. **({"Authorization": f"Bearer {key}"} if key else {}),
  551. },
  552. )
  553. r.raise_for_status()
  554. log.debug(f"r.text: {r.text}")
  555. return True
  556. except Exception as e:
  557. log.exception(e)
  558. detail = None
  559. if r is not None:
  560. try:
  561. res = r.json()
  562. if "error" in res:
  563. detail = f"Ollama: {res['error']}"
  564. except Exception:
  565. detail = f"Ollama: {e}"
  566. raise HTTPException(
  567. status_code=r.status_code if r else 500,
  568. detail=detail if detail else "Open WebUI: Server Connection Error",
  569. )
  570. @router.post("/api/show")
  571. async def show_model_info(
  572. request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
  573. ):
  574. await get_all_models(request)
  575. models = request.app.state.OLLAMA_MODELS
  576. if form_data.name not in models:
  577. raise HTTPException(
  578. status_code=400,
  579. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
  580. )
  581. url_idx = random.choice(models[form_data.name]["urls"])
  582. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  583. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  584. try:
  585. r = requests.request(
  586. method="POST",
  587. url=f"{url}/api/show",
  588. headers={
  589. "Content-Type": "application/json",
  590. **({"Authorization": f"Bearer {key}"} if key else {}),
  591. },
  592. data=form_data.model_dump_json(exclude_none=True).encode(),
  593. )
  594. r.raise_for_status()
  595. return r.json()
  596. except Exception as e:
  597. log.exception(e)
  598. detail = None
  599. if r is not None:
  600. try:
  601. res = r.json()
  602. if "error" in res:
  603. detail = f"Ollama: {res['error']}"
  604. except Exception:
  605. detail = f"Ollama: {e}"
  606. raise HTTPException(
  607. status_code=r.status_code if r else 500,
  608. detail=detail if detail else "Open WebUI: Server Connection Error",
  609. )
  610. class GenerateEmbedForm(BaseModel):
  611. model: str
  612. input: list[str] | str
  613. truncate: Optional[bool] = None
  614. options: Optional[dict] = None
  615. keep_alive: Optional[Union[int, str]] = None
  616. @router.post("/api/embed")
  617. @router.post("/api/embed/{url_idx}")
  618. async def embed(
  619. request: Request,
  620. form_data: GenerateEmbedForm,
  621. url_idx: Optional[int] = None,
  622. user=Depends(get_verified_user),
  623. ):
  624. log.info(f"generate_ollama_batch_embeddings {form_data}")
  625. if url_idx is None:
  626. await get_all_models(request)
  627. models = request.app.state.OLLAMA_MODELS
  628. model = form_data.model
  629. if ":" not in model:
  630. model = f"{model}:latest"
  631. if model in models:
  632. url_idx = random.choice(models[model]["urls"])
  633. else:
  634. raise HTTPException(
  635. status_code=400,
  636. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  637. )
  638. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  639. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  640. try:
  641. r = requests.request(
  642. method="POST",
  643. url=f"{url}/api/embed",
  644. headers={
  645. "Content-Type": "application/json",
  646. **({"Authorization": f"Bearer {key}"} if key else {}),
  647. },
  648. data=form_data.model_dump_json(exclude_none=True).encode(),
  649. )
  650. r.raise_for_status()
  651. data = r.json()
  652. return data
  653. except Exception as e:
  654. log.exception(e)
  655. detail = None
  656. if r is not None:
  657. try:
  658. res = r.json()
  659. if "error" in res:
  660. detail = f"Ollama: {res['error']}"
  661. except Exception:
  662. detail = f"Ollama: {e}"
  663. raise HTTPException(
  664. status_code=r.status_code if r else 500,
  665. detail=detail if detail else "Open WebUI: Server Connection Error",
  666. )
  667. class GenerateEmbeddingsForm(BaseModel):
  668. model: str
  669. prompt: str
  670. options: Optional[dict] = None
  671. keep_alive: Optional[Union[int, str]] = None
  672. @router.post("/api/embeddings")
  673. @router.post("/api/embeddings/{url_idx}")
  674. async def embeddings(
  675. request: Request,
  676. form_data: GenerateEmbeddingsForm,
  677. url_idx: Optional[int] = None,
  678. user=Depends(get_verified_user),
  679. ):
  680. log.info(f"generate_ollama_embeddings {form_data}")
  681. if url_idx is None:
  682. await get_all_models(request)
  683. models = request.app.state.OLLAMA_MODELS
  684. model = form_data.model
  685. if ":" not in model:
  686. model = f"{model}:latest"
  687. if model in models:
  688. url_idx = random.choice(models[model]["urls"])
  689. else:
  690. raise HTTPException(
  691. status_code=400,
  692. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  693. )
  694. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  695. key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
  696. try:
  697. r = requests.request(
  698. method="POST",
  699. url=f"{url}/api/embeddings",
  700. headers={
  701. "Content-Type": "application/json",
  702. **({"Authorization": f"Bearer {key}"} if key else {}),
  703. },
  704. data=form_data.model_dump_json(exclude_none=True).encode(),
  705. )
  706. r.raise_for_status()
  707. data = r.json()
  708. return data
  709. except Exception as e:
  710. log.exception(e)
  711. detail = None
  712. if r is not None:
  713. try:
  714. res = r.json()
  715. if "error" in res:
  716. detail = f"Ollama: {res['error']}"
  717. except Exception:
  718. detail = f"Ollama: {e}"
  719. raise HTTPException(
  720. status_code=r.status_code if r else 500,
  721. detail=detail if detail else "Open WebUI: Server Connection Error",
  722. )
  723. class GenerateCompletionForm(BaseModel):
  724. model: str
  725. prompt: str
  726. suffix: Optional[str] = None
  727. images: Optional[list[str]] = None
  728. format: Optional[str] = None
  729. options: Optional[dict] = None
  730. system: Optional[str] = None
  731. template: Optional[str] = None
  732. context: Optional[list[int]] = None
  733. stream: Optional[bool] = True
  734. raw: Optional[bool] = None
  735. keep_alive: Optional[Union[int, str]] = None
  736. @router.post("/api/generate")
  737. @router.post("/api/generate/{url_idx}")
  738. async def generate_completion(
  739. request: Request,
  740. form_data: GenerateCompletionForm,
  741. url_idx: Optional[int] = None,
  742. user=Depends(get_verified_user),
  743. ):
  744. if url_idx is None:
  745. await get_all_models(request)
  746. models = request.app.state.OLLAMA_MODELS
  747. model = form_data.model
  748. if ":" not in model:
  749. model = f"{model}:latest"
  750. if model in models:
  751. url_idx = random.choice(models[model]["urls"])
  752. else:
  753. raise HTTPException(
  754. status_code=400,
  755. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
  756. )
  757. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  758. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  759. str(url_idx),
  760. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  761. )
  762. prefix_id = api_config.get("prefix_id", None)
  763. if prefix_id:
  764. form_data.model = form_data.model.replace(f"{prefix_id}.", "")
  765. return await send_post_request(
  766. url=f"{url}/api/generate",
  767. payload=form_data.model_dump_json(exclude_none=True).encode(),
  768. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  769. )
  770. class ChatMessage(BaseModel):
  771. role: str
  772. content: str
  773. tool_calls: Optional[list[dict]] = None
  774. images: Optional[list[str]] = None
  775. class GenerateChatCompletionForm(BaseModel):
  776. model: str
  777. messages: list[ChatMessage]
  778. format: Optional[Union[dict, str]] = None
  779. options: Optional[dict] = None
  780. template: Optional[str] = None
  781. stream: Optional[bool] = True
  782. keep_alive: Optional[Union[int, str]] = None
  783. tools: Optional[list[dict]] = None
  784. async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
  785. if url_idx is None:
  786. models = request.app.state.OLLAMA_MODELS
  787. if model not in models:
  788. raise HTTPException(
  789. status_code=400,
  790. detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
  791. )
  792. url_idx = random.choice(models[model].get("urls", []))
  793. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  794. return url, url_idx
  795. @router.post("/api/chat")
  796. @router.post("/api/chat/{url_idx}")
  797. async def generate_chat_completion(
  798. request: Request,
  799. form_data: dict,
  800. url_idx: Optional[int] = None,
  801. user=Depends(get_verified_user),
  802. bypass_filter: Optional[bool] = False,
  803. ):
  804. if BYPASS_MODEL_ACCESS_CONTROL:
  805. bypass_filter = True
  806. metadata = form_data.pop("metadata", None)
  807. try:
  808. form_data = GenerateChatCompletionForm(**form_data)
  809. except Exception as e:
  810. log.exception(e)
  811. raise HTTPException(
  812. status_code=400,
  813. detail=str(e),
  814. )
  815. payload = {**form_data.model_dump(exclude_none=True)}
  816. if "metadata" in payload:
  817. del payload["metadata"]
  818. model_id = payload["model"]
  819. model_info = Models.get_model_by_id(model_id)
  820. if model_info:
  821. if model_info.base_model_id:
  822. payload["model"] = model_info.base_model_id
  823. params = model_info.params.model_dump()
  824. if params:
  825. if payload.get("options") is None:
  826. payload["options"] = {}
  827. payload["options"] = apply_model_params_to_body_ollama(
  828. params, payload["options"]
  829. )
  830. payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
  831. # Check if user has access to the model
  832. if not bypass_filter and user.role == "user":
  833. if not (
  834. user.id == model_info.user_id
  835. or has_access(
  836. user.id, type="read", access_control=model_info.access_control
  837. )
  838. ):
  839. raise HTTPException(
  840. status_code=403,
  841. detail="Model not found",
  842. )
  843. elif not bypass_filter:
  844. if user.role != "admin":
  845. raise HTTPException(
  846. status_code=403,
  847. detail="Model not found",
  848. )
  849. if ":" not in payload["model"]:
  850. payload["model"] = f"{payload['model']}:latest"
  851. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  852. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  853. str(url_idx),
  854. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  855. )
  856. prefix_id = api_config.get("prefix_id", None)
  857. if prefix_id:
  858. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  859. return await send_post_request(
  860. url=f"{url}/api/chat",
  861. payload=json.dumps(payload),
  862. stream=form_data.stream,
  863. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  864. content_type="application/x-ndjson",
  865. )
  866. # TODO: we should update this part once Ollama supports other types
  867. class OpenAIChatMessageContent(BaseModel):
  868. type: str
  869. model_config = ConfigDict(extra="allow")
  870. class OpenAIChatMessage(BaseModel):
  871. role: str
  872. content: Union[str, list[OpenAIChatMessageContent]]
  873. model_config = ConfigDict(extra="allow")
  874. class OpenAIChatCompletionForm(BaseModel):
  875. model: str
  876. messages: list[OpenAIChatMessage]
  877. model_config = ConfigDict(extra="allow")
  878. class OpenAICompletionForm(BaseModel):
  879. model: str
  880. prompt: str
  881. model_config = ConfigDict(extra="allow")
  882. @router.post("/v1/completions")
  883. @router.post("/v1/completions/{url_idx}")
  884. async def generate_openai_completion(
  885. request: Request,
  886. form_data: dict,
  887. url_idx: Optional[int] = None,
  888. user=Depends(get_verified_user),
  889. ):
  890. try:
  891. form_data = OpenAICompletionForm(**form_data)
  892. except Exception as e:
  893. log.exception(e)
  894. raise HTTPException(
  895. status_code=400,
  896. detail=str(e),
  897. )
  898. payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
  899. if "metadata" in payload:
  900. del payload["metadata"]
  901. model_id = form_data.model
  902. if ":" not in model_id:
  903. model_id = f"{model_id}:latest"
  904. model_info = Models.get_model_by_id(model_id)
  905. if model_info:
  906. if model_info.base_model_id:
  907. payload["model"] = model_info.base_model_id
  908. params = model_info.params.model_dump()
  909. if params:
  910. payload = apply_model_params_to_body_openai(params, payload)
  911. # Check if user has access to the model
  912. if user.role == "user":
  913. if not (
  914. user.id == model_info.user_id
  915. or has_access(
  916. user.id, type="read", access_control=model_info.access_control
  917. )
  918. ):
  919. raise HTTPException(
  920. status_code=403,
  921. detail="Model not found",
  922. )
  923. else:
  924. if user.role != "admin":
  925. raise HTTPException(
  926. status_code=403,
  927. detail="Model not found",
  928. )
  929. if ":" not in payload["model"]:
  930. payload["model"] = f"{payload['model']}:latest"
  931. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  932. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  933. str(url_idx),
  934. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  935. )
  936. prefix_id = api_config.get("prefix_id", None)
  937. if prefix_id:
  938. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  939. return await send_post_request(
  940. url=f"{url}/v1/completions",
  941. payload=json.dumps(payload),
  942. stream=payload.get("stream", False),
  943. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  944. )
  945. @router.post("/v1/chat/completions")
  946. @router.post("/v1/chat/completions/{url_idx}")
  947. async def generate_openai_chat_completion(
  948. request: Request,
  949. form_data: dict,
  950. url_idx: Optional[int] = None,
  951. user=Depends(get_verified_user),
  952. ):
  953. metadata = form_data.pop("metadata", None)
  954. try:
  955. completion_form = OpenAIChatCompletionForm(**form_data)
  956. except Exception as e:
  957. log.exception(e)
  958. raise HTTPException(
  959. status_code=400,
  960. detail=str(e),
  961. )
  962. payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
  963. if "metadata" in payload:
  964. del payload["metadata"]
  965. model_id = completion_form.model
  966. if ":" not in model_id:
  967. model_id = f"{model_id}:latest"
  968. model_info = Models.get_model_by_id(model_id)
  969. if model_info:
  970. if model_info.base_model_id:
  971. payload["model"] = model_info.base_model_id
  972. params = model_info.params.model_dump()
  973. if params:
  974. payload = apply_model_params_to_body_openai(params, payload)
  975. payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
  976. # Check if user has access to the model
  977. if user.role == "user":
  978. if not (
  979. user.id == model_info.user_id
  980. or has_access(
  981. user.id, type="read", access_control=model_info.access_control
  982. )
  983. ):
  984. raise HTTPException(
  985. status_code=403,
  986. detail="Model not found",
  987. )
  988. else:
  989. if user.role != "admin":
  990. raise HTTPException(
  991. status_code=403,
  992. detail="Model not found",
  993. )
  994. if ":" not in payload["model"]:
  995. payload["model"] = f"{payload['model']}:latest"
  996. url, url_idx = await get_ollama_url(request, payload["model"], url_idx)
  997. api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
  998. str(url_idx),
  999. request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
  1000. )
  1001. prefix_id = api_config.get("prefix_id", None)
  1002. if prefix_id:
  1003. payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
  1004. return await send_post_request(
  1005. url=f"{url}/v1/chat/completions",
  1006. payload=json.dumps(payload),
  1007. stream=payload.get("stream", False),
  1008. key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
  1009. )
  1010. @router.get("/v1/models")
  1011. @router.get("/v1/models/{url_idx}")
  1012. async def get_openai_models(
  1013. request: Request,
  1014. url_idx: Optional[int] = None,
  1015. user=Depends(get_verified_user),
  1016. ):
  1017. models = []
  1018. if url_idx is None:
  1019. model_list = await get_all_models(request)
  1020. models = [
  1021. {
  1022. "id": model["model"],
  1023. "object": "model",
  1024. "created": int(time.time()),
  1025. "owned_by": "openai",
  1026. }
  1027. for model in model_list["models"]
  1028. ]
  1029. else:
  1030. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1031. try:
  1032. r = requests.request(method="GET", url=f"{url}/api/tags")
  1033. r.raise_for_status()
  1034. model_list = r.json()
  1035. models = [
  1036. {
  1037. "id": model["model"],
  1038. "object": "model",
  1039. "created": int(time.time()),
  1040. "owned_by": "openai",
  1041. }
  1042. for model in models["models"]
  1043. ]
  1044. except Exception as e:
  1045. log.exception(e)
  1046. error_detail = "Open WebUI: Server Connection Error"
  1047. if r is not None:
  1048. try:
  1049. res = r.json()
  1050. if "error" in res:
  1051. error_detail = f"Ollama: {res['error']}"
  1052. except Exception:
  1053. error_detail = f"Ollama: {e}"
  1054. raise HTTPException(
  1055. status_code=r.status_code if r else 500,
  1056. detail=error_detail,
  1057. )
  1058. if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
  1059. # Filter models based on user access control
  1060. filtered_models = []
  1061. for model in models:
  1062. model_info = Models.get_model_by_id(model["id"])
  1063. if model_info:
  1064. if user.id == model_info.user_id or has_access(
  1065. user.id, type="read", access_control=model_info.access_control
  1066. ):
  1067. filtered_models.append(model)
  1068. models = filtered_models
  1069. return {
  1070. "data": models,
  1071. "object": "list",
  1072. }
  1073. class UrlForm(BaseModel):
  1074. url: str
  1075. class UploadBlobForm(BaseModel):
  1076. filename: str
  1077. def parse_huggingface_url(hf_url):
  1078. try:
  1079. # Parse the URL
  1080. parsed_url = urlparse(hf_url)
  1081. # Get the path and split it into components
  1082. path_components = parsed_url.path.split("/")
  1083. # Extract the desired output
  1084. model_file = path_components[-1]
  1085. return model_file
  1086. except ValueError:
  1087. return None
  1088. async def download_file_stream(
  1089. ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
  1090. ):
  1091. done = False
  1092. if os.path.exists(file_path):
  1093. current_size = os.path.getsize(file_path)
  1094. else:
  1095. current_size = 0
  1096. headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
  1097. timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
  1098. async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
  1099. async with session.get(file_url, headers=headers) as response:
  1100. total_size = int(response.headers.get("content-length", 0)) + current_size
  1101. with open(file_path, "ab+") as file:
  1102. async for data in response.content.iter_chunked(chunk_size):
  1103. current_size += len(data)
  1104. file.write(data)
  1105. done = current_size == total_size
  1106. progress = round((current_size / total_size) * 100, 2)
  1107. yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
  1108. if done:
  1109. file.seek(0)
  1110. hashed = calculate_sha256(file)
  1111. file.seek(0)
  1112. url = f"{ollama_url}/api/blobs/sha256:{hashed}"
  1113. response = requests.post(url, data=file)
  1114. if response.ok:
  1115. res = {
  1116. "done": done,
  1117. "blob": f"sha256:{hashed}",
  1118. "name": file_name,
  1119. }
  1120. os.remove(file_path)
  1121. yield f"data: {json.dumps(res)}\n\n"
  1122. else:
  1123. raise "Ollama: Could not create blob, Please try again."
  1124. # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
  1125. @router.post("/models/download")
  1126. @router.post("/models/download/{url_idx}")
  1127. async def download_model(
  1128. request: Request,
  1129. form_data: UrlForm,
  1130. url_idx: Optional[int] = None,
  1131. user=Depends(get_admin_user),
  1132. ):
  1133. allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
  1134. if not any(form_data.url.startswith(host) for host in allowed_hosts):
  1135. raise HTTPException(
  1136. status_code=400,
  1137. detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
  1138. )
  1139. if url_idx is None:
  1140. url_idx = 0
  1141. url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1142. file_name = parse_huggingface_url(form_data.url)
  1143. if file_name:
  1144. file_path = f"{UPLOAD_DIR}/{file_name}"
  1145. return StreamingResponse(
  1146. download_file_stream(url, form_data.url, file_path, file_name),
  1147. )
  1148. else:
  1149. return None
  1150. # TODO: Progress bar does not reflect size & duration of upload.
  1151. @router.post("/models/upload")
  1152. @router.post("/models/upload/{url_idx}")
  1153. async def upload_model(
  1154. request: Request,
  1155. file: UploadFile = File(...),
  1156. url_idx: Optional[int] = None,
  1157. user=Depends(get_admin_user),
  1158. ):
  1159. if url_idx is None:
  1160. url_idx = 0
  1161. ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
  1162. file_path = os.path.join(UPLOAD_DIR, file.filename)
  1163. os.makedirs(UPLOAD_DIR, exist_ok=True)
  1164. # --- P1: save file locally ---
  1165. chunk_size = 1024 * 1024 * 2 # 2 MB chunks
  1166. with open(file_path, "wb") as out_f:
  1167. while True:
  1168. chunk = file.file.read(chunk_size)
  1169. # log.info(f"Chunk: {str(chunk)}") # DEBUG
  1170. if not chunk:
  1171. break
  1172. out_f.write(chunk)
  1173. async def file_process_stream():
  1174. nonlocal ollama_url
  1175. total_size = os.path.getsize(file_path)
  1176. log.info(f"Total Model Size: {str(total_size)}") # DEBUG
  1177. # --- P2: SSE progress + calculate sha256 hash ---
  1178. file_hash = calculate_sha256(file_path, chunk_size)
  1179. log.info(f"Model Hash: {str(file_hash)}") # DEBUG
  1180. try:
  1181. with open(file_path, "rb") as f:
  1182. bytes_read = 0
  1183. while chunk := f.read(chunk_size):
  1184. bytes_read += len(chunk)
  1185. progress = round(bytes_read / total_size * 100, 2)
  1186. data_msg = {
  1187. "progress": progress,
  1188. "total": total_size,
  1189. "completed": bytes_read,
  1190. }
  1191. yield f"data: {json.dumps(data_msg)}\n\n"
  1192. # --- P3: Upload to ollama /api/blobs ---
  1193. with open(file_path, "rb") as f:
  1194. url = f"{ollama_url}/api/blobs/sha256:{file_hash}"
  1195. response = requests.post(url, data=f)
  1196. if response.ok:
  1197. log.info(f"Uploaded to /api/blobs") # DEBUG
  1198. # Remove local file
  1199. os.remove(file_path)
  1200. # Create model in ollama
  1201. model_name, ext = os.path.splitext(file.filename)
  1202. log.info(f"Created Model: {model_name}") # DEBUG
  1203. create_payload = {
  1204. "model": model_name,
  1205. # Reference the file by its original name => the uploaded blob's digest
  1206. "files": {file.filename: f"sha256:{file_hash}"},
  1207. }
  1208. log.info(f"Model Payload: {create_payload}") # DEBUG
  1209. # Call ollama /api/create
  1210. # https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
  1211. create_resp = requests.post(
  1212. url=f"{ollama_url}/api/create",
  1213. headers={"Content-Type": "application/json"},
  1214. data=json.dumps(create_payload),
  1215. )
  1216. if create_resp.ok:
  1217. log.info(f"API SUCCESS!") # DEBUG
  1218. done_msg = {
  1219. "done": True,
  1220. "blob": f"sha256:{file_hash}",
  1221. "name": file.filename,
  1222. "model_created": model_name,
  1223. }
  1224. yield f"data: {json.dumps(done_msg)}\n\n"
  1225. else:
  1226. raise Exception(
  1227. f"Failed to create model in Ollama. {create_resp.text}"
  1228. )
  1229. else:
  1230. raise Exception("Ollama: Could not create blob, Please try again.")
  1231. except Exception as e:
  1232. res = {"error": str(e)}
  1233. yield f"data: {json.dumps(res)}\n\n"
  1234. return StreamingResponse(file_process_stream(), media_type="text/event-stream")