Pārlūkot izejas kodu

fix: openai issue

Timothy J. Baek 1 gadu atpakaļ
vecāks
revīzija
7bc0c09b25

+ 9 - 3
backend/apps/web/main.py

@@ -1,7 +1,7 @@
 from fastapi import FastAPI, Depends
 from fastapi.routing import APIRoute
 from fastapi.middleware.cors import CORSMiddleware
-from apps.web.routers import auths, users, chats, modelfiles, utils
+from apps.web.routers import auths, users, chats, modelfiles, configs, utils
 from config import WEBUI_VERSION, WEBUI_AUTH
 
 app = FastAPI()
@@ -9,6 +9,7 @@ app = FastAPI()
 origins = ["*"]
 
 app.state.ENABLE_SIGNUP = True
+app.state.DEFAULT_MODELS = "llava:13b"
 
 app.add_middleware(
     CORSMiddleware,
@@ -19,13 +20,18 @@ app.add_middleware(
 )
 
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
-
 app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])
 app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
+app.include_router(configs.router, prefix="/configs", tags=["configs"])
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
 
 
 @app.get("/")
 async def get_status():
-    return {"status": True, "version": WEBUI_VERSION, "auth": WEBUI_AUTH}
+    return {
+        "status": True,
+        "version": WEBUI_VERSION,
+        "auth": WEBUI_AUTH,
+        "default_models": app.state.DEFAULT_MODELS,
+    }

+ 41 - 0
backend/apps/web/routers/configs.py

@@ -0,0 +1,41 @@
+from fastapi import Response, Request
+from fastapi import Depends, FastAPI, HTTPException, status
+from datetime import datetime, timedelta
+from typing import List, Union
+
+from fastapi import APIRouter
+from pydantic import BaseModel
+import time
+import uuid
+
+from apps.web.models.users import Users
+
+
+from utils.utils import get_password_hash, get_current_user, create_token
+from utils.misc import get_gravatar_url, validate_email_format
+from constants import ERROR_MESSAGES
+
+router = APIRouter()
+
+
+class SetDefaultModelsForm(BaseModel):
+    models: str
+
+
+############################
+# SetDefaultModels
+############################
+
+
+@router.post("/default/models", response_model=str)
+async def set_global_default_models(
+    request: Request, form_data: SetDefaultModelsForm, user=Depends(get_current_user)
+):
+    if user.role == "admin":
+        request.app.state.DEFAULT_MODELS = form_data.models
+        return request.app.state.DEFAULT_MODELS
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )

+ 31 - 0
src/lib/apis/configs/index.ts

@@ -0,0 +1,31 @@
+import { WEBUI_API_BASE_URL } from '$lib/constants';
+
+export const setDefaultModels = async (token: string, models: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/configs/default/models`, {
+		method: 'GET',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			models: models
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 20 - 11
src/routes/(app)/+page.svelte

@@ -6,7 +6,7 @@
 	import { goto } from '$app/navigation';
 	import { page } from '$app/stores';
 
-	import { models, modelfiles, user, settings, chats, chatId } from '$lib/stores';
+	import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores';
 	import { OLLAMA_API_BASE_URL } from '$lib/constants';
 
 	import { generateChatCompletion, generateTitle } from '$lib/apis/ollama';
@@ -90,9 +90,18 @@
 			messages: {},
 			currentId: null
 		};
-		selectedModels = $page.url.searchParams.get('models')
-			? $page.url.searchParams.get('models')?.split(',')
-			: $settings.models ?? [''];
+
+		console.log($config);
+
+		if ($page.url.searchParams.get('models')) {
+			selectedModels = $page.url.searchParams.get('models')?.split(',');
+		} else if ($settings?.models) {
+			selectedModels = $settings?.models;
+		} else if ($config?.default_models) {
+			selectedModels = $config?.default_models.split(',');
+		} else {
+			selectedModels = [''];
+		}
 
 		let _settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
 		settings.set({
@@ -383,13 +392,13 @@
 										  }
 										: { content: message.content })
 								})),
-							seed: $settings.options.seed ?? undefined,
-							stop: $settings.options.stop ?? undefined,
-							temperature: $settings.options.temperature ?? undefined,
-							top_p: $settings.options.top_p ?? undefined,
-							num_ctx: $settings.options.num_ctx ?? undefined,
-							frequency_penalty: $settings.options.repeat_penalty ?? undefined,
-							max_tokens: $settings.options.num_predict ?? undefined
+							seed: $settings?.options?.seed ?? undefined,
+							stop: $settings?.options?.stop ?? undefined,
+							temperature: $settings?.options?.temperature ?? undefined,
+							top_p: $settings?.options?.top_p ?? undefined,
+							num_ctx: $settings?.options?.num_ctx ?? undefined,
+							frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
+							max_tokens: $settings?.options?.num_predict ?? undefined
 						})
 					}
 				).catch((err) => {

+ 7 - 7
src/routes/(app)/c/[id]/+page.svelte

@@ -409,13 +409,13 @@
 										  }
 										: { content: message.content })
 								})),
-							seed: $settings.options.seed ?? undefined,
-							stop: $settings.options.stop ?? undefined,
-							temperature: $settings.options.temperature ?? undefined,
-							top_p: $settings.options.top_p ?? undefined,
-							num_ctx: $settings.options.num_ctx ?? undefined,
-							frequency_penalty: $settings.options.repeat_penalty ?? undefined,
-							max_tokens: $settings.options.num_predict ?? undefined
+							seed: $settings?.options?.seed ?? undefined,
+							stop: $settings?.options?.stop ?? undefined,
+							temperature: $settings?.options?.temperature ?? undefined,
+							top_p: $settings?.options?.top_p ?? undefined,
+							num_ctx: $settings?.options?.num_ctx ?? undefined,
+							frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
+							max_tokens: $settings?.options?.num_predict ?? undefined
 						})
 					}
 				).catch((err) => {