|
@@ -875,15 +875,88 @@ async def generate_chat_completion(
|
|
url_idx: Optional[int] = None,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_verified_user),
|
|
user=Depends(get_verified_user),
|
|
):
|
|
):
|
|
- model_id = get_model_id_from_custom_model_id(form_data.model)
|
|
|
|
- model = model_id
|
|
|
|
|
|
+
|
|
|
|
+ log.debug(
|
|
|
|
+ "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
|
|
|
|
+ form_data.model_dump_json(exclude_none=True).encode()
|
|
|
|
+ )
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ payload = {
|
|
|
|
+ **form_data.model_dump(exclude_none=True),
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ model_id = form_data.model
|
|
|
|
+ model_info = Models.get_model_by_id(model_id)
|
|
|
|
+
|
|
|
|
+ if model_info:
|
|
|
|
+ print(model_info)
|
|
|
|
+ if model_info.base_model_id:
|
|
|
|
+ payload["model"] = model_info.base_model_id
|
|
|
|
+
|
|
|
|
+ model_info.params = model_info.params.model_dump()
|
|
|
|
+
|
|
|
|
+ if model_info.params:
|
|
|
|
+ payload["options"] = {}
|
|
|
|
+
|
|
|
|
+ payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
|
|
|
|
+ payload["options"]["mirostat_eta"] = model_info.params.get(
|
|
|
|
+ "mirostat_eta", None
|
|
|
|
+ )
|
|
|
|
+ payload["options"]["mirostat_tau"] = model_info.params.get(
|
|
|
|
+ "mirostat_tau", None
|
|
|
|
+ )
|
|
|
|
+ payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
|
|
|
|
+
|
|
|
|
+ payload["options"]["repeat_last_n"] = model_info.params.get(
|
|
|
|
+ "repeat_last_n", None
|
|
|
|
+ )
|
|
|
|
+ payload["options"]["repeat_penalty"] = model_info.params.get(
|
|
|
|
+ "frequency_penalty", None
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ payload["options"]["temperature"] = model_info.params.get(
|
|
|
|
+ "temperature", None
|
|
|
|
+ )
|
|
|
|
+ payload["options"]["seed"] = model_info.params.get("seed", None)
|
|
|
|
+
|
|
|
|
+ # TODO: add "stop" back in
|
|
|
|
+ # payload["stop"] = model_info.params.get("stop", None)
|
|
|
|
+
|
|
|
|
+ payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
|
|
|
|
+
|
|
|
|
+ payload["options"]["num_predict"] = model_info.params.get(
|
|
|
|
+ "max_tokens", None
|
|
|
|
+ )
|
|
|
|
+ payload["options"]["top_k"] = model_info.params.get("top_k", None)
|
|
|
|
+
|
|
|
|
+ payload["options"]["top_p"] = model_info.params.get("top_p", None)
|
|
|
|
+
|
|
|
|
+ if model_info.params.get("system", None):
|
|
|
|
+ # Check if the payload already has a system message
|
|
|
|
+ # If not, add a system message to the payload
|
|
|
|
+ if payload.get("messages"):
|
|
|
|
+ for message in payload["messages"]:
|
|
|
|
+ if message.get("role") == "system":
|
|
|
|
+ message["content"] = (
|
|
|
|
+ model_info.params.get("system", None) + message["content"]
|
|
|
|
+ )
|
|
|
|
+ break
|
|
|
|
+ else:
|
|
|
|
+ payload["messages"].insert(
|
|
|
|
+ 0,
|
|
|
|
+ {
|
|
|
|
+ "role": "system",
|
|
|
|
+ "content": model_info.params.get("system", None),
|
|
|
|
+ },
|
|
|
|
+ )
|
|
|
|
|
|
if url_idx == None:
|
|
if url_idx == None:
|
|
- if ":" not in model:
|
|
|
|
- model = f"{model}:latest"
|
|
|
|
|
|
+ if ":" not in payload["model"]:
|
|
|
|
+ payload["model"] = f"{payload['model']}:latest"
|
|
|
|
|
|
- if model in app.state.MODELS:
|
|
|
|
- url_idx = random.choice(app.state.MODELS[model]["urls"])
|
|
|
|
|
|
+ if payload["model"] in app.state.MODELS:
|
|
|
|
+ url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
|
|
else:
|
|
else:
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=400,
|
|
status_code=400,
|
|
@@ -893,23 +966,12 @@ async def generate_chat_completion(
|
|
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
|
|
|
|
- r = None
|
|
|
|
-
|
|
|
|
- # payload = {
|
|
|
|
- # **form_data.model_dump_json(exclude_none=True).encode(),
|
|
|
|
- # "model": model,
|
|
|
|
- # "messages": form_data.messages,
|
|
|
|
|
|
+ print(payload)
|
|
|
|
|
|
- # }
|
|
|
|
-
|
|
|
|
- log.debug(
|
|
|
|
- "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
|
|
|
|
- form_data.model_dump_json(exclude_none=True).encode()
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
|
|
+ r = None
|
|
|
|
|
|
def get_request():
|
|
def get_request():
|
|
- nonlocal form_data
|
|
|
|
|
|
+ nonlocal payload
|
|
nonlocal r
|
|
nonlocal r
|
|
|
|
|
|
request_id = str(uuid.uuid4())
|
|
request_id = str(uuid.uuid4())
|
|
@@ -918,7 +980,7 @@ async def generate_chat_completion(
|
|
|
|
|
|
def stream_content():
|
|
def stream_content():
|
|
try:
|
|
try:
|
|
- if form_data.stream:
|
|
|
|
|
|
+ if payload.get("stream", None):
|
|
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
|
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
|
|
|
|
|
for chunk in r.iter_content(chunk_size=8192):
|
|
for chunk in r.iter_content(chunk_size=8192):
|
|
@@ -936,7 +998,7 @@ async def generate_chat_completion(
|
|
r = requests.request(
|
|
r = requests.request(
|
|
method="POST",
|
|
method="POST",
|
|
url=f"{url}/api/chat",
|
|
url=f"{url}/api/chat",
|
|
- data=form_data.model_dump_json(exclude_none=True).encode(),
|
|
|
|
|
|
+ data=json.dumps(payload),
|
|
stream=True,
|
|
stream=True,
|
|
)
|
|
)
|
|
|
|
|
|
@@ -992,14 +1054,56 @@ async def generate_openai_chat_completion(
|
|
user=Depends(get_verified_user),
|
|
user=Depends(get_verified_user),
|
|
):
|
|
):
|
|
|
|
|
|
- if url_idx == None:
|
|
|
|
- model = form_data.model
|
|
|
|
|
|
+ payload = {
|
|
|
|
+ **form_data.model_dump(exclude_none=True),
|
|
|
|
+ }
|
|
|
|
|
|
- if ":" not in model:
|
|
|
|
- model = f"{model}:latest"
|
|
|
|
|
|
+ model_id = form_data.model
|
|
|
|
+ model_info = Models.get_model_by_id(model_id)
|
|
|
|
|
|
- if model in app.state.MODELS:
|
|
|
|
- url_idx = random.choice(app.state.MODELS[model]["urls"])
|
|
|
|
|
|
+ if model_info:
|
|
|
|
+ print(model_info)
|
|
|
|
+ if model_info.base_model_id:
|
|
|
|
+ payload["model"] = model_info.base_model_id
|
|
|
|
+
|
|
|
|
+ model_info.params = model_info.params.model_dump()
|
|
|
|
+
|
|
|
|
+ if model_info.params:
|
|
|
|
+ payload["temperature"] = model_info.params.get("temperature", None)
|
|
|
|
+ payload["top_p"] = model_info.params.get("top_p", None)
|
|
|
|
+ payload["max_tokens"] = model_info.params.get("max_tokens", None)
|
|
|
|
+ payload["frequency_penalty"] = model_info.params.get(
|
|
|
|
+ "frequency_penalty", None
|
|
|
|
+ )
|
|
|
|
+ payload["seed"] = model_info.params.get("seed", None)
|
|
|
|
+ # TODO: add "stop" back in
|
|
|
|
+ # payload["stop"] = model_info.params.get("stop", None)
|
|
|
|
+
|
|
|
|
+ if model_info.params.get("system", None):
|
|
|
|
+ # Check if the payload already has a system message
|
|
|
|
+ # If not, add a system message to the payload
|
|
|
|
+ if payload.get("messages"):
|
|
|
|
+ for message in payload["messages"]:
|
|
|
|
+ if message.get("role") == "system":
|
|
|
|
+ message["content"] = (
|
|
|
|
+ model_info.params.get("system", None) + message["content"]
|
|
|
|
+ )
|
|
|
|
+ break
|
|
|
|
+ else:
|
|
|
|
+ payload["messages"].insert(
|
|
|
|
+ 0,
|
|
|
|
+ {
|
|
|
|
+ "role": "system",
|
|
|
|
+ "content": model_info.params.get("system", None),
|
|
|
|
+ },
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ if url_idx == None:
|
|
|
|
+ if ":" not in payload["model"]:
|
|
|
|
+ payload["model"] = f"{payload['model']}:latest"
|
|
|
|
+
|
|
|
|
+ if payload["model"] in app.state.MODELS:
|
|
|
|
+ url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
|
|
else:
|
|
else:
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=400,
|
|
status_code=400,
|
|
@@ -1012,7 +1116,7 @@ async def generate_openai_chat_completion(
|
|
r = None
|
|
r = None
|
|
|
|
|
|
def get_request():
|
|
def get_request():
|
|
- nonlocal form_data
|
|
|
|
|
|
+ nonlocal payload
|
|
nonlocal r
|
|
nonlocal r
|
|
|
|
|
|
request_id = str(uuid.uuid4())
|
|
request_id = str(uuid.uuid4())
|
|
@@ -1021,7 +1125,7 @@ async def generate_openai_chat_completion(
|
|
|
|
|
|
def stream_content():
|
|
def stream_content():
|
|
try:
|
|
try:
|
|
- if form_data.stream:
|
|
|
|
|
|
+ if payload.get("stream"):
|
|
yield json.dumps(
|
|
yield json.dumps(
|
|
{"request_id": request_id, "done": False}
|
|
{"request_id": request_id, "done": False}
|
|
) + "\n"
|
|
) + "\n"
|
|
@@ -1041,7 +1145,7 @@ async def generate_openai_chat_completion(
|
|
r = requests.request(
|
|
r = requests.request(
|
|
method="POST",
|
|
method="POST",
|
|
url=f"{url}/v1/chat/completions",
|
|
url=f"{url}/v1/chat/completions",
|
|
- data=form_data.model_dump_json(exclude_none=True).encode(),
|
|
|
|
|
|
+ data=json.dumps(payload),
|
|
stream=True,
|
|
stream=True,
|
|
)
|
|
)
|
|
|
|
|