瀏覽代碼

feat: preset backend logic

Timothy J. Baek 11 月之前
父節點
當前提交
88d053833d
共有 2 個文件被更改,包括 208 次插入58 次删除
  1. 135 31
      backend/apps/ollama/main.py
  2. 73 27
      backend/apps/openai/main.py

+ 135 - 31
backend/apps/ollama/main.py

@@ -875,15 +875,88 @@ async def generate_chat_completion(
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
 ):
-    model_id = get_model_id_from_custom_model_id(form_data.model)
-    model = model_id
+
+    log.debug(
+        "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
+            form_data.model_dump_json(exclude_none=True).encode()
+        )
+    )
+
+    payload = {
+        **form_data.model_dump(exclude_none=True),
+    }
+
+    model_id = form_data.model
+    model_info = Models.get_model_by_id(model_id)
+
+    if model_info:
+        print(model_info)
+        if model_info.base_model_id:
+            payload["model"] = model_info.base_model_id
+
+        model_info.params = model_info.params.model_dump()
+
+        if model_info.params:
+            payload["options"] = {}
+
+            payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
+            payload["options"]["mirostat_eta"] = model_info.params.get(
+                "mirostat_eta", None
+            )
+            payload["options"]["mirostat_tau"] = model_info.params.get(
+                "mirostat_tau", None
+            )
+            payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
+
+            payload["options"]["repeat_last_n"] = model_info.params.get(
+                "repeat_last_n", None
+            )
+            payload["options"]["repeat_penalty"] = model_info.params.get(
+                "frequency_penalty", None
+            )
+
+            payload["options"]["temperature"] = model_info.params.get(
+                "temperature", None
+            )
+            payload["options"]["seed"] = model_info.params.get("seed", None)
+
+            # TODO: add "stop" back in
+            # payload["stop"] = model_info.params.get("stop", None)
+
+            payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
+
+            payload["options"]["num_predict"] = model_info.params.get(
+                "max_tokens", None
+            )
+            payload["options"]["top_k"] = model_info.params.get("top_k", None)
+
+            payload["options"]["top_p"] = model_info.params.get("top_p", None)
+
+        if model_info.params.get("system", None):
+            # Check if the payload already has a system message
+            # If not, add a system message to the payload
+            if payload.get("messages"):
+                for message in payload["messages"]:
+                    if message.get("role") == "system":
+                        message["content"] = (
+                            model_info.params.get("system", None) + message["content"]
+                        )
+                        break
+                else:
+                    payload["messages"].insert(
+                        0,
+                        {
+                            "role": "system",
+                            "content": model_info.params.get("system", None),
+                        },
+                    )
 
     if url_idx == None:
-        if ":" not in model:
-            model = f"{model}:latest"
+        if ":" not in payload["model"]:
+            payload["model"] = f"{payload['model']}:latest"
 
-        if model in app.state.MODELS:
-            url_idx = random.choice(app.state.MODELS[model]["urls"])
+        if payload["model"] in app.state.MODELS:
+            url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
         else:
             raise HTTPException(
                 status_code=400,
@@ -893,23 +966,12 @@ async def generate_chat_completion(
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
 
-    r = None
-
-    # payload = {
-    #     **form_data.model_dump_json(exclude_none=True).encode(),
-    #     "model": model,
-    #     "messages": form_data.messages,
+    print(payload)
 
-    # }
-
-    log.debug(
-        "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
-            form_data.model_dump_json(exclude_none=True).encode()
-        )
-    )
+    r = None
 
     def get_request():
-        nonlocal form_data
+        nonlocal payload
         nonlocal r
 
         request_id = str(uuid.uuid4())
@@ -918,7 +980,7 @@ async def generate_chat_completion(
 
             def stream_content():
                 try:
-                    if form_data.stream:
+                    if payload.get("stream", None):
                         yield json.dumps({"id": request_id, "done": False}) + "\n"
 
                     for chunk in r.iter_content(chunk_size=8192):
@@ -936,7 +998,7 @@ async def generate_chat_completion(
             r = requests.request(
                 method="POST",
                 url=f"{url}/api/chat",
-                data=form_data.model_dump_json(exclude_none=True).encode(),
+                data=json.dumps(payload),
                 stream=True,
             )
 
@@ -992,14 +1054,56 @@ async def generate_openai_chat_completion(
     user=Depends(get_verified_user),
 ):
 
-    if url_idx == None:
-        model = form_data.model
+    payload = {
+        **form_data.model_dump(exclude_none=True),
+    }
 
-        if ":" not in model:
-            model = f"{model}:latest"
+    model_id = form_data.model
+    model_info = Models.get_model_by_id(model_id)
 
-        if model in app.state.MODELS:
-            url_idx = random.choice(app.state.MODELS[model]["urls"])
+    if model_info:
+        print(model_info)
+        if model_info.base_model_id:
+            payload["model"] = model_info.base_model_id
+
+        model_info.params = model_info.params.model_dump()
+
+        if model_info.params:
+            payload["temperature"] = model_info.params.get("temperature", None)
+            payload["top_p"] = model_info.params.get("top_p", None)
+            payload["max_tokens"] = model_info.params.get("max_tokens", None)
+            payload["frequency_penalty"] = model_info.params.get(
+                "frequency_penalty", None
+            )
+            payload["seed"] = model_info.params.get("seed", None)
+            # TODO: add "stop" back in
+            # payload["stop"] = model_info.params.get("stop", None)
+
+        if model_info.params.get("system", None):
+            # Check if the payload already has a system message
+            # If not, add a system message to the payload
+            if payload.get("messages"):
+                for message in payload["messages"]:
+                    if message.get("role") == "system":
+                        message["content"] = (
+                            model_info.params.get("system", None) + message["content"]
+                        )
+                        break
+                else:
+                    payload["messages"].insert(
+                        0,
+                        {
+                            "role": "system",
+                            "content": model_info.params.get("system", None),
+                        },
+                    )
+
+    if url_idx == None:
+        if ":" not in payload["model"]:
+            payload["model"] = f"{payload['model']}:latest"
+
+        if payload["model"] in app.state.MODELS:
+            url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
         else:
             raise HTTPException(
                 status_code=400,
@@ -1012,7 +1116,7 @@ async def generate_openai_chat_completion(
     r = None
 
     def get_request():
-        nonlocal form_data
+        nonlocal payload
         nonlocal r
 
         request_id = str(uuid.uuid4())
@@ -1021,7 +1125,7 @@ async def generate_openai_chat_completion(
 
             def stream_content():
                 try:
-                    if form_data.stream:
+                    if payload.get("stream"):
                         yield json.dumps(
                             {"request_id": request_id, "done": False}
                         ) + "\n"
@@ -1041,7 +1145,7 @@ async def generate_openai_chat_completion(
             r = requests.request(
                 method="POST",
                 url=f"{url}/v1/chat/completions",
-                data=form_data.model_dump_json(exclude_none=True).encode(),
+                data=json.dumps(payload),
                 stream=True,
             )
 

+ 73 - 27
backend/apps/openai/main.py

@@ -315,41 +315,87 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
     body = await request.body()
     # TODO: Remove below after gpt-4-vision fix from Open AI
     # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
+
+    payload = None
+
     try:
-        body = body.decode("utf-8")
-        body = json.loads(body)
+        if "chat/completions" in path:
+            body = body.decode("utf-8")
+            body = json.loads(body)
 
-        print(app.state.MODELS)
+            payload = {**body}
 
-        model = app.state.MODELS[body.get("model")]
+            model_id = body.get("model")
+            model_info = Models.get_model_by_id(model_id)
 
-        idx = model["urlIdx"]
+            if model_info:
+                print(model_info)
+                if model_info.base_model_id:
+                    payload["model"] = model_info.base_model_id
 
-        if "pipeline" in model and model.get("pipeline"):
-            body["user"] = {"name": user.name, "id": user.id}
-            body["title"] = (
-                True if body["stream"] == False and body["max_tokens"] == 50 else False
-            )
+                model_info.params = model_info.params.model_dump()
+
+                if model_info.params:
+                    payload["temperature"] = model_info.params.get("temperature", None)
+                    payload["top_p"] = model_info.params.get("top_p", None)
+                    payload["max_tokens"] = model_info.params.get("max_tokens", None)
+                    payload["frequency_penalty"] = model_info.params.get(
+                        "frequency_penalty", None
+                    )
+                    payload["seed"] = model_info.params.get("seed", None)
+                    # TODO: add "stop" back in
+                    # payload["stop"] = model_info.params.get("stop", None)
+
+                if model_info.params.get("system", None):
+                    # Check if the payload already has a system message
+                    # If not, add a system message to the payload
+                    if payload.get("messages"):
+                        for message in payload["messages"]:
+                            if message.get("role") == "system":
+                                message["content"] = (
+                                    model_info.params.get("system", None)
+                                    + message["content"]
+                                )
+                                break
+                        else:
+                            payload["messages"].insert(
+                                0,
+                                {
+                                    "role": "system",
+                                    "content": model_info.params.get("system", None),
+                                },
+                            )
+            else:
+                pass
+
+            print(app.state.MODELS)
+            model = app.state.MODELS[payload.get("model")]
+
+            idx = model["urlIdx"]
+
+            if "pipeline" in model and model.get("pipeline"):
+                payload["user"] = {"name": user.name, "id": user.id}
+                payload["title"] = (
+                    True
+                    if payload["stream"] == False and payload["max_tokens"] == 50
+                    else False
+                )
+
+            # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
+            # This is a workaround until OpenAI fixes the issue with this model
+            if payload.get("model") == "gpt-4-vision-preview":
+                if "max_tokens" not in payload:
+                    payload["max_tokens"] = 4000
+                log.debug("Modified payload:", payload)
+
+            # Convert the modified body back to JSON
+            payload = json.dumps(payload)
 
-        # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
-        # This is a workaround until OpenAI fixes the issue with this model
-        if body.get("model") == "gpt-4-vision-preview":
-            if "max_tokens" not in body:
-                body["max_tokens"] = 4000
-            log.debug("Modified body_dict:", body)
-
-        # Fix for ChatGPT calls failing because the num_ctx key is in body
-        if "num_ctx" in body:
-            # If 'num_ctx' is in the dictionary, delete it
-            # Leaving it there generates an error with the
-            # OpenAI API (Feb 2024)
-            del body["num_ctx"]
-
-        # Convert the modified body back to JSON
-        body = json.dumps(body)
     except json.JSONDecodeError as e:
         log.error("Error loading request body into a dictionary:", e)
 
+    print(payload)
+
     url = app.state.config.OPENAI_API_BASE_URLS[idx]
     key = app.state.config.OPENAI_API_KEYS[idx]
 
@@ -368,7 +414,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
         r = requests.request(
             method=request.method,
             url=target_url,
-            data=body,
+            data=payload if payload else body,
             headers=headers,
             stream=True,
         )