Browse Source

Merge pull request #3177 from Yash-1511/main

feat: add tavily web search in web search provider
Timothy Jaeryang Baek 10 months ago
parent
commit
8db439a0d1

+ 1 - 1
README.md

@@ -37,7 +37,7 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-
 
 - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
 
-- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, and `Serply` and inject the results directly into your chat experience.
+- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo` and `TavilySearch` and inject the results directly into your chat experience.
 
 - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
 

+ 17 - 1
backend/apps/rag/main.py

@@ -73,6 +73,7 @@ from apps.rag.search.serper import search_serper
 from apps.rag.search.serpstack import search_serpstack
 from apps.rag.search.serply import search_serply
 from apps.rag.search.duckduckgo import search_duckduckgo
+from apps.rag.search.tavily import search_tavily
 
 from utils.misc import (
     calculate_sha256,
@@ -119,6 +120,7 @@ from config import (
     SERPSTACK_HTTPS,
     SERPER_API_KEY,
     SERPLY_API_KEY,
+    TAVILY_API_KEY,
     RAG_WEB_SEARCH_RESULT_COUNT,
     RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
     RAG_EMBEDDING_OPENAI_BATCH_SIZE,
@@ -172,6 +174,7 @@ app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
 app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
 app.state.config.SERPER_API_KEY = SERPER_API_KEY
 app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
+app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
 app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
 app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
 
@@ -400,6 +403,7 @@ async def get_rag_config(user=Depends(get_admin_user)):
                 "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                 "serper_api_key": app.state.config.SERPER_API_KEY,
                 "serply_api_key": app.state.config.SERPLY_API_KEY,
+                "tavily_api_key": app.state.config.TAVILY_API_KEY,
                 "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                 "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
             },
@@ -428,6 +432,7 @@ class WebSearchConfig(BaseModel):
     serpstack_https: Optional[bool] = None
     serper_api_key: Optional[str] = None
     serply_api_key: Optional[str] = None
+    tavily_api_key: Optional[str] = None
     result_count: Optional[int] = None
     concurrent_requests: Optional[int] = None
 
@@ -479,6 +484,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
         app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
         app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
         app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
+        app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
         app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
         app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
             form_data.web.search.concurrent_requests
@@ -508,6 +514,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
                 "serpstack_https": app.state.config.SERPSTACK_HTTPS,
                 "serper_api_key": app.state.config.SERPER_API_KEY,
                 "serply_api_key": app.state.config.SERPLY_API_KEY,
+                "tavily_api_key": app.state.config.TAVILY_API_KEY,
                 "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
                 "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
             },
@@ -756,7 +763,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
     - SERPSTACK_API_KEY
     - SERPER_API_KEY
     - SERPLY_API_KEY
-
+    - TAVILY_API_KEY
     Args:
         query (str): The query to search for
     """
@@ -825,6 +832,15 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
             raise Exception("No SERPLY_API_KEY found in environment variables")
     elif engine == "duckduckgo":
         return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
+    elif engine == "tavily":
+        if app.state.config.TAVILY_API_KEY:
+            return search_tavily(
+                app.state.config.TAVILY_API_KEY,
+                query,
+                app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+            )
+        else:
+            raise Exception("No TAVILY_API_KEY found in environment variables")
     else:
         raise Exception("No search engine API key found in environment variables")
 

+ 39 - 0
backend/apps/rag/search/tavily.py

@@ -0,0 +1,39 @@
+import logging
+
+import requests
+
+from apps.rag.search.main import SearchResult
+from config import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+
+def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
+    """Search using Tavily's Search API and return the results as a list of SearchResult objects.
+
+    Args:
+        api_key (str): A Tavily Search API key
+        query (str): The query to search for
+
+    Returns:
+        List[SearchResult]: A list of search results
+    """
+    url = "https://api.tavily.com/search"
+    data = {"query": query, "api_key": api_key}
+
+    response = requests.post(url, json=data)
+    response.raise_for_status()
+
+    json_response = response.json()
+
+    raw_search_results = json_response.get("results", [])
+
+    return [
+        SearchResult(
+            link=result["url"],
+            title=result.get("title", ""),
+            snippet=result.get("content"),
+        )
+        for result in raw_search_results[:count]
+    ]

+ 5 - 0
backend/config.py

@@ -943,6 +943,11 @@ SERPLY_API_KEY = PersistentConfig(
     os.getenv("SERPLY_API_KEY", ""),
 )
 
+TAVILY_API_KEY = PersistentConfig(
+    "TAVILY_API_KEY",
+    "rag.web.search.tavily_api_key",
+    os.getenv("TAVILY_API_KEY", ""),
+)
 
 RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
     "RAG_WEB_SEARCH_RESULT_COUNT",

+ 20 - 1
src/lib/components/admin/Settings/WebSearch.svelte

@@ -18,7 +18,8 @@
 		'serpstack',
 		'serper',
 		'serply',
-		'duckduckgo'
+		'duckduckgo',
+		'tavily'
 	];
 
 	let youtubeLanguage = 'en';
@@ -214,6 +215,24 @@
 									</div>
 								</div>
 							</div>
+						{:else if webConfig.search.engine === 'tavily'}
+							<div>
+								<div class=" self-center text-xs font-medium mb-1">
+									{$i18n.t('Tavily API Key')}
+								</div>
+
+								<div class="flex w-full">
+									<div class="flex-1">
+										<input
+											class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+											type="text"
+											placeholder={$i18n.t('Enter Tavily API Key')}
+											bind:value={webConfig.search.tavily_api_key}
+											autocomplete="off"
+										/>
+									</div>
+								</div>
+							</div>
 						{/if}
 					</div>
 				{/if}