浏览代码

refac: switch to meta and params, remove source

Jun Siang Cheah 11 月之前
父节点
当前提交
f21c8626d6

+ 1 - 3
backend/apps/litellm/main.py

@@ -78,9 +78,7 @@ with open(LITELLM_CONFIG_DIR, "r") as file:
 
 
 app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value
 app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value
 app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value
 app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value
-app.state.MODEL_CONFIG = [
-    model.to_form() for model in Models.get_all_models_by_source("litellm")
-]
+app.state.MODEL_CONFIG = Models.get_all_models()
 
 
 app.state.ENABLE = ENABLE_LITELLM
 app.state.ENABLE = ENABLE_LITELLM
 app.state.CONFIG = litellm_config
 app.state.CONFIG = litellm_config

+ 1 - 3
backend/apps/ollama/main.py

@@ -66,9 +66,7 @@ app.state.config = AppConfig()
 
 
 app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
 app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
 app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
-app.state.MODEL_CONFIG = [
-    model.to_form() for model in Models.get_all_models_by_source("ollama")
-]
+app.state.MODEL_CONFIG = Models.get_all_models()
 
 
 app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 app.state.MODELS = {}
 app.state.MODELS = {}

+ 1 - 3
backend/apps/openai/main.py

@@ -52,9 +52,7 @@ app.state.config = AppConfig()
 
 
 app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
 app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
 app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
-app.state.MODEL_CONFIG = [
-    model.to_form() for model in Models.get_all_models_by_source("openai")
-]
+app.state.MODEL_CONFIG = Models.get_all_models()
 
 
 
 
 app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
 app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API

+ 12 - 0
backend/apps/web/internal/db.py

@@ -1,3 +1,5 @@
+import json
+
 from peewee import *
 from peewee import *
 from peewee_migrate import Router
 from peewee_migrate import Router
 from playhouse.db_url import connect
 from playhouse.db_url import connect
@@ -8,6 +10,16 @@ import logging
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["DB"])
 log.setLevel(SRC_LOG_LEVELS["DB"])
 
 
+
+class JSONField(TextField):
+    def db_value(self, value):
+        return json.dumps(value)
+
+    def python_value(self, value):
+        if value is not None:
+            return json.loads(value)
+
+
 # Check if the file exists
 # Check if the file exists
 if os.path.exists(f"{DATA_DIR}/ollama.db"):
 if os.path.exists(f"{DATA_DIR}/ollama.db"):
     # Rename the file
     # Rename the file

+ 4 - 9
backend/apps/web/internal/migrations/008_add_models.py → backend/apps/web/internal/migrations/009_add_models.py

@@ -1,4 +1,4 @@
-"""Peewee migrations -- 008_add_models.py.
+"""Peewee migrations -- 009_add_models.py.
 
 
 Some examples (model - class or model name)::
 Some examples (model - class or model name)::
 
 
@@ -39,20 +39,15 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
 
 
     @migrator.create_model
     @migrator.create_model
     class Model(pw.Model):
     class Model(pw.Model):
-        id = pw.TextField()
-        source = pw.TextField()
-        base_model = pw.TextField(null=True)
+        id = pw.TextField(unique=True)
+        meta = pw.TextField()
+        base_model_id = pw.TextField(null=True)
         name = pw.TextField()
         name = pw.TextField()
         params = pw.TextField()
         params = pw.TextField()
 
 
         class Meta:
         class Meta:
             table_name = "model"
             table_name = "model"
 
 
-            indexes = (
-                # Create a unique index on the id, source columns
-                (("id", "source"), True),
-            )
-
 
 
 def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
 def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
     """Write your rollback migrations here."""
     """Write your rollback migrations here."""

+ 29 - 56
backend/apps/web/models/models.py

@@ -6,7 +6,7 @@ import peewee as pw
 from playhouse.shortcuts import model_to_dict
 from playhouse.shortcuts import model_to_dict
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
-from apps.web.internal.db import DB
+from apps.web.internal.db import DB, JSONField
 
 
 from config import SRC_LOG_LEVELS
 from config import SRC_LOG_LEVELS
 
 
@@ -22,6 +22,12 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 # ModelParams is a model for the data stored in the params field of the Model table
 # ModelParams is a model for the data stored in the params field of the Model table
 # It isn't currently used in the backend, but it's here as a reference
 # It isn't currently used in the backend, but it's here as a reference
 class ModelParams(BaseModel):
 class ModelParams(BaseModel):
+    pass
+
+
+# ModelMeta is a model for the data stored in the meta field of the Model table
+# It isn't currently used in the backend, but it's here as a reference
+class ModelMeta(BaseModel):
     description: str
     description: str
     """
     """
         User-facing description of the model.
         User-facing description of the model.
@@ -34,50 +40,42 @@ class ModelParams(BaseModel):
 
 
 
 
 class Model(pw.Model):
 class Model(pw.Model):
-    id = pw.TextField()
+    id = pw.TextField(unique=True)
     """
     """
         The model's id as used in the API. If set to an existing model, it will override the model.
         The model's id as used in the API. If set to an existing model, it will override the model.
     """
     """
 
 
-    source = pw.TextField()
+    meta = JSONField()
     """
     """
-    The source of the model, e.g., ollama, openai, or litellm.
+        Holds a JSON encoded blob of metadata, see `ModelMeta`.
     """
     """
 
 
-    base_model = pw.TextField(null=True)
+    base_model_id = pw.TextField(null=True)
     """
     """
-    An optional pointer to the actual model that should be used when proxying requests.
-    Currently unused - but will be used to support Modelfile like behaviour in the future
+        An optional pointer to the actual model that should be used when proxying requests.
+        Currently unused - but will be used to support Modelfile like behaviour in the future
     """
     """
 
 
     name = pw.TextField()
     name = pw.TextField()
     """
     """
-    The human-readable display name of the model.
+        The human-readable display name of the model.
     """
     """
 
 
-    params = pw.TextField()
+    params = JSONField()
     """
     """
-    Holds a JSON encoded blob of parameters, see `ModelParams`.
+        Holds a JSON encoded blob of parameters, see `ModelParams`.
     """
     """
 
 
     class Meta:
     class Meta:
         database = DB
         database = DB
 
 
-        indexes = (
-            # Create a unique index on the id, source columns
-            (("id", "source"), True),
-        )
-
 
 
 class ModelModel(BaseModel):
 class ModelModel(BaseModel):
     id: str
     id: str
-    source: str
-    base_model: Optional[str] = None
+    meta: ModelMeta
+    base_model_id: Optional[str] = None
     name: str
     name: str
-    params: str
-
-    def to_form(self) -> "ModelForm":
-        return ModelForm(**{**self.model_dump(), "params": json.loads(self.params)})
+    params: ModelParams
 
 
 
 
 ####################
 ####################
@@ -85,17 +83,6 @@ class ModelModel(BaseModel):
 ####################
 ####################
 
 
 
 
-class ModelForm(BaseModel):
-    id: str
-    source: str
-    base_model: Optional[str] = None
-    name: str
-    params: dict
-
-    def to_db_model(self) -> ModelModel:
-        return ModelModel(**{**self.model_dump(), "params": json.dumps(self.params)})
-
-
 class ModelsTable:
 class ModelsTable:
 
 
     def __init__(
     def __init__(
@@ -108,51 +95,37 @@ class ModelsTable:
     def get_all_models(self) -> list[ModelModel]:
     def get_all_models(self) -> list[ModelModel]:
         return [ModelModel(**model_to_dict(model)) for model in Model.select()]
         return [ModelModel(**model_to_dict(model)) for model in Model.select()]
 
 
-    def get_all_models_by_source(self, source: str) -> list[ModelModel]:
-        return [
-            ModelModel(**model_to_dict(model))
-            for model in Model.select().where(Model.source == source)
-        ]
-
-    def update_all_models(self, models: list[ModelForm]) -> bool:
+    def update_all_models(self, models: list[ModelModel]) -> bool:
         try:
         try:
             with self.db.atomic():
             with self.db.atomic():
                 # Fetch current models from the database
                 # Fetch current models from the database
                 current_models = self.get_all_models()
                 current_models = self.get_all_models()
-                current_model_dict = {
-                    (model.id, model.source): model for model in current_models
-                }
+                current_model_dict = {model.id: model for model in current_models}
 
 
-                # Create a set of model IDs and sources from the current models and the new models
+                # Create a set of model IDs from the current models and the new models
                 current_model_keys = set(current_model_dict.keys())
                 current_model_keys = set(current_model_dict.keys())
-                new_model_keys = set((model.id, model.source) for model in models)
+                new_model_keys = set(model.id for model in models)
 
 
                 # Determine which models need to be created, updated, or deleted
                 # Determine which models need to be created, updated, or deleted
                 models_to_create = [
                 models_to_create = [
-                    model
-                    for model in models
-                    if (model.id, model.source) not in current_model_keys
+                    model for model in models if model.id not in current_model_keys
                 ]
                 ]
                 models_to_update = [
                 models_to_update = [
-                    model
-                    for model in models
-                    if (model.id, model.source) in current_model_keys
+                    model for model in models if model.id in current_model_keys
                 ]
                 ]
                 models_to_delete = current_model_keys - new_model_keys
                 models_to_delete = current_model_keys - new_model_keys
 
 
                 # Perform the necessary database operations
                 # Perform the necessary database operations
                 for model in models_to_create:
                 for model in models_to_create:
-                    Model.create(**model.to_db_model().model_dump())
+                    Model.create(**model.model_dump())
 
 
                 for model in models_to_update:
                 for model in models_to_update:
-                    Model.update(**model.to_db_model().model_dump()).where(
-                        (Model.id == model.id) & (Model.source == model.source)
+                    Model.update(**model.model_dump()).where(
+                        Model.id == model.id
                     ).execute()
                     ).execute()
 
 
                 for model_id, model_source in models_to_delete:
                 for model_id, model_source in models_to_delete:
-                    Model.delete().where(
-                        (Model.id == model_id) & (Model.source == model_source)
-                    ).execute()
+                    Model.delete().where(Model.id == model_id).execute()
 
 
             return True
             return True
         except Exception as e:
         except Exception as e:

+ 7 - 16
backend/main.py

@@ -37,7 +37,7 @@ import asyncio
 from pydantic import BaseModel
 from pydantic import BaseModel
 from typing import List, Optional
 from typing import List, Optional
 
 
-from apps.web.models.models import Models, ModelModel, ModelForm
+from apps.web.models.models import Models, ModelModel
 from utils.utils import get_admin_user
 from utils.utils import get_admin_user
 from apps.rag.utils import rag_messages
 from apps.rag.utils import rag_messages
 
 
@@ -112,7 +112,7 @@ app.state.config = AppConfig()
 app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
 app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
 app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 
 
-app.state.MODEL_CONFIG = [model.to_form() for model in Models.get_all_models()]
+app.state.MODEL_CONFIG = Models.get_all_models()
 
 
 app.state.config.WEBHOOK_URL = WEBHOOK_URL
 app.state.config.WEBHOOK_URL = WEBHOOK_URL
 
 
@@ -320,7 +320,7 @@ async def update_model_filter_config(
 
 
 
 
 class SetModelConfigForm(BaseModel):
 class SetModelConfigForm(BaseModel):
-    models: List[ModelForm]
+    models: List[ModelModel]
 
 
 
 
 @app.post("/api/config/models")
 @app.post("/api/config/models")
@@ -333,19 +333,10 @@ async def update_model_config(
             detail=ERROR_MESSAGES.DEFAULT("Failed to update model config"),
             detail=ERROR_MESSAGES.DEFAULT("Failed to update model config"),
         )
         )
 
 
-    ollama_app.state.MODEL_CONFIG = [
-        model for model in form_data.models if model.source == "ollama"
-    ]
-
-    openai_app.state.MODEL_CONFIG = [
-        model for model in form_data.models if model.source == "openai"
-    ]
-
-    litellm_app.state.MODEL_CONFIG = [
-        model for model in form_data.models if model.source == "litellm"
-    ]
-
-    app.state.MODEL_CONFIG = [model for model in form_data.models]
+    ollama_app.state.MODEL_CONFIG = form_data.models
+    openai_app.state.MODEL_CONFIG = form_data.models
+    litellm_app.state.MODEL_CONFIG = form_data.models
+    app.state.MODEL_CONFIG = form_data.models
 
 
     return {"models": app.state.MODEL_CONFIG}
     return {"models": app.state.MODEL_CONFIG}
 
 

+ 5 - 3
src/lib/apis/index.ts

@@ -227,16 +227,18 @@ export const getModelConfig = async (token: string): Promise<GlobalModelConfig>
 export interface ModelConfig {
 export interface ModelConfig {
 	id: string;
 	id: string;
 	name: string;
 	name: string;
-	source: string;
-	base_model?: string;
+	meta: ModelMeta;
+	base_model_id?: string;
 	params: ModelParams;
 	params: ModelParams;
 }
 }
 
 
-export interface ModelParams {
+export interface ModelMeta {
 	description?: string;
 	description?: string;
 	vision_capable?: boolean;
 	vision_capable?: boolean;
 }
 }
 
 
+export interface ModelParams {}
+
 export type GlobalModelConfig = ModelConfig[];
 export type GlobalModelConfig = ModelConfig[];
 
 
 export const updateModelConfig = async (token: string, config: GlobalModelConfig) => {
 export const updateModelConfig = async (token: string, config: GlobalModelConfig) => {

+ 1 - 1
src/lib/components/chat/Chat.svelte

@@ -343,7 +343,7 @@
 					const hasImages = messages.some((message) =>
 					const hasImages = messages.some((message) =>
 						message.files?.some((file) => file.type === 'image')
 						message.files?.some((file) => file.type === 'image')
 					);
 					);
-					if (hasImages && !(model.custom_info?.params.vision_capable ?? true)) {
+					if (hasImages && !(model.custom_info?.meta.vision_capable ?? true)) {
 						toast.error(
 						toast.error(
 							$i18n.t('Model {{modelName}} is not vision capable', {
 							$i18n.t('Model {{modelName}} is not vision capable', {
 								modelName: model.custom_info?.name ?? model.name ?? model.id
 								modelName: model.custom_info?.name ?? model.name ?? model.id

+ 1 - 1
src/lib/components/chat/MessageInput.svelte

@@ -359,7 +359,7 @@
 			if (!model) {
 			if (!model) {
 				continue;
 				continue;
 			}
 			}
-			if (model.custom_info?.params.vision_capable ?? true) {
+			if (model.custom_info?.meta.vision_capable ?? true) {
 				visionCapableCount++;
 				visionCapableCount++;
 			}
 			}
 		}
 		}

+ 2 - 2
src/lib/components/chat/ModelSelector/Selector.svelte

@@ -307,10 +307,10 @@
 									</div>
 									</div>
 								</Tooltip>
 								</Tooltip>
 							{/if}
 							{/if}
-							{#if item.info?.custom_info?.params.description}
+							{#if item.info?.custom_info?.meta.description}
 								<Tooltip
 								<Tooltip
 									content={`${sanitizeResponseContent(
 									content={`${sanitizeResponseContent(
-										item.info.custom_info?.params.description
+										item.info.custom_info?.meta.description
 									).replaceAll('\n', '<br>')}`}
 									).replaceAll('\n', '<br>')}`}
 								>
 								>
 									<div class="">
 									<div class="">

+ 6 - 10
src/lib/components/chat/Settings/Models.svelte

@@ -80,8 +80,8 @@
 		const model = $models.find((m) => m.id === selectedModelId);
 		const model = $models.find((m) => m.id === selectedModelId);
 		if (model) {
 		if (model) {
 			modelName = model.custom_info?.name ?? model.name;
 			modelName = model.custom_info?.name ?? model.name;
-			modelDescription = model.custom_info?.params.description ?? '';
-			modelIsVisionCapable = model.custom_info?.params.vision_capable ?? false;
+			modelDescription = model.custom_info?.meta.description ?? '';
+			modelIsVisionCapable = model.custom_info?.meta.vision_capable ?? false;
 		}
 		}
 	};
 	};
 
 
@@ -518,18 +518,16 @@
 		if (!model) {
 		if (!model) {
 			return;
 			return;
 		}
 		}
-		const modelSource =
-			'details' in model ? 'ollama' : model.source === 'LiteLLM' ? 'litellm' : 'openai';
 		// Remove any existing config
 		// Remove any existing config
 		modelConfig = modelConfig.filter(
 		modelConfig = modelConfig.filter(
-			(m) => !(m.id === selectedModelId && m.source === modelSource)
+			(m) => !(m.id === selectedModelId)
 		);
 		);
 		// Add new config
 		// Add new config
 		modelConfig.push({
 		modelConfig.push({
 			id: selectedModelId,
 			id: selectedModelId,
 			name: modelName,
 			name: modelName,
-			source: modelSource,
-			params: {
+			params: {},
+			meta: {
 				description: modelDescription,
 				description: modelDescription,
 				vision_capable: modelIsVisionCapable
 				vision_capable: modelIsVisionCapable
 			}
 			}
@@ -549,10 +547,8 @@
 		if (!model) {
 		if (!model) {
 			return;
 			return;
 		}
 		}
-		const modelSource =
-			'details' in model ? 'ollama' : model.source === 'LiteLLM' ? 'litellm' : 'openai';
 		modelConfig = modelConfig.filter(
 		modelConfig = modelConfig.filter(
-			(m) => !(m.id === selectedModelId && m.source === modelSource)
+			(m) => !(m.id === selectedModelId)
 		);
 		);
 		await updateModelConfig(localStorage.token, modelConfig);
 		await updateModelConfig(localStorage.token, modelConfig);
 		toast.success(
 		toast.success(