|
@@ -29,6 +29,8 @@ import time
|
|
from urllib.parse import urlparse
|
|
from urllib.parse import urlparse
|
|
from typing import Optional, List, Union
|
|
from typing import Optional, List, Union
|
|
|
|
|
|
|
|
+from starlette.background import BackgroundTask
|
|
|
|
+
|
|
from apps.webui.models.models import Models
|
|
from apps.webui.models.models import Models
|
|
from apps.webui.models.users import Users
|
|
from apps.webui.models.users import Users
|
|
from constants import ERROR_MESSAGES
|
|
from constants import ERROR_MESSAGES
|
|
@@ -75,9 +77,6 @@ app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
|
app.state.MODELS = {}
|
|
app.state.MODELS = {}
|
|
|
|
|
|
|
|
|
|
-REQUEST_POOL = []
|
|
|
|
-
|
|
|
|
-
|
|
|
|
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
|
|
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
|
|
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
|
|
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
|
|
# least connections, or least response time for better resource utilization and performance optimization.
|
|
# least connections, or least response time for better resource utilization and performance optimization.
|
|
@@ -132,16 +131,6 @@ async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin
|
|
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
|
|
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
|
|
|
|
|
|
|
|
|
|
-@app.get("/cancel/{request_id}")
|
|
|
|
-async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
|
|
|
|
- if user:
|
|
|
|
- if request_id in REQUEST_POOL:
|
|
|
|
- REQUEST_POOL.remove(request_id)
|
|
|
|
- return True
|
|
|
|
- else:
|
|
|
|
- raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
async def fetch_url(url):
|
|
async def fetch_url(url):
|
|
timeout = aiohttp.ClientTimeout(total=5)
|
|
timeout = aiohttp.ClientTimeout(total=5)
|
|
try:
|
|
try:
|
|
@@ -154,6 +143,45 @@ async def fetch_url(url):
|
|
return None
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
+async def cleanup_response(
|
|
|
|
+ response: Optional[aiohttp.ClientResponse],
|
|
|
|
+ session: Optional[aiohttp.ClientSession],
|
|
|
|
+):
|
|
|
|
+ if response:
|
|
|
|
+ response.close()
|
|
|
|
+ if session:
|
|
|
|
+ await session.close()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+async def post_streaming_url(url, payload):
|
|
|
|
+ r = None
|
|
|
|
+ try:
|
|
|
|
+ session = aiohttp.ClientSession()
|
|
|
|
+ r = await session.post(url, data=payload)
|
|
|
|
+ r.raise_for_status()
|
|
|
|
+
|
|
|
|
+ return StreamingResponse(
|
|
|
|
+ r.content,
|
|
|
|
+ status_code=r.status,
|
|
|
|
+ headers=dict(r.headers),
|
|
|
|
+ background=BackgroundTask(cleanup_response, response=r, session=session),
|
|
|
|
+ )
|
|
|
|
+ except Exception as e:
|
|
|
|
+ error_detail = "Open WebUI: Server Connection Error"
|
|
|
|
+ if r is not None:
|
|
|
|
+ try:
|
|
|
|
+ res = await r.json()
|
|
|
|
+ if "error" in res:
|
|
|
|
+ error_detail = f"Ollama: {res['error']}"
|
|
|
|
+ except:
|
|
|
|
+ error_detail = f"Ollama: {e}"
|
|
|
|
+
|
|
|
|
+ raise HTTPException(
|
|
|
|
+ status_code=r.status if r else 500,
|
|
|
|
+ detail=error_detail,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+
|
|
def merge_models_lists(model_lists):
|
|
def merge_models_lists(model_lists):
|
|
merged_models = {}
|
|
merged_models = {}
|
|
|
|
|
|
@@ -313,65 +341,7 @@ async def pull_model(
|
|
# Admin should be able to pull models from any source
|
|
# Admin should be able to pull models from any source
|
|
payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
|
|
payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
|
|
|
|
|
|
- def get_request():
|
|
|
|
- nonlocal url
|
|
|
|
- nonlocal r
|
|
|
|
-
|
|
|
|
- request_id = str(uuid.uuid4())
|
|
|
|
- try:
|
|
|
|
- REQUEST_POOL.append(request_id)
|
|
|
|
-
|
|
|
|
- def stream_content():
|
|
|
|
- try:
|
|
|
|
- yield json.dumps({"id": request_id, "done": False}) + "\n"
|
|
|
|
-
|
|
|
|
- for chunk in r.iter_content(chunk_size=8192):
|
|
|
|
- if request_id in REQUEST_POOL:
|
|
|
|
- yield chunk
|
|
|
|
- else:
|
|
|
|
- log.warning("User: canceled request")
|
|
|
|
- break
|
|
|
|
- finally:
|
|
|
|
- if hasattr(r, "close"):
|
|
|
|
- r.close()
|
|
|
|
- if request_id in REQUEST_POOL:
|
|
|
|
- REQUEST_POOL.remove(request_id)
|
|
|
|
-
|
|
|
|
- r = requests.request(
|
|
|
|
- method="POST",
|
|
|
|
- url=f"{url}/api/pull",
|
|
|
|
- data=json.dumps(payload),
|
|
|
|
- stream=True,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- r.raise_for_status()
|
|
|
|
-
|
|
|
|
- return StreamingResponse(
|
|
|
|
- stream_content(),
|
|
|
|
- status_code=r.status_code,
|
|
|
|
- headers=dict(r.headers),
|
|
|
|
- )
|
|
|
|
- except Exception as e:
|
|
|
|
- raise e
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- return await run_in_threadpool(get_request)
|
|
|
|
-
|
|
|
|
- except Exception as e:
|
|
|
|
- log.exception(e)
|
|
|
|
- error_detail = "Open WebUI: Server Connection Error"
|
|
|
|
- if r is not None:
|
|
|
|
- try:
|
|
|
|
- res = r.json()
|
|
|
|
- if "error" in res:
|
|
|
|
- error_detail = f"Ollama: {res['error']}"
|
|
|
|
- except:
|
|
|
|
- error_detail = f"Ollama: {e}"
|
|
|
|
-
|
|
|
|
- raise HTTPException(
|
|
|
|
- status_code=r.status_code if r else 500,
|
|
|
|
- detail=error_detail,
|
|
|
|
- )
|
|
|
|
|
|
+ return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
|
|
|
|
|
|
|
|
|
|
class PushModelForm(BaseModel):
|
|
class PushModelForm(BaseModel):
|
|
@@ -399,50 +369,9 @@ async def push_model(
|
|
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
log.debug(f"url: {url}")
|
|
log.debug(f"url: {url}")
|
|
|
|
|
|
- r = None
|
|
|
|
-
|
|
|
|
- def get_request():
|
|
|
|
- nonlocal url
|
|
|
|
- nonlocal r
|
|
|
|
- try:
|
|
|
|
-
|
|
|
|
- def stream_content():
|
|
|
|
- for chunk in r.iter_content(chunk_size=8192):
|
|
|
|
- yield chunk
|
|
|
|
-
|
|
|
|
- r = requests.request(
|
|
|
|
- method="POST",
|
|
|
|
- url=f"{url}/api/push",
|
|
|
|
- data=form_data.model_dump_json(exclude_none=True).encode(),
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- r.raise_for_status()
|
|
|
|
-
|
|
|
|
- return StreamingResponse(
|
|
|
|
- stream_content(),
|
|
|
|
- status_code=r.status_code,
|
|
|
|
- headers=dict(r.headers),
|
|
|
|
- )
|
|
|
|
- except Exception as e:
|
|
|
|
- raise e
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- return await run_in_threadpool(get_request)
|
|
|
|
- except Exception as e:
|
|
|
|
- log.exception(e)
|
|
|
|
- error_detail = "Open WebUI: Server Connection Error"
|
|
|
|
- if r is not None:
|
|
|
|
- try:
|
|
|
|
- res = r.json()
|
|
|
|
- if "error" in res:
|
|
|
|
- error_detail = f"Ollama: {res['error']}"
|
|
|
|
- except:
|
|
|
|
- error_detail = f"Ollama: {e}"
|
|
|
|
-
|
|
|
|
- raise HTTPException(
|
|
|
|
- status_code=r.status_code if r else 500,
|
|
|
|
- detail=error_detail,
|
|
|
|
- )
|
|
|
|
|
|
+ return await post_streaming_url(
|
|
|
|
+ f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
class CreateModelForm(BaseModel):
|
|
class CreateModelForm(BaseModel):
|
|
@@ -461,53 +390,9 @@ async def create_model(
|
|
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
|
|
|
|
- r = None
|
|
|
|
-
|
|
|
|
- def get_request():
|
|
|
|
- nonlocal url
|
|
|
|
- nonlocal r
|
|
|
|
- try:
|
|
|
|
-
|
|
|
|
- def stream_content():
|
|
|
|
- for chunk in r.iter_content(chunk_size=8192):
|
|
|
|
- yield chunk
|
|
|
|
-
|
|
|
|
- r = requests.request(
|
|
|
|
- method="POST",
|
|
|
|
- url=f"{url}/api/create",
|
|
|
|
- data=form_data.model_dump_json(exclude_none=True).encode(),
|
|
|
|
- stream=True,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- r.raise_for_status()
|
|
|
|
-
|
|
|
|
- log.debug(f"r: {r}")
|
|
|
|
-
|
|
|
|
- return StreamingResponse(
|
|
|
|
- stream_content(),
|
|
|
|
- status_code=r.status_code,
|
|
|
|
- headers=dict(r.headers),
|
|
|
|
- )
|
|
|
|
- except Exception as e:
|
|
|
|
- raise e
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- return await run_in_threadpool(get_request)
|
|
|
|
- except Exception as e:
|
|
|
|
- log.exception(e)
|
|
|
|
- error_detail = "Open WebUI: Server Connection Error"
|
|
|
|
- if r is not None:
|
|
|
|
- try:
|
|
|
|
- res = r.json()
|
|
|
|
- if "error" in res:
|
|
|
|
- error_detail = f"Ollama: {res['error']}"
|
|
|
|
- except:
|
|
|
|
- error_detail = f"Ollama: {e}"
|
|
|
|
-
|
|
|
|
- raise HTTPException(
|
|
|
|
- status_code=r.status_code if r else 500,
|
|
|
|
- detail=error_detail,
|
|
|
|
- )
|
|
|
|
|
|
+ return await post_streaming_url(
|
|
|
|
+ f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
class CopyModelForm(BaseModel):
|
|
class CopyModelForm(BaseModel):
|
|
@@ -797,66 +682,9 @@ async def generate_completion(
|
|
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
|
|
|
|
- r = None
|
|
|
|
-
|
|
|
|
- def get_request():
|
|
|
|
- nonlocal form_data
|
|
|
|
- nonlocal r
|
|
|
|
-
|
|
|
|
- request_id = str(uuid.uuid4())
|
|
|
|
- try:
|
|
|
|
- REQUEST_POOL.append(request_id)
|
|
|
|
-
|
|
|
|
- def stream_content():
|
|
|
|
- try:
|
|
|
|
- if form_data.stream:
|
|
|
|
- yield json.dumps({"id": request_id, "done": False}) + "\n"
|
|
|
|
-
|
|
|
|
- for chunk in r.iter_content(chunk_size=8192):
|
|
|
|
- if request_id in REQUEST_POOL:
|
|
|
|
- yield chunk
|
|
|
|
- else:
|
|
|
|
- log.warning("User: canceled request")
|
|
|
|
- break
|
|
|
|
- finally:
|
|
|
|
- if hasattr(r, "close"):
|
|
|
|
- r.close()
|
|
|
|
- if request_id in REQUEST_POOL:
|
|
|
|
- REQUEST_POOL.remove(request_id)
|
|
|
|
-
|
|
|
|
- r = requests.request(
|
|
|
|
- method="POST",
|
|
|
|
- url=f"{url}/api/generate",
|
|
|
|
- data=form_data.model_dump_json(exclude_none=True).encode(),
|
|
|
|
- stream=True,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- r.raise_for_status()
|
|
|
|
-
|
|
|
|
- return StreamingResponse(
|
|
|
|
- stream_content(),
|
|
|
|
- status_code=r.status_code,
|
|
|
|
- headers=dict(r.headers),
|
|
|
|
- )
|
|
|
|
- except Exception as e:
|
|
|
|
- raise e
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- return await run_in_threadpool(get_request)
|
|
|
|
- except Exception as e:
|
|
|
|
- error_detail = "Open WebUI: Server Connection Error"
|
|
|
|
- if r is not None:
|
|
|
|
- try:
|
|
|
|
- res = r.json()
|
|
|
|
- if "error" in res:
|
|
|
|
- error_detail = f"Ollama: {res['error']}"
|
|
|
|
- except:
|
|
|
|
- error_detail = f"Ollama: {e}"
|
|
|
|
-
|
|
|
|
- raise HTTPException(
|
|
|
|
- status_code=r.status_code if r else 500,
|
|
|
|
- detail=error_detail,
|
|
|
|
- )
|
|
|
|
|
|
+ return await post_streaming_url(
|
|
|
|
+ f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
class ChatMessage(BaseModel):
|
|
@@ -981,67 +809,7 @@ async def generate_chat_completion(
|
|
|
|
|
|
print(payload)
|
|
print(payload)
|
|
|
|
|
|
- r = None
|
|
|
|
-
|
|
|
|
- def get_request():
|
|
|
|
- nonlocal payload
|
|
|
|
- nonlocal r
|
|
|
|
-
|
|
|
|
- request_id = str(uuid.uuid4())
|
|
|
|
- try:
|
|
|
|
- REQUEST_POOL.append(request_id)
|
|
|
|
-
|
|
|
|
- def stream_content():
|
|
|
|
- try:
|
|
|
|
- if payload.get("stream", True):
|
|
|
|
- yield json.dumps({"id": request_id, "done": False}) + "\n"
|
|
|
|
-
|
|
|
|
- for chunk in r.iter_content(chunk_size=8192):
|
|
|
|
- if request_id in REQUEST_POOL:
|
|
|
|
- yield chunk
|
|
|
|
- else:
|
|
|
|
- log.warning("User: canceled request")
|
|
|
|
- break
|
|
|
|
- finally:
|
|
|
|
- if hasattr(r, "close"):
|
|
|
|
- r.close()
|
|
|
|
- if request_id in REQUEST_POOL:
|
|
|
|
- REQUEST_POOL.remove(request_id)
|
|
|
|
-
|
|
|
|
- r = requests.request(
|
|
|
|
- method="POST",
|
|
|
|
- url=f"{url}/api/chat",
|
|
|
|
- data=json.dumps(payload),
|
|
|
|
- stream=True,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- r.raise_for_status()
|
|
|
|
-
|
|
|
|
- return StreamingResponse(
|
|
|
|
- stream_content(),
|
|
|
|
- status_code=r.status_code,
|
|
|
|
- headers=dict(r.headers),
|
|
|
|
- )
|
|
|
|
- except Exception as e:
|
|
|
|
- log.exception(e)
|
|
|
|
- raise e
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- return await run_in_threadpool(get_request)
|
|
|
|
- except Exception as e:
|
|
|
|
- error_detail = "Open WebUI: Server Connection Error"
|
|
|
|
- if r is not None:
|
|
|
|
- try:
|
|
|
|
- res = r.json()
|
|
|
|
- if "error" in res:
|
|
|
|
- error_detail = f"Ollama: {res['error']}"
|
|
|
|
- except:
|
|
|
|
- error_detail = f"Ollama: {e}"
|
|
|
|
-
|
|
|
|
- raise HTTPException(
|
|
|
|
- status_code=r.status_code if r else 500,
|
|
|
|
- detail=error_detail,
|
|
|
|
- )
|
|
|
|
|
|
+ return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))
|
|
|
|
|
|
|
|
|
|
# TODO: we should update this part once Ollama supports other types
|
|
# TODO: we should update this part once Ollama supports other types
|
|
@@ -1132,68 +900,7 @@ async def generate_openai_chat_completion(
|
|
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
|
|
|
|
- r = None
|
|
|
|
-
|
|
|
|
- def get_request():
|
|
|
|
- nonlocal payload
|
|
|
|
- nonlocal r
|
|
|
|
-
|
|
|
|
- request_id = str(uuid.uuid4())
|
|
|
|
- try:
|
|
|
|
- REQUEST_POOL.append(request_id)
|
|
|
|
-
|
|
|
|
- def stream_content():
|
|
|
|
- try:
|
|
|
|
- if payload.get("stream"):
|
|
|
|
- yield json.dumps(
|
|
|
|
- {"request_id": request_id, "done": False}
|
|
|
|
- ) + "\n"
|
|
|
|
-
|
|
|
|
- for chunk in r.iter_content(chunk_size=8192):
|
|
|
|
- if request_id in REQUEST_POOL:
|
|
|
|
- yield chunk
|
|
|
|
- else:
|
|
|
|
- log.warning("User: canceled request")
|
|
|
|
- break
|
|
|
|
- finally:
|
|
|
|
- if hasattr(r, "close"):
|
|
|
|
- r.close()
|
|
|
|
- if request_id in REQUEST_POOL:
|
|
|
|
- REQUEST_POOL.remove(request_id)
|
|
|
|
-
|
|
|
|
- r = requests.request(
|
|
|
|
- method="POST",
|
|
|
|
- url=f"{url}/v1/chat/completions",
|
|
|
|
- data=json.dumps(payload),
|
|
|
|
- stream=True,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- r.raise_for_status()
|
|
|
|
-
|
|
|
|
- return StreamingResponse(
|
|
|
|
- stream_content(),
|
|
|
|
- status_code=r.status_code,
|
|
|
|
- headers=dict(r.headers),
|
|
|
|
- )
|
|
|
|
- except Exception as e:
|
|
|
|
- raise e
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- return await run_in_threadpool(get_request)
|
|
|
|
- except Exception as e:
|
|
|
|
- error_detail = "Open WebUI: Server Connection Error"
|
|
|
|
- if r is not None:
|
|
|
|
- try:
|
|
|
|
- res = r.json()
|
|
|
|
- if "error" in res:
|
|
|
|
- error_detail = f"Ollama: {res['error']}"
|
|
|
|
- except:
|
|
|
|
- error_detail = f"Ollama: {e}"
|
|
|
|
-
|
|
|
|
- raise HTTPException(
|
|
|
|
- status_code=r.status_code if r else 500,
|
|
|
|
- detail=error_detail,
|
|
|
|
- )
|
|
|
|
|
|
+ return await post_streaming_url(f"{url}/v1/chat/completions", json.dumps(payload))
|
|
|
|
|
|
|
|
|
|
@app.get("/v1/models")
|
|
@app.get("/v1/models")
|