Browse Source

Merge pull request #4237 from michaelpoluektov/refactor-webui-main

refactor: Simplify functions
Timothy Jaeryang Baek 9 months ago
parent
commit
a9a6ed8b71
5 changed files with 284 additions and 389 deletions
  1. 2 4
      backend/apps/socket/main.py
  2. 225 302
      backend/apps/webui/main.py
  3. 2 10
      backend/apps/webui/models/models.py
  4. 16 41
      backend/main.py
  5. 39 32
      backend/utils/misc.py

+ 2 - 4
backend/apps/socket/main.py

@@ -52,7 +52,6 @@ async def user_join(sid, data):
             user = Users.get_user_by_id(data["id"])
 
         if user:
-
             SESSION_POOL[sid] = user.id
             if user.id in USER_POOL:
                 USER_POOL[user.id].append(sid)
@@ -80,7 +79,6 @@ def get_models_in_use():
 
 @sio.on("usage")
 async def usage(sid, data):
-
     model_id = data["model"]
 
     # Cancel previous callback if there is one
@@ -139,7 +137,7 @@ async def disconnect(sid):
         print(f"Unknown session ID {sid} disconnected")
 
 
-async def get_event_emitter(request_info):
+def get_event_emitter(request_info):
     async def __event_emitter__(event_data):
         await sio.emit(
             "chat-events",
@@ -154,7 +152,7 @@ async def get_event_emitter(request_info):
     return __event_emitter__
 
 
-async def get_event_call(request_info):
+def get_event_call(request_info):
     async def __event_call__(event_data):
         response = await sio.call(
             "chat-events",

+ 225 - 302
backend/apps/webui/main.py

@@ -1,9 +1,6 @@
-from fastapi import FastAPI, Depends
-from fastapi.routing import APIRoute
+from fastapi import FastAPI
 from fastapi.responses import StreamingResponse
 from fastapi.middleware.cors import CORSMiddleware
-from starlette.middleware.sessions import SessionMiddleware
-from sqlalchemy.orm import Session
 from apps.webui.routers import (
     auths,
     users,
@@ -22,12 +19,15 @@ from apps.webui.models.functions import Functions
 from apps.webui.models.models import Models
 from apps.webui.utils import load_function_module_by_id
 
-from utils.misc import stream_message_template
+from utils.misc import (
+    openai_chat_chunk_message_template,
+    openai_chat_completion_message_template,
+    add_or_update_system_message,
+)
 from utils.task import prompt_template
 
 
 from config import (
-    WEBUI_BUILD_HASH,
     SHOW_ADMIN_DETAILS,
     ADMIN_EMAIL,
     WEBUI_AUTH,
@@ -51,11 +51,9 @@ from config import (
 from apps.socket.main import get_event_call, get_event_emitter
 
 import inspect
-import uuid
-import time
 import json
 
-from typing import Iterator, Generator, AsyncGenerator, Optional
+from typing import Iterator, Generator, AsyncGenerator
 from pydantic import BaseModel
 
 app = FastAPI()
@@ -127,60 +125,58 @@ async def get_status():
     }
 
 
+def get_function_module(pipe_id: str):
+    # Check if function is already loaded
+    if pipe_id not in app.state.FUNCTIONS:
+        function_module, _, _ = load_function_module_by_id(pipe_id)
+        app.state.FUNCTIONS[pipe_id] = function_module
+    else:
+        function_module = app.state.FUNCTIONS[pipe_id]
+
+    if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
+        valves = Functions.get_function_valves_by_id(pipe_id)
+        function_module.valves = function_module.Valves(**(valves if valves else {}))
+    return function_module
+
+
 async def get_pipe_models():
     pipes = Functions.get_functions_by_type("pipe", active_only=True)
     pipe_models = []
 
     for pipe in pipes:
-        # Check if function is already loaded
-        if pipe.id not in app.state.FUNCTIONS:
-            function_module, function_type, frontmatter = load_function_module_by_id(
-                pipe.id
-            )
-            app.state.FUNCTIONS[pipe.id] = function_module
-        else:
-            function_module = app.state.FUNCTIONS[pipe.id]
-
-        if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
-            valves = Functions.get_function_valves_by_id(pipe.id)
-            function_module.valves = function_module.Valves(
-                **(valves if valves else {})
-            )
+        function_module = get_function_module(pipe.id)
 
         # Check if function is a manifold
-        if hasattr(function_module, "type"):
-            if function_module.type == "manifold":
-                manifold_pipes = []
-
-                # Check if pipes is a function or a list
-                if callable(function_module.pipes):
-                    manifold_pipes = function_module.pipes()
-                else:
-                    manifold_pipes = function_module.pipes
-
-                for p in manifold_pipes:
-                    manifold_pipe_id = f'{pipe.id}.{p["id"]}'
-                    manifold_pipe_name = p["name"]
-
-                    if hasattr(function_module, "name"):
-                        manifold_pipe_name = (
-                            f"{function_module.name}{manifold_pipe_name}"
-                        )
-
-                    pipe_flag = {"type": pipe.type}
-                    if hasattr(function_module, "ChatValves"):
-                        pipe_flag["valves_spec"] = function_module.ChatValves.schema()
-
-                    pipe_models.append(
-                        {
-                            "id": manifold_pipe_id,
-                            "name": manifold_pipe_name,
-                            "object": "model",
-                            "created": pipe.created_at,
-                            "owned_by": "openai",
-                            "pipe": pipe_flag,
-                        }
-                    )
+        if hasattr(function_module, "pipes"):
+            manifold_pipes = []
+
+            # Check if pipes is a function or a list
+            if callable(function_module.pipes):
+                manifold_pipes = function_module.pipes()
+            else:
+                manifold_pipes = function_module.pipes
+
+            for p in manifold_pipes:
+                manifold_pipe_id = f'{pipe.id}.{p["id"]}'
+                manifold_pipe_name = p["name"]
+
+                if hasattr(function_module, "name"):
+                    manifold_pipe_name = f"{function_module.name}{manifold_pipe_name}"
+
+                pipe_flag = {"type": pipe.type}
+                if hasattr(function_module, "ChatValves"):
+                    pipe_flag["valves_spec"] = function_module.ChatValves.schema()
+
+                pipe_models.append(
+                    {
+                        "id": manifold_pipe_id,
+                        "name": manifold_pipe_name,
+                        "object": "model",
+                        "created": pipe.created_at,
+                        "owned_by": "openai",
+                        "pipe": pipe_flag,
+                    }
+                )
         else:
             pipe_flag = {"type": "pipe"}
             if hasattr(function_module, "ChatValves"):
@@ -200,284 +196,211 @@ async def get_pipe_models():
     return pipe_models
 
 
-async def generate_function_chat_completion(form_data, user):
-    model_id = form_data.get("model")
-    model_info = Models.get_model_by_id(model_id)
+async def execute_pipe(pipe, params):
+    if inspect.iscoroutinefunction(pipe):
+        return await pipe(**params)
+    else:
+        return pipe(**params)
 
-    metadata = None
-    if "metadata" in form_data:
-        metadata = form_data["metadata"]
-        del form_data["metadata"]
 
-    __event_emitter__ = None
-    __event_call__ = None
-    __task__ = None
+async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
+    if isinstance(res, str):
+        return res
+    if isinstance(res, Generator):
+        return "".join(map(str, res))
+    if isinstance(res, AsyncGenerator):
+        return "".join([str(stream) async for stream in res])
 
-    if metadata:
-        if (
-            metadata.get("session_id")
-            and metadata.get("chat_id")
-            and metadata.get("message_id")
-        ):
-            __event_emitter__ = await get_event_emitter(metadata)
-            __event_call__ = await get_event_call(metadata)
 
-        if metadata.get("task"):
-            __task__ = metadata.get("task")
+def process_line(form_data: dict, line):
+    if isinstance(line, BaseModel):
+        line = line.model_dump_json()
+        line = f"data: {line}"
+    if isinstance(line, dict):
+        line = f"data: {json.dumps(line)}"
 
-    if model_info:
-        if model_info.base_model_id:
-            form_data["model"] = model_info.base_model_id
+    try:
+        line = line.decode("utf-8")
+    except Exception:
+        pass
 
-        model_info.params = model_info.params.model_dump()
+    if line.startswith("data:"):
+        return f"{line}\n\n"
+    else:
+        line = openai_chat_chunk_message_template(form_data["model"], line)
+        return f"data: {json.dumps(line)}\n\n"
+
+
+def get_pipe_id(form_data: dict) -> str:
+    pipe_id = form_data["model"]
+    if "." in pipe_id:
+        pipe_id, _ = pipe_id.split(".", 1)
+    print(pipe_id)
+    return pipe_id
+
+
+def get_function_params(function_module, form_data, user, extra_params={}):
+    pipe_id = get_pipe_id(form_data)
+    # Get the signature of the function
+    sig = inspect.signature(function_module.pipe)
+    params = {"body": form_data}
+
+    for key, value in extra_params.items():
+        if key in sig.parameters:
+            params[key] = value
+
+    if "__user__" in sig.parameters:
+        __user__ = {
+            "id": user.id,
+            "email": user.email,
+            "name": user.name,
+            "role": user.role,
+        }
+
+        try:
+            if hasattr(function_module, "UserValves"):
+                __user__["valves"] = function_module.UserValves(
+                    **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
+                )
+        except Exception as e:
+            print(e)
 
-        if model_info.params:
-            if model_info.params.get("temperature", None) is not None:
-                form_data["temperature"] = float(model_info.params.get("temperature"))
+        params["__user__"] = __user__
+    return params
 
-            if model_info.params.get("top_p", None):
-                form_data["top_p"] = int(model_info.params.get("top_p", None))
 
-            if model_info.params.get("max_tokens", None):
-                form_data["max_tokens"] = int(model_info.params.get("max_tokens", None))
+# inplace function: form_data is modified
+def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
+    if not params:
+        return form_data
 
-            if model_info.params.get("frequency_penalty", None):
-                form_data["frequency_penalty"] = int(
-                    model_info.params.get("frequency_penalty", None)
-                )
+    mappings = {
+        "temperature": float,
+        "top_p": int,
+        "max_tokens": int,
+        "frequency_penalty": int,
+        "seed": lambda x: x,
+        "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
+    }
 
-            if model_info.params.get("seed", None):
-                form_data["seed"] = model_info.params.get("seed", None)
-
-            if model_info.params.get("stop", None):
-                form_data["stop"] = (
-                    [
-                        bytes(stop, "utf-8").decode("unicode_escape")
-                        for stop in model_info.params["stop"]
-                    ]
-                    if model_info.params.get("stop", None)
-                    else None
-                )
+    for key, cast_func in mappings.items():
+        if (value := params.get(key)) is not None:
+            form_data[key] = cast_func(value)
 
-        system = model_info.params.get("system", None)
-        if system:
-            system = prompt_template(
-                system,
-                **(
-                    {
-                        "user_name": user.name,
-                        "user_location": (
-                            user.info.get("location") if user.info else None
-                        ),
-                    }
-                    if user
-                    else {}
-                ),
-            )
-            # Check if the payload already has a system message
-            # If not, add a system message to the payload
-            if form_data.get("messages"):
-                for message in form_data["messages"]:
-                    if message.get("role") == "system":
-                        message["content"] = system + message["content"]
-                        break
-                else:
-                    form_data["messages"].insert(
-                        0,
-                        {
-                            "role": "system",
-                            "content": system,
-                        },
-                    )
+    return form_data
 
-    else:
-        pass
 
-    async def job():
-        pipe_id = form_data["model"]
-        if "." in pipe_id:
-            pipe_id, sub_pipe_id = pipe_id.split(".", 1)
-        print(pipe_id)
+# inplace function: form_data is modified
+def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
+    system = params.get("system", None)
+    if not system:
+        return form_data
 
-        # Check if function is already loaded
-        if pipe_id not in app.state.FUNCTIONS:
-            function_module, function_type, frontmatter = load_function_module_by_id(
-                pipe_id
-            )
-            app.state.FUNCTIONS[pipe_id] = function_module
-        else:
-            function_module = app.state.FUNCTIONS[pipe_id]
+    if user:
+        template_params = {
+            "user_name": user.name,
+            "user_location": user.info.get("location") if user.info else None,
+        }
+    else:
+        template_params = {}
+    system = prompt_template(system, **template_params)
+    form_data["messages"] = add_or_update_system_message(
+        system, form_data.get("messages", [])
+    )
+    return form_data
 
-        if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
 
-            valves = Functions.get_function_valves_by_id(pipe_id)
-            function_module.valves = function_module.Valves(
-                **(valves if valves else {})
-            )
+async def generate_function_chat_completion(form_data, user):
+    model_id = form_data.get("model")
+    model_info = Models.get_model_by_id(model_id)
+    metadata = form_data.pop("metadata", None)
 
-        pipe = function_module.pipe
+    __event_emitter__ = None
+    __event_call__ = None
+    __task__ = None
 
-        # Get the signature of the function
-        sig = inspect.signature(pipe)
-        params = {"body": form_data}
+    if metadata:
+        if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
+            __event_emitter__ = get_event_emitter(metadata)
+            __event_call__ = get_event_call(metadata)
+        __task__ = metadata.get("task", None)
 
-        if "__user__" in sig.parameters:
-            __user__ = {
-                "id": user.id,
-                "email": user.email,
-                "name": user.name,
-                "role": user.role,
-            }
+    if model_info:
+        if model_info.base_model_id:
+            form_data["model"] = model_info.base_model_id
 
-            try:
-                if hasattr(function_module, "UserValves"):
-                    __user__["valves"] = function_module.UserValves(
-                        **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
-                    )
-            except Exception as e:
-                print(e)
+        params = model_info.params.model_dump()
+        form_data = apply_model_params_to_body(params, form_data)
+        form_data = apply_model_system_prompt_to_body(params, form_data, user)
 
-            params = {**params, "__user__": __user__}
+    pipe_id = get_pipe_id(form_data)
+    function_module = get_function_module(pipe_id)
 
-        if "__event_emitter__" in sig.parameters:
-            params = {**params, "__event_emitter__": __event_emitter__}
+    pipe = function_module.pipe
+    params = get_function_params(
+        function_module,
+        form_data,
+        user,
+        {
+            "__event_emitter__": __event_emitter__,
+            "__event_call__": __event_call__,
+            "__task__": __task__,
+        },
+    )
 
-        if "__event_call__" in sig.parameters:
-            params = {**params, "__event_call__": __event_call__}
+    if form_data["stream"]:
 
-        if "__task__" in sig.parameters:
-            params = {**params, "__task__": __task__}
+        async def stream_content():
+            try:
+                res = await execute_pipe(pipe, params)
 
-        if form_data["stream"]:
+                # Directly return if the response is a StreamingResponse
+                if isinstance(res, StreamingResponse):
+                    async for data in res.body_iterator:
+                        yield data
+                    return
+                if isinstance(res, dict):
+                    yield f"data: {json.dumps(res)}\n\n"
+                    return
 
-            async def stream_content():
-                try:
-                    if inspect.iscoroutinefunction(pipe):
-                        res = await pipe(**params)
-                    else:
-                        res = pipe(**params)
+            except Exception as e:
+                print(f"Error: {e}")
+                yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
+                return
 
-                    # Directly return if the response is a StreamingResponse
-                    if isinstance(res, StreamingResponse):
-                        async for data in res.body_iterator:
-                            yield data
-                        return
-                    if isinstance(res, dict):
-                        yield f"data: {json.dumps(res)}\n\n"
-                        return
+            if isinstance(res, str):
+                message = openai_chat_chunk_message_template(form_data["model"], res)
+                yield f"data: {json.dumps(message)}\n\n"
 
-                except Exception as e:
-                    print(f"Error: {e}")
-                    yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
-                    return
+            if isinstance(res, Iterator):
+                for line in res:
+                    yield process_line(form_data, line)
 
-                if isinstance(res, str):
-                    message = stream_message_template(form_data["model"], res)
-                    yield f"data: {json.dumps(message)}\n\n"
-
-                if isinstance(res, Iterator):
-                    for line in res:
-                        if isinstance(line, BaseModel):
-                            line = line.model_dump_json()
-                            line = f"data: {line}"
-                        if isinstance(line, dict):
-                            line = f"data: {json.dumps(line)}"
-
-                        try:
-                            line = line.decode("utf-8")
-                        except:
-                            pass
-
-                        if line.startswith("data:"):
-                            yield f"{line}\n\n"
-                        else:
-                            line = stream_message_template(form_data["model"], line)
-                            yield f"data: {json.dumps(line)}\n\n"
-
-                if isinstance(res, str) or isinstance(res, Generator):
-                    finish_message = {
-                        "id": f"{form_data['model']}-{str(uuid.uuid4())}",
-                        "object": "chat.completion.chunk",
-                        "created": int(time.time()),
-                        "model": form_data["model"],
-                        "choices": [
-                            {
-                                "index": 0,
-                                "delta": {},
-                                "logprobs": None,
-                                "finish_reason": "stop",
-                            }
-                        ],
-                    }
+            if isinstance(res, AsyncGenerator):
+                async for line in res:
+                    yield process_line(form_data, line)
 
-                    yield f"data: {json.dumps(finish_message)}\n\n"
-                    yield f"data: [DONE]"
-
-                if isinstance(res, AsyncGenerator):
-                    async for line in res:
-                        if isinstance(line, BaseModel):
-                            line = line.model_dump_json()
-                            line = f"data: {line}"
-                        if isinstance(line, dict):
-                            line = f"data: {json.dumps(line)}"
-
-                        try:
-                            line = line.decode("utf-8")
-                        except:
-                            pass
-
-                        if line.startswith("data:"):
-                            yield f"{line}\n\n"
-                        else:
-                            line = stream_message_template(form_data["model"], line)
-                            yield f"data: {json.dumps(line)}\n\n"
-
-            return StreamingResponse(stream_content(), media_type="text/event-stream")
-        else:
+            if isinstance(res, str) or isinstance(res, Generator):
+                finish_message = openai_chat_chunk_message_template(
+                    form_data["model"], ""
+                )
+                finish_message["choices"][0]["finish_reason"] = "stop"
+                yield f"data: {json.dumps(finish_message)}\n\n"
+                yield "data: [DONE]"
 
-            try:
-                if inspect.iscoroutinefunction(pipe):
-                    res = await pipe(**params)
-                else:
-                    res = pipe(**params)
+        return StreamingResponse(stream_content(), media_type="text/event-stream")
+    else:
+        try:
+            res = await execute_pipe(pipe, params)
 
-                if isinstance(res, StreamingResponse):
-                    return res
-            except Exception as e:
-                print(f"Error: {e}")
-                return {"error": {"detail": str(e)}}
+        except Exception as e:
+            print(f"Error: {e}")
+            return {"error": {"detail": str(e)}}
 
-            if isinstance(res, dict):
-                return res
-            elif isinstance(res, BaseModel):
-                return res.model_dump()
-            else:
-                message = ""
-                if isinstance(res, str):
-                    message = res
-                elif isinstance(res, Generator):
-                    for stream in res:
-                        message = f"{message}{stream}"
-                elif isinstance(res, AsyncGenerator):
-                    async for stream in res:
-                        message = f"{message}{stream}"
-
-                return {
-                    "id": f"{form_data['model']}-{str(uuid.uuid4())}",
-                    "object": "chat.completion",
-                    "created": int(time.time()),
-                    "model": form_data["model"],
-                    "choices": [
-                        {
-                            "index": 0,
-                            "message": {
-                                "role": "assistant",
-                                "content": message,
-                            },
-                            "logprobs": None,
-                            "finish_reason": "stop",
-                        }
-                    ],
-                }
+        if isinstance(res, StreamingResponse) or isinstance(res, dict):
+            return res
+        if isinstance(res, BaseModel):
+            return res.model_dump()
 
-    return await job()
+        message = await get_message_content(res)
+        return openai_chat_completion_message_template(form_data["model"], message)

+ 2 - 10
backend/apps/webui/models/models.py

@@ -1,13 +1,11 @@
-import json
 import logging
-from typing import Optional
+from typing import Optional, List
 
 from pydantic import BaseModel, ConfigDict
-from sqlalchemy import String, Column, BigInteger, Text
+from sqlalchemy import Column, BigInteger, Text
 
 from apps.webui.internal.db import Base, JSONField, get_db
 
-from typing import List, Union, Optional
 from config import SRC_LOG_LEVELS
 
 import time
@@ -113,7 +111,6 @@ class ModelForm(BaseModel):
 
 
 class ModelsTable:
-
     def insert_new_model(
         self, form_data: ModelForm, user_id: str
     ) -> Optional[ModelModel]:
@@ -126,9 +123,7 @@ class ModelsTable:
             }
         )
         try:
-
             with get_db() as db:
-
                 result = Model(**model.model_dump())
                 db.add(result)
                 db.commit()
@@ -144,13 +139,11 @@ class ModelsTable:
 
     def get_all_models(self) -> List[ModelModel]:
         with get_db() as db:
-
             return [ModelModel.model_validate(model) for model in db.query(Model).all()]
 
     def get_model_by_id(self, id: str) -> Optional[ModelModel]:
         try:
             with get_db() as db:
-
                 model = db.get(Model, id)
                 return ModelModel.model_validate(model)
         except:
@@ -178,7 +171,6 @@ class ModelsTable:
     def delete_model_by_id(self, id: str) -> bool:
         try:
             with get_db() as db:
-
                 db.query(Model).filter_by(id=id).delete()
                 db.commit()
 

+ 16 - 41
backend/main.py

@@ -13,8 +13,6 @@ import aiohttp
 import requests
 import mimetypes
 import shutil
-import os
-import uuid
 import inspect
 
 from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
@@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware
 from starlette.responses import StreamingResponse, Response, RedirectResponse
 
 
-from apps.socket.main import sio, app as socket_app, get_event_emitter, get_event_call
+from apps.socket.main import app as socket_app, get_event_emitter, get_event_call
 from apps.ollama.main import (
     app as ollama_app,
     get_all_models as get_ollama_models,
@@ -619,32 +617,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                     content={"detail": str(e)},
                 )
 
-            # Extract valves from the request body
-            valves = None
-            if "valves" in body:
-                valves = body["valves"]
-                del body["valves"]
-
-            # Extract session_id, chat_id and message_id from the request body
-            session_id = None
-            if "session_id" in body:
-                session_id = body["session_id"]
-                del body["session_id"]
-            chat_id = None
-            if "chat_id" in body:
-                chat_id = body["chat_id"]
-                del body["chat_id"]
-            message_id = None
-            if "id" in body:
-                message_id = body["id"]
-                del body["id"]
-
-            __event_emitter__ = await get_event_emitter(
-                {"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
-            )
-            __event_call__ = await get_event_call(
-                {"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
-            )
+            metadata = {
+                "chat_id": body.pop("chat_id", None),
+                "message_id": body.pop("id", None),
+                "session_id": body.pop("session_id", None),
+                "valves": body.pop("valves", None),
+            }
+
+            __event_emitter__ = get_event_emitter(metadata)
+            __event_call__ = get_event_call(metadata)
 
             # Initialize data_items to store additional data to be sent to the client
             data_items = []
@@ -709,13 +690,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             if len(citations) > 0:
                 data_items.append({"citations": citations})
 
-            body["metadata"] = {
-                "session_id": session_id,
-                "chat_id": chat_id,
-                "message_id": message_id,
-                "valves": valves,
-            }
-
+            body["metadata"] = metadata
             modified_body_bytes = json.dumps(body).encode("utf-8")
             # Replace the request body with the modified one
             request._body = modified_body_bytes
@@ -1191,13 +1166,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
                             status_code=r.status_code,
                             content=res,
                         )
-                except:
+                except Exception:
                     pass
 
             else:
                 pass
 
-    __event_emitter__ = await get_event_emitter(
+    __event_emitter__ = get_event_emitter(
         {
             "chat_id": data["chat_id"],
             "message_id": data["id"],
@@ -1205,7 +1180,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
         }
     )
 
-    __event_call__ = await get_event_call(
+    __event_call__ = get_event_call(
         {
             "chat_id": data["chat_id"],
             "message_id": data["id"],
@@ -1334,14 +1309,14 @@ async def chat_completed(
         )
     model = app.state.MODELS[model_id]
 
-    __event_emitter__ = await get_event_emitter(
+    __event_emitter__ = get_event_emitter(
         {
             "chat_id": data["chat_id"],
             "message_id": data["id"],
             "session_id": data["session_id"],
         }
     )
-    __event_call__ = await get_event_call(
+    __event_call__ = get_event_call(
         {
             "chat_id": data["chat_id"],
             "message_id": data["id"],

+ 39 - 32
backend/utils/misc.py

@@ -1,6 +1,5 @@
 from pathlib import Path
 import hashlib
-import json
 import re
 from datetime import timedelta
 from typing import Optional, List, Tuple
@@ -8,37 +7,39 @@ import uuid
 import time
 
 
-def get_last_user_message_item(messages: List[dict]) -> str:
+def get_last_user_message_item(messages: List[dict]) -> Optional[dict]:
     for message in reversed(messages):
         if message["role"] == "user":
             return message
     return None
 
 
-def get_last_user_message(messages: List[dict]) -> str:
-    message = get_last_user_message_item(messages)
-
-    if message is not None:
-        if isinstance(message["content"], list):
-            for item in message["content"]:
-                if item["type"] == "text":
-                    return item["text"]
+def get_content_from_message(message: dict) -> Optional[str]:
+    if isinstance(message["content"], list):
+        for item in message["content"]:
+            if item["type"] == "text":
+                return item["text"]
+    else:
         return message["content"]
     return None
 
 
-def get_last_assistant_message(messages: List[dict]) -> str:
+def get_last_user_message(messages: List[dict]) -> Optional[str]:
+    message = get_last_user_message_item(messages)
+    if message is None:
+        return None
+
+    return get_content_from_message(message)
+
+
+def get_last_assistant_message(messages: List[dict]) -> Optional[str]:
     for message in reversed(messages):
         if message["role"] == "assistant":
-            if isinstance(message["content"], list):
-                for item in message["content"]:
-                    if item["type"] == "text":
-                        return item["text"]
-            return message["content"]
+            return get_content_from_message(message)
     return None
 
 
-def get_system_message(messages: List[dict]) -> dict:
+def get_system_message(messages: List[dict]) -> Optional[dict]:
     for message in messages:
         if message["role"] == "system":
             return message
@@ -49,7 +50,7 @@ def remove_system_message(messages: List[dict]) -> List[dict]:
     return [message for message in messages if message["role"] != "system"]
 
 
-def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
+def pop_system_message(messages: List[dict]) -> Tuple[Optional[dict], List[dict]]:
     return get_system_message(messages), remove_system_message(messages)
 
 
@@ -87,23 +88,29 @@ def add_or_update_system_message(content: str, messages: List[dict]):
     return messages
 
 
-def stream_message_template(model: str, message: str):
+def openai_chat_message_template(model: str):
     return {
         "id": f"{model}-{str(uuid.uuid4())}",
-        "object": "chat.completion.chunk",
         "created": int(time.time()),
         "model": model,
-        "choices": [
-            {
-                "index": 0,
-                "delta": {"content": message},
-                "logprobs": None,
-                "finish_reason": None,
-            }
-        ],
+        "choices": [{"index": 0, "logprobs": None, "finish_reason": None}],
     }
 
 
+def openai_chat_chunk_message_template(model: str, message: str):
+    template = openai_chat_message_template(model)
+    template["object"] = "chat.completion.chunk"
+    template["choices"][0]["delta"] = {"content": message}
+    return template
+
+
+def openai_chat_completion_message_template(model: str, message: str):
+    template = openai_chat_message_template(model)
+    template["object"] = "chat.completion"
+    template["choices"][0]["message"] = {"content": message, "role": "assistant"}
+    template["choices"][0]["finish_reason"] = "stop"
+
+
 def get_gravatar_url(email):
     # Trim leading and trailing whitespace from
     # an email address and force all characters
@@ -174,7 +181,7 @@ def extract_folders_after_data_docs(path):
     tags = []
 
     folders = parts[index_docs:-1]
-    for idx, part in enumerate(folders):
+    for idx, _ in enumerate(folders):
         tags.append("/".join(folders[: idx + 1]))
 
     return tags
@@ -270,11 +277,11 @@ def parse_ollama_modelfile(model_text):
             value = param_match.group(1)
 
             try:
-                if param_type == int:
+                if param_type is int:
                     value = int(value)
-                elif param_type == float:
+                elif param_type is float:
                     value = float(value)
-                elif param_type == bool:
+                elif param_type is bool:
                     value = value.lower() == "true"
             except Exception as e:
                 print(e)