Timothy J. Baek 10 mesi fa
parent
commit
de26a78a16
1 ha cambiato i file con 16 aggiunte e 6 eliminazioni
  1. 16 6
      backend/main.py

+ 16 - 6
backend/main.py

@@ -42,7 +42,7 @@ from apps.openai.main import (
 from apps.audio.main import app as audio_app
 from apps.images.main import app as images_app
 from apps.rag.main import app as rag_app
-from apps.webui.main import app as webui_app
+from apps.webui.main import app as webui_app, get_pipe_models
 
 
 from pydantic import BaseModel
@@ -448,10 +448,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
                     if citations and data.get("citations"):
                         data_items.append({"citations": citations})
-                        del data["citations"]
 
                 del data["files"]
 
+            if data.get("citations"):
+                del data["citations"]
+
             if context != "":
                 system_prompt = rag_template(
                     rag_app.state.config.RAG_TEMPLATE, context, prompt
@@ -691,17 +693,18 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
 
 
 async def get_all_models():
+    pipe_models = []
     openai_models = []
     ollama_models = []
 
+    pipe_models = await get_pipe_models()
+
     if app.state.config.ENABLE_OPENAI_API:
         openai_models = await get_openai_models()
-
         openai_models = openai_models["data"]
 
     if app.state.config.ENABLE_OLLAMA_API:
         ollama_models = await get_ollama_models()
-
         ollama_models = [
             {
                 "id": model["model"],
@@ -714,9 +717,9 @@ async def get_all_models():
             for model in ollama_models["models"]
         ]
 
-    models = openai_models + ollama_models
-    custom_models = Models.get_all_models()
+    models = pipe_models + openai_models + ollama_models
 
+    custom_models = Models.get_all_models()
     for custom_model in custom_models:
         if custom_model.base_model_id == None:
             for model in models:
@@ -791,6 +794,13 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
     model = app.state.MODELS[model_id]
     print(model)
 
+    
+
+    if model.get('pipe') == True:
+        print('hi')
+    
+    
+    
     if model["owned_by"] == "ollama":
         return await generate_ollama_chat_completion(form_data, user=user)
     else: