Timothy Jaeryang Baek 4 ماه پیش
والد
کامیت
df0cdd9f3c
2فایلهای تغییر یافته به همراه354 افزوده شده و 397 حذف شده
  1. 232 214
      backend/open_webui/routers/audio.py
  2. 122 183
      backend/open_webui/routers/ollama.py

+ 232 - 214
backend/open_webui/routers/audio.py

@@ -11,25 +11,27 @@ from pydub.silence import split_on_silence
 import aiohttp
 import aiohttp
 import aiofiles
 import aiofiles
 import requests
 import requests
+
+from fastapi import (
+    Depends,
+    FastAPI,
+    File,
+    HTTPException,
+    Request,
+    UploadFile,
+    status,
+    APIRouter,
+)
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse
+from pydantic import BaseModel
+
+
+from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.config import (
 from open_webui.config import (
-    AUDIO_STT_ENGINE,
-    AUDIO_STT_MODEL,
-    AUDIO_STT_OPENAI_API_BASE_URL,
-    AUDIO_STT_OPENAI_API_KEY,
-    AUDIO_TTS_API_KEY,
-    AUDIO_TTS_ENGINE,
-    AUDIO_TTS_MODEL,
-    AUDIO_TTS_OPENAI_API_BASE_URL,
-    AUDIO_TTS_OPENAI_API_KEY,
-    AUDIO_TTS_SPLIT_ON,
-    AUDIO_TTS_VOICE,
-    AUDIO_TTS_AZURE_SPEECH_REGION,
-    AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
-    WHISPER_MODEL,
     WHISPER_MODEL_AUTO_UPDATE,
     WHISPER_MODEL_AUTO_UPDATE,
     WHISPER_MODEL_DIR,
     WHISPER_MODEL_DIR,
     CACHE_DIR,
     CACHE_DIR,
-    AppConfig,
 )
 )
 
 
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
@@ -40,52 +42,75 @@ from open_webui.env import (
     ENABLE_FORWARD_USER_INFO_HEADERS,
     ENABLE_FORWARD_USER_INFO_HEADERS,
 )
 )
 
 
-from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import FileResponse
-from pydantic import BaseModel
-from open_webui.utils.auth import get_admin_user, get_verified_user
+
+router = APIRouter()
 
 
 # Constants
 # Constants
 MAX_FILE_SIZE_MB = 25
 MAX_FILE_SIZE_MB = 25
 MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024  # Convert MB to bytes
 MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024  # Convert MB to bytes
 
 
-
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["AUDIO"])
 log.setLevel(SRC_LOG_LEVELS["AUDIO"])
 
 
-
-# setting device type for whisper model
-whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
-log.info(f"whisper_device_type: {whisper_device_type}")
-
 SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
 SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
 SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
 SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
 
 
 
 
+##########################################
+#
+# Utility functions
+#
+##########################################
+
+from pydub import AudioSegment
+from pydub.utils import mediainfo
+
+
+def is_mp4_audio(file_path):
+    """Check if the given file is an MP4 audio file."""
+    if not os.path.isfile(file_path):
+        print(f"File not found: {file_path}")
+        return False
+
+    info = mediainfo(file_path)
+    if (
+        info.get("codec_name") == "aac"
+        and info.get("codec_type") == "audio"
+        and info.get("codec_tag_string") == "mp4a"
+    ):
+        return True
+    return False
+
+
+def convert_mp4_to_wav(file_path, output_path):
+    """Convert MP4 audio file to WAV format."""
+    audio = AudioSegment.from_file(file_path, format="mp4")
+    audio.export(output_path, format="wav")
+    print(f"Converted {file_path} to {output_path}")
+
+
 def set_faster_whisper_model(model: str, auto_update: bool = False):
 def set_faster_whisper_model(model: str, auto_update: bool = False):
-    if model and app.state.config.STT_ENGINE == "":
+    whisper_model = None
+    if model:
         from faster_whisper import WhisperModel
         from faster_whisper import WhisperModel
 
 
         faster_whisper_kwargs = {
         faster_whisper_kwargs = {
             "model_size_or_path": model,
             "model_size_or_path": model,
-            "device": whisper_device_type,
+            "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
             "compute_type": "int8",
             "compute_type": "int8",
             "download_root": WHISPER_MODEL_DIR,
             "download_root": WHISPER_MODEL_DIR,
             "local_files_only": not auto_update,
             "local_files_only": not auto_update,
         }
         }
 
 
         try:
         try:
-            app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
+            whisper_model = WhisperModel(**faster_whisper_kwargs)
         except Exception:
         except Exception:
             log.warning(
             log.warning(
                 "WhisperModel initialization failed, attempting download with local_files_only=False"
                 "WhisperModel initialization failed, attempting download with local_files_only=False"
             )
             )
             faster_whisper_kwargs["local_files_only"] = False
             faster_whisper_kwargs["local_files_only"] = False
-            app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
-
-    else:
-        app.state.faster_whisper_model = None
+            whisper_model = WhisperModel(**faster_whisper_kwargs)
+    return whisper_model
 
 
 
 
 class TTSConfigForm(BaseModel):
 class TTSConfigForm(BaseModel):
@@ -113,98 +138,75 @@ class AudioConfigUpdateForm(BaseModel):
     stt: STTConfigForm
     stt: STTConfigForm
 
 
 
 
-from pydub import AudioSegment
-from pydub.utils import mediainfo
-
-
-def is_mp4_audio(file_path):
-    """Check if the given file is an MP4 audio file."""
-    if not os.path.isfile(file_path):
-        print(f"File not found: {file_path}")
-        return False
-
-    info = mediainfo(file_path)
-    if (
-        info.get("codec_name") == "aac"
-        and info.get("codec_type") == "audio"
-        and info.get("codec_tag_string") == "mp4a"
-    ):
-        return True
-    return False
-
-
-def convert_mp4_to_wav(file_path, output_path):
-    """Convert MP4 audio file to WAV format."""
-    audio = AudioSegment.from_file(file_path, format="mp4")
-    audio.export(output_path, format="wav")
-    print(f"Converted {file_path} to {output_path}")
-
-
-@app.get("/config")
-async def get_audio_config(user=Depends(get_admin_user)):
+@router.get("/config")
+async def get_audio_config(request: Request, user=Depends(get_admin_user)):
     return {
     return {
         "tts": {
         "tts": {
-            "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
-            "API_KEY": app.state.config.TTS_API_KEY,
-            "ENGINE": app.state.config.TTS_ENGINE,
-            "MODEL": app.state.config.TTS_MODEL,
-            "VOICE": app.state.config.TTS_VOICE,
-            "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
-            "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
-            "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
+            "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
+            "API_KEY": request.app.state.config.TTS_API_KEY,
+            "ENGINE": request.app.state.config.TTS_ENGINE,
+            "MODEL": request.app.state.config.TTS_MODEL,
+            "VOICE": request.app.state.config.TTS_VOICE,
+            "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
+            "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
+            "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
         },
         },
         "stt": {
         "stt": {
-            "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
-            "ENGINE": app.state.config.STT_ENGINE,
-            "MODEL": app.state.config.STT_MODEL,
-            "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
+            "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
+            "ENGINE": request.app.state.config.STT_ENGINE,
+            "MODEL": request.app.state.config.STT_MODEL,
+            "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
         },
         },
     }
     }
 
 
 
 
-@app.post("/config/update")
+@router.post("/config/update")
 async def update_audio_config(
 async def update_audio_config(
-    form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
+    request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
 ):
 ):
-    app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
-    app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
-    app.state.config.TTS_API_KEY = form_data.tts.API_KEY
-    app.state.config.TTS_ENGINE = form_data.tts.ENGINE
-    app.state.config.TTS_MODEL = form_data.tts.MODEL
-    app.state.config.TTS_VOICE = form_data.tts.VOICE
-    app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
-    app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
-    app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
+    request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
+    request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
+    request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
+    request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
+    request.app.state.config.TTS_MODEL = form_data.tts.MODEL
+    request.app.state.config.TTS_VOICE = form_data.tts.VOICE
+    request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
+    request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
+    request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
         form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
         form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
     )
     )
 
 
-    app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
-    app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
-    app.state.config.STT_ENGINE = form_data.stt.ENGINE
-    app.state.config.STT_MODEL = form_data.stt.MODEL
-    app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
-    set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
+    request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
+    request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
+    request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
+    request.app.state.config.STT_MODEL = form_data.stt.MODEL
+    request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
+
+    if request.app.state.config.STT_ENGINE == "":
+        request.app.state.faster_whisper_model = set_faster_whisper_model(
+            form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
+        )
 
 
     return {
     return {
         "tts": {
         "tts": {
-            "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
-            "API_KEY": app.state.config.TTS_API_KEY,
-            "ENGINE": app.state.config.TTS_ENGINE,
-            "MODEL": app.state.config.TTS_MODEL,
-            "VOICE": app.state.config.TTS_VOICE,
-            "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
-            "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
-            "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
+            "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
+            "API_KEY": request.app.state.config.TTS_API_KEY,
+            "ENGINE": request.app.state.config.TTS_ENGINE,
+            "MODEL": request.app.state.config.TTS_MODEL,
+            "VOICE": request.app.state.config.TTS_VOICE,
+            "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
+            "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
+            "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
         },
         },
         "stt": {
         "stt": {
-            "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
-            "ENGINE": app.state.config.STT_ENGINE,
-            "MODEL": app.state.config.STT_MODEL,
-            "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
+            "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
+            "ENGINE": request.app.state.config.STT_ENGINE,
+            "MODEL": request.app.state.config.STT_MODEL,
+            "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
         },
         },
     }
     }
 
 
@@ -213,18 +215,18 @@ def load_speech_pipeline():
     from transformers import pipeline
     from transformers import pipeline
     from datasets import load_dataset
     from datasets import load_dataset
 
 
-    if app.state.speech_synthesiser is None:
-        app.state.speech_synthesiser = pipeline(
+    if request.app.state.speech_synthesiser is None:
+        request.app.state.speech_synthesiser = pipeline(
             "text-to-speech", "microsoft/speecht5_tts"
             "text-to-speech", "microsoft/speecht5_tts"
         )
         )
 
 
-    if app.state.speech_speaker_embeddings_dataset is None:
-        app.state.speech_speaker_embeddings_dataset = load_dataset(
+    if request.app.state.speech_speaker_embeddings_dataset is None:
+        request.app.state.speech_speaker_embeddings_dataset = load_dataset(
             "Matthijs/cmu-arctic-xvectors", split="validation"
             "Matthijs/cmu-arctic-xvectors", split="validation"
         )
         )
 
 
 
 
-@app.post("/speech")
+@router.post("/speech")
 async def speech(request: Request, user=Depends(get_verified_user)):
 async def speech(request: Request, user=Depends(get_verified_user)):
     body = await request.body()
     body = await request.body()
     name = hashlib.sha256(body).hexdigest()
     name = hashlib.sha256(body).hexdigest()
@@ -236,9 +238,11 @@ async def speech(request: Request, user=Depends(get_verified_user)):
     if file_path.is_file():
     if file_path.is_file():
         return FileResponse(file_path)
         return FileResponse(file_path)
 
 
-    if app.state.config.TTS_ENGINE == "openai":
+    if request.app.state.config.TTS_ENGINE == "openai":
         headers = {}
         headers = {}
-        headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
+        headers["Authorization"] = (
+            f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}"
+        )
         headers["Content-Type"] = "application/json"
         headers["Content-Type"] = "application/json"
 
 
         if ENABLE_FORWARD_USER_INFO_HEADERS:
         if ENABLE_FORWARD_USER_INFO_HEADERS:
@@ -250,7 +254,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         try:
         try:
             body = body.decode("utf-8")
             body = body.decode("utf-8")
             body = json.loads(body)
             body = json.loads(body)
-            body["model"] = app.state.config.TTS_MODEL
+            body["model"] = request.app.state.config.TTS_MODEL
             body = json.dumps(body).encode("utf-8")
             body = json.dumps(body).encode("utf-8")
         except Exception:
         except Exception:
             pass
             pass
@@ -258,7 +262,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         try:
         try:
             async with aiohttp.ClientSession() as session:
             async with aiohttp.ClientSession() as session:
                 async with session.post(
                 async with session.post(
-                    url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
+                    url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
                     data=body,
                     data=body,
                     headers=headers,
                     headers=headers,
                 ) as r:
                 ) as r:
@@ -287,7 +291,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
                 detail=error_detail,
                 detail=error_detail,
             )
             )
 
 
-    elif app.state.config.TTS_ENGINE == "elevenlabs":
+    elif request.app.state.config.TTS_ENGINE == "elevenlabs":
         try:
         try:
             payload = json.loads(body.decode("utf-8"))
             payload = json.loads(body.decode("utf-8"))
         except Exception as e:
         except Exception as e:
@@ -305,11 +309,11 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         headers = {
         headers = {
             "Accept": "audio/mpeg",
             "Accept": "audio/mpeg",
             "Content-Type": "application/json",
             "Content-Type": "application/json",
-            "xi-api-key": app.state.config.TTS_API_KEY,
+            "xi-api-key": request.app.state.config.TTS_API_KEY,
         }
         }
         data = {
         data = {
             "text": payload["input"],
             "text": payload["input"],
-            "model_id": app.state.config.TTS_MODEL,
+            "model_id": request.app.state.config.TTS_MODEL,
             "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
             "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
         }
         }
 
 
@@ -341,21 +345,21 @@ async def speech(request: Request, user=Depends(get_verified_user)):
                 detail=error_detail,
                 detail=error_detail,
             )
             )
 
 
-    elif app.state.config.TTS_ENGINE == "azure":
+    elif request.app.state.config.TTS_ENGINE == "azure":
         try:
         try:
             payload = json.loads(body.decode("utf-8"))
             payload = json.loads(body.decode("utf-8"))
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
             raise HTTPException(status_code=400, detail="Invalid JSON payload")
             raise HTTPException(status_code=400, detail="Invalid JSON payload")
 
 
-        region = app.state.config.TTS_AZURE_SPEECH_REGION
-        language = app.state.config.TTS_VOICE
-        locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1])
-        output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
+        region = request.app.state.config.TTS_AZURE_SPEECH_REGION
+        language = request.app.state.config.TTS_VOICE
+        locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
+        output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
         url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
         url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
 
 
         headers = {
         headers = {
-            "Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY,
+            "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
             "Content-Type": "application/ssml+xml",
             "Content-Type": "application/ssml+xml",
             "X-Microsoft-OutputFormat": output_format,
             "X-Microsoft-OutputFormat": output_format,
         }
         }
@@ -378,7 +382,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
             raise HTTPException(status_code=500, detail=str(e))
             raise HTTPException(status_code=500, detail=str(e))
-    elif app.state.config.TTS_ENGINE == "transformers":
+    elif request.app.state.config.TTS_ENGINE == "transformers":
         payload = None
         payload = None
         try:
         try:
             payload = json.loads(body.decode("utf-8"))
             payload = json.loads(body.decode("utf-8"))
@@ -391,12 +395,12 @@ async def speech(request: Request, user=Depends(get_verified_user)):
 
 
         load_speech_pipeline()
         load_speech_pipeline()
 
 
-        embeddings_dataset = app.state.speech_speaker_embeddings_dataset
+        embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
 
 
         speaker_index = 6799
         speaker_index = 6799
         try:
         try:
             speaker_index = embeddings_dataset["filename"].index(
             speaker_index = embeddings_dataset["filename"].index(
-                app.state.config.TTS_MODEL
+                request.app.state.config.TTS_MODEL
             )
             )
         except Exception:
         except Exception:
             pass
             pass
@@ -405,7 +409,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
             embeddings_dataset[speaker_index]["xvector"]
             embeddings_dataset[speaker_index]["xvector"]
         ).unsqueeze(0)
         ).unsqueeze(0)
 
 
-        speech = app.state.speech_synthesiser(
+        speech = request.app.state.speech_synthesiser(
             payload["input"],
             payload["input"],
             forward_params={"speaker_embeddings": speaker_embedding},
             forward_params={"speaker_embeddings": speaker_embedding},
         )
         )
@@ -417,17 +421,19 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         return FileResponse(file_path)
         return FileResponse(file_path)
 
 
 
 
-def transcribe(file_path):
+def transcribe(request: Request, file_path):
     print("transcribe", file_path)
     print("transcribe", file_path)
     filename = os.path.basename(file_path)
     filename = os.path.basename(file_path)
     file_dir = os.path.dirname(file_path)
     file_dir = os.path.dirname(file_path)
     id = filename.split(".")[0]
     id = filename.split(".")[0]
 
 
-    if app.state.config.STT_ENGINE == "":
-        if app.state.faster_whisper_model is None:
-            set_faster_whisper_model(app.state.config.WHISPER_MODEL)
+    if request.app.state.config.STT_ENGINE == "":
+        if request.app.state.faster_whisper_model is None:
+            request.app.state.faster_whisper_model = set_faster_whisper_model(
+                request.app.state.config.WHISPER_MODEL
+            )
 
 
-        model = app.state.faster_whisper_model
+        model = request.app.state.faster_whisper_model
         segments, info = model.transcribe(file_path, beam_size=5)
         segments, info = model.transcribe(file_path, beam_size=5)
         log.info(
         log.info(
             "Detected language '%s' with probability %f"
             "Detected language '%s' with probability %f"
@@ -444,31 +450,24 @@ def transcribe(file_path):
 
 
         log.debug(data)
         log.debug(data)
         return data
         return data
-    elif app.state.config.STT_ENGINE == "openai":
+    elif request.app.state.config.STT_ENGINE == "openai":
         if is_mp4_audio(file_path):
         if is_mp4_audio(file_path):
-            print("is_mp4_audio")
             os.rename(file_path, file_path.replace(".wav", ".mp4"))
             os.rename(file_path, file_path.replace(".wav", ".mp4"))
             # Convert MP4 audio file to WAV format
             # Convert MP4 audio file to WAV format
             convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
             convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
 
 
-        headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
-
-        files = {"file": (filename, open(file_path, "rb"))}
-        data = {"model": app.state.config.STT_MODEL}
-
-        log.debug(files, data)
-
         r = None
         r = None
         try:
         try:
             r = requests.post(
             r = requests.post(
-                url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
-                headers=headers,
-                files=files,
-                data=data,
+                url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
+                headers={
+                    "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
+                },
+                files={"file": (filename, open(file_path, "rb"))},
+                data={"model": request.app.state.config.STT_MODEL},
             )
             )
 
 
             r.raise_for_status()
             r.raise_for_status()
-
             data = r.json()
             data = r.json()
 
 
             # save the transcript to a json file
             # save the transcript to a json file
@@ -476,24 +475,43 @@ def transcribe(file_path):
             with open(transcript_file, "w") as f:
             with open(transcript_file, "w") as f:
                 json.dump(data, f)
                 json.dump(data, f)
 
 
-            print(data)
             return data
             return data
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
+
+            detail = None
             if r is not None:
             if r is not None:
                 try:
                 try:
                     res = r.json()
                     res = r.json()
                     if "error" in res:
                     if "error" in res:
-                        error_detail = f"External: {res['error']['message']}"
+                        detail = f"External: {res['error'].get('message', '')}"
                 except Exception:
                 except Exception:
-                    error_detail = f"External: {e}"
+                    detail = f"External: {e}"
+
+            raise Exception(detail if detail else "Open WebUI: Server Connection Error")
 
 
-            raise Exception(error_detail)
 
 
+def compress_audio(file_path):
+    if os.path.getsize(file_path) > MAX_FILE_SIZE:
+        file_dir = os.path.dirname(file_path)
+        audio = AudioSegment.from_file(file_path)
+        audio = audio.set_frame_rate(16000).set_channels(1)  # Compress audio
+        compressed_path = f"{file_dir}/{id}_compressed.opus"
+        audio.export(compressed_path, format="opus", bitrate="32k")
+        log.debug(f"Compressed audio to {compressed_path}")
 
 
-@app.post("/transcriptions")
+        if (
+            os.path.getsize(compressed_path) > MAX_FILE_SIZE
+        ):  # Still larger than MAX_FILE_SIZE after compression
+            raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
+        return compressed_path
+    else:
+        return file_path
+
+
+@router.post("/transcriptions")
 def transcription(
 def transcription(
+    request: Request,
     file: UploadFile = File(...),
     file: UploadFile = File(...),
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
 ):
 ):
@@ -520,36 +538,22 @@ def transcription(
             f.write(contents)
             f.write(contents)
 
 
         try:
         try:
-            if os.path.getsize(file_path) > MAX_FILE_SIZE:  # file is bigger than 25MB
-                log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB")
-                audio = AudioSegment.from_file(file_path)
-                audio = audio.set_frame_rate(16000).set_channels(1)  # Compress audio
-                compressed_path = f"{file_dir}/{id}_compressed.opus"
-                audio.export(compressed_path, format="opus", bitrate="32k")
-                log.debug(f"Compressed audio to {compressed_path}")
-                file_path = compressed_path
-
-                if (
-                    os.path.getsize(file_path) > MAX_FILE_SIZE
-                ):  # Still larger than 25MB after compression
-                    log.debug(
-                        f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}"
-                    )
-                    raise HTTPException(
-                        status_code=status.HTTP_400_BAD_REQUEST,
-                        detail=ERROR_MESSAGES.FILE_TOO_LARGE(
-                            size=f"{MAX_FILE_SIZE_MB}MB"
-                        ),
-                    )
-
-                data = transcribe(file_path)
-            else:
-                data = transcribe(file_path)
+            try:
+                file_path = compress_audio(file_path)
+            except Exception as e:
+                log.exception(e)
+
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT(e),
+                )
 
 
+            data = transcribe(request, file_path)
             file_path = file_path.split("/")[-1]
             file_path = file_path.split("/")[-1]
             return {**data, "filename": file_path}
             return {**data, "filename": file_path}
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
+
             raise HTTPException(
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,
                 status_code=status.HTTP_400_BAD_REQUEST,
                 detail=ERROR_MESSAGES.DEFAULT(e),
                 detail=ERROR_MESSAGES.DEFAULT(e),
@@ -564,39 +568,41 @@ def transcription(
         )
         )
 
 
 
 
-def get_available_models() -> list[dict]:
-    if app.state.config.TTS_ENGINE == "openai":
-        return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
-    elif app.state.config.TTS_ENGINE == "elevenlabs":
-        headers = {
-            "xi-api-key": app.state.config.TTS_API_KEY,
-            "Content-Type": "application/json",
-        }
-
+def get_available_models(request: Request) -> list[dict]:
+    available_models = []
+    if request.app.state.config.TTS_ENGINE == "openai":
+        available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
+    elif request.app.state.config.TTS_ENGINE == "elevenlabs":
         try:
         try:
             response = requests.get(
             response = requests.get(
-                "https://api.elevenlabs.io/v1/models", headers=headers, timeout=5
+                "https://api.elevenlabs.io/v1/models",
+                headers={
+                    "xi-api-key": request.app.state.config.TTS_API_KEY,
+                    "Content-Type": "application/json",
+                },
+                timeout=5,
             )
             )
             response.raise_for_status()
             response.raise_for_status()
             models = response.json()
             models = response.json()
-            return [
+
+            available_models = [
                 {"name": model["name"], "id": model["model_id"]} for model in models
                 {"name": model["name"], "id": model["model_id"]} for model in models
             ]
             ]
         except requests.RequestException as e:
         except requests.RequestException as e:
             log.error(f"Error fetching voices: {str(e)}")
             log.error(f"Error fetching voices: {str(e)}")
-    return []
+    return available_models
 
 
 
 
-@app.get("/models")
-async def get_models(user=Depends(get_verified_user)):
-    return {"models": get_available_models()}
+@router.get("/models")
+async def get_models(request: Request, user=Depends(get_verified_user)):
+    return {"models": get_available_models(request)}
 
 
 
 
-def get_available_voices() -> dict:
+def get_available_voices(request) -> dict:
     """Returns {voice_id: voice_name} dict"""
     """Returns {voice_id: voice_name} dict"""
-    ret = {}
-    if app.state.config.TTS_ENGINE == "openai":
-        ret = {
+    available_voices = {}
+    if request.app.state.config.TTS_ENGINE == "openai":
+        available_voices = {
             "alloy": "alloy",
             "alloy": "alloy",
             "echo": "echo",
             "echo": "echo",
             "fable": "fable",
             "fable": "fable",
@@ -604,33 +610,38 @@ def get_available_voices() -> dict:
             "nova": "nova",
             "nova": "nova",
             "shimmer": "shimmer",
             "shimmer": "shimmer",
         }
         }
-    elif app.state.config.TTS_ENGINE == "elevenlabs":
+    elif request.app.state.config.TTS_ENGINE == "elevenlabs":
         try:
         try:
-            ret = get_elevenlabs_voices()
+            available_voices = get_elevenlabs_voices(
+                api_key=request.app.state.config.TTS_API_KEY
+            )
         except Exception:
         except Exception:
             # Avoided @lru_cache with exception
             # Avoided @lru_cache with exception
             pass
             pass
-    elif app.state.config.TTS_ENGINE == "azure":
+    elif request.app.state.config.TTS_ENGINE == "azure":
         try:
         try:
-            region = app.state.config.TTS_AZURE_SPEECH_REGION
+            region = request.app.state.config.TTS_AZURE_SPEECH_REGION
             url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
             url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
-            headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY}
+            headers = {
+                "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
+            }
 
 
             response = requests.get(url, headers=headers)
             response = requests.get(url, headers=headers)
             response.raise_for_status()
             response.raise_for_status()
             voices = response.json()
             voices = response.json()
+
             for voice in voices:
             for voice in voices:
-                ret[voice["ShortName"]] = (
+                available_voices[voice["ShortName"]] = (
                     f"{voice['DisplayName']} ({voice['ShortName']})"
                     f"{voice['DisplayName']} ({voice['ShortName']})"
                 )
                 )
         except requests.RequestException as e:
         except requests.RequestException as e:
             log.error(f"Error fetching voices: {str(e)}")
             log.error(f"Error fetching voices: {str(e)}")
 
 
-    return ret
+    return available_voices
 
 
 
 
 @lru_cache
 @lru_cache
-def get_elevenlabs_voices() -> dict:
+def get_elevenlabs_voices(api_key: str) -> dict:
     """
     """
     Note, set the following in your .env file to use Elevenlabs:
     Note, set the following in your .env file to use Elevenlabs:
     AUDIO_TTS_ENGINE=elevenlabs
     AUDIO_TTS_ENGINE=elevenlabs
@@ -638,13 +649,16 @@ def get_elevenlabs_voices() -> dict:
     AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL  # From https://api.elevenlabs.io/v1/voices
     AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL  # From https://api.elevenlabs.io/v1/voices
     AUDIO_TTS_MODEL=eleven_multilingual_v2
     AUDIO_TTS_MODEL=eleven_multilingual_v2
     """
     """
-    headers = {
-        "xi-api-key": app.state.config.TTS_API_KEY,
-        "Content-Type": "application/json",
-    }
+
     try:
     try:
         # TODO: Add retries
         # TODO: Add retries
-        response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers)
+        response = requests.get(
+            "https://api.elevenlabs.io/v1/voices",
+            headers={
+                "xi-api-key": api_key,
+                "Content-Type": "application/json",
+            },
+        )
         response.raise_for_status()
         response.raise_for_status()
         voices_data = response.json()
         voices_data = response.json()
 
 
@@ -659,6 +673,10 @@ def get_elevenlabs_voices() -> dict:
     return voices
     return voices
 
 
 
 
-@app.get("/voices")
-async def get_voices(user=Depends(get_verified_user)):
-    return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]}
+@router.get("/voices")
+async def get_voices(request: Request, user=Depends(get_verified_user)):
+    return {
+        "voices": [
+            {"id": k, "name": v} for k, v in get_available_voices(request).items()
+        ]
+    }

+ 122 - 183
backend/open_webui/routers/ollama.py

@@ -385,7 +385,7 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
     if request.app.state.config.ENABLE_OLLAMA_API:
     if request.app.state.config.ENABLE_OLLAMA_API:
         if url_idx is None:
         if url_idx is None:
             # returns lowest version
             # returns lowest version
-            tasks = [
+            request_tasks = [
                 send_get_request(
                 send_get_request(
                     f"{url}/api/version",
                     f"{url}/api/version",
                     request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
                     request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
@@ -394,7 +394,7 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
                 )
                 )
                 for url in request.app.state.config.OLLAMA_BASE_URLS
                 for url in request.app.state.config.OLLAMA_BASE_URLS
             ]
             ]
-            responses = await asyncio.gather(*tasks)
+            responses = await asyncio.gather(*request_tasks)
             responses = list(filter(lambda x: x is not None, responses))
             responses = list(filter(lambda x: x is not None, responses))
 
 
             if len(responses) > 0:
             if len(responses) > 0:
@@ -446,7 +446,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u
     List models that are currently loaded into Ollama memory, and which node they are loaded on.
     List models that are currently loaded into Ollama memory, and which node they are loaded on.
     """
     """
     if request.app.state.config.ENABLE_OLLAMA_API:
     if request.app.state.config.ENABLE_OLLAMA_API:
-        tasks = [
+        request_tasks = [
             send_get_request(
             send_get_request(
                 f"{url}/api/ps",
                 f"{url}/api/ps",
                 request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
                 request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
@@ -455,7 +455,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u
             )
             )
             for url in request.app.state.config.OLLAMA_BASE_URLS
             for url in request.app.state.config.OLLAMA_BASE_URLS
         ]
         ]
-        responses = await asyncio.gather(*tasks)
+        responses = await asyncio.gather(*request_tasks)
 
 
         return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
         return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
     else:
     else:
@@ -502,8 +502,8 @@ async def push_model(
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
 ):
 ):
     if url_idx is None:
     if url_idx is None:
-        model_list = await get_all_models()
-        models = {model["model"]: model for model in model_list["models"]}
+        await get_all_models(request)
+        models = request.app.state.OLLAMA_MODELS
 
 
         if form_data.name in models:
         if form_data.name in models:
             url_idx = models[form_data.name]["urls"][0]
             url_idx = models[form_data.name]["urls"][0]
@@ -540,7 +540,6 @@ async def create_model(
 ):
 ):
     log.debug(f"form_data: {form_data}")
     log.debug(f"form_data: {form_data}")
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
 
 
     return await send_post_request(
     return await send_post_request(
         url=f"{url}/api/create",
         url=f"{url}/api/create",
@@ -563,8 +562,8 @@ async def copy_model(
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
 ):
 ):
     if url_idx is None:
     if url_idx is None:
-        model_list = await get_all_models()
-        models = {model["model"]: model for model in model_list["models"]}
+        await get_all_models()
+        models = request.app.state.OLLAMA_MODELS
 
 
         if form_data.source in models:
         if form_data.source in models:
             url_idx = models[form_data.source]["urls"][0]
             url_idx = models[form_data.source]["urls"][0]
@@ -575,45 +574,37 @@ async def copy_model(
             )
             )
 
 
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
-
-    parsed_url = urlparse(url)
-    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
-
-    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
-    key = api_config.get("key", None)
-
-    headers = {"Content-Type": "application/json"}
-    if key:
-        headers["Authorization"] = f"Bearer {key}"
-
-    r = requests.request(
-        method="POST",
-        url=f"{url}/api/copy",
-        headers=headers,
-        data=form_data.model_dump_json(exclude_none=True).encode(),
-    )
+    key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
 
 
     try:
     try:
+        r = requests.request(
+            method="POST",
+            url=f"{url}/api/copy",
+            headers={
+                "Content-Type": "application/json",
+                **({"Authorization": f"Bearer {key}"} if key else {}),
+            },
+            data=form_data.model_dump_json(exclude_none=True).encode(),
+        )
         r.raise_for_status()
         r.raise_for_status()
 
 
         log.debug(f"r.text: {r.text}")
         log.debug(f"r.text: {r.text}")
-
         return True
         return True
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
+
+        detail = None
         if r is not None:
         if r is not None:
             try:
             try:
                 res = r.json()
                 res = r.json()
                 if "error" in res:
                 if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
+                    detail = f"Ollama: {res['error']}"
             except Exception:
             except Exception:
-                error_detail = f"Ollama: {e}"
+                detail = f"Ollama: {e}"
 
 
         raise HTTPException(
         raise HTTPException(
             status_code=r.status_code if r else 500,
             status_code=r.status_code if r else 500,
-            detail=error_detail,
+            detail=detail if detail else "Open WebUI: Server Connection Error",
         )
         )
 
 
 
 
@@ -626,8 +617,8 @@ async def delete_model(
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
 ):
 ):
     if url_idx is None:
     if url_idx is None:
-        model_list = await get_all_models()
-        models = {model["model"]: model for model in model_list["models"]}
+        await get_all_models()
+        models = request.app.state.OLLAMA_MODELS
 
 
         if form_data.name in models:
         if form_data.name in models:
             url_idx = models[form_data.name]["urls"][0]
             url_idx = models[form_data.name]["urls"][0]
@@ -638,44 +629,37 @@ async def delete_model(
             )
             )
 
 
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
+    key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
 
 
-    parsed_url = urlparse(url)
-    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
-
-    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
-    key = api_config.get("key", None)
-
-    headers = {"Content-Type": "application/json"}
-    if key:
-        headers["Authorization"] = f"Bearer {key}"
-
-    r = requests.request(
-        method="DELETE",
-        url=f"{url}/api/delete",
-        data=form_data.model_dump_json(exclude_none=True).encode(),
-        headers=headers,
-    )
     try:
     try:
+        r = requests.request(
+            method="DELETE",
+            url=f"{url}/api/delete",
+            data=form_data.model_dump_json(exclude_none=True).encode(),
+            headers={
+                "Content-Type": "application/json",
+                **({"Authorization": f"Bearer {key}"} if key else {}),
+            },
+        )
         r.raise_for_status()
         r.raise_for_status()
 
 
         log.debug(f"r.text: {r.text}")
         log.debug(f"r.text: {r.text}")
-
         return True
         return True
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
+
+        detail = None
         if r is not None:
         if r is not None:
             try:
             try:
                 res = r.json()
                 res = r.json()
                 if "error" in res:
                 if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
+                    detail = f"Ollama: {res['error']}"
             except Exception:
             except Exception:
-                error_detail = f"Ollama: {e}"
+                detail = f"Ollama: {e}"
 
 
         raise HTTPException(
         raise HTTPException(
             status_code=r.status_code if r else 500,
             status_code=r.status_code if r else 500,
-            detail=error_detail,
+            detail=detail if detail else "Open WebUI: Server Connection Error",
         )
         )
 
 
 
 
@@ -683,8 +667,8 @@ async def delete_model(
 async def show_model_info(
 async def show_model_info(
     request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
     request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
 ):
 ):
-    model_list = await get_all_models()
-    models = {model["model"]: model for model in model_list["models"]}
+    await get_all_models()
+    models = request.app.state.OLLAMA_MODELS
 
 
     if form_data.name not in models:
     if form_data.name not in models:
         raise HTTPException(
         raise HTTPException(
@@ -693,53 +677,41 @@ async def show_model_info(
         )
         )
 
 
     url_idx = random.choice(models[form_data.name]["urls"])
     url_idx = random.choice(models[form_data.name]["urls"])
-    url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
-
-    parsed_url = urlparse(url)
-    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
-
-    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
-    key = api_config.get("key", None)
 
 
-    headers = {"Content-Type": "application/json"}
-    if key:
-        headers["Authorization"] = f"Bearer {key}"
+    url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
+    key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
 
 
-    r = requests.request(
-        method="POST",
-        url=f"{url}/api/show",
-        headers=headers,
-        data=form_data.model_dump_json(exclude_none=True).encode(),
-    )
     try:
     try:
+        r = requests.request(
+            method="POST",
+            url=f"{url}/api/show",
+            headers={
+                "Content-Type": "application/json",
+                **({"Authorization": f"Bearer {key}"} if key else {}),
+            },
+            data=form_data.model_dump_json(exclude_none=True).encode(),
+        )
         r.raise_for_status()
         r.raise_for_status()
 
 
         return r.json()
         return r.json()
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
+
+        detail = None
         if r is not None:
         if r is not None:
             try:
             try:
                 res = r.json()
                 res = r.json()
                 if "error" in res:
                 if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
+                    detail = f"Ollama: {res['error']}"
             except Exception:
             except Exception:
-                error_detail = f"Ollama: {e}"
+                detail = f"Ollama: {e}"
 
 
         raise HTTPException(
         raise HTTPException(
             status_code=r.status_code if r else 500,
             status_code=r.status_code if r else 500,
-            detail=error_detail,
+            detail=detail if detail else "Open WebUI: Server Connection Error",
         )
         )
 
 
 
 
-class GenerateEmbeddingsForm(BaseModel):
-    model: str
-    prompt: str
-    options: Optional[dict] = None
-    keep_alive: Optional[Union[int, str]] = None
-
-
 class GenerateEmbedForm(BaseModel):
 class GenerateEmbedForm(BaseModel):
     model: str
     model: str
     input: list[str] | str
     input: list[str] | str
@@ -750,33 +722,17 @@ class GenerateEmbedForm(BaseModel):
 
 
 @router.post("/api/embed")
 @router.post("/api/embed")
 @router.post("/api/embed/{url_idx}")
 @router.post("/api/embed/{url_idx}")
-async def generate_embeddings(
+async def embed(
+    request: Request,
     form_data: GenerateEmbedForm,
     form_data: GenerateEmbedForm,
     url_idx: Optional[int] = None,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
 ):
 ):
-    return await generate_ollama_batch_embeddings(form_data, url_idx)
-
-
-@router.post("/api/embeddings")
-@router.post("/api/embeddings/{url_idx}")
-async def generate_embeddings(
-    form_data: GenerateEmbeddingsForm,
-    url_idx: Optional[int] = None,
-    user=Depends(get_verified_user),
-):
-    return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
-
-
-async def generate_ollama_embeddings(
-    form_data: GenerateEmbeddingsForm,
-    url_idx: Optional[int] = None,
-):
-    log.info(f"generate_ollama_embeddings {form_data}")
+    log.info(f"generate_ollama_batch_embeddings {form_data}")
 
 
     if url_idx is None:
     if url_idx is None:
-        model_list = await get_all_models()
-        models = {model["model"]: model for model in model_list["models"]}
+        await get_all_models()
+        models = request.app.state.OLLAMA_MODELS
 
 
         model = form_data.model
         model = form_data.model
 
 
@@ -792,61 +748,60 @@ async def generate_ollama_embeddings(
             )
             )
 
 
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
+    key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
 
 
-    parsed_url = urlparse(url)
-    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
-
-    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
-    key = api_config.get("key", None)
-
-    headers = {"Content-Type": "application/json"}
-    if key:
-        headers["Authorization"] = f"Bearer {key}"
-
-    r = requests.request(
-        method="POST",
-        url=f"{url}/api/embeddings",
-        headers=headers,
-        data=form_data.model_dump_json(exclude_none=True).encode(),
-    )
     try:
     try:
+        r = requests.request(
+            method="POST",
+            url=f"{url}/api/embed",
+            headers={
+                "Content-Type": "application/json",
+                **({"Authorization": f"Bearer {key}"} if key else {}),
+            },
+            data=form_data.model_dump_json(exclude_none=True).encode(),
+        )
         r.raise_for_status()
         r.raise_for_status()
 
 
         data = r.json()
         data = r.json()
-
-        log.info(f"generate_ollama_embeddings {data}")
-
-        if "embedding" in data:
-            return data
-        else:
-            raise Exception("Something went wrong :/")
+        return data
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
+
+        detail = None
         if r is not None:
         if r is not None:
             try:
             try:
                 res = r.json()
                 res = r.json()
                 if "error" in res:
                 if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
+                    detail = f"Ollama: {res['error']}"
             except Exception:
             except Exception:
-                error_detail = f"Ollama: {e}"
+                detail = f"Ollama: {e}"
 
 
         raise HTTPException(
         raise HTTPException(
             status_code=r.status_code if r else 500,
             status_code=r.status_code if r else 500,
-            detail=error_detail,
+            detail=detail if detail else "Open WebUI: Server Connection Error",
         )
         )
 
 
 
 
-async def generate_ollama_batch_embeddings(
-    form_data: GenerateEmbedForm,
+class GenerateEmbeddingsForm(BaseModel):
+    model: str
+    prompt: str
+    options: Optional[dict] = None
+    keep_alive: Optional[Union[int, str]] = None
+
+
+@router.post("/api/embeddings")
+@router.post("/api/embeddings/{url_idx}")
+async def embeddings(
+    request: Request,
+    form_data: GenerateEmbeddingsForm,
     url_idx: Optional[int] = None,
     url_idx: Optional[int] = None,
+    user=Depends(get_verified_user),
 ):
 ):
-    log.info(f"generate_ollama_batch_embeddings {form_data}")
+    log.info(f"generate_ollama_embeddings {form_data}")
 
 
     if url_idx is None:
     if url_idx is None:
-        model_list = await get_all_models()
-        models = {model["model"]: model for model in model_list["models"]}
+        await get_all_models()
+        models = request.app.state.OLLAMA_MODELS
 
 
         model = form_data.model
         model = form_data.model
 
 
@@ -862,47 +817,38 @@ async def generate_ollama_batch_embeddings(
             )
             )
 
 
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
-
-    parsed_url = urlparse(url)
-    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
+    key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
 
 
-    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
-    key = api_config.get("key", None)
-
-    headers = {"Content-Type": "application/json"}
-    if key:
-        headers["Authorization"] = f"Bearer {key}"
-
-    r = requests.request(
-        method="POST",
-        url=f"{url}/api/embed",
-        headers=headers,
-        data=form_data.model_dump_json(exclude_none=True).encode(),
-    )
     try:
     try:
+        r = requests.request(
+            method="POST",
+            url=f"{url}/api/embeddings",
+            headers={
+                "Content-Type": "application/json",
+                **({"Authorization": f"Bearer {key}"} if key else {}),
+            },
+            data=form_data.model_dump_json(exclude_none=True).encode(),
+        )
         r.raise_for_status()
         r.raise_for_status()
 
 
         data = r.json()
         data = r.json()
-
-        log.info(f"generate_ollama_batch_embeddings {data}")
-
-        if "embeddings" in data:
-            return data
-        else:
-            raise Exception("Something went wrong :/")
+        return data
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
+
+        detail = None
         if r is not None:
         if r is not None:
             try:
             try:
                 res = r.json()
                 res = r.json()
                 if "error" in res:
                 if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
+                    detail = f"Ollama: {res['error']}"
             except Exception:
             except Exception:
-                error_detail = f"Ollama: {e}"
+                detail = f"Ollama: {e}"
 
 
-        raise Exception(error_detail)
+        raise HTTPException(
+            status_code=r.status_code if r else 500,
+            detail=detail if detail else "Open WebUI: Server Connection Error",
+        )
 
 
 
 
 class GenerateCompletionForm(BaseModel):
 class GenerateCompletionForm(BaseModel):
@@ -947,10 +893,10 @@ async def generate_completion(
 
 
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
     api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+
     prefix_id = api_config.get("prefix_id", None)
     prefix_id = api_config.get("prefix_id", None)
     if prefix_id:
     if prefix_id:
         form_data.model = form_data.model.replace(f"{prefix_id}.", "")
         form_data.model = form_data.model.replace(f"{prefix_id}.", "")
-    log.info(f"url: {url}")
 
 
     return await send_post_request(
     return await send_post_request(
         url=f"{url}/api/generate",
         url=f"{url}/api/generate",
@@ -975,7 +921,7 @@ class GenerateChatCompletionForm(BaseModel):
     keep_alive: Optional[Union[int, str]] = None
     keep_alive: Optional[Union[int, str]] = None
 
 
 
 
-async def get_ollama_url(url_idx: Optional[int], model: str):
+async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
     if url_idx is None:
     if url_idx is None:
         models = request.app.state.OLLAMA_MODELS
         models = request.app.state.OLLAMA_MODELS
         if model not in models:
         if model not in models:
@@ -1001,7 +947,6 @@ async def generate_chat_completion(
         bypass_filter = True
         bypass_filter = True
 
 
     payload = {**form_data.model_dump(exclude_none=True)}
     payload = {**form_data.model_dump(exclude_none=True)}
-    log.debug(f"generate_chat_completion() - 1.payload = {payload}")
     if "metadata" in payload:
     if "metadata" in payload:
         del payload["metadata"]
         del payload["metadata"]
 
 
@@ -1045,13 +990,9 @@ async def generate_chat_completion(
     if ":" not in payload["model"]:
     if ":" not in payload["model"]:
         payload["model"] = f"{payload['model']}:latest"
         payload["model"] = f"{payload['model']}:latest"
 
 
-    url = await get_ollama_url(url_idx, payload["model"])
-    log.debug(f"generate_chat_completion() - 2.payload = {payload}")
-
-    parsed_url = urlparse(url)
-    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
+    url = await get_ollama_url(request, payload["model"], url_idx)
+    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
 
 
-    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
     prefix_id = api_config.get("prefix_id", None)
     prefix_id = api_config.get("prefix_id", None)
     if prefix_id:
     if prefix_id:
         payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
         payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
@@ -1148,10 +1089,9 @@ async def generate_openai_completion(
     if ":" not in payload["model"]:
     if ":" not in payload["model"]:
         payload["model"] = f"{payload['model']}:latest"
         payload["model"] = f"{payload['model']}:latest"
 
 
-    url = await get_ollama_url(url_idx, payload["model"])
-    log.info(f"url: {url}")
-
+    url = await get_ollama_url(request, payload["model"], url_idx)
     api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
     api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+
     prefix_id = api_config.get("prefix_id", None)
     prefix_id = api_config.get("prefix_id", None)
 
 
     if prefix_id:
     if prefix_id:
@@ -1223,10 +1163,9 @@ async def generate_openai_chat_completion(
     if ":" not in payload["model"]:
     if ":" not in payload["model"]:
         payload["model"] = f"{payload['model']}:latest"
         payload["model"] = f"{payload['model']}:latest"
 
 
-    url = await get_ollama_url(url_idx, payload["model"])
-    log.info(f"url: {url}")
-
+    url = await get_ollama_url(request, payload["model"], url_idx)
     api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
     api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+
     prefix_id = api_config.get("prefix_id", None)
     prefix_id = api_config.get("prefix_id", None)
     if prefix_id:
     if prefix_id:
         payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
         payload["model"] = payload["model"].replace(f"{prefix_id}.", "")