Browse Source

fix: ollama rag issue workaround

Timothy J. Baek 9 months ago
parent
commit
1aaa2e8219
2 changed files with 33 additions and 6 deletions
  1. 18 6
      backend/main.py
  2. 15 0
      backend/utils/misc.py

+ 18 - 6
backend/main.py

@@ -79,6 +79,7 @@ from utils.task import (
 from utils.misc import (
     get_last_user_message,
     add_or_update_system_message,
+    prepend_to_first_user_message_content,
     parse_duration,
 )
 
@@ -686,12 +687,23 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             if len(contexts) > 0:
                 context_string = "/n".join(contexts).strip()
                 prompt = get_last_user_message(body["messages"])
-                body["messages"] = add_or_update_system_message(
-                    rag_template(
-                        rag_app.state.config.RAG_TEMPLATE, context_string, prompt
-                    ),
-                    body["messages"],
-                )
+
+                # Workaround for Ollama 2.0+ system prompt issue
+                # TODO: replace with add_or_update_system_message
+                if model["owned_by"] == "ollama":
+                    body["messages"] = prepend_to_first_user_message_content(
+                        rag_template(
+                            rag_app.state.config.RAG_TEMPLATE, context_string, prompt
+                        ),
+                        body["messages"],
+                    )
+                else:
+                    body["messages"] = add_or_update_system_message(
+                        rag_template(
+                            rag_app.state.config.RAG_TEMPLATE, context_string, prompt
+                        ),
+                        body["messages"],
+                    )
 
             # If there are citations, add them to the data_items
             if len(citations) > 0:

+ 15 - 0
backend/utils/misc.py

@@ -53,6 +53,21 @@ def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
     return get_system_message(messages), remove_system_message(messages)
 
 
+def prepend_to_first_user_message_content(
+    content: str, messages: List[dict]
+) -> List[dict]:
+    for message in messages:
+        if message["role"] == "user":
+            if isinstance(message["content"], list):
+                for item in message["content"]:
+                    if item["type"] == "text":
+                        item["text"] = f"{content}\n{item['text']}"
+            else:
+                message["content"] = f"{content}\n{message['content']}"
+            break
+    return messages
+
+
 def add_or_update_system_message(content: str, messages: List[dict]):
     """
     Adds a new system message at the beginning of the messages list