|
@@ -13,7 +13,15 @@ from aiocache import cached
|
|
|
|
|
|
import requests
|
|
import requests
|
|
|
|
|
|
-from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
|
|
|
|
|
|
+from fastapi import (
|
|
|
|
+ Depends,
|
|
|
|
+ FastAPI,
|
|
|
|
+ File,
|
|
|
|
+ HTTPException,
|
|
|
|
+ Request,
|
|
|
|
+ UploadFile,
|
|
|
|
+ APIRouter,
|
|
|
|
+)
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel, ConfigDict
|
|
from pydantic import BaseModel, ConfigDict
|
|
@@ -26,18 +34,15 @@ from open_webui.models.models import Models
|
|
from open_webui.config import (
|
|
from open_webui.config import (
|
|
UPLOAD_DIR,
|
|
UPLOAD_DIR,
|
|
)
|
|
)
|
|
-
|
|
|
|
-
|
|
|
|
from open_webui.env import (
|
|
from open_webui.env import (
|
|
|
|
+ ENV,
|
|
|
|
+ SRC_LOG_LEVELS,
|
|
AIOHTTP_CLIENT_TIMEOUT,
|
|
AIOHTTP_CLIENT_TIMEOUT,
|
|
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
|
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
|
BYPASS_MODEL_ACCESS_CONTROL,
|
|
BYPASS_MODEL_ACCESS_CONTROL,
|
|
)
|
|
)
|
|
|
|
|
|
-
|
|
|
|
from open_webui.constants import ERROR_MESSAGES
|
|
from open_webui.constants import ERROR_MESSAGES
|
|
-from open_webui.env import ENV, SRC_LOG_LEVELS
|
|
|
|
-
|
|
|
|
|
|
|
|
from open_webui.utils.misc import (
|
|
from open_webui.utils.misc import (
|
|
calculate_sha256,
|
|
calculate_sha256,
|
|
@@ -54,13 +59,15 @@ log = logging.getLogger(__name__)
|
|
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
|
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
|
|
|
|
|
|
|
|
|
|
|
+router = APIRouter()
|
|
|
|
+
|
|
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
|
|
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
|
|
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
|
|
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
|
|
# least connections, or least response time for better resource utilization and performance optimization.
|
|
# least connections, or least response time for better resource utilization and performance optimization.
|
|
|
|
|
|
|
|
|
|
-@app.head("/")
|
|
|
|
-@app.get("/")
|
|
|
|
|
|
+@router.head("/")
|
|
|
|
+@router.get("/")
|
|
async def get_status():
|
|
async def get_status():
|
|
return {"status": True}
|
|
return {"status": True}
|
|
|
|
|
|
@@ -70,7 +77,7 @@ class ConnectionVerificationForm(BaseModel):
|
|
key: Optional[str] = None
|
|
key: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
-@app.post("/verify")
|
|
|
|
|
|
+@router.post("/verify")
|
|
async def verify_connection(
|
|
async def verify_connection(
|
|
form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
|
|
form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
|
|
):
|
|
):
|
|
@@ -110,12 +117,12 @@ async def verify_connection(
|
|
raise HTTPException(status_code=500, detail=error_detail)
|
|
raise HTTPException(status_code=500, detail=error_detail)
|
|
|
|
|
|
|
|
|
|
-@app.get("/config")
|
|
|
|
-async def get_config(user=Depends(get_admin_user)):
|
|
|
|
|
|
+@router.get("/config")
|
|
|
|
+async def get_config(request: Request, user=Depends(get_admin_user)):
|
|
return {
|
|
return {
|
|
- "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
|
|
|
|
- "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
|
|
|
|
- "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
|
|
|
|
|
|
+ "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
|
|
|
|
+ "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
|
|
|
|
+ "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -125,23 +132,25 @@ class OllamaConfigForm(BaseModel):
|
|
OLLAMA_API_CONFIGS: dict
|
|
OLLAMA_API_CONFIGS: dict
|
|
|
|
|
|
|
|
|
|
-@app.post("/config/update")
|
|
|
|
-async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
|
|
|
|
- app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
|
|
|
|
- app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
|
|
|
|
|
|
+@router.post("/config/update")
|
|
|
|
+async def update_config(
|
|
|
|
+ request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user)
|
|
|
|
+):
|
|
|
|
+ request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
|
|
|
|
+ request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
|
|
|
|
|
|
- app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
|
|
|
|
|
|
+ request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
|
|
|
|
|
|
# Remove any extra configs
|
|
# Remove any extra configs
|
|
- config_urls = app.state.config.OLLAMA_API_CONFIGS.keys()
|
|
|
|
- for url in list(app.state.config.OLLAMA_BASE_URLS):
|
|
|
|
|
|
+ config_urls = request.app.state.config.OLLAMA_API_CONFIGS.keys()
|
|
|
|
+ for url in list(request.app.state.config.OLLAMA_BASE_URLS):
|
|
if url not in config_urls:
|
|
if url not in config_urls:
|
|
- app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
|
|
|
|
|
|
+ request.app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
|
|
|
|
|
|
return {
|
|
return {
|
|
- "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
|
|
|
|
- "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
|
|
|
|
- "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
|
|
|
|
|
|
+ "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
|
|
|
|
+ "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
|
|
|
|
+ "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -158,6 +167,12 @@ async def aiohttp_get(url, key=None):
|
|
return None
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
+def get_api_key(url, configs):
|
|
|
|
+ parsed_url = urlparse(url)
|
|
|
|
+ base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
|
|
+ return configs.get(base_url, {}).get("key", None)
|
|
|
|
+
|
|
|
|
+
|
|
async def cleanup_response(
|
|
async def cleanup_response(
|
|
response: Optional[aiohttp.ClientResponse],
|
|
response: Optional[aiohttp.ClientResponse],
|
|
session: Optional[aiohttp.ClientSession],
|
|
session: Optional[aiohttp.ClientSession],
|
|
@@ -169,7 +184,11 @@ async def cleanup_response(
|
|
|
|
|
|
|
|
|
|
async def post_streaming_url(
|
|
async def post_streaming_url(
|
|
- url: str, payload: Union[str, bytes], stream: bool = True, content_type=None
|
|
|
|
|
|
+ url: str,
|
|
|
|
+ payload: Union[str, bytes],
|
|
|
|
+ stream: bool = True,
|
|
|
|
+ key: Optional[str] = None,
|
|
|
|
+ content_type=None,
|
|
):
|
|
):
|
|
r = None
|
|
r = None
|
|
try:
|
|
try:
|
|
@@ -177,12 +196,6 @@ async def post_streaming_url(
|
|
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
|
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
|
)
|
|
)
|
|
|
|
|
|
- parsed_url = urlparse(url)
|
|
|
|
- base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
|
|
-
|
|
|
|
- api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
|
|
|
- key = api_config.get("key", None)
|
|
|
|
-
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
headers = {"Content-Type": "application/json"}
|
|
if key:
|
|
if key:
|
|
headers["Authorization"] = f"Bearer {key}"
|
|
headers["Authorization"] = f"Bearer {key}"
|
|
@@ -246,13 +259,13 @@ def merge_models_lists(model_lists):
|
|
@cached(ttl=3)
|
|
@cached(ttl=3)
|
|
async def get_all_models():
|
|
async def get_all_models():
|
|
log.info("get_all_models()")
|
|
log.info("get_all_models()")
|
|
- if app.state.config.ENABLE_OLLAMA_API:
|
|
|
|
|
|
+ if request.app.state.config.ENABLE_OLLAMA_API:
|
|
tasks = []
|
|
tasks = []
|
|
- for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS):
|
|
|
|
- if url not in app.state.config.OLLAMA_API_CONFIGS:
|
|
|
|
|
|
+ for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
|
|
|
|
+ if url not in request.app.state.config.OLLAMA_API_CONFIGS:
|
|
tasks.append(aiohttp_get(f"{url}/api/tags"))
|
|
tasks.append(aiohttp_get(f"{url}/api/tags"))
|
|
else:
|
|
else:
|
|
- api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
|
|
|
|
+ api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
enable = api_config.get("enable", True)
|
|
enable = api_config.get("enable", True)
|
|
key = api_config.get("key", None)
|
|
key = api_config.get("key", None)
|
|
|
|
|
|
@@ -265,8 +278,8 @@ async def get_all_models():
|
|
|
|
|
|
for idx, response in enumerate(responses):
|
|
for idx, response in enumerate(responses):
|
|
if response:
|
|
if response:
|
|
- url = app.state.config.OLLAMA_BASE_URLS[idx]
|
|
|
|
- api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
|
|
|
|
+ url = request.app.state.config.OLLAMA_BASE_URLS[idx]
|
|
|
|
+ api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
|
|
|
|
prefix_id = api_config.get("prefix_id", None)
|
|
prefix_id = api_config.get("prefix_id", None)
|
|
model_ids = api_config.get("model_ids", [])
|
|
model_ids = api_config.get("model_ids", [])
|
|
@@ -298,21 +311,21 @@ async def get_all_models():
|
|
return models
|
|
return models
|
|
|
|
|
|
|
|
|
|
-@app.get("/api/tags")
|
|
|
|
-@app.get("/api/tags/{url_idx}")
|
|
|
|
|
|
+@router.get("/api/tags")
|
|
|
|
+@router.get("/api/tags/{url_idx}")
|
|
async def get_ollama_tags(
|
|
async def get_ollama_tags(
|
|
- url_idx: Optional[int] = None, user=Depends(get_verified_user)
|
|
|
|
|
|
+ request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
|
|
):
|
|
):
|
|
models = []
|
|
models = []
|
|
if url_idx is None:
|
|
if url_idx is None:
|
|
models = await get_all_models()
|
|
models = await get_all_models()
|
|
else:
|
|
else:
|
|
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
+ url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
parsed_url = urlparse(url)
|
|
parsed_url = urlparse(url)
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
|
|
|
|
- api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
|
|
|
|
|
+ api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
|
key = api_config.get("key", None)
|
|
key = api_config.get("key", None)
|
|
|
|
|
|
headers = {}
|
|
headers = {}
|
|
@@ -356,18 +369,20 @@ async def get_ollama_tags(
|
|
return models
|
|
return models
|
|
|
|
|
|
|
|
|
|
-@app.get("/api/version")
|
|
|
|
-@app.get("/api/version/{url_idx}")
|
|
|
|
-async def get_ollama_versions(url_idx: Optional[int] = None):
|
|
|
|
- if app.state.config.ENABLE_OLLAMA_API:
|
|
|
|
|
|
+@router.get("/api/version")
|
|
|
|
+@router.get("/api/version/{url_idx}")
|
|
|
|
+async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
|
|
|
|
+ if request.app.state.config.ENABLE_OLLAMA_API:
|
|
if url_idx is None:
|
|
if url_idx is None:
|
|
# returns lowest version
|
|
# returns lowest version
|
|
tasks = [
|
|
tasks = [
|
|
aiohttp_get(
|
|
aiohttp_get(
|
|
f"{url}/api/version",
|
|
f"{url}/api/version",
|
|
- app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
|
|
|
|
|
|
+ request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
|
|
|
|
+ "key", None
|
|
|
|
+ ),
|
|
)
|
|
)
|
|
- for url in app.state.config.OLLAMA_BASE_URLS
|
|
|
|
|
|
+ for url in request.app.state.config.OLLAMA_BASE_URLS
|
|
]
|
|
]
|
|
responses = await asyncio.gather(*tasks)
|
|
responses = await asyncio.gather(*tasks)
|
|
responses = list(filter(lambda x: x is not None, responses))
|
|
responses = list(filter(lambda x: x is not None, responses))
|
|
@@ -387,7 +402,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
|
|
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
|
|
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
|
|
)
|
|
)
|
|
else:
|
|
else:
|
|
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
+ url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
r = None
|
|
r = None
|
|
try:
|
|
try:
|
|
@@ -414,22 +429,24 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
|
|
return {"version": False}
|
|
return {"version": False}
|
|
|
|
|
|
|
|
|
|
-@app.get("/api/ps")
|
|
|
|
-async def get_ollama_loaded_models(user=Depends(get_verified_user)):
|
|
|
|
|
|
+@router.get("/api/ps")
|
|
|
|
+async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
|
|
"""
|
|
"""
|
|
List models that are currently loaded into Ollama memory, and which node they are loaded on.
|
|
List models that are currently loaded into Ollama memory, and which node they are loaded on.
|
|
"""
|
|
"""
|
|
- if app.state.config.ENABLE_OLLAMA_API:
|
|
|
|
|
|
+ if request.app.state.config.ENABLE_OLLAMA_API:
|
|
tasks = [
|
|
tasks = [
|
|
aiohttp_get(
|
|
aiohttp_get(
|
|
f"{url}/api/ps",
|
|
f"{url}/api/ps",
|
|
- app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
|
|
|
|
|
|
+ request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
|
|
|
|
+ "key", None
|
|
|
|
+ ),
|
|
)
|
|
)
|
|
- for url in app.state.config.OLLAMA_BASE_URLS
|
|
|
|
|
|
+ for url in request.app.state.config.OLLAMA_BASE_URLS
|
|
]
|
|
]
|
|
responses = await asyncio.gather(*tasks)
|
|
responses = await asyncio.gather(*tasks)
|
|
|
|
|
|
- return dict(zip(app.state.config.OLLAMA_BASE_URLS, responses))
|
|
|
|
|
|
+ return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
|
|
else:
|
|
else:
|
|
return {}
|
|
return {}
|
|
|
|
|
|
@@ -438,18 +455,25 @@ class ModelNameForm(BaseModel):
|
|
name: str
|
|
name: str
|
|
|
|
|
|
|
|
|
|
-@app.post("/api/pull")
|
|
|
|
-@app.post("/api/pull/{url_idx}")
|
|
|
|
|
|
+@router.post("/api/pull")
|
|
|
|
+@router.post("/api/pull/{url_idx}")
|
|
async def pull_model(
|
|
async def pull_model(
|
|
- form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
|
|
|
|
|
|
+ request: Request,
|
|
|
|
+ form_data: ModelNameForm,
|
|
|
|
+ url_idx: int = 0,
|
|
|
|
+ user=Depends(get_admin_user),
|
|
):
|
|
):
|
|
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
+ url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
|
|
|
|
# Admin should be able to pull models from any source
|
|
# Admin should be able to pull models from any source
|
|
payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
|
|
payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
|
|
|
|
|
|
- return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
|
|
|
|
|
|
+ return await post_streaming_url(
|
|
|
|
+ url=f"{url}/api/pull",
|
|
|
|
+ payload=json.dumps(payload),
|
|
|
|
+ key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
class PushModelForm(BaseModel):
|
|
class PushModelForm(BaseModel):
|
|
@@ -458,9 +482,10 @@ class PushModelForm(BaseModel):
|
|
stream: Optional[bool] = None
|
|
stream: Optional[bool] = None
|
|
|
|
|
|
|
|
|
|
-@app.delete("/api/push")
|
|
|
|
-@app.delete("/api/push/{url_idx}")
|
|
|
|
|
|
+@router.delete("/api/push")
|
|
|
|
+@router.delete("/api/push/{url_idx}")
|
|
async def push_model(
|
|
async def push_model(
|
|
|
|
+ request: Request,
|
|
form_data: PushModelForm,
|
|
form_data: PushModelForm,
|
|
url_idx: Optional[int] = None,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_admin_user),
|
|
user=Depends(get_admin_user),
|
|
@@ -477,11 +502,13 @@ async def push_model(
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
|
)
|
|
)
|
|
|
|
|
|
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
+ url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
log.debug(f"url: {url}")
|
|
log.debug(f"url: {url}")
|
|
|
|
|
|
return await post_streaming_url(
|
|
return await post_streaming_url(
|
|
- f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
|
|
|
|
|
|
+ url=f"{url}/api/push",
|
|
|
|
+ payload=form_data.model_dump_json(exclude_none=True).encode(),
|
|
|
|
+ key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -492,17 +519,22 @@ class CreateModelForm(BaseModel):
|
|
path: Optional[str] = None
|
|
path: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
-@app.post("/api/create")
|
|
|
|
-@app.post("/api/create/{url_idx}")
|
|
|
|
|
|
+@router.post("/api/create")
|
|
|
|
+@router.post("/api/create/{url_idx}")
|
|
async def create_model(
|
|
async def create_model(
|
|
- form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
|
|
|
|
|
|
+ request: Request,
|
|
|
|
+ form_data: CreateModelForm,
|
|
|
|
+ url_idx: int = 0,
|
|
|
|
+ user=Depends(get_admin_user),
|
|
):
|
|
):
|
|
log.debug(f"form_data: {form_data}")
|
|
log.debug(f"form_data: {form_data}")
|
|
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
+ url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
|
|
|
|
return await post_streaming_url(
|
|
return await post_streaming_url(
|
|
- f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
|
|
|
|
|
|
+ url=f"{url}/api/create",
|
|
|
|
+ payload=form_data.model_dump_json(exclude_none=True).encode(),
|
|
|
|
+ key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -511,9 +543,10 @@ class CopyModelForm(BaseModel):
|
|
destination: str
|
|
destination: str
|
|
|
|
|
|
|
|
|
|
-@app.post("/api/copy")
|
|
|
|
-@app.post("/api/copy/{url_idx}")
|
|
|
|
|
|
+@router.post("/api/copy")
|
|
|
|
+@router.post("/api/copy/{url_idx}")
|
|
async def copy_model(
|
|
async def copy_model(
|
|
|
|
+ request: Request,
|
|
form_data: CopyModelForm,
|
|
form_data: CopyModelForm,
|
|
url_idx: Optional[int] = None,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_admin_user),
|
|
user=Depends(get_admin_user),
|
|
@@ -530,13 +563,13 @@ async def copy_model(
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
|
|
)
|
|
)
|
|
|
|
|
|
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
+ url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
|
|
|
|
parsed_url = urlparse(url)
|
|
parsed_url = urlparse(url)
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
|
|
|
|
- api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
|
|
|
|
|
+ api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
|
key = api_config.get("key", None)
|
|
key = api_config.get("key", None)
|
|
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
headers = {"Content-Type": "application/json"}
|
|
@@ -573,9 +606,10 @@ async def copy_model(
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
-@app.delete("/api/delete")
|
|
|
|
-@app.delete("/api/delete/{url_idx}")
|
|
|
|
|
|
+@router.delete("/api/delete")
|
|
|
|
+@router.delete("/api/delete/{url_idx}")
|
|
async def delete_model(
|
|
async def delete_model(
|
|
|
|
+ request: Request,
|
|
form_data: ModelNameForm,
|
|
form_data: ModelNameForm,
|
|
url_idx: Optional[int] = None,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_admin_user),
|
|
user=Depends(get_admin_user),
|
|
@@ -592,13 +626,13 @@ async def delete_model(
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
|
)
|
|
)
|
|
|
|
|
|
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
+ url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
|
|
|
|
parsed_url = urlparse(url)
|
|
parsed_url = urlparse(url)
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
|
|
|
|
- api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
|
|
|
|
|
+ api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
|
key = api_config.get("key", None)
|
|
key = api_config.get("key", None)
|
|
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
headers = {"Content-Type": "application/json"}
|
|
@@ -634,8 +668,10 @@ async def delete_model(
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
-@app.post("/api/show")
|
|
|
|
-async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
|
|
|
|
|
|
+@router.post("/api/show")
|
|
|
|
+async def show_model_info(
|
|
|
|
+ request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
|
|
|
|
+):
|
|
model_list = await get_all_models()
|
|
model_list = await get_all_models()
|
|
models = {model["model"]: model for model in model_list["models"]}
|
|
models = {model["model"]: model for model in model_list["models"]}
|
|
|
|
|
|
@@ -646,13 +682,13 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
|
|
)
|
|
)
|
|
|
|
|
|
url_idx = random.choice(models[form_data.name]["urls"])
|
|
url_idx = random.choice(models[form_data.name]["urls"])
|
|
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
+ url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
|
|
|
|
parsed_url = urlparse(url)
|
|
parsed_url = urlparse(url)
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
|
|
|
|
- api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
|
|
|
|
|
+ api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
|
key = api_config.get("key", None)
|
|
key = api_config.get("key", None)
|
|
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
headers = {"Content-Type": "application/json"}
|
|
@@ -701,8 +737,8 @@ class GenerateEmbedForm(BaseModel):
|
|
keep_alive: Optional[Union[int, str]] = None
|
|
keep_alive: Optional[Union[int, str]] = None
|
|
|
|
|
|
|
|
|
|
-@app.post("/api/embed")
|
|
|
|
-@app.post("/api/embed/{url_idx}")
|
|
|
|
|
|
+@router.post("/api/embed")
|
|
|
|
+@router.post("/api/embed/{url_idx}")
|
|
async def generate_embeddings(
|
|
async def generate_embeddings(
|
|
form_data: GenerateEmbedForm,
|
|
form_data: GenerateEmbedForm,
|
|
url_idx: Optional[int] = None,
|
|
url_idx: Optional[int] = None,
|
|
@@ -711,8 +747,8 @@ async def generate_embeddings(
|
|
return await generate_ollama_batch_embeddings(form_data, url_idx)
|
|
return await generate_ollama_batch_embeddings(form_data, url_idx)
|
|
|
|
|
|
|
|
|
|
-@app.post("/api/embeddings")
|
|
|
|
-@app.post("/api/embeddings/{url_idx}")
|
|
|
|
|
|
+@router.post("/api/embeddings")
|
|
|
|
+@router.post("/api/embeddings/{url_idx}")
|
|
async def generate_embeddings(
|
|
async def generate_embeddings(
|
|
form_data: GenerateEmbeddingsForm,
|
|
form_data: GenerateEmbeddingsForm,
|
|
url_idx: Optional[int] = None,
|
|
url_idx: Optional[int] = None,
|
|
@@ -744,13 +780,13 @@ async def generate_ollama_embeddings(
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
|
)
|
|
)
|
|
|
|
|
|
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
+ url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
|
|
|
|
parsed_url = urlparse(url)
|
|
parsed_url = urlparse(url)
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
|
|
|
|
- api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
|
|
|
|
|
+ api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
|
key = api_config.get("key", None)
|
|
key = api_config.get("key", None)
|
|
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
headers = {"Content-Type": "application/json"}
|
|
@@ -814,13 +850,13 @@ async def generate_ollama_batch_embeddings(
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
|
)
|
|
)
|
|
|
|
|
|
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
+ url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
|
|
|
|
parsed_url = urlparse(url)
|
|
parsed_url = urlparse(url)
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
|
|
|
|
- api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
|
|
|
|
|
+ api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
|
key = api_config.get("key", None)
|
|
key = api_config.get("key", None)
|
|
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
headers = {"Content-Type": "application/json"}
|
|
@@ -873,9 +909,10 @@ class GenerateCompletionForm(BaseModel):
|
|
keep_alive: Optional[Union[int, str]] = None
|
|
keep_alive: Optional[Union[int, str]] = None
|
|
|
|
|
|
|
|
|
|
-@app.post("/api/generate")
|
|
|
|
-@app.post("/api/generate/{url_idx}")
|
|
|
|
|
|
+@router.post("/api/generate")
|
|
|
|
+@router.post("/api/generate/{url_idx}")
|
|
async def generate_completion(
|
|
async def generate_completion(
|
|
|
|
+ request: Request,
|
|
form_data: GenerateCompletionForm,
|
|
form_data: GenerateCompletionForm,
|
|
url_idx: Optional[int] = None,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_verified_user),
|
|
user=Depends(get_verified_user),
|
|
@@ -897,15 +934,17 @@ async def generate_completion(
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
|
)
|
|
)
|
|
|
|
|
|
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
- api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
|
|
|
|
+ url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
+ api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
prefix_id = api_config.get("prefix_id", None)
|
|
prefix_id = api_config.get("prefix_id", None)
|
|
if prefix_id:
|
|
if prefix_id:
|
|
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
|
|
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
|
|
|
|
return await post_streaming_url(
|
|
return await post_streaming_url(
|
|
- f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
|
|
|
|
|
|
+ url=f"{url}/api/generate",
|
|
|
|
+ payload=form_data.model_dump_json(exclude_none=True).encode(),
|
|
|
|
+ key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -936,13 +975,14 @@ async def get_ollama_url(url_idx: Optional[int], model: str):
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
|
|
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
|
|
)
|
|
)
|
|
url_idx = random.choice(models[model]["urls"])
|
|
url_idx = random.choice(models[model]["urls"])
|
|
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
+ url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
return url
|
|
return url
|
|
|
|
|
|
|
|
|
|
-@app.post("/api/chat")
|
|
|
|
-@app.post("/api/chat/{url_idx}")
|
|
|
|
|
|
+@router.post("/api/chat")
|
|
|
|
+@router.post("/api/chat/{url_idx}")
|
|
async def generate_chat_completion(
|
|
async def generate_chat_completion(
|
|
|
|
+ request: Request,
|
|
form_data: GenerateChatCompletionForm,
|
|
form_data: GenerateChatCompletionForm,
|
|
url_idx: Optional[int] = None,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_verified_user),
|
|
user=Depends(get_verified_user),
|
|
@@ -1003,15 +1043,16 @@ async def generate_chat_completion(
|
|
parsed_url = urlparse(url)
|
|
parsed_url = urlparse(url)
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
|
|
|
|
- api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
|
|
|
|
|
+ api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
|
prefix_id = api_config.get("prefix_id", None)
|
|
prefix_id = api_config.get("prefix_id", None)
|
|
if prefix_id:
|
|
if prefix_id:
|
|
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
|
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
|
|
|
|
|
return await post_streaming_url(
|
|
return await post_streaming_url(
|
|
- f"{url}/api/chat",
|
|
|
|
- json.dumps(payload),
|
|
|
|
|
|
+ url=f"{url}/api/chat",
|
|
|
|
+ payload=json.dumps(payload),
|
|
stream=form_data.stream,
|
|
stream=form_data.stream,
|
|
|
|
+ key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
|
|
content_type="application/x-ndjson",
|
|
content_type="application/x-ndjson",
|
|
)
|
|
)
|
|
|
|
|
|
@@ -1043,10 +1084,13 @@ class OpenAICompletionForm(BaseModel):
|
|
model_config = ConfigDict(extra="allow")
|
|
model_config = ConfigDict(extra="allow")
|
|
|
|
|
|
|
|
|
|
-@app.post("/v1/completions")
|
|
|
|
-@app.post("/v1/completions/{url_idx}")
|
|
|
|
|
|
+@router.post("/v1/completions")
|
|
|
|
+@router.post("/v1/completions/{url_idx}")
|
|
async def generate_openai_completion(
|
|
async def generate_openai_completion(
|
|
- form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user)
|
|
|
|
|
|
+ request: Request,
|
|
|
|
+ form_data: dict,
|
|
|
|
+ url_idx: Optional[int] = None,
|
|
|
|
+ user=Depends(get_verified_user),
|
|
):
|
|
):
|
|
try:
|
|
try:
|
|
form_data = OpenAICompletionForm(**form_data)
|
|
form_data = OpenAICompletionForm(**form_data)
|
|
@@ -1099,22 +1143,24 @@ async def generate_openai_completion(
|
|
url = await get_ollama_url(url_idx, payload["model"])
|
|
url = await get_ollama_url(url_idx, payload["model"])
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
|
|
|
|
- api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
|
|
|
|
+ api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
prefix_id = api_config.get("prefix_id", None)
|
|
prefix_id = api_config.get("prefix_id", None)
|
|
|
|
|
|
if prefix_id:
|
|
if prefix_id:
|
|
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
|
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
|
|
|
|
|
return await post_streaming_url(
|
|
return await post_streaming_url(
|
|
- f"{url}/v1/completions",
|
|
|
|
- json.dumps(payload),
|
|
|
|
|
|
+ url=f"{url}/v1/completions",
|
|
|
|
+ payload=json.dumps(payload),
|
|
stream=payload.get("stream", False),
|
|
stream=payload.get("stream", False),
|
|
|
|
+ key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
-@app.post("/v1/chat/completions")
|
|
|
|
-@app.post("/v1/chat/completions/{url_idx}")
|
|
|
|
|
|
+@router.post("/v1/chat/completions")
|
|
|
|
+@router.post("/v1/chat/completions/{url_idx}")
|
|
async def generate_openai_chat_completion(
|
|
async def generate_openai_chat_completion(
|
|
|
|
+ request: Request,
|
|
form_data: dict,
|
|
form_data: dict,
|
|
url_idx: Optional[int] = None,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_verified_user),
|
|
user=Depends(get_verified_user),
|
|
@@ -1172,21 +1218,23 @@ async def generate_openai_chat_completion(
|
|
url = await get_ollama_url(url_idx, payload["model"])
|
|
url = await get_ollama_url(url_idx, payload["model"])
|
|
log.info(f"url: {url}")
|
|
log.info(f"url: {url}")
|
|
|
|
|
|
- api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
|
|
|
|
+ api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
prefix_id = api_config.get("prefix_id", None)
|
|
prefix_id = api_config.get("prefix_id", None)
|
|
if prefix_id:
|
|
if prefix_id:
|
|
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
|
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
|
|
|
|
|
return await post_streaming_url(
|
|
return await post_streaming_url(
|
|
- f"{url}/v1/chat/completions",
|
|
|
|
- json.dumps(payload),
|
|
|
|
|
|
+ url=f"{url}/v1/chat/completions",
|
|
|
|
+ payload=json.dumps(payload),
|
|
stream=payload.get("stream", False),
|
|
stream=payload.get("stream", False),
|
|
|
|
+ key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
-@app.get("/v1/models")
|
|
|
|
-@app.get("/v1/models/{url_idx}")
|
|
|
|
|
|
+@router.get("/v1/models")
|
|
|
|
+@router.get("/v1/models/{url_idx}")
|
|
async def get_openai_models(
|
|
async def get_openai_models(
|
|
|
|
+ request: Request,
|
|
url_idx: Optional[int] = None,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_verified_user),
|
|
user=Depends(get_verified_user),
|
|
):
|
|
):
|
|
@@ -1205,7 +1253,7 @@ async def get_openai_models(
|
|
]
|
|
]
|
|
|
|
|
|
else:
|
|
else:
|
|
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
+ url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
try:
|
|
try:
|
|
r = requests.request(method="GET", url=f"{url}/api/tags")
|
|
r = requests.request(method="GET", url=f"{url}/api/tags")
|
|
r.raise_for_status()
|
|
r.raise_for_status()
|
|
@@ -1329,9 +1377,10 @@ async def download_file_stream(
|
|
|
|
|
|
|
|
|
|
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
|
|
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
|
|
-@app.post("/models/download")
|
|
|
|
-@app.post("/models/download/{url_idx}")
|
|
|
|
|
|
+@router.post("/models/download")
|
|
|
|
+@router.post("/models/download/{url_idx}")
|
|
async def download_model(
|
|
async def download_model(
|
|
|
|
+ request: Request,
|
|
form_data: UrlForm,
|
|
form_data: UrlForm,
|
|
url_idx: Optional[int] = None,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_admin_user),
|
|
user=Depends(get_admin_user),
|
|
@@ -1346,7 +1395,7 @@ async def download_model(
|
|
|
|
|
|
if url_idx is None:
|
|
if url_idx is None:
|
|
url_idx = 0
|
|
url_idx = 0
|
|
- url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
+ url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
file_name = parse_huggingface_url(form_data.url)
|
|
file_name = parse_huggingface_url(form_data.url)
|
|
|
|
|
|
@@ -1360,16 +1409,17 @@ async def download_model(
|
|
return None
|
|
return None
|
|
|
|
|
|
|
|
|
|
-@app.post("/models/upload")
|
|
|
|
-@app.post("/models/upload/{url_idx}")
|
|
|
|
|
|
+@router.post("/models/upload")
|
|
|
|
+@router.post("/models/upload/{url_idx}")
|
|
def upload_model(
|
|
def upload_model(
|
|
|
|
+ request: Request,
|
|
file: UploadFile = File(...),
|
|
file: UploadFile = File(...),
|
|
url_idx: Optional[int] = None,
|
|
url_idx: Optional[int] = None,
|
|
user=Depends(get_admin_user),
|
|
user=Depends(get_admin_user),
|
|
):
|
|
):
|
|
if url_idx is None:
|
|
if url_idx is None:
|
|
url_idx = 0
|
|
url_idx = 0
|
|
- ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
+ ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
|
|
|
|
|
file_path = f"{UPLOAD_DIR}/{file.filename}"
|
|
file_path = f"{UPLOAD_DIR}/{file.filename}"
|
|
|
|
|