Browse Source

feat: unified chat completions endpoint

Timothy J. Baek 10 months ago
parent
commit
84defafc14
3 changed files with 39 additions and 6 deletions
  1. 6 1
      backend/apps/ollama/main.py
  2. 32 2
      backend/main.py
  3. 1 3
      src/lib/components/chat/Chat.svelte

+ 6 - 1
backend/apps/ollama/main.py

@@ -849,9 +849,14 @@ async def generate_chat_completion(
 
 
 # TODO: we should update this part once Ollama supports other types
+class OpenAIChatMessageContent(BaseModel):
+    type: str
+    model_config = ConfigDict(extra="allow")
+
+
 class OpenAIChatMessage(BaseModel):
     role: str
-    content: str
+    content: Union[str, OpenAIChatMessageContent]
 
     model_config = ConfigDict(extra="allow")
 

+ 32 - 2
backend/main.py

@@ -25,8 +25,17 @@ from starlette.responses import StreamingResponse, Response
 
 
 from apps.socket.main import app as socket_app
-from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
-from apps.openai.main import app as openai_app, get_all_models as get_openai_models
+from apps.ollama.main import (
+    app as ollama_app,
+    OpenAIChatCompletionForm,
+    get_all_models as get_ollama_models,
+    generate_openai_chat_completion as generate_ollama_chat_completion,
+)
+from apps.openai.main import (
+    app as openai_app,
+    get_all_models as get_openai_models,
+    generate_chat_completion as generate_openai_chat_completion,
+)
 
 from apps.audio.main import app as audio_app
 from apps.images.main import app as images_app
@@ -485,6 +494,27 @@ async def get_models(user=Depends(get_verified_user)):
     return {"data": models}
 
 
+@app.post("/api/chat/completions")
+async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
+    model_id = form_data["model"]
+    if model_id not in app.state.MODELS:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Model not found",
+        )
+
+    model = app.state.MODELS[model_id]
+
+    print(model)
+
+    if model["owned_by"] == "ollama":
+        return await generate_ollama_chat_completion(
+            OpenAIChatCompletionForm(**form_data), user=user
+        )
+    else:
+        return await generate_openai_chat_completion(form_data, user=user)
+
+
 @app.post("/api/chat/completed")
 async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
     data = form_data

+ 1 - 3
src/lib/components/chat/Chat.svelte

@@ -1134,9 +1134,7 @@
 				titleModelId,
 				userPrompt,
 				$chatId,
-				titleModel?.owned_by === 'openai' ?? false
-					? `${OPENAI_API_BASE_URL}`
-					: `${OLLAMA_API_BASE_URL}/v1`
+				`${WEBUI_BASE_URL}/api`
 			);
 
 			return title;