Ver código fonte

feat: pipelines filter outlet

Timothy J. Baek 11 meses atrás
pai
commit
ef8d84296e
3 arquivos alterados com 156 adições e 4 exclusões
  1. 70 2
      backend/main.py
  2. 39 0
      src/lib/apis/index.ts
  3. 47 2
      src/lib/components/chat/Chat.svelte

+ 70 - 2
backend/main.py

@@ -141,7 +141,8 @@ class RAGMiddleware(BaseHTTPMiddleware):
         return_citations = False
 
         if request.method == "POST" and (
-            "/api/chat" in request.url.path or "/chat/completions" in request.url.path
+            "/ollama/api/chat" in request.url.path
+            or "/chat/completions" in request.url.path
         ):
             log.debug(f"request.url.path: {request.url.path}")
 
@@ -229,7 +230,8 @@ app.add_middleware(RAGMiddleware)
 class PipelineMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
         if request.method == "POST" and (
-            "/api/chat" in request.url.path or "/chat/completions" in request.url.path
+            "/ollama/api/chat" in request.url.path
+            or "/chat/completions" in request.url.path
         ):
             log.debug(f"request.url.path: {request.url.path}")
 
@@ -308,6 +310,9 @@ class PipelineMiddleware(BaseHTTPMiddleware):
                     else:
                         pass
 
+            if "chat_id" in data:
+                del data["chat_id"]
+
             modified_body_bytes = json.dumps(data).encode("utf-8")
             # Replace the request body with the modified one
             request._body = modified_body_bytes
@@ -464,6 +469,69 @@ async def get_models(user=Depends(get_verified_user)):
     return {"data": models}
 
 
+@app.post("/api/chat/completed")
+async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
+    data = form_data
+    model_id = data["model"]
+
+    filters = [
+        model
+        for model in app.state.MODELS.values()
+        if "pipeline" in model
+        and "type" in model["pipeline"]
+        and model["pipeline"]["type"] == "filter"
+        and (
+            model["pipeline"]["pipelines"] == ["*"]
+            or any(
+                model_id == target_model_id
+                for target_model_id in model["pipeline"]["pipelines"]
+            )
+        )
+    ]
+    sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
+
+    for filter in sorted_filters:
+        r = None
+        try:
+            urlIdx = filter["urlIdx"]
+
+            url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
+            key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
+
+            if key != "":
+                headers = {"Authorization": f"Bearer {key}"}
+                r = requests.post(
+                    f"{url}/{filter['id']}/filter/outlet",
+                    headers=headers,
+                    json={
+                        "user": {"id": user.id, "name": user.name, "role": user.role},
+                        "body": data,
+                    },
+                )
+
+                r.raise_for_status()
+                data = r.json()
+        except Exception as e:
+            # Handle connection error here
+            print(f"Connection error: {e}")
+
+            if r is not None:
+                try:
+                    res = r.json()
+                    if "detail" in res:
+                        return JSONResponse(
+                            status_code=r.status_code,
+                            content=res,
+                        )
+                except:
+                    pass
+
+            else:
+                pass
+
+    return data
+
+
 @app.get("/api/pipelines/list")
 async def get_pipelines_list(user=Depends(get_admin_user)):
     responses = await get_openai_models(raw=True)

+ 39 - 0
src/lib/apis/index.ts

@@ -49,6 +49,45 @@ export const getModels = async (token: string = '') => {
 	return models;
 };
 
+type ChatCompletedForm = {
+	model: string;
+	messages: string[];
+	chat_id: string;
+};
+
+export const chatCompleted = async (token: string, body: ChatCompletedForm) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_BASE_URL}/api/chat/completed`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		},
+		body: JSON.stringify(body)
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = err;
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const getPipelinesList = async (token: string = '') => {
 	let error = null;
 

+ 47 - 2
src/lib/components/chat/Chat.svelte

@@ -48,6 +48,7 @@
 	import { runWebSearch } from '$lib/apis/rag';
 	import Banner from '../common/Banner.svelte';
 	import { getUserSettings } from '$lib/apis/users';
+	import { chatCompleted } from '$lib/apis';
 
 	const i18n: Writable<i18nType> = getContext('i18n');
 
@@ -576,7 +577,8 @@
 			format: $settings.requestFormat ?? undefined,
 			keep_alive: $settings.keepAlive ?? undefined,
 			docs: docs.length > 0 ? docs : undefined,
-			citations: docs.length > 0
+			citations: docs.length > 0,
+			chat_id: $chatId
 		});
 
 		if (res && res.ok) {
@@ -596,6 +598,27 @@
 					if (stopResponseFlag) {
 						controller.abort('User: Stop Response');
 						await cancelOllamaRequest(localStorage.token, currentRequestId);
+					} else {
+						const res = await chatCompleted(localStorage.token, {
+							model: model,
+							messages: messages.map((m) => ({
+								id: m.id,
+								role: m.role,
+								content: m.content,
+								timestamp: m.timestamp
+							})),
+							chat_id: $chatId
+						}).catch((error) => {
+							console.error(error);
+							return null;
+						});
+
+						if (res !== null) {
+							// Update chat history with the new messages
+							for (const message of res.messages) {
+								history.messages[message.id] = { ...history.messages[message.id], ...message };
+							}
+						}
 					}
 
 					currentRequestId = null;
@@ -829,7 +852,8 @@
 					frequency_penalty: $settings?.params?.frequency_penalty ?? undefined,
 					max_tokens: $settings?.params?.max_tokens ?? undefined,
 					docs: docs.length > 0 ? docs : undefined,
-					citations: docs.length > 0
+					citations: docs.length > 0,
+					chat_id: $chatId
 				},
 				`${OPENAI_API_BASE_URL}`
 			);
@@ -855,6 +879,27 @@
 
 						if (stopResponseFlag) {
 							controller.abort('User: Stop Response');
+						} else {
+							const res = await chatCompleted(localStorage.token, {
+								model: model,
+								messages: messages.map((m) => ({
+									id: m.id,
+									role: m.role,
+									content: m.content,
+									timestamp: m.timestamp
+								})),
+								chat_id: $chatId
+							}).catch((error) => {
+								console.error(error);
+								return null;
+							});
+
+							if (res !== null) {
+								// Update chat history with the new messages
+								for (const message of res.messages) {
+									history.messages[message.id] = { ...history.messages[message.id], ...message };
+								}
+							}
 						}
 
 						break;