Browse Source

fix: message type edge case

Timothy J. Baek 1 year ago
parent
commit
d936353da0
3 changed files with 73 additions and 24 deletions
  1. 35 6
      backend/main.py
  2. 19 9
      src/routes/(app)/+page.svelte
  3. 19 9
      src/routes/(app)/c/[id]/+page.svelte

+ 35 - 6
backend/main.py

@@ -85,7 +85,24 @@ class RAGMiddleware(BaseHTTPMiddleware):
                         last_user_message_idx = i
                         break
 
-                query = data["messages"][last_user_message_idx]["content"]
+                user_message = data["messages"][last_user_message_idx]
+
+                if isinstance(user_message["content"], list):
+                    # Handle list content input
+                    content_type = "list"
+                    query = ""
+                    for content_item in user_message["content"]:
+                        if content_item["type"] == "text":
+                            query = content_item["text"]
+                            break
+                elif isinstance(user_message["content"], str):
+                    # Handle text content input
+                    content_type = "text"
+                    query = user_message["content"]
+                else:
+                    # Fallback in case the input does not match expected types
+                    content_type = None
+                    query = ""
 
                 relevant_contexts = []
 
@@ -112,16 +129,28 @@ class RAGMiddleware(BaseHTTPMiddleware):
                     if context:
                         context_string += " ".join(context["documents"][0]) + "\n"
 
-                content = rag_template(
+                ra_content = rag_template(
                     template=rag_app.state.RAG_TEMPLATE,
                     context=context_string,
                     query=query,
                 )
 
-                new_user_message = {
-                    **data["messages"][last_user_message_idx],
-                    "content": content,
-                }
+                if content_type == "list":
+                    new_content = []
+                    for content_item in user_message["content"]:
+                        if content_item["type"] == "text":
+                            # Update the text item's content with ra_content
+                            new_content.append({"type": "text", "text": ra_content})
+                        else:
+                            # Keep other types of content as they are
+                            new_content.append(content_item)
+                    new_user_message = {**user_message, "content": new_content}
+                else:
+                    new_user_message = {
+                        **user_message,
+                        "content": ra_content,
+                    }
+
                 data["messages"][last_user_message_idx] = new_user_message
                 del data["docs"]
 

+ 19 - 9
src/routes/(app)/+page.svelte

@@ -295,15 +295,25 @@
 			...messages
 		]
 			.filter((message) => message)
-			.map((message, idx, arr) => ({
-				role: message.role,
-				content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content,
-				...(message.files && {
-					images: message.files
-						.filter((file) => file.type === 'image')
-						.map((file) => file.url.slice(file.url.indexOf(',') + 1))
-				})
-			}));
+			.map((message, idx, arr) => {
+				// Prepare the base message object
+				const baseMessage = {
+					role: message.role,
+					content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content
+				};
+
+				// Extract and format image URLs if any exist
+				const imageUrls = message.files
+					?.filter((file) => file.type === 'image')
+					.map((file) => file.url.slice(file.url.indexOf(',') + 1));
+
+				// Add images array only if it contains elements
+				if (imageUrls && imageUrls.length > 0) {
+					baseMessage.images = imageUrls;
+				}
+
+				return baseMessage;
+			});
 
 		let lastImageIndex = -1;
 

+ 19 - 9
src/routes/(app)/c/[id]/+page.svelte

@@ -308,15 +308,25 @@
 			...messages
 		]
 			.filter((message) => message)
-			.map((message, idx, arr) => ({
-				role: message.role,
-				content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content,
-				...(message.files && {
-					images: message.files
-						.filter((file) => file.type === 'image')
-						.map((file) => file.url.slice(file.url.indexOf(',') + 1))
-				})
-			}));
+			.map((message, idx, arr) => {
+				// Prepare the base message object
+				const baseMessage = {
+					role: message.role,
+					content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content
+				};
+
+				// Extract and format image URLs if any exist
+				const imageUrls = message.files
+					?.filter((file) => file.type === 'image')
+					.map((file) => file.url.slice(file.url.indexOf(',') + 1));
+
+				// Add images array only if it contains elements
+				if (imageUrls && imageUrls.length > 0) {
+					baseMessage.images = imageUrls;
+				}
+
+				return baseMessage;
+			});
 
 		let lastImageIndex = -1;