Browse Source

fix: separate /embed and /embedding ollama endpoint

Timothy J. Baek 8 months ago
parent
commit
f1fae805a2
1 changed files with 50 additions and 1 deletions
  1. 50 1
      backend/open_webui/apps/ollama/main.py

+ 50 - 1
backend/open_webui/apps/ollama/main.py

@@ -545,6 +545,55 @@ class GenerateEmbeddingsForm(BaseModel):
 
 
 @app.post("/api/embed")
 @app.post("/api/embed")
 @app.post("/api/embed/{url_idx}")
 @app.post("/api/embed/{url_idx}")
+async def generate_embeddings(
+    form_data: GenerateEmbeddingsForm,
+    url_idx: Optional[int] = None,
+    user=Depends(get_verified_user),
+):
+    if url_idx is None:
+        model = form_data.model
+
+        if ":" not in model:
+            model = f"{model}:latest"
+
+        if model in app.state.MODELS:
+            url_idx = random.choice(app.state.MODELS[model]["urls"])
+        else:
+            raise HTTPException(
+                status_code=400,
+                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
+            )
+
+    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
+    log.info(f"url: {url}")
+
+    r = requests.request(
+        method="POST",
+        url=f"{url}/api/embed",
+        headers={"Content-Type": "application/json"},
+        data=form_data.model_dump_json(exclude_none=True).encode(),
+    )
+    try:
+        r.raise_for_status()
+
+        return r.json()
+    except Exception as e:
+        log.exception(e)
+        error_detail = "Open WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = r.json()
+                if "error" in res:
+                    error_detail = f"Ollama: {res['error']}"
+            except Exception:
+                error_detail = f"Ollama: {e}"
+
+        raise HTTPException(
+            status_code=r.status_code if r else 500,
+            detail=error_detail,
+        )
+
+
 @app.post("/api/embeddings")
 @app.post("/api/embeddings")
 @app.post("/api/embeddings/{url_idx}")
 @app.post("/api/embeddings/{url_idx}")
 async def generate_embeddings(
 async def generate_embeddings(
@@ -571,7 +620,7 @@ async def generate_embeddings(
 
 
     r = requests.request(
     r = requests.request(
         method="POST",
         method="POST",
-        url=f"{url}/api/embed",
+        url=f"{url}/api/embeddings",
         headers={"Content-Type": "application/json"},
         headers={"Content-Type": "application/json"},
         data=form_data.model_dump_json(exclude_none=True).encode(),
         data=form_data.model_dump_json(exclude_none=True).encode(),
     )
     )