Timothy Jaeryang Baek 4 months ago
parent
commit
df0cdd9f3c
2 changed files with 354 additions and 397 deletions
  1. 232 214
      backend/open_webui/routers/audio.py
  2. 122 183
      backend/open_webui/routers/ollama.py

+ 232 - 214
backend/open_webui/routers/audio.py

@@ -11,25 +11,27 @@ from pydub.silence import split_on_silence
 import aiohttp
 import aiofiles
 import requests
+
+from fastapi import (
+    Depends,
+    FastAPI,
+    File,
+    HTTPException,
+    Request,
+    UploadFile,
+    status,
+    APIRouter,
+)
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse
+from pydantic import BaseModel
+
+
+from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.config import (
-    AUDIO_STT_ENGINE,
-    AUDIO_STT_MODEL,
-    AUDIO_STT_OPENAI_API_BASE_URL,
-    AUDIO_STT_OPENAI_API_KEY,
-    AUDIO_TTS_API_KEY,
-    AUDIO_TTS_ENGINE,
-    AUDIO_TTS_MODEL,
-    AUDIO_TTS_OPENAI_API_BASE_URL,
-    AUDIO_TTS_OPENAI_API_KEY,
-    AUDIO_TTS_SPLIT_ON,
-    AUDIO_TTS_VOICE,
-    AUDIO_TTS_AZURE_SPEECH_REGION,
-    AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
-    WHISPER_MODEL,
     WHISPER_MODEL_AUTO_UPDATE,
     WHISPER_MODEL_DIR,
     CACHE_DIR,
-    AppConfig,
 )
 
 from open_webui.constants import ERROR_MESSAGES
@@ -40,52 +42,75 @@ from open_webui.env import (
     ENABLE_FORWARD_USER_INFO_HEADERS,
 )
 
-from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import FileResponse
-from pydantic import BaseModel
-from open_webui.utils.auth import get_admin_user, get_verified_user
+
+router = APIRouter()
 
 # Constants
 MAX_FILE_SIZE_MB = 25
 MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024  # Convert MB to bytes
 
-
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["AUDIO"])
 
-
-# setting device type for whisper model
-whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
-log.info(f"whisper_device_type: {whisper_device_type}")
-
 SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
 SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
 
 
+##########################################
+#
+# Utility functions
+#
+##########################################
+
+from pydub import AudioSegment
+from pydub.utils import mediainfo
+
+
+def is_mp4_audio(file_path):
+    """Check if the given file is an MP4 audio file."""
+    if not os.path.isfile(file_path):
+        print(f"File not found: {file_path}")
+        return False
+
+    info = mediainfo(file_path)
+    if (
+        info.get("codec_name") == "aac"
+        and info.get("codec_type") == "audio"
+        and info.get("codec_tag_string") == "mp4a"
+    ):
+        return True
+    return False
+
+
+def convert_mp4_to_wav(file_path, output_path):
+    """Convert MP4 audio file to WAV format."""
+    audio = AudioSegment.from_file(file_path, format="mp4")
+    audio.export(output_path, format="wav")
+    print(f"Converted {file_path} to {output_path}")
+
+
 def set_faster_whisper_model(model: str, auto_update: bool = False):
-    if model and app.state.config.STT_ENGINE == "":
+    whisper_model = None
+    if model:
         from faster_whisper import WhisperModel
 
         faster_whisper_kwargs = {
             "model_size_or_path": model,
-            "device": whisper_device_type,
+            "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
             "compute_type": "int8",
             "download_root": WHISPER_MODEL_DIR,
             "local_files_only": not auto_update,
         }
 
         try:
-            app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
+            whisper_model = WhisperModel(**faster_whisper_kwargs)
         except Exception:
             log.warning(
                 "WhisperModel initialization failed, attempting download with local_files_only=False"
             )
             faster_whisper_kwargs["local_files_only"] = False
-            app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
-
-    else:
-        app.state.faster_whisper_model = None
+            whisper_model = WhisperModel(**faster_whisper_kwargs)
+    return whisper_model
 
 
 class TTSConfigForm(BaseModel):
@@ -113,98 +138,75 @@ class AudioConfigUpdateForm(BaseModel):
     stt: STTConfigForm
 
 
-from pydub import AudioSegment
-from pydub.utils import mediainfo
-
-
-def is_mp4_audio(file_path):
-    """Check if the given file is an MP4 audio file."""
-    if not os.path.isfile(file_path):
-        print(f"File not found: {file_path}")
-        return False
-
-    info = mediainfo(file_path)
-    if (
-        info.get("codec_name") == "aac"
-        and info.get("codec_type") == "audio"
-        and info.get("codec_tag_string") == "mp4a"
-    ):
-        return True
-    return False
-
-
-def convert_mp4_to_wav(file_path, output_path):
-    """Convert MP4 audio file to WAV format."""
-    audio = AudioSegment.from_file(file_path, format="mp4")
-    audio.export(output_path, format="wav")
-    print(f"Converted {file_path} to {output_path}")
-
-
-@app.get("/config")
-async def get_audio_config(user=Depends(get_admin_user)):
+@router.get("/config")
+async def get_audio_config(request: Request, user=Depends(get_admin_user)):
     return {
         "tts": {
-            "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
-            "API_KEY": app.state.config.TTS_API_KEY,
-            "ENGINE": app.state.config.TTS_ENGINE,
-            "MODEL": app.state.config.TTS_MODEL,
-            "VOICE": app.state.config.TTS_VOICE,
-            "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
-            "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
-            "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
+            "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
+            "API_KEY": request.app.state.config.TTS_API_KEY,
+            "ENGINE": request.app.state.config.TTS_ENGINE,
+            "MODEL": request.app.state.config.TTS_MODEL,
+            "VOICE": request.app.state.config.TTS_VOICE,
+            "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
+            "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
+            "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
         },
         "stt": {
-            "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
-            "ENGINE": app.state.config.STT_ENGINE,
-            "MODEL": app.state.config.STT_MODEL,
-            "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
+            "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
+            "ENGINE": request.app.state.config.STT_ENGINE,
+            "MODEL": request.app.state.config.STT_MODEL,
+            "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
         },
     }
 
 
-@app.post("/config/update")
+@router.post("/config/update")
 async def update_audio_config(
-    form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
+    request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
 ):
-    app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
-    app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
-    app.state.config.TTS_API_KEY = form_data.tts.API_KEY
-    app.state.config.TTS_ENGINE = form_data.tts.ENGINE
-    app.state.config.TTS_MODEL = form_data.tts.MODEL
-    app.state.config.TTS_VOICE = form_data.tts.VOICE
-    app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
-    app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
-    app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
+    request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
+    request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
+    request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
+    request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
+    request.app.state.config.TTS_MODEL = form_data.tts.MODEL
+    request.app.state.config.TTS_VOICE = form_data.tts.VOICE
+    request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
+    request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
+    request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
         form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
     )
 
-    app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
-    app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
-    app.state.config.STT_ENGINE = form_data.stt.ENGINE
-    app.state.config.STT_MODEL = form_data.stt.MODEL
-    app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
-    set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
+    request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
+    request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
+    request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
+    request.app.state.config.STT_MODEL = form_data.stt.MODEL
+    request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
+
+    if request.app.state.config.STT_ENGINE == "":
+        request.app.state.faster_whisper_model = set_faster_whisper_model(
+            form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
+        )
 
     return {
         "tts": {
-            "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
-            "API_KEY": app.state.config.TTS_API_KEY,
-            "ENGINE": app.state.config.TTS_ENGINE,
-            "MODEL": app.state.config.TTS_MODEL,
-            "VOICE": app.state.config.TTS_VOICE,
-            "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
-            "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
-            "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
+            "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
+            "API_KEY": request.app.state.config.TTS_API_KEY,
+            "ENGINE": request.app.state.config.TTS_ENGINE,
+            "MODEL": request.app.state.config.TTS_MODEL,
+            "VOICE": request.app.state.config.TTS_VOICE,
+            "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
+            "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
+            "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
         },
         "stt": {
-            "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
-            "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
-            "ENGINE": app.state.config.STT_ENGINE,
-            "MODEL": app.state.config.STT_MODEL,
-            "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
+            "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
+            "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
+            "ENGINE": request.app.state.config.STT_ENGINE,
+            "MODEL": request.app.state.config.STT_MODEL,
+            "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
         },
     }
 
@@ -213,18 +215,18 @@ def load_speech_pipeline():
     from transformers import pipeline
     from datasets import load_dataset
 
-    if app.state.speech_synthesiser is None:
-        app.state.speech_synthesiser = pipeline(
+    if request.app.state.speech_synthesiser is None:
+        request.app.state.speech_synthesiser = pipeline(
             "text-to-speech", "microsoft/speecht5_tts"
         )
 
-    if app.state.speech_speaker_embeddings_dataset is None:
-        app.state.speech_speaker_embeddings_dataset = load_dataset(
+    if request.app.state.speech_speaker_embeddings_dataset is None:
+        request.app.state.speech_speaker_embeddings_dataset = load_dataset(
             "Matthijs/cmu-arctic-xvectors", split="validation"
         )
 
 
-@app.post("/speech")
+@router.post("/speech")
 async def speech(request: Request, user=Depends(get_verified_user)):
     body = await request.body()
     name = hashlib.sha256(body).hexdigest()
@@ -236,9 +238,11 @@ async def speech(request: Request, user=Depends(get_verified_user)):
     if file_path.is_file():
         return FileResponse(file_path)
 
-    if app.state.config.TTS_ENGINE == "openai":
+    if request.app.state.config.TTS_ENGINE == "openai":
         headers = {}
-        headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
+        headers["Authorization"] = (
+            f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}"
+        )
         headers["Content-Type"] = "application/json"
 
         if ENABLE_FORWARD_USER_INFO_HEADERS:
@@ -250,7 +254,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         try:
             body = body.decode("utf-8")
             body = json.loads(body)
-            body["model"] = app.state.config.TTS_MODEL
+            body["model"] = request.app.state.config.TTS_MODEL
             body = json.dumps(body).encode("utf-8")
         except Exception:
             pass
@@ -258,7 +262,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         try:
             async with aiohttp.ClientSession() as session:
                 async with session.post(
-                    url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
+                    url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
                     data=body,
                     headers=headers,
                 ) as r:
@@ -287,7 +291,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
                 detail=error_detail,
             )
 
-    elif app.state.config.TTS_ENGINE == "elevenlabs":
+    elif request.app.state.config.TTS_ENGINE == "elevenlabs":
         try:
             payload = json.loads(body.decode("utf-8"))
         except Exception as e:
@@ -305,11 +309,11 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         headers = {
             "Accept": "audio/mpeg",
             "Content-Type": "application/json",
-            "xi-api-key": app.state.config.TTS_API_KEY,
+            "xi-api-key": request.app.state.config.TTS_API_KEY,
         }
         data = {
             "text": payload["input"],
-            "model_id": app.state.config.TTS_MODEL,
+            "model_id": request.app.state.config.TTS_MODEL,
             "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
         }
 
@@ -341,21 +345,21 @@ async def speech(request: Request, user=Depends(get_verified_user)):
                 detail=error_detail,
             )
 
-    elif app.state.config.TTS_ENGINE == "azure":
+    elif request.app.state.config.TTS_ENGINE == "azure":
         try:
             payload = json.loads(body.decode("utf-8"))
         except Exception as e:
             log.exception(e)
             raise HTTPException(status_code=400, detail="Invalid JSON payload")
 
-        region = app.state.config.TTS_AZURE_SPEECH_REGION
-        language = app.state.config.TTS_VOICE
-        locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1])
-        output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
+        region = request.app.state.config.TTS_AZURE_SPEECH_REGION
+        language = request.app.state.config.TTS_VOICE
+        locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
+        output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
         url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
 
         headers = {
-            "Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY,
+            "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
             "Content-Type": "application/ssml+xml",
             "X-Microsoft-OutputFormat": output_format,
         }
@@ -378,7 +382,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         except Exception as e:
             log.exception(e)
             raise HTTPException(status_code=500, detail=str(e))
-    elif app.state.config.TTS_ENGINE == "transformers":
+    elif request.app.state.config.TTS_ENGINE == "transformers":
         payload = None
         try:
             payload = json.loads(body.decode("utf-8"))
@@ -391,12 +395,12 @@ async def speech(request: Request, user=Depends(get_verified_user)):
 
         load_speech_pipeline()
 
-        embeddings_dataset = app.state.speech_speaker_embeddings_dataset
+        embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
 
         speaker_index = 6799
         try:
             speaker_index = embeddings_dataset["filename"].index(
-                app.state.config.TTS_MODEL
+                request.app.state.config.TTS_MODEL
             )
         except Exception:
             pass
@@ -405,7 +409,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
             embeddings_dataset[speaker_index]["xvector"]
         ).unsqueeze(0)
 
-        speech = app.state.speech_synthesiser(
+        speech = request.app.state.speech_synthesiser(
             payload["input"],
             forward_params={"speaker_embeddings": speaker_embedding},
         )
@@ -417,17 +421,19 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         return FileResponse(file_path)
 
 
-def transcribe(file_path):
+def transcribe(request: Request, file_path):
     print("transcribe", file_path)
     filename = os.path.basename(file_path)
     file_dir = os.path.dirname(file_path)
     id = filename.split(".")[0]
 
-    if app.state.config.STT_ENGINE == "":
-        if app.state.faster_whisper_model is None:
-            set_faster_whisper_model(app.state.config.WHISPER_MODEL)
+    if request.app.state.config.STT_ENGINE == "":
+        if request.app.state.faster_whisper_model is None:
+            request.app.state.faster_whisper_model = set_faster_whisper_model(
+                request.app.state.config.WHISPER_MODEL
+            )
 
-        model = app.state.faster_whisper_model
+        model = request.app.state.faster_whisper_model
         segments, info = model.transcribe(file_path, beam_size=5)
         log.info(
             "Detected language '%s' with probability %f"
@@ -444,31 +450,24 @@ def transcribe(file_path):
 
         log.debug(data)
         return data
-    elif app.state.config.STT_ENGINE == "openai":
+    elif request.app.state.config.STT_ENGINE == "openai":
         if is_mp4_audio(file_path):
-            print("is_mp4_audio")
             os.rename(file_path, file_path.replace(".wav", ".mp4"))
             # Convert MP4 audio file to WAV format
             convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
 
-        headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
-
-        files = {"file": (filename, open(file_path, "rb"))}
-        data = {"model": app.state.config.STT_MODEL}
-
-        log.debug(files, data)
-
         r = None
         try:
             r = requests.post(
-                url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
-                headers=headers,
-                files=files,
-                data=data,
+                url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
+                headers={
+                    "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
+                },
+                files={"file": (filename, open(file_path, "rb"))},
+                data={"model": request.app.state.config.STT_MODEL},
             )
 
             r.raise_for_status()
-
             data = r.json()
 
             # save the transcript to a json file
@@ -476,24 +475,43 @@ def transcribe(file_path):
             with open(transcript_file, "w") as f:
                 json.dump(data, f)
 
-            print(data)
             return data
         except Exception as e:
             log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
+
+            detail = None
             if r is not None:
                 try:
                     res = r.json()
                     if "error" in res:
-                        error_detail = f"External: {res['error']['message']}"
+                        detail = f"External: {res['error'].get('message', '')}"
                 except Exception:
-                    error_detail = f"External: {e}"
+                    detail = f"External: {e}"
+
+            raise Exception(detail if detail else "Open WebUI: Server Connection Error")
 
-            raise Exception(error_detail)
 
+def compress_audio(file_path):
+    if os.path.getsize(file_path) > MAX_FILE_SIZE:
+        file_dir = os.path.dirname(file_path)
+        audio = AudioSegment.from_file(file_path)
+        audio = audio.set_frame_rate(16000).set_channels(1)  # Compress audio
+        compressed_path = f"{file_dir}/{id}_compressed.opus"
+        audio.export(compressed_path, format="opus", bitrate="32k")
+        log.debug(f"Compressed audio to {compressed_path}")
 
-@app.post("/transcriptions")
+        if (
+            os.path.getsize(compressed_path) > MAX_FILE_SIZE
+        ):  # Still larger than MAX_FILE_SIZE after compression
+            raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
+        return compressed_path
+    else:
+        return file_path
+
+
+@router.post("/transcriptions")
 def transcription(
+    request: Request,
     file: UploadFile = File(...),
     user=Depends(get_verified_user),
 ):
@@ -520,36 +538,22 @@ def transcription(
             f.write(contents)
 
         try:
-            if os.path.getsize(file_path) > MAX_FILE_SIZE:  # file is bigger than 25MB
-                log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB")
-                audio = AudioSegment.from_file(file_path)
-                audio = audio.set_frame_rate(16000).set_channels(1)  # Compress audio
-                compressed_path = f"{file_dir}/{id}_compressed.opus"
-                audio.export(compressed_path, format="opus", bitrate="32k")
-                log.debug(f"Compressed audio to {compressed_path}")
-                file_path = compressed_path
-
-                if (
-                    os.path.getsize(file_path) > MAX_FILE_SIZE
-                ):  # Still larger than 25MB after compression
-                    log.debug(
-                        f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}"
-                    )
-                    raise HTTPException(
-                        status_code=status.HTTP_400_BAD_REQUEST,
-                        detail=ERROR_MESSAGES.FILE_TOO_LARGE(
-                            size=f"{MAX_FILE_SIZE_MB}MB"
-                        ),
-                    )
-
-                data = transcribe(file_path)
-            else:
-                data = transcribe(file_path)
+            try:
+                file_path = compress_audio(file_path)
+            except Exception as e:
+                log.exception(e)
+
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT(e),
+                )
 
+            data = transcribe(request, file_path)
             file_path = file_path.split("/")[-1]
             return {**data, "filename": file_path}
         except Exception as e:
             log.exception(e)
+
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,
                 detail=ERROR_MESSAGES.DEFAULT(e),
@@ -564,39 +568,41 @@ def transcription(
         )
 
 
-def get_available_models() -> list[dict]:
-    if app.state.config.TTS_ENGINE == "openai":
-        return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
-    elif app.state.config.TTS_ENGINE == "elevenlabs":
-        headers = {
-            "xi-api-key": app.state.config.TTS_API_KEY,
-            "Content-Type": "application/json",
-        }
-
+def get_available_models(request: Request) -> list[dict]:
+    available_models = []
+    if request.app.state.config.TTS_ENGINE == "openai":
+        available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
+    elif request.app.state.config.TTS_ENGINE == "elevenlabs":
         try:
             response = requests.get(
-                "https://api.elevenlabs.io/v1/models", headers=headers, timeout=5
+                "https://api.elevenlabs.io/v1/models",
+                headers={
+                    "xi-api-key": request.app.state.config.TTS_API_KEY,
+                    "Content-Type": "application/json",
+                },
+                timeout=5,
             )
             response.raise_for_status()
             models = response.json()
-            return [
+
+            available_models = [
                 {"name": model["name"], "id": model["model_id"]} for model in models
             ]
         except requests.RequestException as e:
             log.error(f"Error fetching voices: {str(e)}")
-    return []
+    return available_models
 
 
-@app.get("/models")
-async def get_models(user=Depends(get_verified_user)):
-    return {"models": get_available_models()}
+@router.get("/models")
+async def get_models(request: Request, user=Depends(get_verified_user)):
+    return {"models": get_available_models(request)}
 
 
-def get_available_voices() -> dict:
+def get_available_voices(request) -> dict:
     """Returns {voice_id: voice_name} dict"""
-    ret = {}
-    if app.state.config.TTS_ENGINE == "openai":
-        ret = {
+    available_voices = {}
+    if request.app.state.config.TTS_ENGINE == "openai":
+        available_voices = {
             "alloy": "alloy",
             "echo": "echo",
             "fable": "fable",
@@ -604,33 +610,38 @@ def get_available_voices() -> dict:
             "nova": "nova",
             "shimmer": "shimmer",
         }
-    elif app.state.config.TTS_ENGINE == "elevenlabs":
+    elif request.app.state.config.TTS_ENGINE == "elevenlabs":
         try:
-            ret = get_elevenlabs_voices()
+            available_voices = get_elevenlabs_voices(
+                api_key=request.app.state.config.TTS_API_KEY
+            )
         except Exception:
             # Avoided @lru_cache with exception
             pass
-    elif app.state.config.TTS_ENGINE == "azure":
+    elif request.app.state.config.TTS_ENGINE == "azure":
         try:
-            region = app.state.config.TTS_AZURE_SPEECH_REGION
+            region = request.app.state.config.TTS_AZURE_SPEECH_REGION
             url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
-            headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY}
+            headers = {
+                "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
+            }
 
             response = requests.get(url, headers=headers)
             response.raise_for_status()
             voices = response.json()
+
             for voice in voices:
-                ret[voice["ShortName"]] = (
+                available_voices[voice["ShortName"]] = (
                     f"{voice['DisplayName']} ({voice['ShortName']})"
                 )
         except requests.RequestException as e:
             log.error(f"Error fetching voices: {str(e)}")
 
-    return ret
+    return available_voices
 
 
 @lru_cache
-def get_elevenlabs_voices() -> dict:
+def get_elevenlabs_voices(api_key: str) -> dict:
     """
     Note, set the following in your .env file to use Elevenlabs:
     AUDIO_TTS_ENGINE=elevenlabs
@@ -638,13 +649,16 @@ def get_elevenlabs_voices() -> dict:
     AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL  # From https://api.elevenlabs.io/v1/voices
     AUDIO_TTS_MODEL=eleven_multilingual_v2
     """
-    headers = {
-        "xi-api-key": app.state.config.TTS_API_KEY,
-        "Content-Type": "application/json",
-    }
+
     try:
         # TODO: Add retries
-        response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers)
+        response = requests.get(
+            "https://api.elevenlabs.io/v1/voices",
+            headers={
+                "xi-api-key": api_key,
+                "Content-Type": "application/json",
+            },
+        )
         response.raise_for_status()
         voices_data = response.json()
 
@@ -659,6 +673,10 @@ def get_elevenlabs_voices() -> dict:
     return voices
 
 
-@app.get("/voices")
-async def get_voices(user=Depends(get_verified_user)):
-    return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]}
+@router.get("/voices")
+async def get_voices(request: Request, user=Depends(get_verified_user)):
+    return {
+        "voices": [
+            {"id": k, "name": v} for k, v in get_available_voices(request).items()
+        ]
+    }

+ 122 - 183
backend/open_webui/routers/ollama.py

@@ -385,7 +385,7 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
     if request.app.state.config.ENABLE_OLLAMA_API:
         if url_idx is None:
             # returns lowest version
-            tasks = [
+            request_tasks = [
                 send_get_request(
                     f"{url}/api/version",
                     request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
@@ -394,7 +394,7 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
                 )
                 for url in request.app.state.config.OLLAMA_BASE_URLS
             ]
-            responses = await asyncio.gather(*tasks)
+            responses = await asyncio.gather(*request_tasks)
             responses = list(filter(lambda x: x is not None, responses))
 
             if len(responses) > 0:
@@ -446,7 +446,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u
     List models that are currently loaded into Ollama memory, and which node they are loaded on.
     """
     if request.app.state.config.ENABLE_OLLAMA_API:
-        tasks = [
+        request_tasks = [
             send_get_request(
                 f"{url}/api/ps",
                 request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
@@ -455,7 +455,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u
             )
             for url in request.app.state.config.OLLAMA_BASE_URLS
         ]
-        responses = await asyncio.gather(*tasks)
+        responses = await asyncio.gather(*request_tasks)
 
         return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
     else:
@@ -502,8 +502,8 @@ async def push_model(
     user=Depends(get_admin_user),
 ):
     if url_idx is None:
-        model_list = await get_all_models()
-        models = {model["model"]: model for model in model_list["models"]}
+        await get_all_models(request)
+        models = request.app.state.OLLAMA_MODELS
 
         if form_data.name in models:
             url_idx = models[form_data.name]["urls"][0]
@@ -540,7 +540,6 @@ async def create_model(
 ):
     log.debug(f"form_data: {form_data}")
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
 
     return await send_post_request(
         url=f"{url}/api/create",
@@ -563,8 +562,8 @@ async def copy_model(
     user=Depends(get_admin_user),
 ):
     if url_idx is None:
-        model_list = await get_all_models()
-        models = {model["model"]: model for model in model_list["models"]}
+        await get_all_models()
+        models = request.app.state.OLLAMA_MODELS
 
         if form_data.source in models:
             url_idx = models[form_data.source]["urls"][0]
@@ -575,45 +574,37 @@ async def copy_model(
             )
 
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
-
-    parsed_url = urlparse(url)
-    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
-
-    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
-    key = api_config.get("key", None)
-
-    headers = {"Content-Type": "application/json"}
-    if key:
-        headers["Authorization"] = f"Bearer {key}"
-
-    r = requests.request(
-        method="POST",
-        url=f"{url}/api/copy",
-        headers=headers,
-        data=form_data.model_dump_json(exclude_none=True).encode(),
-    )
+    key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
 
     try:
+        r = requests.request(
+            method="POST",
+            url=f"{url}/api/copy",
+            headers={
+                "Content-Type": "application/json",
+                **({"Authorization": f"Bearer {key}"} if key else {}),
+            },
+            data=form_data.model_dump_json(exclude_none=True).encode(),
+        )
         r.raise_for_status()
 
         log.debug(f"r.text: {r.text}")
-
         return True
     except Exception as e:
         log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
+
+        detail = None
         if r is not None:
             try:
                 res = r.json()
                 if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
+                    detail = f"Ollama: {res['error']}"
             except Exception:
-                error_detail = f"Ollama: {e}"
+                detail = f"Ollama: {e}"
 
         raise HTTPException(
             status_code=r.status_code if r else 500,
-            detail=error_detail,
+            detail=detail if detail else "Open WebUI: Server Connection Error",
         )
 
 
@@ -626,8 +617,8 @@ async def delete_model(
     user=Depends(get_admin_user),
 ):
     if url_idx is None:
-        model_list = await get_all_models()
-        models = {model["model"]: model for model in model_list["models"]}
+        await get_all_models()
+        models = request.app.state.OLLAMA_MODELS
 
         if form_data.name in models:
             url_idx = models[form_data.name]["urls"][0]
@@ -638,44 +629,37 @@ async def delete_model(
             )
 
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
+    key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
 
-    parsed_url = urlparse(url)
-    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
-
-    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
-    key = api_config.get("key", None)
-
-    headers = {"Content-Type": "application/json"}
-    if key:
-        headers["Authorization"] = f"Bearer {key}"
-
-    r = requests.request(
-        method="DELETE",
-        url=f"{url}/api/delete",
-        data=form_data.model_dump_json(exclude_none=True).encode(),
-        headers=headers,
-    )
     try:
+        r = requests.request(
+            method="DELETE",
+            url=f"{url}/api/delete",
+            data=form_data.model_dump_json(exclude_none=True).encode(),
+            headers={
+                "Content-Type": "application/json",
+                **({"Authorization": f"Bearer {key}"} if key else {}),
+            },
+        )
         r.raise_for_status()
 
         log.debug(f"r.text: {r.text}")
-
         return True
     except Exception as e:
         log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
+
+        detail = None
         if r is not None:
             try:
                 res = r.json()
                 if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
+                    detail = f"Ollama: {res['error']}"
             except Exception:
-                error_detail = f"Ollama: {e}"
+                detail = f"Ollama: {e}"
 
         raise HTTPException(
             status_code=r.status_code if r else 500,
-            detail=error_detail,
+            detail=detail if detail else "Open WebUI: Server Connection Error",
         )
 
 
@@ -683,8 +667,8 @@ async def delete_model(
 async def show_model_info(
     request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
 ):
-    model_list = await get_all_models()
-    models = {model["model"]: model for model in model_list["models"]}
+    await get_all_models()
+    models = request.app.state.OLLAMA_MODELS
 
     if form_data.name not in models:
         raise HTTPException(
@@ -693,53 +677,41 @@ async def show_model_info(
         )
 
     url_idx = random.choice(models[form_data.name]["urls"])
-    url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
-
-    parsed_url = urlparse(url)
-    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
-
-    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
-    key = api_config.get("key", None)
 
-    headers = {"Content-Type": "application/json"}
-    if key:
-        headers["Authorization"] = f"Bearer {key}"
+    url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
+    key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
 
-    r = requests.request(
-        method="POST",
-        url=f"{url}/api/show",
-        headers=headers,
-        data=form_data.model_dump_json(exclude_none=True).encode(),
-    )
     try:
+        r = requests.request(
+            method="POST",
+            url=f"{url}/api/show",
+            headers={
+                "Content-Type": "application/json",
+                **({"Authorization": f"Bearer {key}"} if key else {}),
+            },
+            data=form_data.model_dump_json(exclude_none=True).encode(),
+        )
         r.raise_for_status()
 
         return r.json()
     except Exception as e:
         log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
+
+        detail = None
         if r is not None:
             try:
                 res = r.json()
                 if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
+                    detail = f"Ollama: {res['error']}"
             except Exception:
-                error_detail = f"Ollama: {e}"
+                detail = f"Ollama: {e}"
 
         raise HTTPException(
             status_code=r.status_code if r else 500,
-            detail=error_detail,
+            detail=detail if detail else "Open WebUI: Server Connection Error",
         )
 
 
-class GenerateEmbeddingsForm(BaseModel):
-    model: str
-    prompt: str
-    options: Optional[dict] = None
-    keep_alive: Optional[Union[int, str]] = None
-
-
 class GenerateEmbedForm(BaseModel):
     model: str
     input: list[str] | str
@@ -750,33 +722,17 @@ class GenerateEmbedForm(BaseModel):
 
 @router.post("/api/embed")
 @router.post("/api/embed/{url_idx}")
-async def generate_embeddings(
+async def embed(
+    request: Request,
     form_data: GenerateEmbedForm,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
 ):
-    return await generate_ollama_batch_embeddings(form_data, url_idx)
-
-
-@router.post("/api/embeddings")
-@router.post("/api/embeddings/{url_idx}")
-async def generate_embeddings(
-    form_data: GenerateEmbeddingsForm,
-    url_idx: Optional[int] = None,
-    user=Depends(get_verified_user),
-):
-    return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
-
-
-async def generate_ollama_embeddings(
-    form_data: GenerateEmbeddingsForm,
-    url_idx: Optional[int] = None,
-):
-    log.info(f"generate_ollama_embeddings {form_data}")
+    log.info(f"generate_ollama_batch_embeddings {form_data}")
 
     if url_idx is None:
-        model_list = await get_all_models()
-        models = {model["model"]: model for model in model_list["models"]}
+        await get_all_models()
+        models = request.app.state.OLLAMA_MODELS
 
         model = form_data.model
 
@@ -792,61 +748,60 @@ async def generate_ollama_embeddings(
             )
 
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
+    key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
 
-    parsed_url = urlparse(url)
-    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
-
-    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
-    key = api_config.get("key", None)
-
-    headers = {"Content-Type": "application/json"}
-    if key:
-        headers["Authorization"] = f"Bearer {key}"
-
-    r = requests.request(
-        method="POST",
-        url=f"{url}/api/embeddings",
-        headers=headers,
-        data=form_data.model_dump_json(exclude_none=True).encode(),
-    )
     try:
+        r = requests.request(
+            method="POST",
+            url=f"{url}/api/embed",
+            headers={
+                "Content-Type": "application/json",
+                **({"Authorization": f"Bearer {key}"} if key else {}),
+            },
+            data=form_data.model_dump_json(exclude_none=True).encode(),
+        )
         r.raise_for_status()
 
         data = r.json()
-
-        log.info(f"generate_ollama_embeddings {data}")
-
-        if "embedding" in data:
-            return data
-        else:
-            raise Exception("Something went wrong :/")
+        return data
     except Exception as e:
         log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
+
+        detail = None
         if r is not None:
             try:
                 res = r.json()
                 if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
+                    detail = f"Ollama: {res['error']}"
             except Exception:
-                error_detail = f"Ollama: {e}"
+                detail = f"Ollama: {e}"
 
         raise HTTPException(
             status_code=r.status_code if r else 500,
-            detail=error_detail,
+            detail=detail if detail else "Open WebUI: Server Connection Error",
         )
 
 
-async def generate_ollama_batch_embeddings(
-    form_data: GenerateEmbedForm,
+class GenerateEmbeddingsForm(BaseModel):
+    model: str
+    prompt: str
+    options: Optional[dict] = None
+    keep_alive: Optional[Union[int, str]] = None
+
+
+@router.post("/api/embeddings")
+@router.post("/api/embeddings/{url_idx}")
+async def embeddings(
+    request: Request,
+    form_data: GenerateEmbeddingsForm,
     url_idx: Optional[int] = None,
+    user=Depends(get_verified_user),
 ):
-    log.info(f"generate_ollama_batch_embeddings {form_data}")
+    log.info(f"generate_ollama_embeddings {form_data}")
 
     if url_idx is None:
-        model_list = await get_all_models()
-        models = {model["model"]: model for model in model_list["models"]}
+        await get_all_models()
+        models = request.app.state.OLLAMA_MODELS
 
         model = form_data.model
 
@@ -862,47 +817,38 @@ async def generate_ollama_batch_embeddings(
             )
 
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
-
-    parsed_url = urlparse(url)
-    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
+    key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
 
-    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
-    key = api_config.get("key", None)
-
-    headers = {"Content-Type": "application/json"}
-    if key:
-        headers["Authorization"] = f"Bearer {key}"
-
-    r = requests.request(
-        method="POST",
-        url=f"{url}/api/embed",
-        headers=headers,
-        data=form_data.model_dump_json(exclude_none=True).encode(),
-    )
     try:
+        r = requests.request(
+            method="POST",
+            url=f"{url}/api/embeddings",
+            headers={
+                "Content-Type": "application/json",
+                **({"Authorization": f"Bearer {key}"} if key else {}),
+            },
+            data=form_data.model_dump_json(exclude_none=True).encode(),
+        )
         r.raise_for_status()
 
         data = r.json()
-
-        log.info(f"generate_ollama_batch_embeddings {data}")
-
-        if "embeddings" in data:
-            return data
-        else:
-            raise Exception("Something went wrong :/")
+        return data
     except Exception as e:
         log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
+
+        detail = None
         if r is not None:
             try:
                 res = r.json()
                 if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
+                    detail = f"Ollama: {res['error']}"
             except Exception:
-                error_detail = f"Ollama: {e}"
+                detail = f"Ollama: {e}"
 
-        raise Exception(error_detail)
+        raise HTTPException(
+            status_code=r.status_code if r else 500,
+            detail=detail if detail else "Open WebUI: Server Connection Error",
+        )
 
 
 class GenerateCompletionForm(BaseModel):
@@ -947,10 +893,10 @@ async def generate_completion(
 
     url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
     api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+
     prefix_id = api_config.get("prefix_id", None)
     if prefix_id:
         form_data.model = form_data.model.replace(f"{prefix_id}.", "")
-    log.info(f"url: {url}")
 
     return await send_post_request(
         url=f"{url}/api/generate",
@@ -975,7 +921,7 @@ class GenerateChatCompletionForm(BaseModel):
     keep_alive: Optional[Union[int, str]] = None
 
 
-async def get_ollama_url(url_idx: Optional[int], model: str):
+async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
     if url_idx is None:
         models = request.app.state.OLLAMA_MODELS
         if model not in models:
@@ -1001,7 +947,6 @@ async def generate_chat_completion(
         bypass_filter = True
 
     payload = {**form_data.model_dump(exclude_none=True)}
-    log.debug(f"generate_chat_completion() - 1.payload = {payload}")
     if "metadata" in payload:
         del payload["metadata"]
 
@@ -1045,13 +990,9 @@ async def generate_chat_completion(
     if ":" not in payload["model"]:
         payload["model"] = f"{payload['model']}:latest"
 
-    url = await get_ollama_url(url_idx, payload["model"])
-    log.debug(f"generate_chat_completion() - 2.payload = {payload}")
-
-    parsed_url = urlparse(url)
-    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
+    url = await get_ollama_url(request, payload["model"], url_idx)
+    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
 
-    api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
     prefix_id = api_config.get("prefix_id", None)
     if prefix_id:
         payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
@@ -1148,10 +1089,9 @@ async def generate_openai_completion(
     if ":" not in payload["model"]:
         payload["model"] = f"{payload['model']}:latest"
 
-    url = await get_ollama_url(url_idx, payload["model"])
-    log.info(f"url: {url}")
-
+    url = await get_ollama_url(request, payload["model"], url_idx)
     api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+
     prefix_id = api_config.get("prefix_id", None)
 
     if prefix_id:
@@ -1223,10 +1163,9 @@ async def generate_openai_chat_completion(
     if ":" not in payload["model"]:
         payload["model"] = f"{payload['model']}:latest"
 
-    url = await get_ollama_url(url_idx, payload["model"])
-    log.info(f"url: {url}")
-
+    url = await get_ollama_url(request, payload["model"], url_idx)
     api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+
     prefix_id = api_config.get("prefix_id", None)
     if prefix_id:
         payload["model"] = payload["model"].replace(f"{prefix_id}.", "")