|
@@ -345,108 +345,156 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
|
|
|
)
|
|
|
|
|
|
|
|
|
-@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
|
|
-async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|
|
+@app.post("/chat/completions")
|
|
|
+@app.post("/chat/completions/{url_idx}")
|
|
|
+async def generate_chat_completion(
|
|
|
+ form_data: dict,
|
|
|
+ url_idx: Optional[int] = None,
|
|
|
+ user=Depends(get_verified_user),
|
|
|
+):
|
|
|
idx = 0
|
|
|
+ payload = {**form_data}
|
|
|
|
|
|
- body = await request.body()
|
|
|
- # TODO: Remove below after gpt-4-vision fix from Open AI
|
|
|
- # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
|
|
|
+ model_id = form_data.get("model")
|
|
|
+ model_info = Models.get_model_by_id(model_id)
|
|
|
|
|
|
- payload = None
|
|
|
+ if model_info:
|
|
|
+ print(model_info)
|
|
|
+ if model_info.base_model_id:
|
|
|
+ payload["model"] = model_info.base_model_id
|
|
|
|
|
|
- try:
|
|
|
- if "chat/completions" in path:
|
|
|
- body = body.decode("utf-8")
|
|
|
- body = json.loads(body)
|
|
|
+ model_info.params = model_info.params.model_dump()
|
|
|
|
|
|
- payload = {**body}
|
|
|
+ if model_info.params:
|
|
|
+ if model_info.params.get("temperature", None) is not None:
|
|
|
+ payload["temperature"] = float(model_info.params.get("temperature"))
|
|
|
|
|
|
- model_id = body.get("model")
|
|
|
- model_info = Models.get_model_by_id(model_id)
|
|
|
+ if model_info.params.get("top_p", None):
|
|
|
+ payload["top_p"] = int(model_info.params.get("top_p", None))
|
|
|
|
|
|
- if model_info:
|
|
|
- print(model_info)
|
|
|
- if model_info.base_model_id:
|
|
|
- payload["model"] = model_info.base_model_id
|
|
|
+ if model_info.params.get("max_tokens", None):
|
|
|
+ payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
|
|
|
|
|
|
- model_info.params = model_info.params.model_dump()
|
|
|
+ if model_info.params.get("frequency_penalty", None):
|
|
|
+ payload["frequency_penalty"] = int(
|
|
|
+ model_info.params.get("frequency_penalty", None)
|
|
|
+ )
|
|
|
+
|
|
|
+ if model_info.params.get("seed", None):
|
|
|
+ payload["seed"] = model_info.params.get("seed", None)
|
|
|
+
|
|
|
+ if model_info.params.get("stop", None):
|
|
|
+ payload["stop"] = (
|
|
|
+ [
|
|
|
+ bytes(stop, "utf-8").decode("unicode_escape")
|
|
|
+ for stop in model_info.params["stop"]
|
|
|
+ ]
|
|
|
+ if model_info.params.get("stop", None)
|
|
|
+ else None
|
|
|
+ )
|
|
|
|
|
|
- if model_info.params:
|
|
|
- if model_info.params.get("temperature", None) is not None:
|
|
|
- payload["temperature"] = float(
|
|
|
- model_info.params.get("temperature")
|
|
|
+ if model_info.params.get("system", None):
|
|
|
+ # Check if the payload already has a system message
|
|
|
+ # If not, add a system message to the payload
|
|
|
+ if payload.get("messages"):
|
|
|
+ for message in payload["messages"]:
|
|
|
+ if message.get("role") == "system":
|
|
|
+ message["content"] = (
|
|
|
+ model_info.params.get("system", None) + message["content"]
|
|
|
)
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ payload["messages"].insert(
|
|
|
+ 0,
|
|
|
+ {
|
|
|
+ "role": "system",
|
|
|
+ "content": model_info.params.get("system", None),
|
|
|
+ },
|
|
|
+ )
|
|
|
|
|
|
- if model_info.params.get("top_p", None):
|
|
|
- payload["top_p"] = int(model_info.params.get("top_p", None))
|
|
|
+ else:
|
|
|
+ pass
|
|
|
|
|
|
- if model_info.params.get("max_tokens", None):
|
|
|
- payload["max_tokens"] = int(
|
|
|
- model_info.params.get("max_tokens", None)
|
|
|
- )
|
|
|
+ model = app.state.MODELS[payload.get("model")]
|
|
|
+ idx = model["urlIdx"]
|
|
|
|
|
|
- if model_info.params.get("frequency_penalty", None):
|
|
|
- payload["frequency_penalty"] = int(
|
|
|
- model_info.params.get("frequency_penalty", None)
|
|
|
- )
|
|
|
+ if "pipeline" in model and model.get("pipeline"):
|
|
|
+ payload["user"] = {"name": user.name, "id": user.id}
|
|
|
|
|
|
- if model_info.params.get("seed", None):
|
|
|
- payload["seed"] = model_info.params.get("seed", None)
|
|
|
-
|
|
|
- if model_info.params.get("stop", None):
|
|
|
- payload["stop"] = (
|
|
|
- [
|
|
|
- bytes(stop, "utf-8").decode("unicode_escape")
|
|
|
- for stop in model_info.params["stop"]
|
|
|
- ]
|
|
|
- if model_info.params.get("stop", None)
|
|
|
- else None
|
|
|
- )
|
|
|
+ # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
|
|
|
+ # This is a workaround until OpenAI fixes the issue with this model
|
|
|
+ if payload.get("model") == "gpt-4-vision-preview":
|
|
|
+ if "max_tokens" not in payload:
|
|
|
+ payload["max_tokens"] = 4000
|
|
|
+ log.debug("Modified payload:", payload)
|
|
|
|
|
|
- if model_info.params.get("system", None):
|
|
|
- # Check if the payload already has a system message
|
|
|
- # If not, add a system message to the payload
|
|
|
- if payload.get("messages"):
|
|
|
- for message in payload["messages"]:
|
|
|
- if message.get("role") == "system":
|
|
|
- message["content"] = (
|
|
|
- model_info.params.get("system", None)
|
|
|
- + message["content"]
|
|
|
- )
|
|
|
- break
|
|
|
- else:
|
|
|
- payload["messages"].insert(
|
|
|
- 0,
|
|
|
- {
|
|
|
- "role": "system",
|
|
|
- "content": model_info.params.get("system", None),
|
|
|
- },
|
|
|
- )
|
|
|
- else:
|
|
|
- pass
|
|
|
+ # Convert the modified body back to JSON
|
|
|
+ payload = json.dumps(payload)
|
|
|
+
|
|
|
+ print(payload)
|
|
|
|
|
|
- model = app.state.MODELS[payload.get("model")]
|
|
|
+ url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
|
|
+ key = app.state.config.OPENAI_API_KEYS[idx]
|
|
|
|
|
|
- idx = model["urlIdx"]
|
|
|
+ print(payload)
|
|
|
|
|
|
- if "pipeline" in model and model.get("pipeline"):
|
|
|
- payload["user"] = {"name": user.name, "id": user.id}
|
|
|
+ headers = {}
|
|
|
+ headers["Authorization"] = f"Bearer {key}"
|
|
|
+ headers["Content-Type"] = "application/json"
|
|
|
|
|
|
- # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
|
|
|
- # This is a workaround until OpenAI fixes the issue with this model
|
|
|
- if payload.get("model") == "gpt-4-vision-preview":
|
|
|
- if "max_tokens" not in payload:
|
|
|
- payload["max_tokens"] = 4000
|
|
|
- log.debug("Modified payload:", payload)
|
|
|
+ r = None
|
|
|
+ session = None
|
|
|
+ streaming = False
|
|
|
|
|
|
- # Convert the modified body back to JSON
|
|
|
- payload = json.dumps(payload)
|
|
|
+ try:
|
|
|
+ session = aiohttp.ClientSession(trust_env=True)
|
|
|
+ r = await session.request(
|
|
|
+ method="POST",
|
|
|
+ url=f"{url}/chat/completions",
|
|
|
+ data=payload,
|
|
|
+ headers=headers,
|
|
|
+ )
|
|
|
|
|
|
- except json.JSONDecodeError as e:
|
|
|
- log.error("Error loading request body into a dictionary:", e)
|
|
|
+ r.raise_for_status()
|
|
|
|
|
|
- print(payload)
|
|
|
+ # Check if response is SSE
|
|
|
+ if "text/event-stream" in r.headers.get("Content-Type", ""):
|
|
|
+ streaming = True
|
|
|
+ return StreamingResponse(
|
|
|
+ r.content,
|
|
|
+ status_code=r.status,
|
|
|
+ headers=dict(r.headers),
|
|
|
+ background=BackgroundTask(
|
|
|
+ cleanup_response, response=r, session=session
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ response_data = await r.json()
|
|
|
+ return response_data
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(e)
|
|
|
+ error_detail = "Open WebUI: Server Connection Error"
|
|
|
+ if r is not None:
|
|
|
+ try:
|
|
|
+ res = await r.json()
|
|
|
+ print(res)
|
|
|
+ if "error" in res:
|
|
|
+ error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
|
|
+ except:
|
|
|
+ error_detail = f"External: {e}"
|
|
|
+ raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
|
|
|
+ finally:
|
|
|
+ if not streaming and session:
|
|
|
+ if r:
|
|
|
+ r.close()
|
|
|
+ await session.close()
|
|
|
+
|
|
|
+
|
|
|
+@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
|
|
+async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|
|
+ idx = 0
|
|
|
+
|
|
|
+ body = await request.body()
|
|
|
|
|
|
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
|
|
key = app.state.config.OPENAI_API_KEYS[idx]
|
|
@@ -466,7 +514,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|
|
r = await session.request(
|
|
|
method=request.method,
|
|
|
url=target_url,
|
|
|
- data=payload if payload else body,
|
|
|
+ data=body,
|
|
|
headers=headers,
|
|
|
)
|
|
|
|