|
@@ -1,6 +1,6 @@
|
|
-from fastapi import FastAPI, Request, Response, HTTPException, Depends
|
|
|
|
|
|
+from fastapi import FastAPI, Request, HTTPException, Depends
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
-from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
|
|
|
|
|
|
+from fastapi.responses import StreamingResponse, FileResponse
|
|
|
|
|
|
import requests
|
|
import requests
|
|
import aiohttp
|
|
import aiohttp
|
|
@@ -12,16 +12,12 @@ from pydantic import BaseModel
|
|
from starlette.background import BackgroundTask
|
|
from starlette.background import BackgroundTask
|
|
|
|
|
|
from apps.webui.models.models import Models
|
|
from apps.webui.models.models import Models
|
|
-from apps.webui.models.users import Users
|
|
|
|
from constants import ERROR_MESSAGES
|
|
from constants import ERROR_MESSAGES
|
|
from utils.utils import (
|
|
from utils.utils import (
|
|
- decode_token,
|
|
|
|
- get_verified_user,
|
|
|
|
get_verified_user,
|
|
get_verified_user,
|
|
get_admin_user,
|
|
get_admin_user,
|
|
)
|
|
)
|
|
-from utils.task import prompt_template
|
|
|
|
-from utils.misc import add_or_update_system_message
|
|
|
|
|
|
+from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body
|
|
|
|
|
|
from config import (
|
|
from config import (
|
|
SRC_LOG_LEVELS,
|
|
SRC_LOG_LEVELS,
|
|
@@ -69,8 +65,6 @@ app.state.MODELS = {}
|
|
async def check_url(request: Request, call_next):
|
|
async def check_url(request: Request, call_next):
|
|
if len(app.state.MODELS) == 0:
|
|
if len(app.state.MODELS) == 0:
|
|
await get_all_models()
|
|
await get_all_models()
|
|
- else:
|
|
|
|
- pass
|
|
|
|
|
|
|
|
response = await call_next(request)
|
|
response = await call_next(request)
|
|
return response
|
|
return response
|
|
@@ -175,7 +169,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|
res = r.json()
|
|
res = r.json()
|
|
if "error" in res:
|
|
if "error" in res:
|
|
error_detail = f"External: {res['error']}"
|
|
error_detail = f"External: {res['error']}"
|
|
- except:
|
|
|
|
|
|
+ except Exception:
|
|
error_detail = f"External: {e}"
|
|
error_detail = f"External: {e}"
|
|
|
|
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
@@ -234,64 +228,58 @@ def merge_models_lists(model_lists):
|
|
return merged_list
|
|
return merged_list
|
|
|
|
|
|
|
|
|
|
-async def get_all_models(raw: bool = False):
|
|
|
|
|
|
+def is_openai_api_disabled():
|
|
|
|
+ api_keys = app.state.config.OPENAI_API_KEYS
|
|
|
|
+ no_keys = len(api_keys) == 1 and api_keys[0] == ""
|
|
|
|
+ return no_keys or not app.state.config.ENABLE_OPENAI_API
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+async def get_all_models_raw() -> list:
|
|
|
|
+ if is_openai_api_disabled():
|
|
|
|
+ return []
|
|
|
|
+
|
|
|
|
+ # Check if API KEYS length is same than API URLS length
|
|
|
|
+ num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
|
|
|
|
+ num_keys = len(app.state.config.OPENAI_API_KEYS)
|
|
|
|
+
|
|
|
|
+ if num_keys != num_urls:
|
|
|
|
+ # if there are more keys than urls, remove the extra keys
|
|
|
|
+ if num_keys > num_urls:
|
|
|
|
+ new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
|
|
|
|
+ app.state.config.OPENAI_API_KEYS = new_keys
|
|
|
|
+ # if there are more urls than keys, add empty keys
|
|
|
|
+ else:
|
|
|
|
+ app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
|
|
|
|
+
|
|
|
|
+ tasks = [
|
|
|
|
+ fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
|
|
|
|
+ for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ responses = await asyncio.gather(*tasks)
|
|
|
|
+ log.debug(f"get_all_models:responses() {responses}")
|
|
|
|
+
|
|
|
|
+ return responses
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+async def get_all_models() -> dict[str, list]:
|
|
log.info("get_all_models()")
|
|
log.info("get_all_models()")
|
|
|
|
+ if is_openai_api_disabled():
|
|
|
|
+ return {"data": []}
|
|
|
|
|
|
- if (
|
|
|
|
- len(app.state.config.OPENAI_API_KEYS) == 1
|
|
|
|
- and app.state.config.OPENAI_API_KEYS[0] == ""
|
|
|
|
- ) or not app.state.config.ENABLE_OPENAI_API:
|
|
|
|
- models = {"data": []}
|
|
|
|
- else:
|
|
|
|
- # Check if API KEYS length is same than API URLS length
|
|
|
|
- if len(app.state.config.OPENAI_API_KEYS) != len(
|
|
|
|
- app.state.config.OPENAI_API_BASE_URLS
|
|
|
|
- ):
|
|
|
|
- # if there are more keys than urls, remove the extra keys
|
|
|
|
- if len(app.state.config.OPENAI_API_KEYS) > len(
|
|
|
|
- app.state.config.OPENAI_API_BASE_URLS
|
|
|
|
- ):
|
|
|
|
- app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
|
|
|
|
- : len(app.state.config.OPENAI_API_BASE_URLS)
|
|
|
|
- ]
|
|
|
|
- # if there are more urls than keys, add empty keys
|
|
|
|
- else:
|
|
|
|
- app.state.config.OPENAI_API_KEYS += [
|
|
|
|
- ""
|
|
|
|
- for _ in range(
|
|
|
|
- len(app.state.config.OPENAI_API_BASE_URLS)
|
|
|
|
- - len(app.state.config.OPENAI_API_KEYS)
|
|
|
|
- )
|
|
|
|
- ]
|
|
|
|
|
|
+ responses = await get_all_models_raw()
|
|
|
|
|
|
- tasks = [
|
|
|
|
- fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
|
|
|
|
- for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
|
|
|
|
- ]
|
|
|
|
-
|
|
|
|
- responses = await asyncio.gather(*tasks)
|
|
|
|
- log.debug(f"get_all_models:responses() {responses}")
|
|
|
|
-
|
|
|
|
- if raw:
|
|
|
|
- return responses
|
|
|
|
-
|
|
|
|
- models = {
|
|
|
|
- "data": merge_models_lists(
|
|
|
|
- list(
|
|
|
|
- map(
|
|
|
|
- lambda response: (
|
|
|
|
- response["data"]
|
|
|
|
- if (response and "data" in response)
|
|
|
|
- else (response if isinstance(response, list) else None)
|
|
|
|
- ),
|
|
|
|
- responses,
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
- }
|
|
|
|
|
|
+ def extract_data(response):
|
|
|
|
+ if response and "data" in response:
|
|
|
|
+ return response["data"]
|
|
|
|
+ if isinstance(response, list):
|
|
|
|
+ return response
|
|
|
|
+ return None
|
|
|
|
|
|
- log.debug(f"models: {models}")
|
|
|
|
- app.state.MODELS = {model["id"]: model for model in models["data"]}
|
|
|
|
|
|
+ models = {"data": merge_models_lists(map(extract_data, responses))}
|
|
|
|
+
|
|
|
|
+ log.debug(f"models: {models}")
|
|
|
|
+ app.state.MODELS = {model["id"]: model for model in models["data"]}
|
|
|
|
|
|
return models
|
|
return models
|
|
|
|
|
|
@@ -299,7 +287,7 @@ async def get_all_models(raw: bool = False):
|
|
@app.get("/models")
|
|
@app.get("/models")
|
|
@app.get("/models/{url_idx}")
|
|
@app.get("/models/{url_idx}")
|
|
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
|
|
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
|
|
- if url_idx == None:
|
|
|
|
|
|
+ if url_idx is None:
|
|
models = await get_all_models()
|
|
models = await get_all_models()
|
|
if app.state.config.ENABLE_MODEL_FILTER:
|
|
if app.state.config.ENABLE_MODEL_FILTER:
|
|
if user.role == "user":
|
|
if user.role == "user":
|
|
@@ -340,7 +328,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
|
|
res = r.json()
|
|
res = r.json()
|
|
if "error" in res:
|
|
if "error" in res:
|
|
error_detail = f"External: {res['error']}"
|
|
error_detail = f"External: {res['error']}"
|
|
- except:
|
|
|
|
|
|
+ except Exception:
|
|
error_detail = f"External: {e}"
|
|
error_detail = f"External: {e}"
|
|
|
|
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
@@ -358,8 +346,7 @@ async def generate_chat_completion(
|
|
):
|
|
):
|
|
idx = 0
|
|
idx = 0
|
|
payload = {**form_data}
|
|
payload = {**form_data}
|
|
- if "metadata" in payload:
|
|
|
|
- del payload["metadata"]
|
|
|
|
|
|
+ payload.pop("metadata")
|
|
|
|
|
|
model_id = form_data.get("model")
|
|
model_id = form_data.get("model")
|
|
model_info = Models.get_model_by_id(model_id)
|
|
model_info = Models.get_model_by_id(model_id)
|
|
@@ -368,70 +355,9 @@ async def generate_chat_completion(
|
|
if model_info.base_model_id:
|
|
if model_info.base_model_id:
|
|
payload["model"] = model_info.base_model_id
|
|
payload["model"] = model_info.base_model_id
|
|
|
|
|
|
- model_info.params = model_info.params.model_dump()
|
|
|
|
-
|
|
|
|
- if model_info.params:
|
|
|
|
- if (
|
|
|
|
- model_info.params.get("temperature", None) is not None
|
|
|
|
- and payload.get("temperature") is None
|
|
|
|
- ):
|
|
|
|
- payload["temperature"] = float(model_info.params.get("temperature"))
|
|
|
|
-
|
|
|
|
- if model_info.params.get("top_p", None) and payload.get("top_p") is None:
|
|
|
|
- payload["top_p"] = int(model_info.params.get("top_p", None))
|
|
|
|
-
|
|
|
|
- if (
|
|
|
|
- model_info.params.get("max_tokens", None)
|
|
|
|
- and payload.get("max_tokens") is None
|
|
|
|
- ):
|
|
|
|
- payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
|
|
|
|
-
|
|
|
|
- if (
|
|
|
|
- model_info.params.get("frequency_penalty", None)
|
|
|
|
- and payload.get("frequency_penalty") is None
|
|
|
|
- ):
|
|
|
|
- payload["frequency_penalty"] = int(
|
|
|
|
- model_info.params.get("frequency_penalty", None)
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- if (
|
|
|
|
- model_info.params.get("seed", None) is not None
|
|
|
|
- and payload.get("seed") is None
|
|
|
|
- ):
|
|
|
|
- payload["seed"] = model_info.params.get("seed", None)
|
|
|
|
-
|
|
|
|
- if model_info.params.get("stop", None) and payload.get("stop") is None:
|
|
|
|
- payload["stop"] = (
|
|
|
|
- [
|
|
|
|
- bytes(stop, "utf-8").decode("unicode_escape")
|
|
|
|
- for stop in model_info.params["stop"]
|
|
|
|
- ]
|
|
|
|
- if model_info.params.get("stop", None)
|
|
|
|
- else None
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- system = model_info.params.get("system", None)
|
|
|
|
- if system:
|
|
|
|
- system = prompt_template(
|
|
|
|
- system,
|
|
|
|
- **(
|
|
|
|
- {
|
|
|
|
- "user_name": user.name,
|
|
|
|
- "user_location": (
|
|
|
|
- user.info.get("location") if user.info else None
|
|
|
|
- ),
|
|
|
|
- }
|
|
|
|
- if user
|
|
|
|
- else {}
|
|
|
|
- ),
|
|
|
|
- )
|
|
|
|
- if payload.get("messages"):
|
|
|
|
- payload["messages"] = add_or_update_system_message(
|
|
|
|
- system, payload["messages"]
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- else:
|
|
|
|
- pass
|
|
|
|
|
|
+ params = model_info.params.model_dump()
|
|
|
|
+ payload = apply_model_params_to_body(params, payload)
|
|
|
|
+ payload = apply_model_system_prompt_to_body(params, payload, user)
|
|
|
|
|
|
model = app.state.MODELS[payload.get("model")]
|
|
model = app.state.MODELS[payload.get("model")]
|
|
idx = model["urlIdx"]
|
|
idx = model["urlIdx"]
|
|
@@ -506,7 +432,7 @@ async def generate_chat_completion(
|
|
print(res)
|
|
print(res)
|
|
if "error" in res:
|
|
if "error" in res:
|
|
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
|
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
|
- except:
|
|
|
|
|
|
+ except Exception:
|
|
error_detail = f"External: {e}"
|
|
error_detail = f"External: {e}"
|
|
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
|
|
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
|
|
finally:
|
|
finally:
|
|
@@ -569,7 +495,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|
print(res)
|
|
print(res)
|
|
if "error" in res:
|
|
if "error" in res:
|
|
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
|
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
|
- except:
|
|
|
|
|
|
+ except Exception:
|
|
error_detail = f"External: {e}"
|
|
error_detail = f"External: {e}"
|
|
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
|
|
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
|
|
finally:
|
|
finally:
|