Przeglądaj źródła

Merge pull request #1653 from open-webui/litellm-as-subprocess

fix: litellm as subprocess
Timothy Jaeryang Baek 1 rok temu
rodzic
commit
5997774ab8

+ 280 - 54
backend/apps/litellm/main.py

@@ -1,100 +1,326 @@
-import logging
-
-from litellm.proxy.proxy_server import ProxyConfig, initialize
-from litellm.proxy.proxy_server import app
+from fastapi import FastAPI, Depends, HTTPException
+from fastapi.routing import APIRoute
+from fastapi.middleware.cors import CORSMiddleware
 
 
+import logging
 from fastapi import FastAPI, Request, Depends, status, Response
 from fastapi import FastAPI, Request, Depends, status, Response
 from fastapi.responses import JSONResponse
 from fastapi.responses import JSONResponse
 
 
 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
 from starlette.responses import StreamingResponse
 from starlette.responses import StreamingResponse
 import json
 import json
+import time
+import requests
+
+from pydantic import BaseModel
+from typing import Optional, List
 
 
-from utils.utils import get_http_authorization_cred, get_current_user
+from utils.utils import get_verified_user, get_current_user, get_admin_user
 from config import SRC_LOG_LEVELS, ENV
 from config import SRC_LOG_LEVELS, ENV
+from constants import MESSAGES
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["LITELLM"])
 log.setLevel(SRC_LOG_LEVELS["LITELLM"])
 
 
 
 
-from config import (
-    MODEL_FILTER_ENABLED,
-    MODEL_FILTER_LIST,
+from config import MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, DATA_DIR
+
+
+import asyncio
+import subprocess
+import yaml
+
+app = FastAPI()
+
+origins = ["*"]
+
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=origins,
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
 )
 )
 
 
 
 
-proxy_config = ProxyConfig()
+LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml"
+
+with open(LITELLM_CONFIG_DIR, "r") as file:
+    litellm_config = yaml.safe_load(file)
+
+app.state.CONFIG = litellm_config
+
+# Global variable to store the subprocess reference
+background_process = None
 
 
 
 
-async def config():
-    router, model_list, general_settings = await proxy_config.load_config(
-        router=None, config_file_path="./data/litellm/config.yaml"
+async def run_background_process(command):
+    global background_process
+    log.info("run_background_process")
+
+    try:
+        # Log the command to be executed
+        log.info(f"Executing command: {command}")
+        # Execute the command and create a subprocess
+        process = await asyncio.create_subprocess_exec(
+            *command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE
+        )
+        background_process = process
+        log.info("Subprocess started successfully.")
+
+        # Capture STDERR for debugging purposes
+        stderr_output = await process.stderr.read()
+        stderr_text = stderr_output.decode().strip()
+        if stderr_text:
+            log.info(f"Subprocess STDERR: {stderr_text}")
+
+        # log.info output line by line
+        async for line in process.stdout:
+            log.info(line.decode().strip())
+
+        # Wait for the process to finish
+        returncode = await process.wait()
+        log.info(f"Subprocess exited with return code {returncode}")
+    except Exception as e:
+        log.error(f"Failed to start subprocess: {e}")
+        raise  # Optionally re-raise the exception if you want it to propagate
+
+
+async def start_litellm_background():
+    log.info("start_litellm_background")
+    # Command to run in the background
+    command = (
+        "litellm --port 14365 --telemetry False --config ./data/litellm/config.yaml"
     )
     )
 
 
-    await initialize(config="./data/litellm/config.yaml", telemetry=False)
+    await run_background_process(command)
 
 
 
 
-async def startup():
-    await config()
+async def shutdown_litellm_background():
+    log.info("shutdown_litellm_background")
+    global background_process
+    if background_process:
+        background_process.terminate()
+        await background_process.wait()  # Ensure the process has terminated
+        log.info("Subprocess terminated")
+        background_process = None
 
 
 
 
 @app.on_event("startup")
 @app.on_event("startup")
-async def on_startup():
-    await startup()
+async def startup_event():
+
+    log.info("startup_event")
+    # TODO: Check config.yaml file and create one
+    asyncio.create_task(start_litellm_background())
 
 
 
 
 app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
 app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
 app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 
 
 
 
-@app.middleware("http")
-async def auth_middleware(request: Request, call_next):
-    auth_header = request.headers.get("Authorization", "")
-    request.state.user = None
+@app.get("/")
+async def get_status():
+    return {"status": True}
+
 
 
+async def restart_litellm():
+    """
+    Endpoint to restart the litellm background service.
+    """
+    log.info("Requested restart of litellm service.")
     try:
     try:
-        user = get_current_user(get_http_authorization_cred(auth_header))
-        log.debug(f"user: {user}")
-        request.state.user = user
+        # Shut down the existing process if it is running
+        await shutdown_litellm_background()
+        log.info("litellm service shutdown complete.")
+
+        # Restart the background service
+
+        asyncio.create_task(start_litellm_background())
+        log.info("litellm service restart complete.")
+
+        return {
+            "status": "success",
+            "message": "litellm service restarted successfully.",
+        }
     except Exception as e:
     except Exception as e:
-        return JSONResponse(status_code=400, content={"detail": str(e)})
+        log.info(f"Error restarting litellm service: {e}")
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
+        )
 
 
-    response = await call_next(request)
-    return response
 
 
+@app.get("/restart")
+async def restart_litellm_handler(user=Depends(get_admin_user)):
+    return await restart_litellm()
 
 
-class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
-    async def dispatch(
-        self, request: Request, call_next: RequestResponseEndpoint
-    ) -> Response:
 
 
-        response = await call_next(request)
-        user = request.state.user
+@app.get("/config")
+async def get_config(user=Depends(get_admin_user)):
+    return app.state.CONFIG
 
 
-        if "/models" in request.url.path:
-            if isinstance(response, StreamingResponse):
-                # Read the content of the streaming response
-                body = b""
-                async for chunk in response.body_iterator:
-                    body += chunk
 
 
-                data = json.loads(body.decode("utf-8"))
+class LiteLLMConfigForm(BaseModel):
+    general_settings: Optional[dict] = None
+    litellm_settings: Optional[dict] = None
+    model_list: Optional[List[dict]] = None
+    router_settings: Optional[dict] = None
 
 
-                if app.state.MODEL_FILTER_ENABLED:
-                    if user and user.role == "user":
-                        data["data"] = list(
-                            filter(
-                                lambda model: model["id"]
-                                in app.state.MODEL_FILTER_LIST,
-                                data["data"],
-                            )
-                        )
 
 
-                # Modified Flag
-                data["modified"] = True
-                return JSONResponse(content=data)
+@app.post("/config/update")
+async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)):
+    app.state.CONFIG = form_data.model_dump(exclude_none=True)
 
 
-        return response
+    with open(LITELLM_CONFIG_DIR, "w") as file:
+        yaml.dump(app.state.CONFIG, file)
 
 
+    await restart_litellm()
+    return app.state.CONFIG
 
 
-app.add_middleware(ModifyModelsResponseMiddleware)
+
+@app.get("/models")
+@app.get("/v1/models")
+async def get_models(user=Depends(get_current_user)):
+    while not background_process:
+        await asyncio.sleep(0.1)
+
+    url = "http://localhost:14365/v1"
+    r = None
+    try:
+        r = requests.request(method="GET", url=f"{url}/models")
+        r.raise_for_status()
+
+        data = r.json()
+
+        if app.state.MODEL_FILTER_ENABLED:
+            if user and user.role == "user":
+                data["data"] = list(
+                    filter(
+                        lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
+                        data["data"],
+                    )
+                )
+
+        return data
+    except Exception as e:
+
+        log.exception(e)
+        error_detail = "Open WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = r.json()
+                if "error" in res:
+                    error_detail = f"External: {res['error']}"
+            except:
+                error_detail = f"External: {e}"
+
+        return {
+            "data": [
+                {
+                    "id": model["model_name"],
+                    "object": "model",
+                    "created": int(time.time()),
+                    "owned_by": "openai",
+                }
+                for model in app.state.CONFIG["model_list"]
+            ],
+            "object": "list",
+        }
+
+
+@app.get("/model/info")
+async def get_model_list(user=Depends(get_admin_user)):
+    return {"data": app.state.CONFIG["model_list"]}
+
+
+class AddLiteLLMModelForm(BaseModel):
+    model_name: str
+    litellm_params: dict
+
+
+@app.post("/model/new")
+async def add_model_to_config(
+    form_data: AddLiteLLMModelForm, user=Depends(get_admin_user)
+):
+    # TODO: Validate model form
+
+    app.state.CONFIG["model_list"].append(form_data.model_dump())
+
+    with open(LITELLM_CONFIG_DIR, "w") as file:
+        yaml.dump(app.state.CONFIG, file)
+
+    await restart_litellm()
+
+    return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
+
+
+class DeleteLiteLLMModelForm(BaseModel):
+    id: str
+
+
+@app.post("/model/delete")
+async def delete_model_from_config(
+    form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user)
+):
+    app.state.CONFIG["model_list"] = [
+        model
+        for model in app.state.CONFIG["model_list"]
+        if model["model_name"] != form_data.id
+    ]
+
+    with open(LITELLM_CONFIG_DIR, "w") as file:
+        yaml.dump(app.state.CONFIG, file)
+
+    await restart_litellm()
+
+    return {"message": MESSAGES.MODEL_DELETED(form_data.id)}
+
+
+@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
+async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
+    body = await request.body()
+
+    url = "http://localhost:14365"
+
+    target_url = f"{url}/{path}"
+
+    headers = {}
+    # headers["Authorization"] = f"Bearer {key}"
+    headers["Content-Type"] = "application/json"
+
+    r = None
+
+    try:
+        r = requests.request(
+            method=request.method,
+            url=target_url,
+            data=body,
+            headers=headers,
+            stream=True,
+        )
+
+        r.raise_for_status()
+
+        # Check if response is SSE
+        if "text/event-stream" in r.headers.get("Content-Type", ""):
+            return StreamingResponse(
+                r.iter_content(chunk_size=8192),
+                status_code=r.status_code,
+                headers=dict(r.headers),
+            )
+        else:
+            response_data = r.json()
+            return response_data
+    except Exception as e:
+        log.exception(e)
+        error_detail = "Open WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = r.json()
+                if "error" in res:
+                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
+            except:
+                error_detail = f"External: {e}"
+
+        raise HTTPException(
+            status_code=r.status_code if r else 500, detail=error_detail
+        )

+ 4 - 0
backend/constants.py

@@ -3,6 +3,10 @@ from enum import Enum
 
 
 class MESSAGES(str, Enum):
 class MESSAGES(str, Enum):
     DEFAULT = lambda msg="": f"{msg if msg else ''}"
     DEFAULT = lambda msg="": f"{msg if msg else ''}"
+    MODEL_ADDED = lambda model="": f"The model '{model}' has been added successfully."
+    MODEL_DELETED = (
+        lambda model="": f"The model '{model}' has been deleted successfully."
+    )
 
 
 
 
 class WEBHOOK_MESSAGES(str, Enum):
 class WEBHOOK_MESSAGES(str, Enum):

+ 12 - 2
backend/main.py

@@ -20,12 +20,17 @@ from starlette.middleware.base import BaseHTTPMiddleware
 from apps.ollama.main import app as ollama_app
 from apps.ollama.main import app as ollama_app
 from apps.openai.main import app as openai_app
 from apps.openai.main import app as openai_app
 
 
-from apps.litellm.main import app as litellm_app, startup as litellm_app_startup
+from apps.litellm.main import (
+    app as litellm_app,
+    start_litellm_background,
+    shutdown_litellm_background,
+)
 from apps.audio.main import app as audio_app
 from apps.audio.main import app as audio_app
 from apps.images.main import app as images_app
 from apps.images.main import app as images_app
 from apps.rag.main import app as rag_app
 from apps.rag.main import app as rag_app
 from apps.web.main import app as webui_app
 from apps.web.main import app as webui_app
 
 
+import asyncio
 from pydantic import BaseModel
 from pydantic import BaseModel
 from typing import List
 from typing import List
 
 
@@ -170,7 +175,7 @@ async def check_url(request: Request, call_next):
 
 
 @app.on_event("startup")
 @app.on_event("startup")
 async def on_startup():
 async def on_startup():
-    await litellm_app_startup()
+    asyncio.create_task(start_litellm_background())
 
 
 
 
 app.mount("/api/v1", webui_app)
 app.mount("/api/v1", webui_app)
@@ -315,3 +320,8 @@ app.mount(
     SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
     SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
     name="spa-static-files",
     name="spa-static-files",
 )
 )
+
+
+@app.on_event("shutdown")
+async def shutdown_event():
+    await shutdown_litellm_background()

+ 3 - 1
backend/requirements.txt

@@ -17,7 +17,9 @@ peewee
 peewee-migrate
 peewee-migrate
 bcrypt
 bcrypt
 
 
-litellm==1.30.7
+litellm==1.35.17
+litellm[proxy]==1.35.17
+
 boto3
 boto3
 
 
 argon2-cffi
 argon2-cffi

+ 6 - 6
src/lib/components/chat/Settings/Models.svelte

@@ -35,7 +35,7 @@
 	let liteLLMRPM = '';
 	let liteLLMRPM = '';
 	let liteLLMMaxTokens = '';
 	let liteLLMMaxTokens = '';
 
 
-	let deleteLiteLLMModelId = '';
+	let deleteLiteLLMModelName = '';
 
 
 	$: liteLLMModelName = liteLLMModel;
 	$: liteLLMModelName = liteLLMModel;
 
 
@@ -472,7 +472,7 @@
 	};
 	};
 
 
 	const deleteLiteLLMModelHandler = async () => {
 	const deleteLiteLLMModelHandler = async () => {
-		const res = await deleteLiteLLMModel(localStorage.token, deleteLiteLLMModelId).catch(
+		const res = await deleteLiteLLMModel(localStorage.token, deleteLiteLLMModelName).catch(
 			(error) => {
 			(error) => {
 				toast.error(error);
 				toast.error(error);
 				return null;
 				return null;
@@ -485,7 +485,7 @@
 			}
 			}
 		}
 		}
 
 
-		deleteLiteLLMModelId = '';
+		deleteLiteLLMModelName = '';
 		liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token);
 		liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token);
 		models.set(await getModels());
 		models.set(await getModels());
 	};
 	};
@@ -1099,14 +1099,14 @@
 								<div class="flex-1 mr-2">
 								<div class="flex-1 mr-2">
 									<select
 									<select
 										class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
 										class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
-										bind:value={deleteLiteLLMModelId}
+										bind:value={deleteLiteLLMModelName}
 										placeholder={$i18n.t('Select a model')}
 										placeholder={$i18n.t('Select a model')}
 									>
 									>
-										{#if !deleteLiteLLMModelId}
+										{#if !deleteLiteLLMModelName}
 											<option value="" disabled selected>{$i18n.t('Select a model')}</option>
 											<option value="" disabled selected>{$i18n.t('Select a model')}</option>
 										{/if}
 										{/if}
 										{#each liteLLMModelInfo as model}
 										{#each liteLLMModelInfo as model}
-											<option value={model.model_info.id} class="bg-gray-100 dark:bg-gray-700"
+											<option value={model.model_name} class="bg-gray-100 dark:bg-gray-700"
 												>{model.model_name}</option
 												>{model.model_name}</option
 											>
 											>
 										{/each}
 										{/each}