Timothy J. Baek 10 mesi fa
parent
commit
448ca9d836
2 ha cambiato i file con 132 aggiunte e 89 eliminazioni
  1. 127 85
      backend/main.py
  2. 5 4
      src/lib/components/workspace/Functions/FunctionEditor.svelte

+ 127 - 85
backend/main.py

@@ -170,6 +170,13 @@ app.state.MODELS = {}
 origins = ["*"]
 origins = ["*"]
 
 
 
 
+##################################
+#
+# ChatCompletion Middleware
+#
+##################################
+
+
 async def get_function_call_response(
 async def get_function_call_response(
     messages, files, tool_id, template, task_model_id, user
     messages, files, tool_id, template, task_model_id, user
 ):
 ):
@@ -469,6 +476,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
 
 app.add_middleware(ChatCompletionMiddleware)
 app.add_middleware(ChatCompletionMiddleware)
 
 
+##################################
+#
+# Pipeline Middleware
+#
+##################################
+
 
 
 def filter_pipeline(payload, user):
 def filter_pipeline(payload, user):
     user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
     user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
@@ -628,7 +641,6 @@ async def update_embedding_function(request: Request, call_next):
 
 
 app.mount("/ws", socket_app)
 app.mount("/ws", socket_app)
 
 
-
 app.mount("/ollama", ollama_app)
 app.mount("/ollama", ollama_app)
 app.mount("/openai", openai_app)
 app.mount("/openai", openai_app)
 
 
@@ -730,6 +742,104 @@ async def get_models(user=Depends(get_verified_user)):
     return {"data": models}
     return {"data": models}
 
 
 
 
+@app.post("/api/chat/completions")
+async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
+    model_id = form_data["model"]
+    if model_id not in app.state.MODELS:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Model not found",
+        )
+
+    model = app.state.MODELS[model_id]
+    print(model)
+
+    if model["owned_by"] == "ollama":
+        return await generate_ollama_chat_completion(form_data, user=user)
+    else:
+        return await generate_openai_chat_completion(form_data, user=user)
+
+
+@app.post("/api/chat/completed")
+async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
+    data = form_data
+    model_id = data["model"]
+
+    filters = [
+        model
+        for model in app.state.MODELS.values()
+        if "pipeline" in model
+        and "type" in model["pipeline"]
+        and model["pipeline"]["type"] == "filter"
+        and (
+            model["pipeline"]["pipelines"] == ["*"]
+            or any(
+                model_id == target_model_id
+                for target_model_id in model["pipeline"]["pipelines"]
+            )
+        )
+    ]
+    sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
+
+    print(model_id)
+
+    if model_id in app.state.MODELS:
+        model = app.state.MODELS[model_id]
+        if "pipeline" in model:
+            sorted_filters = [model] + sorted_filters
+
+    for filter in sorted_filters:
+        r = None
+        try:
+            urlIdx = filter["urlIdx"]
+
+            url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
+            key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
+
+            if key != "":
+                headers = {"Authorization": f"Bearer {key}"}
+                r = requests.post(
+                    f"{url}/{filter['id']}/filter/outlet",
+                    headers=headers,
+                    json={
+                        "user": {"id": user.id, "name": user.name, "role": user.role},
+                        "body": data,
+                    },
+                )
+
+                r.raise_for_status()
+                data = r.json()
+        except Exception as e:
+            # Handle connection error here
+            print(f"Connection error: {e}")
+
+            if r is not None:
+                try:
+                    res = r.json()
+                    if "detail" in res:
+                        return JSONResponse(
+                            status_code=r.status_code,
+                            content=res,
+                        )
+                except:
+                    pass
+
+            else:
+                pass
+
+    return data
+
+
+##################################
+#
+# Task Endpoints
+#
+##################################
+
+
+# TODO: Refactor task API endpoints below into a separate file
+
+
 @app.get("/api/task/config")
 @app.get("/api/task/config")
 async def get_task_config(user=Depends(get_verified_user)):
 async def get_task_config(user=Depends(get_verified_user)):
     return {
     return {
@@ -1015,92 +1125,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
         )
         )
 
 
 
 
-@app.post("/api/chat/completions")
-async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
-    model_id = form_data["model"]
-    if model_id not in app.state.MODELS:
-        raise HTTPException(
-            status_code=status.HTTP_404_NOT_FOUND,
-            detail="Model not found",
-        )
-
-    model = app.state.MODELS[model_id]
-    print(model)
-
-    if model["owned_by"] == "ollama":
-        return await generate_ollama_chat_completion(form_data, user=user)
-    else:
-        return await generate_openai_chat_completion(form_data, user=user)
-
-
-@app.post("/api/chat/completed")
-async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
-    data = form_data
-    model_id = data["model"]
-
-    filters = [
-        model
-        for model in app.state.MODELS.values()
-        if "pipeline" in model
-        and "type" in model["pipeline"]
-        and model["pipeline"]["type"] == "filter"
-        and (
-            model["pipeline"]["pipelines"] == ["*"]
-            or any(
-                model_id == target_model_id
-                for target_model_id in model["pipeline"]["pipelines"]
-            )
-        )
-    ]
-    sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
-
-    print(model_id)
-
-    if model_id in app.state.MODELS:
-        model = app.state.MODELS[model_id]
-        if "pipeline" in model:
-            sorted_filters = [model] + sorted_filters
-
-    for filter in sorted_filters:
-        r = None
-        try:
-            urlIdx = filter["urlIdx"]
-
-            url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
-            key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
-
-            if key != "":
-                headers = {"Authorization": f"Bearer {key}"}
-                r = requests.post(
-                    f"{url}/{filter['id']}/filter/outlet",
-                    headers=headers,
-                    json={
-                        "user": {"id": user.id, "name": user.name, "role": user.role},
-                        "body": data,
-                    },
-                )
-
-                r.raise_for_status()
-                data = r.json()
-        except Exception as e:
-            # Handle connection error here
-            print(f"Connection error: {e}")
+##################################
+#
+# Pipelines Endpoints
+#
+##################################
 
 
-            if r is not None:
-                try:
-                    res = r.json()
-                    if "detail" in res:
-                        return JSONResponse(
-                            status_code=r.status_code,
-                            content=res,
-                        )
-                except:
-                    pass
 
 
-            else:
-                pass
-
-    return data
+# TODO: Refactor pipelines API endpoints below into a separate file
 
 
 
 
 @app.get("/api/pipelines/list")
 @app.get("/api/pipelines/list")
@@ -1423,6 +1455,13 @@ async def update_pipeline_valves(
         )
         )
 
 
 
 
+##################################
+#
+# Config Endpoints
+#
+##################################
+
+
 @app.get("/api/config")
 @app.get("/api/config")
 async def get_app_config():
 async def get_app_config():
     # Checking and Handling the Absence of 'ui' in CONFIG_DATA
     # Checking and Handling the Absence of 'ui' in CONFIG_DATA
@@ -1486,6 +1525,9 @@ async def update_model_filter_config(
     }
     }
 
 
 
 
+# TODO: webhook endpoint should be under config endpoints
+
+
 @app.get("/api/webhook")
 @app.get("/api/webhook")
 async def get_webhook_url(user=Depends(get_admin_user)):
 async def get_webhook_url(user=Depends(get_admin_user)):
     return {
     return {

+ 5 - 4
src/lib/components/workspace/Functions/FunctionEditor.svelte

@@ -30,9 +30,10 @@
 	let boilerplate = `from pydantic import BaseModel
 	let boilerplate = `from pydantic import BaseModel
 from typing import Optional
 from typing import Optional
 
 
+
 class Filter:
 class Filter:
     class Valves(BaseModel):
     class Valves(BaseModel):
-        max_turns: int
+        max_turns: int = 4
         pass
         pass
 
 
     def __init__(self):
     def __init__(self):
@@ -42,14 +43,14 @@ class Filter:
 
 
         # Initialize 'valves' with specific configurations. Using 'Valves' instance helps encapsulate settings,
         # Initialize 'valves' with specific configurations. Using 'Valves' instance helps encapsulate settings,
         # which ensures settings are managed cohesively and not confused with operational flags like 'file_handler'.
         # which ensures settings are managed cohesively and not confused with operational flags like 'file_handler'.
-        self.valves = self.Valves(**{"max_turns": 10})
+        self.valves = self.Valves(**{"max_turns": 2})
         pass
         pass
 
 
     def inlet(self, body: dict, user: Optional[dict] = None) -> dict:
     def inlet(self, body: dict, user: Optional[dict] = None) -> dict:
         # Modify the request body or validate it before processing by the chat completion API.
         # Modify the request body or validate it before processing by the chat completion API.
         # This function is the pre-processor for the API where various checks on the input can be performed.
         # This function is the pre-processor for the API where various checks on the input can be performed.
         # It can also modify the request before sending it to the API.
         # It can also modify the request before sending it to the API.
-        
+
         print("inlet")
         print("inlet")
         print(body)
         print(body)
         print(user)
         print(user)
@@ -65,7 +66,7 @@ class Filter:
 
 
     def outlet(self, body: dict, user: Optional[dict] = None) -> dict:
     def outlet(self, body: dict, user: Optional[dict] = None) -> dict:
         # Modify or analyze the response body after processing by the API.
         # Modify or analyze the response body after processing by the API.
-        # This function is the post-processor for the API, which can be used to modify the response 
+        # This function is the post-processor for the API, which can be used to modify the response
         # or perform additional checks and analytics.
         # or perform additional checks and analytics.
         print(f"outlet")
         print(f"outlet")
         print(body)
         print(body)