浏览代码

refac: embeddings function

Timothy J. Baek 6 月之前
父节点
当前提交
451f1bae15
共有 1 个文件被更改,包括 29 次插入31 次删除
  1. 29 31
      backend/open_webui/apps/retrieval/utils.py

+ 29 - 31
backend/open_webui/apps/retrieval/utils.py

@@ -272,26 +272,26 @@ def get_embedding_function(
         return lambda query: embedding_function.encode(query).tolist()
     elif embedding_engine in ["ollama", "openai"]:
         if embedding_engine == "ollama":
-            func = lambda query: generate_ollama_embeddings(
+            func = lambda query: generate_embeddings(
                 model=embedding_model,
-                input=query,
+                text=query,
             )
         elif embedding_engine == "openai":
-            func = lambda query: generate_openai_embeddings(
+            func = lambda query: generate_embeddings(
                 model=embedding_model,
                 text=query,
                 key=openai_key,
                 url=openai_url,
             )
 
-        def generate_multiple(query, f):
+        def generate_multiple(query, func):
             if isinstance(query, list):
                 embeddings = []
                 for i in range(0, len(query), embedding_batch_size):
-                    embeddings.extend(f(query[i : i + embedding_batch_size]))
+                    embeddings.extend(func(query[i : i + embedding_batch_size]))
                 return embeddings
             else:
-                return f(query)
+                return func(query)
 
         return lambda query: generate_multiple(query, func)
 
@@ -438,20 +438,6 @@ def get_model_path(model: str, update_model: bool = False):
         return model
 
 
-def generate_openai_embeddings(
-    model: str,
-    text: Union[str, list[str]],
-    key: str,
-    url: str = "https://api.openai.com/v1",
-):
-    if isinstance(text, list):
-        embeddings = generate_openai_batch_embeddings(model, text, key, url)
-    else:
-        embeddings = generate_openai_batch_embeddings(model, [text], key, url)
-
-    return embeddings[0] if isinstance(text, str) else embeddings
-
-
 def generate_openai_batch_embeddings(
     model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
 ) -> Optional[list[list[float]]]:
@@ -475,19 +461,31 @@ def generate_openai_batch_embeddings(
         return None
 
 
-def generate_ollama_embeddings(
-    model: str, input: list[str]
-) -> Optional[list[list[float]]]:
-    if isinstance(input, list):
-        embeddings = generate_ollama_batch_embeddings(
-            GenerateEmbedForm(**{"model": model, "input": input})
-        )
-    else:
-        embeddings = generate_ollama_batch_embeddings(
-            GenerateEmbedForm(**{"model": model, "input": [input]})
+def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
+    if engine == "ollama":
+        if isinstance(text, list):
+            embeddings = generate_ollama_batch_embeddings(
+                GenerateEmbedForm(**{"model": model, "input": text})
+            )
+        else:
+            embeddings = generate_ollama_batch_embeddings(
+                GenerateEmbedForm(**{"model": model, "input": [text]})
+            )
+        return (
+            embeddings["embeddings"][0]
+            if isinstance(text, str)
+            else embeddings["embeddings"]
         )
+    elif engine == "openai":
+        key = kwargs.get("key", "")
+        url = kwargs.get("url", "https://api.openai.com/v1")
+
+        if isinstance(text, list):
+            embeddings = generate_openai_batch_embeddings(model, text, key, url)
+        else:
+            embeddings = generate_openai_batch_embeddings(model, [text], key, url)
 
-    return embeddings["embeddings"]
+        return embeddings[0] if isinstance(text, str) else embeddings
 
 
 import operator