|
@@ -1009,9 +1009,12 @@ async def get_body_and_model_and_user(request, models):
|
|
|
|
|
|
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
|
- if not request.method == "POST" and any(
|
|
|
- endpoint in request.url.path
|
|
|
- for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
|
|
+ if not (
|
|
|
+ request.method == "POST"
|
|
|
+ and any(
|
|
|
+ endpoint in request.url.path
|
|
|
+ for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
|
|
+ )
|
|
|
):
|
|
|
return await call_next(request)
|
|
|
log.debug(f"request.url.path: {request.url.path}")
|
|
@@ -1214,9 +1217,12 @@ app.add_middleware(ChatCompletionMiddleware)
|
|
|
|
|
|
class PipelineMiddleware(BaseHTTPMiddleware):
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
|
- if not request.method == "POST" and any(
|
|
|
- endpoint in request.url.path
|
|
|
- for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
|
|
+ if not (
|
|
|
+ request.method == "POST"
|
|
|
+ and any(
|
|
|
+ endpoint in request.url.path
|
|
|
+ for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
|
|
+ )
|
|
|
):
|
|
|
return await call_next(request)
|
|
|
|
|
@@ -1664,17 +1670,17 @@ async def generate_function_chat_completion(form_data, user, models: dict = {}):
|
|
|
return openai_chat_completion_message_template(form_data["model"], message)
|
|
|
|
|
|
|
|
|
-async def get_all_base_models():
|
|
|
+async def get_all_base_models(request):
|
|
|
function_models = []
|
|
|
openai_models = []
|
|
|
ollama_models = []
|
|
|
|
|
|
if app.state.config.ENABLE_OPENAI_API:
|
|
|
- openai_models = await openai.get_all_models()
|
|
|
+ openai_models = await openai.get_all_models(request)
|
|
|
openai_models = openai_models["data"]
|
|
|
|
|
|
if app.state.config.ENABLE_OLLAMA_API:
|
|
|
- ollama_models = await ollama.get_all_models()
|
|
|
+ ollama_models = await ollama.get_all_models(request)
|
|
|
ollama_models = [
|
|
|
{
|
|
|
"id": model["model"],
|
|
@@ -1729,8 +1735,8 @@ async def get_all_base_models():
|
|
|
|
|
|
|
|
|
@cached(ttl=3)
|
|
|
-async def get_all_models():
|
|
|
- models = await get_all_base_models()
|
|
|
+async def get_all_models(request):
|
|
|
+ models = await get_all_base_models(request)
|
|
|
|
|
|
# If there are no models, return an empty list
|
|
|
if len([model for model in models if not model.get("arena", False)]) == 0:
|
|
@@ -1859,8 +1865,8 @@ async def get_all_models():
|
|
|
|
|
|
|
|
|
@app.get("/api/models")
|
|
|
-async def get_models(user=Depends(get_verified_user)):
|
|
|
- models = await get_all_models()
|
|
|
+async def get_models(request: Request, user=Depends(get_verified_user)):
|
|
|
+ models = await get_all_models(request)
|
|
|
|
|
|
# Filter out filter pipelines
|
|
|
models = [
|
|
@@ -2042,7 +2048,7 @@ async def generate_chat_completions(
|
|
|
async def chat_completed(
|
|
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
|
|
):
|
|
|
- model_list = await get_all_models()
|
|
|
+ model_list = await get_all_models(request)
|
|
|
models = {model["id"]: model for model in model_list}
|
|
|
|
|
|
data = form_data
|