فهرست منبع

refac: apps/openai/main.py and utils

Michael Poluektov 9 ماه پیش
والد
کامیت
12c21fac22
8فایلهای تغییر یافته به همراه149 افزوده شده و 231 حذف شده
  1. 60 134
      backend/apps/openai/main.py
  2. 15 12
      backend/apps/socket/main.py
  3. 2 43
      backend/apps/webui/main.py
  4. 13 18
      backend/apps/webui/routers/tools.py
  5. 12 16
      backend/main.py
  6. 43 0
      backend/utils/misc.py
  7. 1 2
      backend/utils/task.py
  8. 3 6
      backend/utils/utils.py

+ 60 - 134
backend/apps/openai/main.py

@@ -1,6 +1,6 @@
-from fastapi import FastAPI, Request, Response, HTTPException, Depends
+from fastapi import FastAPI, Request, HTTPException, Depends
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
+from fastapi.responses import StreamingResponse, FileResponse
 
 
 import requests
 import requests
 import aiohttp
 import aiohttp
@@ -12,16 +12,12 @@ from pydantic import BaseModel
 from starlette.background import BackgroundTask
 from starlette.background import BackgroundTask
 
 
 from apps.webui.models.models import Models
 from apps.webui.models.models import Models
-from apps.webui.models.users import Users
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 from utils.utils import (
 from utils.utils import (
-    decode_token,
-    get_verified_user,
     get_verified_user,
     get_verified_user,
     get_admin_user,
     get_admin_user,
 )
 )
-from utils.task import prompt_template
-from utils.misc import add_or_update_system_message
+from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body
 
 
 from config import (
 from config import (
     SRC_LOG_LEVELS,
     SRC_LOG_LEVELS,
@@ -69,8 +65,6 @@ app.state.MODELS = {}
 async def check_url(request: Request, call_next):
 async def check_url(request: Request, call_next):
     if len(app.state.MODELS) == 0:
     if len(app.state.MODELS) == 0:
         await get_all_models()
         await get_all_models()
-    else:
-        pass
 
 
     response = await call_next(request)
     response = await call_next(request)
     return response
     return response
@@ -175,7 +169,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
                     res = r.json()
                     res = r.json()
                     if "error" in res:
                     if "error" in res:
                         error_detail = f"External: {res['error']}"
                         error_detail = f"External: {res['error']}"
-                except:
+                except Exception:
                     error_detail = f"External: {e}"
                     error_detail = f"External: {e}"
 
 
             raise HTTPException(
             raise HTTPException(
@@ -234,64 +228,58 @@ def merge_models_lists(model_lists):
     return merged_list
     return merged_list
 
 
 
 
-async def get_all_models(raw: bool = False):
+def is_openai_api_disabled():
+    api_keys = app.state.config.OPENAI_API_KEYS
+    no_keys = len(api_keys) == 1 and api_keys[0] == ""
+    return no_keys or not app.state.config.ENABLE_OPENAI_API
+
+
+async def get_all_models_raw() -> list:
+    if is_openai_api_disabled():
+        return []
+
+    # Check if API KEYS length is same than API URLS length
+    num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
+    num_keys = len(app.state.config.OPENAI_API_KEYS)
+
+    if num_keys != num_urls:
+        # if there are more keys than urls, remove the extra keys
+        if num_keys > num_urls:
+            new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
+            app.state.config.OPENAI_API_KEYS = new_keys
+        # if there are more urls than keys, add empty keys
+        else:
+            app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
+
+    tasks = [
+        fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
+        for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
+    ]
+
+    responses = await asyncio.gather(*tasks)
+    log.debug(f"get_all_models:responses() {responses}")
+
+    return responses
+
+
+async def get_all_models() -> dict[str, list]:
     log.info("get_all_models()")
     log.info("get_all_models()")
+    if is_openai_api_disabled():
+        return {"data": []}
 
 
-    if (
-        len(app.state.config.OPENAI_API_KEYS) == 1
-        and app.state.config.OPENAI_API_KEYS[0] == ""
-    ) or not app.state.config.ENABLE_OPENAI_API:
-        models = {"data": []}
-    else:
-        # Check if API KEYS length is same than API URLS length
-        if len(app.state.config.OPENAI_API_KEYS) != len(
-            app.state.config.OPENAI_API_BASE_URLS
-        ):
-            # if there are more keys than urls, remove the extra keys
-            if len(app.state.config.OPENAI_API_KEYS) > len(
-                app.state.config.OPENAI_API_BASE_URLS
-            ):
-                app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
-                    : len(app.state.config.OPENAI_API_BASE_URLS)
-                ]
-            # if there are more urls than keys, add empty keys
-            else:
-                app.state.config.OPENAI_API_KEYS += [
-                    ""
-                    for _ in range(
-                        len(app.state.config.OPENAI_API_BASE_URLS)
-                        - len(app.state.config.OPENAI_API_KEYS)
-                    )
-                ]
+    responses = await get_all_models_raw()
 
 
-        tasks = [
-            fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
-            for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
-        ]
-
-        responses = await asyncio.gather(*tasks)
-        log.debug(f"get_all_models:responses() {responses}")
-
-        if raw:
-            return responses
-
-        models = {
-            "data": merge_models_lists(
-                list(
-                    map(
-                        lambda response: (
-                            response["data"]
-                            if (response and "data" in response)
-                            else (response if isinstance(response, list) else None)
-                        ),
-                        responses,
-                    )
-                )
-            )
-        }
+    def extract_data(response):
+        if response and "data" in response:
+            return response["data"]
+        if isinstance(response, list):
+            return response
+        return None
 
 
-        log.debug(f"models: {models}")
-        app.state.MODELS = {model["id"]: model for model in models["data"]}
+    models = {"data": merge_models_lists(map(extract_data, responses))}
+
+    log.debug(f"models: {models}")
+    app.state.MODELS = {model["id"]: model for model in models["data"]}
 
 
     return models
     return models
 
 
@@ -299,7 +287,7 @@ async def get_all_models(raw: bool = False):
 @app.get("/models")
 @app.get("/models")
 @app.get("/models/{url_idx}")
 @app.get("/models/{url_idx}")
 async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
 async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
-    if url_idx == None:
+    if url_idx is None:
         models = await get_all_models()
         models = await get_all_models()
         if app.state.config.ENABLE_MODEL_FILTER:
         if app.state.config.ENABLE_MODEL_FILTER:
             if user.role == "user":
             if user.role == "user":
@@ -340,7 +328,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
                     res = r.json()
                     res = r.json()
                     if "error" in res:
                     if "error" in res:
                         error_detail = f"External: {res['error']}"
                         error_detail = f"External: {res['error']}"
-                except:
+                except Exception:
                     error_detail = f"External: {e}"
                     error_detail = f"External: {e}"
 
 
             raise HTTPException(
             raise HTTPException(
@@ -358,8 +346,7 @@ async def generate_chat_completion(
 ):
 ):
     idx = 0
     idx = 0
     payload = {**form_data}
     payload = {**form_data}
-    if "metadata" in payload:
-        del payload["metadata"]
+    payload.pop("metadata")
 
 
     model_id = form_data.get("model")
     model_id = form_data.get("model")
     model_info = Models.get_model_by_id(model_id)
     model_info = Models.get_model_by_id(model_id)
@@ -368,70 +355,9 @@ async def generate_chat_completion(
         if model_info.base_model_id:
         if model_info.base_model_id:
             payload["model"] = model_info.base_model_id
             payload["model"] = model_info.base_model_id
 
 
-        model_info.params = model_info.params.model_dump()
-
-        if model_info.params:
-            if (
-                model_info.params.get("temperature", None) is not None
-                and payload.get("temperature") is None
-            ):
-                payload["temperature"] = float(model_info.params.get("temperature"))
-
-            if model_info.params.get("top_p", None) and payload.get("top_p") is None:
-                payload["top_p"] = int(model_info.params.get("top_p", None))
-
-            if (
-                model_info.params.get("max_tokens", None)
-                and payload.get("max_tokens") is None
-            ):
-                payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
-
-            if (
-                model_info.params.get("frequency_penalty", None)
-                and payload.get("frequency_penalty") is None
-            ):
-                payload["frequency_penalty"] = int(
-                    model_info.params.get("frequency_penalty", None)
-                )
-
-            if (
-                model_info.params.get("seed", None) is not None
-                and payload.get("seed") is None
-            ):
-                payload["seed"] = model_info.params.get("seed", None)
-
-            if model_info.params.get("stop", None) and payload.get("stop") is None:
-                payload["stop"] = (
-                    [
-                        bytes(stop, "utf-8").decode("unicode_escape")
-                        for stop in model_info.params["stop"]
-                    ]
-                    if model_info.params.get("stop", None)
-                    else None
-                )
-
-        system = model_info.params.get("system", None)
-        if system:
-            system = prompt_template(
-                system,
-                **(
-                    {
-                        "user_name": user.name,
-                        "user_location": (
-                            user.info.get("location") if user.info else None
-                        ),
-                    }
-                    if user
-                    else {}
-                ),
-            )
-            if payload.get("messages"):
-                payload["messages"] = add_or_update_system_message(
-                    system, payload["messages"]
-                )
-
-    else:
-        pass
+        params = model_info.params.model_dump()
+        payload = apply_model_params_to_body(params, payload)
+        payload = apply_model_system_prompt_to_body(params, payload, user)
 
 
     model = app.state.MODELS[payload.get("model")]
     model = app.state.MODELS[payload.get("model")]
     idx = model["urlIdx"]
     idx = model["urlIdx"]
@@ -506,7 +432,7 @@ async def generate_chat_completion(
                 print(res)
                 print(res)
                 if "error" in res:
                 if "error" in res:
                     error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
                     error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
-            except:
+            except Exception:
                 error_detail = f"External: {e}"
                 error_detail = f"External: {e}"
         raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
         raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
     finally:
     finally:
@@ -569,7 +495,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
                 print(res)
                 print(res)
                 if "error" in res:
                 if "error" in res:
                     error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
                     error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
-            except:
+            except Exception:
                 error_detail = f"External: {e}"
                 error_detail = f"External: {e}"
         raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
         raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
     finally:
     finally:

+ 15 - 12
backend/apps/socket/main.py

@@ -44,23 +44,26 @@ async def user_join(sid, data):
     print("user-join", sid, data)
     print("user-join", sid, data)
 
 
     auth = data["auth"] if "auth" in data else None
     auth = data["auth"] if "auth" in data else None
+    if not auth or "token" not in auth:
+        return
 
 
-    if auth and "token" in auth:
-        data = decode_token(auth["token"])
+    data = decode_token(auth["token"])
+    if data is None or "id" not in data:
+        return
 
 
-        if data is not None and "id" in data:
-            user = Users.get_user_by_id(data["id"])
+    user = Users.get_user_by_id(data["id"])
+    if not user:
+        return
 
 
-        if user:
-            SESSION_POOL[sid] = user.id
-            if user.id in USER_POOL:
-                USER_POOL[user.id].append(sid)
-            else:
-                USER_POOL[user.id] = [sid]
+    SESSION_POOL[sid] = user.id
+    if user.id in USER_POOL:
+        USER_POOL[user.id].append(sid)
+    else:
+        USER_POOL[user.id] = [sid]
 
 
-            print(f"user {user.name}({user.id}) connected with session ID {sid}")
+    print(f"user {user.name}({user.id}) connected with session ID {sid}")
 
 
-            await sio.emit("user-count", {"count": len(set(USER_POOL))})
+    await sio.emit("user-count", {"count": len(set(USER_POOL))})
 
 
 
 
 @sio.on("user-count")
 @sio.on("user-count")

+ 2 - 43
backend/apps/webui/main.py

@@ -22,9 +22,9 @@ from apps.webui.utils import load_function_module_by_id
 from utils.misc import (
 from utils.misc import (
     openai_chat_chunk_message_template,
     openai_chat_chunk_message_template,
     openai_chat_completion_message_template,
     openai_chat_completion_message_template,
-    add_or_update_system_message,
+    apply_model_params_to_body,
+    apply_model_system_prompt_to_body,
 )
 )
-from utils.task import prompt_template
 
 
 
 
 from config import (
 from config import (
@@ -269,47 +269,6 @@ def get_function_params(function_module, form_data, user, extra_params={}):
     return params
     return params
 
 
 
 
-# inplace function: form_data is modified
-def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
-    if not params:
-        return form_data
-
-    mappings = {
-        "temperature": float,
-        "top_p": int,
-        "max_tokens": int,
-        "frequency_penalty": int,
-        "seed": lambda x: x,
-        "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
-    }
-
-    for key, cast_func in mappings.items():
-        if (value := params.get(key)) is not None:
-            form_data[key] = cast_func(value)
-
-    return form_data
-
-
-# inplace function: form_data is modified
-def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
-    system = params.get("system", None)
-    if not system:
-        return form_data
-
-    if user:
-        template_params = {
-            "user_name": user.name,
-            "user_location": user.info.get("location") if user.info else None,
-        }
-    else:
-        template_params = {}
-    system = prompt_template(system, **template_params)
-    form_data["messages"] = add_or_update_system_message(
-        system, form_data.get("messages", [])
-    )
-    return form_data
-
-
 async def generate_function_chat_completion(form_data, user):
 async def generate_function_chat_completion(form_data, user):
     model_id = form_data.get("model")
     model_id = form_data.get("model")
     model_info = Models.get_model_by_id(model_id)
     model_info = Models.get_model_by_id(model_id)

+ 13 - 18
backend/apps/webui/routers/tools.py

@@ -1,12 +1,8 @@
-from fastapi import Depends, FastAPI, HTTPException, status, Request
-from datetime import datetime, timedelta
-from typing import List, Union, Optional
+from fastapi import Depends, HTTPException, status, Request
+from typing import List, Optional
 
 
 from fastapi import APIRouter
 from fastapi import APIRouter
-from pydantic import BaseModel
-import json
 
 
-from apps.webui.models.users import Users
 from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
 from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
 from apps.webui.utils import load_toolkit_module_by_id
 from apps.webui.utils import load_toolkit_module_by_id
 
 
@@ -14,7 +10,6 @@ from utils.utils import get_admin_user, get_verified_user
 from utils.tools import get_tools_specs
 from utils.tools import get_tools_specs
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
-from importlib import util
 import os
 import os
 from pathlib import Path
 from pathlib import Path
 
 
@@ -69,7 +64,7 @@ async def create_new_toolkit(
     form_data.id = form_data.id.lower()
     form_data.id = form_data.id.lower()
 
 
     toolkit = Tools.get_tool_by_id(form_data.id)
     toolkit = Tools.get_tool_by_id(form_data.id)
-    if toolkit == None:
+    if toolkit is None:
         toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
         toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
         try:
         try:
             with open(toolkit_path, "w") as tool_file:
             with open(toolkit_path, "w") as tool_file:
@@ -98,7 +93,7 @@ async def create_new_toolkit(
             print(e)
             print(e)
             raise HTTPException(
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,
                 status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.DEFAULT(e),
+                detail=ERROR_MESSAGES.DEFAULT(str(e)),
             )
             )
     else:
     else:
         raise HTTPException(
         raise HTTPException(
@@ -170,7 +165,7 @@ async def update_toolkit_by_id(
     except Exception as e:
     except Exception as e:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
             status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.DEFAULT(e),
+            detail=ERROR_MESSAGES.DEFAULT(str(e)),
         )
         )
 
 
 
 
@@ -210,7 +205,7 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
         except Exception as e:
         except Exception as e:
             raise HTTPException(
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,
                 status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.DEFAULT(e),
+                detail=ERROR_MESSAGES.DEFAULT(str(e)),
             )
             )
     else:
     else:
         raise HTTPException(
         raise HTTPException(
@@ -233,7 +228,7 @@ async def get_toolkit_valves_spec_by_id(
         if id in request.app.state.TOOLS:
         if id in request.app.state.TOOLS:
             toolkit_module = request.app.state.TOOLS[id]
             toolkit_module = request.app.state.TOOLS[id]
         else:
         else:
-            toolkit_module, frontmatter = load_toolkit_module_by_id(id)
+            toolkit_module, _ = load_toolkit_module_by_id(id)
             request.app.state.TOOLS[id] = toolkit_module
             request.app.state.TOOLS[id] = toolkit_module
 
 
         if hasattr(toolkit_module, "Valves"):
         if hasattr(toolkit_module, "Valves"):
@@ -261,7 +256,7 @@ async def update_toolkit_valves_by_id(
         if id in request.app.state.TOOLS:
         if id in request.app.state.TOOLS:
             toolkit_module = request.app.state.TOOLS[id]
             toolkit_module = request.app.state.TOOLS[id]
         else:
         else:
-            toolkit_module, frontmatter = load_toolkit_module_by_id(id)
+            toolkit_module, _ = load_toolkit_module_by_id(id)
             request.app.state.TOOLS[id] = toolkit_module
             request.app.state.TOOLS[id] = toolkit_module
 
 
         if hasattr(toolkit_module, "Valves"):
         if hasattr(toolkit_module, "Valves"):
@@ -276,7 +271,7 @@ async def update_toolkit_valves_by_id(
                 print(e)
                 print(e)
                 raise HTTPException(
                 raise HTTPException(
                     status_code=status.HTTP_400_BAD_REQUEST,
                     status_code=status.HTTP_400_BAD_REQUEST,
-                    detail=ERROR_MESSAGES.DEFAULT(e),
+                    detail=ERROR_MESSAGES.DEFAULT(str(e)),
                 )
                 )
         else:
         else:
             raise HTTPException(
             raise HTTPException(
@@ -306,7 +301,7 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)
         except Exception as e:
         except Exception as e:
             raise HTTPException(
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,
                 status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.DEFAULT(e),
+                detail=ERROR_MESSAGES.DEFAULT(str(e)),
             )
             )
     else:
     else:
         raise HTTPException(
         raise HTTPException(
@@ -324,7 +319,7 @@ async def get_toolkit_user_valves_spec_by_id(
         if id in request.app.state.TOOLS:
         if id in request.app.state.TOOLS:
             toolkit_module = request.app.state.TOOLS[id]
             toolkit_module = request.app.state.TOOLS[id]
         else:
         else:
-            toolkit_module, frontmatter = load_toolkit_module_by_id(id)
+            toolkit_module, _ = load_toolkit_module_by_id(id)
             request.app.state.TOOLS[id] = toolkit_module
             request.app.state.TOOLS[id] = toolkit_module
 
 
         if hasattr(toolkit_module, "UserValves"):
         if hasattr(toolkit_module, "UserValves"):
@@ -348,7 +343,7 @@ async def update_toolkit_user_valves_by_id(
         if id in request.app.state.TOOLS:
         if id in request.app.state.TOOLS:
             toolkit_module = request.app.state.TOOLS[id]
             toolkit_module = request.app.state.TOOLS[id]
         else:
         else:
-            toolkit_module, frontmatter = load_toolkit_module_by_id(id)
+            toolkit_module, _ = load_toolkit_module_by_id(id)
             request.app.state.TOOLS[id] = toolkit_module
             request.app.state.TOOLS[id] = toolkit_module
 
 
         if hasattr(toolkit_module, "UserValves"):
         if hasattr(toolkit_module, "UserValves"):
@@ -365,7 +360,7 @@ async def update_toolkit_user_valves_by_id(
                 print(e)
                 print(e)
                 raise HTTPException(
                 raise HTTPException(
                     status_code=status.HTTP_400_BAD_REQUEST,
                     status_code=status.HTTP_400_BAD_REQUEST,
-                    detail=ERROR_MESSAGES.DEFAULT(e),
+                    detail=ERROR_MESSAGES.DEFAULT(str(e)),
                 )
                 )
         else:
         else:
             raise HTTPException(
             raise HTTPException(

+ 12 - 16
backend/main.py

@@ -36,6 +36,7 @@ from apps.ollama.main import (
 from apps.openai.main import (
 from apps.openai.main import (
     app as openai_app,
     app as openai_app,
     get_all_models as get_openai_models,
     get_all_models as get_openai_models,
+    get_all_models_raw as get_openai_models_raw,
     generate_chat_completion as generate_openai_chat_completion,
     generate_chat_completion as generate_openai_chat_completion,
 )
 )
 
 
@@ -957,7 +958,7 @@ async def get_all_models():
 
 
     custom_models = Models.get_all_models()
     custom_models = Models.get_all_models()
     for custom_model in custom_models:
     for custom_model in custom_models:
-        if custom_model.base_model_id == None:
+        if custom_model.base_model_id is None:
             for model in models:
             for model in models:
                 if (
                 if (
                     custom_model.id == model["id"]
                     custom_model.id == model["id"]
@@ -1656,13 +1657,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
 
 
 @app.get("/api/pipelines/list")
 @app.get("/api/pipelines/list")
 async def get_pipelines_list(user=Depends(get_admin_user)):
 async def get_pipelines_list(user=Depends(get_admin_user)):
-    responses = await get_openai_models(raw=True)
+    responses = await get_openai_models_raw()
 
 
     print(responses)
     print(responses)
     urlIdxs = [
     urlIdxs = [
         idx
         idx
         for idx, response in enumerate(responses)
         for idx, response in enumerate(responses)
-        if response != None and "pipelines" in response
+        if response is not None and "pipelines" in response
     ]
     ]
 
 
     return {
     return {
@@ -1723,7 +1724,7 @@ async def upload_pipeline(
                 res = r.json()
                 res = r.json()
                 if "detail" in res:
                 if "detail" in res:
                     detail = res["detail"]
                     detail = res["detail"]
-            except:
+            except Exception:
                 pass
                 pass
 
 
         raise HTTPException(
         raise HTTPException(
@@ -1769,7 +1770,7 @@ async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user))
                 res = r.json()
                 res = r.json()
                 if "detail" in res:
                 if "detail" in res:
                     detail = res["detail"]
                     detail = res["detail"]
-            except:
+            except Exception:
                 pass
                 pass
 
 
         raise HTTPException(
         raise HTTPException(
@@ -1811,7 +1812,7 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_
                 res = r.json()
                 res = r.json()
                 if "detail" in res:
                 if "detail" in res:
                     detail = res["detail"]
                     detail = res["detail"]
-            except:
+            except Exception:
                 pass
                 pass
 
 
         raise HTTPException(
         raise HTTPException(
@@ -1844,7 +1845,7 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
                 res = r.json()
                 res = r.json()
                 if "detail" in res:
                 if "detail" in res:
                     detail = res["detail"]
                     detail = res["detail"]
-            except:
+            except Exception:
                 pass
                 pass
 
 
         raise HTTPException(
         raise HTTPException(
@@ -1859,7 +1860,6 @@ async def get_pipeline_valves(
     pipeline_id: str,
     pipeline_id: str,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
 ):
 ):
-    models = await get_all_models()
     r = None
     r = None
     try:
     try:
         url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
         url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
@@ -1898,8 +1898,6 @@ async def get_pipeline_valves_spec(
     pipeline_id: str,
     pipeline_id: str,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
 ):
 ):
-    models = await get_all_models()
-
     r = None
     r = None
     try:
     try:
         url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
         url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
@@ -1922,7 +1920,7 @@ async def get_pipeline_valves_spec(
                 res = r.json()
                 res = r.json()
                 if "detail" in res:
                 if "detail" in res:
                     detail = res["detail"]
                     detail = res["detail"]
-            except:
+            except Exception:
                 pass
                 pass
 
 
         raise HTTPException(
         raise HTTPException(
@@ -1938,8 +1936,6 @@ async def update_pipeline_valves(
     form_data: dict,
     form_data: dict,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
 ):
 ):
-    models = await get_all_models()
-
     r = None
     r = None
     try:
     try:
         url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
         url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
@@ -1967,7 +1963,7 @@ async def update_pipeline_valves(
                 res = r.json()
                 res = r.json()
                 if "detail" in res:
                 if "detail" in res:
                     detail = res["detail"]
                     detail = res["detail"]
-            except:
+            except Exception:
                 pass
                 pass
 
 
         raise HTTPException(
         raise HTTPException(
@@ -2068,7 +2064,7 @@ async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
 
 
 
 
 @app.get("/api/version")
 @app.get("/api/version")
-async def get_app_config():
+async def get_app_version():
     return {
     return {
         "version": VERSION,
         "version": VERSION,
     }
     }
@@ -2091,7 +2087,7 @@ async def get_app_latest_release_version():
                 latest_version = data["tag_name"]
                 latest_version = data["tag_name"]
 
 
                 return {"current": VERSION, "latest": latest_version[1:]}
                 return {"current": VERSION, "latest": latest_version[1:]}
-    except aiohttp.ClientError as e:
+    except aiohttp.ClientError:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
             status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
             detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
             detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,

+ 43 - 0
backend/utils/misc.py

@@ -6,6 +6,8 @@ from typing import Optional, List, Tuple
 import uuid
 import uuid
 import time
 import time
 
 
+from utils.task import prompt_template
+
 
 
 def get_last_user_message_item(messages: List[dict]) -> Optional[dict]:
 def get_last_user_message_item(messages: List[dict]) -> Optional[dict]:
     for message in reversed(messages):
     for message in reversed(messages):
@@ -111,6 +113,47 @@ def openai_chat_completion_message_template(model: str, message: str):
     template["choices"][0]["finish_reason"] = "stop"
     template["choices"][0]["finish_reason"] = "stop"
 
 
 
 
+# inplace function: form_data is modified
+def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
+    system = params.get("system", None)
+    if not system:
+        return form_data
+
+    if user:
+        template_params = {
+            "user_name": user.name,
+            "user_location": user.info.get("location") if user.info else None,
+        }
+    else:
+        template_params = {}
+    system = prompt_template(system, **template_params)
+    form_data["messages"] = add_or_update_system_message(
+        system, form_data.get("messages", [])
+    )
+    return form_data
+
+
+# inplace function: form_data is modified
+def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
+    if not params:
+        return form_data
+
+    mappings = {
+        "temperature": float,
+        "top_p": int,
+        "max_tokens": int,
+        "frequency_penalty": int,
+        "seed": lambda x: x,
+        "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
+    }
+
+    for key, cast_func in mappings.items():
+        if (value := params.get(key)) is not None:
+            form_data[key] = cast_func(value)
+
+    return form_data
+
+
 def get_gravatar_url(email):
 def get_gravatar_url(email):
     # Trim leading and trailing whitespace from
     # Trim leading and trailing whitespace from
     # an email address and force all characters
     # an email address and force all characters

+ 1 - 2
backend/utils/task.py

@@ -6,7 +6,7 @@ from typing import Optional
 
 
 
 
 def prompt_template(
 def prompt_template(
-    template: str, user_name: str = None, user_location: str = None
+    template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
 ) -> str:
 ) -> str:
     # Get the current date
     # Get the current date
     current_date = datetime.now()
     current_date = datetime.now()
@@ -83,7 +83,6 @@ def title_generation_template(
 def search_query_generation_template(
 def search_query_generation_template(
     template: str, prompt: str, user: Optional[dict] = None
     template: str, prompt: str, user: Optional[dict] = None
 ) -> str:
 ) -> str:
-
     def replacement_function(match):
     def replacement_function(match):
         full_match = match.group(0)
         full_match = match.group(0)
         start_length = match.group(1)
         start_length = match.group(1)

+ 3 - 6
backend/utils/utils.py

@@ -1,15 +1,12 @@
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 from fastapi import HTTPException, status, Depends, Request
 from fastapi import HTTPException, status, Depends, Request
-from sqlalchemy.orm import Session
 
 
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 
 
-from pydantic import BaseModel
 from typing import Union, Optional
 from typing import Union, Optional
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 from passlib.context import CryptContext
 from passlib.context import CryptContext
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
-import requests
 import jwt
 import jwt
 import uuid
 import uuid
 import logging
 import logging
@@ -54,7 +51,7 @@ def decode_token(token: str) -> Optional[dict]:
     try:
     try:
         decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
         decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
         return decoded
         return decoded
-    except Exception as e:
+    except Exception:
         return None
         return None
 
 
 
 
@@ -71,7 +68,7 @@ def get_http_authorization_cred(auth_header: str):
     try:
     try:
         scheme, credentials = auth_header.split(" ")
         scheme, credentials = auth_header.split(" ")
         return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
         return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
-    except:
+    except Exception:
         raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
         raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
 
 
 
 
@@ -96,7 +93,7 @@ def get_current_user(
 
 
     # auth by jwt token
     # auth by jwt token
     data = decode_token(token)
     data = decode_token(token)
-    if data != None and "id" in data:
+    if data is not None and "id" in data:
         user = Users.get_user_by_id(data["id"])
         user = Users.get_user_by_id(data["id"])
         if user is None:
         if user is None:
             raise HTTPException(
             raise HTTPException(