Timothy J. Baek 10 月之前
父节点
当前提交
1c4e7f0324
共有 2 个文件被更改,包括 169 次插入170 次删除
  1. 159 0
      backend/apps/webui/main.py
  2. 10 170
      backend/main.py

+ 159 - 0
backend/apps/webui/main.py

@@ -1,5 +1,6 @@
 from fastapi import FastAPI, Depends
 from fastapi.routing import APIRoute
+from fastapi.responses import StreamingResponse
 from fastapi.middleware.cors import CORSMiddleware
 from apps.webui.routers import (
     auths,
@@ -17,6 +18,7 @@ from apps.webui.routers import (
 )
 from apps.webui.models.functions import Functions
 from apps.webui.utils import load_function_module_by_id
+from utils.misc import stream_message_template
 
 from config import (
     WEBUI_BUILD_HASH,
@@ -37,6 +39,14 @@ from config import (
     AppConfig,
 )
 
+import inspect
+import uuid
+import time
+import json
+
+from typing import Iterator, Generator
+from pydantic import BaseModel
+
 app = FastAPI()
 
 origins = ["*"]
@@ -166,3 +176,152 @@ async def get_pipe_models():
             )
 
     return pipe_models
+
+
+async def generate_function_chat_completion(form_data, user):
+    async def job():
+        pipe_id = form_data["model"]
+        if "." in pipe_id:
+            pipe_id, sub_pipe_id = pipe_id.split(".", 1)
+        print(pipe_id)
+
+        # Check if function is already loaded
+        if pipe_id not in app.state.FUNCTIONS:
+            function_module, function_type, frontmatter = load_function_module_by_id(
+                pipe_id
+            )
+            app.state.FUNCTIONS[pipe_id] = function_module
+        else:
+            function_module = app.state.FUNCTIONS[pipe_id]
+
+        if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
+
+            valves = Functions.get_function_valves_by_id(pipe_id)
+            function_module.valves = function_module.Valves(
+                **(valves if valves else {})
+            )
+
+        pipe = function_module.pipe
+
+        # Get the signature of the function
+        sig = inspect.signature(pipe)
+        params = {"body": form_data}
+
+        if "__user__" in sig.parameters:
+            __user__ = {
+                "id": user.id,
+                "email": user.email,
+                "name": user.name,
+                "role": user.role,
+            }
+
+            try:
+                if hasattr(function_module, "UserValves"):
+                    __user__["valves"] = function_module.UserValves(
+                        **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
+                    )
+            except Exception as e:
+                print(e)
+
+            params = {**params, "__user__": __user__}
+
+        if form_data["stream"]:
+
+            async def stream_content():
+                try:
+                    if inspect.iscoroutinefunction(pipe):
+                        res = await pipe(**params)
+                    else:
+                        res = pipe(**params)
+                except Exception as e:
+                    print(f"Error: {e}")
+                    yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
+                    return
+
+                if isinstance(res, str):
+                    message = stream_message_template(form_data["model"], res)
+                    yield f"data: {json.dumps(message)}\n\n"
+
+                if isinstance(res, Iterator):
+                    for line in res:
+                        if isinstance(line, BaseModel):
+                            line = line.model_dump_json()
+                            line = f"data: {line}"
+                        try:
+                            line = line.decode("utf-8")
+                        except:
+                            pass
+
+                        if line.startswith("data:"):
+                            yield f"{line}\n\n"
+                        else:
+                            line = stream_message_template(form_data["model"], line)
+                            yield f"data: {json.dumps(line)}\n\n"
+
+                if isinstance(res, str) or isinstance(res, Generator):
+                    finish_message = {
+                        "id": f"{form_data['model']}-{str(uuid.uuid4())}",
+                        "object": "chat.completion.chunk",
+                        "created": int(time.time()),
+                        "model": form_data["model"],
+                        "choices": [
+                            {
+                                "index": 0,
+                                "delta": {},
+                                "logprobs": None,
+                                "finish_reason": "stop",
+                            }
+                        ],
+                    }
+
+                    yield f"data: {json.dumps(finish_message)}\n\n"
+                    yield f"data: [DONE]"
+
+            return StreamingResponse(stream_content(), media_type="text/event-stream")
+        else:
+
+            try:
+                if inspect.iscoroutinefunction(pipe):
+                    res = await pipe(**params)
+                else:
+                    res = pipe(**params)
+            except Exception as e:
+                print(f"Error: {e}")
+                return {"error": {"detail": str(e)}}
+
+            if inspect.iscoroutinefunction(pipe):
+                res = await pipe(**params)
+            else:
+                res = pipe(**params)
+
+            if isinstance(res, dict):
+                return res
+            elif isinstance(res, BaseModel):
+                return res.model_dump()
+            else:
+                message = ""
+                if isinstance(res, str):
+                    message = res
+                if isinstance(res, Generator):
+                    for stream in res:
+                        message = f"{message}{stream}"
+
+                return {
+                    "id": f"{form_data['model']}-{str(uuid.uuid4())}",
+                    "object": "chat.completion",
+                    "created": int(time.time()),
+                    "model": form_data["model"],
+                    "choices": [
+                        {
+                            "index": 0,
+                            "message": {
+                                "role": "assistant",
+                                "content": message,
+                            },
+                            "logprobs": None,
+                            "finish_reason": "stop",
+                        }
+                    ],
+                }
+
+    return await job()

+ 10 - 170
backend/main.py

@@ -43,7 +43,11 @@ from apps.openai.main import (
 from apps.audio.main import app as audio_app
 from apps.images.main import app as images_app
 from apps.rag.main import app as rag_app
-from apps.webui.main import app as webui_app, get_pipe_models
+from apps.webui.main import (
+    app as webui_app,
+    get_pipe_models,
+    generate_function_chat_completion,
+)
 
 
 from pydantic import BaseModel
@@ -228,10 +232,7 @@ async def get_function_call_response(
 
     response = None
     try:
-        if model["owned_by"] == "ollama":
-            response = await generate_ollama_chat_completion(payload, user=user)
-        else:
-            response = await generate_openai_chat_completion(payload, user=user)
+        response = await generate_chat_completions(form_data=payload, user=user)
 
         content = None
 
@@ -900,159 +901,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
 
     pipe = model.get("pipe")
     if pipe:
-
-        async def job():
-            pipe_id = form_data["model"]
-            if "." in pipe_id:
-                pipe_id, sub_pipe_id = pipe_id.split(".", 1)
-            print(pipe_id)
-
-            # Check if function is already loaded
-            if pipe_id not in webui_app.state.FUNCTIONS:
-                function_module, function_type, frontmatter = (
-                    load_function_module_by_id(pipe_id)
-                )
-                webui_app.state.FUNCTIONS[pipe_id] = function_module
-            else:
-                function_module = webui_app.state.FUNCTIONS[pipe_id]
-
-            if hasattr(function_module, "valves") and hasattr(
-                function_module, "Valves"
-            ):
-
-                valves = Functions.get_function_valves_by_id(pipe_id)
-                function_module.valves = function_module.Valves(
-                    **(valves if valves else {})
-                )
-
-            pipe = function_module.pipe
-
-            # Get the signature of the function
-            sig = inspect.signature(pipe)
-            params = {"body": form_data}
-
-            if "__user__" in sig.parameters:
-                __user__ = {
-                    "id": user.id,
-                    "email": user.email,
-                    "name": user.name,
-                    "role": user.role,
-                }
-
-                try:
-                    if hasattr(function_module, "UserValves"):
-                        __user__["valves"] = function_module.UserValves(
-                            **Functions.get_user_valves_by_id_and_user_id(
-                                pipe_id, user.id
-                            )
-                        )
-                except Exception as e:
-                    print(e)
-
-                params = {**params, "__user__": __user__}
-
-            if form_data["stream"]:
-
-                async def stream_content():
-                    try:
-                        if inspect.iscoroutinefunction(pipe):
-                            res = await pipe(**params)
-                        else:
-                            res = pipe(**params)
-                    except Exception as e:
-                        print(f"Error: {e}")
-                        yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
-                        return
-
-                    if isinstance(res, str):
-                        message = stream_message_template(form_data["model"], res)
-                        yield f"data: {json.dumps(message)}\n\n"
-
-                    if isinstance(res, Iterator):
-                        for line in res:
-                            if isinstance(line, BaseModel):
-                                line = line.model_dump_json()
-                                line = f"data: {line}"
-                            try:
-                                line = line.decode("utf-8")
-                            except:
-                                pass
-
-                            if line.startswith("data:"):
-                                yield f"{line}\n\n"
-                            else:
-                                line = stream_message_template(form_data["model"], line)
-                                yield f"data: {json.dumps(line)}\n\n"
-
-                    if isinstance(res, str) or isinstance(res, Generator):
-                        finish_message = {
-                            "id": f"{form_data['model']}-{str(uuid.uuid4())}",
-                            "object": "chat.completion.chunk",
-                            "created": int(time.time()),
-                            "model": form_data["model"],
-                            "choices": [
-                                {
-                                    "index": 0,
-                                    "delta": {},
-                                    "logprobs": None,
-                                    "finish_reason": "stop",
-                                }
-                            ],
-                        }
-
-                        yield f"data: {json.dumps(finish_message)}\n\n"
-                        yield f"data: [DONE]"
-
-                return StreamingResponse(
-                    stream_content(), media_type="text/event-stream"
-                )
-            else:
-
-                try:
-                    if inspect.iscoroutinefunction(pipe):
-                        res = await pipe(**params)
-                    else:
-                        res = pipe(**params)
-                except Exception as e:
-                    print(f"Error: {e}")
-                    return {"error": {"detail": str(e)}}
-
-                if inspect.iscoroutinefunction(pipe):
-                    res = await pipe(**params)
-                else:
-                    res = pipe(**params)
-
-                if isinstance(res, dict):
-                    return res
-                elif isinstance(res, BaseModel):
-                    return res.model_dump()
-                else:
-                    message = ""
-                    if isinstance(res, str):
-                        message = res
-                    if isinstance(res, Generator):
-                        for stream in res:
-                            message = f"{message}{stream}"
-
-                    return {
-                        "id": f"{form_data['model']}-{str(uuid.uuid4())}",
-                        "object": "chat.completion",
-                        "created": int(time.time()),
-                        "model": form_data["model"],
-                        "choices": [
-                            {
-                                "index": 0,
-                                "message": {
-                                    "role": "assistant",
-                                    "content": message,
-                                },
-                                "logprobs": None,
-                                "finish_reason": "stop",
-                            }
-                        ],
-                    }
-
-        return await job()
+        return await generate_function_chat_completion(form_data, user=user)
     if model["owned_by"] == "ollama":
         return await generate_ollama_chat_completion(form_data, user=user)
     else:
@@ -1334,10 +1183,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
             content={"detail": e.args[1]},
         )
 
-    if model["owned_by"] == "ollama":
-        return await generate_ollama_chat_completion(payload, user=user)
-    else:
-        return await generate_openai_chat_completion(payload, user=user)
+    return await generate_chat_completions(form_data=payload, user=user)
 
 
 @app.post("/api/task/query/completions")
@@ -1397,10 +1243,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
             content={"detail": e.args[1]},
         )
 
-    if model["owned_by"] == "ollama":
-        return await generate_ollama_chat_completion(payload, user=user)
-    else:
-        return await generate_openai_chat_completion(payload, user=user)
+    return await generate_chat_completions(form_data=payload, user=user)
 
 
 @app.post("/api/task/emoji/completions")
@@ -1464,10 +1307,7 @@ Message: """{{prompt}}"""
             content={"detail": e.args[1]},
         )
 
-    if model["owned_by"] == "ollama":
-        return await generate_ollama_chat_completion(payload, user=user)
-    else:
-        return await generate_openai_chat_completion(payload, user=user)
+    return await generate_chat_completions(form_data=payload, user=user)
 
 
 @app.post("/api/task/tools/completions")