فهرست منبع

refac: switch to meta and params, remove source

Jun Siang Cheah 11 ماه پیش
والد
کامیت
f21c8626d6

+ 1 - 3
backend/apps/litellm/main.py

@@ -78,9 +78,7 @@ with open(LITELLM_CONFIG_DIR, "r") as file:
 
 app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value
 app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value
-app.state.MODEL_CONFIG = [
-    model.to_form() for model in Models.get_all_models_by_source("litellm")
-]
+app.state.MODEL_CONFIG = Models.get_all_models()
 
 app.state.ENABLE = ENABLE_LITELLM
 app.state.CONFIG = litellm_config

+ 1 - 3
backend/apps/ollama/main.py

@@ -66,9 +66,7 @@ app.state.config = AppConfig()
 
 app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
 app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
-app.state.MODEL_CONFIG = [
-    model.to_form() for model in Models.get_all_models_by_source("ollama")
-]
+app.state.MODEL_CONFIG = Models.get_all_models()
 
 app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 app.state.MODELS = {}

+ 1 - 3
backend/apps/openai/main.py

@@ -52,9 +52,7 @@ app.state.config = AppConfig()
 
 app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
 app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
-app.state.MODEL_CONFIG = [
-    model.to_form() for model in Models.get_all_models_by_source("openai")
-]
+app.state.MODEL_CONFIG = Models.get_all_models()
 
 
 app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API

+ 12 - 0
backend/apps/web/internal/db.py

@@ -1,3 +1,5 @@
+import json
+
 from peewee import *
 from peewee_migrate import Router
 from playhouse.db_url import connect
@@ -8,6 +10,16 @@ import logging
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["DB"])
 
+
+class JSONField(TextField):
+    def db_value(self, value):
+        return json.dumps(value)
+
+    def python_value(self, value):
+        if value is not None:
+            return json.loads(value)
+
+
 # Check if the file exists
 if os.path.exists(f"{DATA_DIR}/ollama.db"):
     # Rename the file

+ 4 - 9
backend/apps/web/internal/migrations/008_add_models.py → backend/apps/web/internal/migrations/009_add_models.py

@@ -1,4 +1,4 @@
-"""Peewee migrations -- 008_add_models.py.
+"""Peewee migrations -- 009_add_models.py.
 
 Some examples (model - class or model name)::
 
@@ -39,20 +39,15 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
 
     @migrator.create_model
     class Model(pw.Model):
-        id = pw.TextField()
-        source = pw.TextField()
-        base_model = pw.TextField(null=True)
+        id = pw.TextField(unique=True)
+        meta = pw.TextField()
+        base_model_id = pw.TextField(null=True)
         name = pw.TextField()
         params = pw.TextField()
 
         class Meta:
             table_name = "model"
 
-            indexes = (
-                # Create a unique index on the id, source columns
-                (("id", "source"), True),
-            )
-
 
 def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
     """Write your rollback migrations here."""

+ 29 - 56
backend/apps/web/models/models.py

@@ -6,7 +6,7 @@ import peewee as pw
 from playhouse.shortcuts import model_to_dict
 from pydantic import BaseModel
 
-from apps.web.internal.db import DB
+from apps.web.internal.db import DB, JSONField
 
 from config import SRC_LOG_LEVELS
 
@@ -22,6 +22,12 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 # ModelParams is a model for the data stored in the params field of the Model table
 # It isn't currently used in the backend, but it's here as a reference
 class ModelParams(BaseModel):
+    pass
+
+
+# ModelMeta is a model for the data stored in the meta field of the Model table
+# It isn't currently used in the backend, but it's here as a reference
+class ModelMeta(BaseModel):
     description: str
     """
         User-facing description of the model.
@@ -34,50 +40,42 @@ class ModelParams(BaseModel):
 
 
 class Model(pw.Model):
-    id = pw.TextField()
+    id = pw.TextField(unique=True)
     """
         The model's id as used in the API. If set to an existing model, it will override the model.
     """
 
-    source = pw.TextField()
+    meta = JSONField()
     """
-    The source of the model, e.g., ollama, openai, or litellm.
+        Holds a JSON encoded blob of metadata, see `ModelMeta`.
     """
 
-    base_model = pw.TextField(null=True)
+    base_model_id = pw.TextField(null=True)
     """
-    An optional pointer to the actual model that should be used when proxying requests.
-    Currently unused - but will be used to support Modelfile like behaviour in the future
+        An optional pointer to the actual model that should be used when proxying requests.
+        Currently unused - but will be used to support Modelfile like behaviour in the future
     """
 
     name = pw.TextField()
     """
-    The human-readable display name of the model.
+        The human-readable display name of the model.
     """
 
-    params = pw.TextField()
+    params = JSONField()
     """
-    Holds a JSON encoded blob of parameters, see `ModelParams`.
+        Holds a JSON encoded blob of parameters, see `ModelParams`.
     """
 
     class Meta:
         database = DB
 
-        indexes = (
-            # Create a unique index on the id, source columns
-            (("id", "source"), True),
-        )
-
 
 class ModelModel(BaseModel):
     id: str
-    source: str
-    base_model: Optional[str] = None
+    meta: ModelMeta
+    base_model_id: Optional[str] = None
     name: str
-    params: str
-
-    def to_form(self) -> "ModelForm":
-        return ModelForm(**{**self.model_dump(), "params": json.loads(self.params)})
+    params: ModelParams
 
 
 ####################
@@ -85,17 +83,6 @@ class ModelModel(BaseModel):
 ####################
 
 
-class ModelForm(BaseModel):
-    id: str
-    source: str
-    base_model: Optional[str] = None
-    name: str
-    params: dict
-
-    def to_db_model(self) -> ModelModel:
-        return ModelModel(**{**self.model_dump(), "params": json.dumps(self.params)})
-
-
 class ModelsTable:
 
     def __init__(
@@ -108,51 +95,37 @@ class ModelsTable:
     def get_all_models(self) -> list[ModelModel]:
         return [ModelModel(**model_to_dict(model)) for model in Model.select()]
 
-    def get_all_models_by_source(self, source: str) -> list[ModelModel]:
-        return [
-            ModelModel(**model_to_dict(model))
-            for model in Model.select().where(Model.source == source)
-        ]
-
-    def update_all_models(self, models: list[ModelForm]) -> bool:
+    def update_all_models(self, models: list[ModelModel]) -> bool:
         try:
             with self.db.atomic():
                 # Fetch current models from the database
                 current_models = self.get_all_models()
-                current_model_dict = {
-                    (model.id, model.source): model for model in current_models
-                }
+                current_model_dict = {model.id: model for model in current_models}
 
-                # Create a set of model IDs and sources from the current models and the new models
+                # Create a set of model IDs from the current models and the new models
                 current_model_keys = set(current_model_dict.keys())
-                new_model_keys = set((model.id, model.source) for model in models)
+                new_model_keys = set(model.id for model in models)
 
                 # Determine which models need to be created, updated, or deleted
                 models_to_create = [
-                    model
-                    for model in models
-                    if (model.id, model.source) not in current_model_keys
+                    model for model in models if model.id not in current_model_keys
                 ]
                 models_to_update = [
-                    model
-                    for model in models
-                    if (model.id, model.source) in current_model_keys
+                    model for model in models if model.id in current_model_keys
                 ]
                 models_to_delete = current_model_keys - new_model_keys
 
                 # Perform the necessary database operations
                 for model in models_to_create:
-                    Model.create(**model.to_db_model().model_dump())
+                    Model.create(**model.model_dump())
 
                 for model in models_to_update:
-                    Model.update(**model.to_db_model().model_dump()).where(
-                        (Model.id == model.id) & (Model.source == model.source)
+                    Model.update(**model.model_dump()).where(
+                        Model.id == model.id
                     ).execute()
 
                 for model_id, model_source in models_to_delete:
-                    Model.delete().where(
-                        (Model.id == model_id) & (Model.source == model_source)
-                    ).execute()
+                    Model.delete().where(Model.id == model_id).execute()
 
             return True
         except Exception as e:

+ 7 - 16
backend/main.py

@@ -37,7 +37,7 @@ import asyncio
 from pydantic import BaseModel
 from typing import List, Optional
 
-from apps.web.models.models import Models, ModelModel, ModelForm
+from apps.web.models.models import Models, ModelModel
 from utils.utils import get_admin_user
 from apps.rag.utils import rag_messages
 
@@ -112,7 +112,7 @@ app.state.config = AppConfig()
 app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
 app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 
-app.state.MODEL_CONFIG = [model.to_form() for model in Models.get_all_models()]
+app.state.MODEL_CONFIG = Models.get_all_models()
 
 app.state.config.WEBHOOK_URL = WEBHOOK_URL
 
@@ -320,7 +320,7 @@ async def update_model_filter_config(
 
 
 class SetModelConfigForm(BaseModel):
-    models: List[ModelForm]
+    models: List[ModelModel]
 
 
 @app.post("/api/config/models")
@@ -333,19 +333,10 @@ async def update_model_config(
             detail=ERROR_MESSAGES.DEFAULT("Failed to update model config"),
         )
 
-    ollama_app.state.MODEL_CONFIG = [
-        model for model in form_data.models if model.source == "ollama"
-    ]
-
-    openai_app.state.MODEL_CONFIG = [
-        model for model in form_data.models if model.source == "openai"
-    ]
-
-    litellm_app.state.MODEL_CONFIG = [
-        model for model in form_data.models if model.source == "litellm"
-    ]
-
-    app.state.MODEL_CONFIG = [model for model in form_data.models]
+    ollama_app.state.MODEL_CONFIG = form_data.models
+    openai_app.state.MODEL_CONFIG = form_data.models
+    litellm_app.state.MODEL_CONFIG = form_data.models
+    app.state.MODEL_CONFIG = form_data.models
 
     return {"models": app.state.MODEL_CONFIG}
 

+ 5 - 3
src/lib/apis/index.ts

@@ -227,16 +227,18 @@ export const getModelConfig = async (token: string): Promise<GlobalModelConfig>
 export interface ModelConfig {
 	id: string;
 	name: string;
-	source: string;
-	base_model?: string;
+	meta: ModelMeta;
+	base_model_id?: string;
 	params: ModelParams;
 }
 
-export interface ModelParams {
+export interface ModelMeta {
 	description?: string;
 	vision_capable?: boolean;
 }
 
+export interface ModelParams {}
+
 export type GlobalModelConfig = ModelConfig[];
 
 export const updateModelConfig = async (token: string, config: GlobalModelConfig) => {

+ 1 - 1
src/lib/components/chat/Chat.svelte

@@ -343,7 +343,7 @@
 					const hasImages = messages.some((message) =>
 						message.files?.some((file) => file.type === 'image')
 					);
-					if (hasImages && !(model.custom_info?.params.vision_capable ?? true)) {
+					if (hasImages && !(model.custom_info?.meta.vision_capable ?? true)) {
 						toast.error(
 							$i18n.t('Model {{modelName}} is not vision capable', {
 								modelName: model.custom_info?.name ?? model.name ?? model.id

+ 1 - 1
src/lib/components/chat/MessageInput.svelte

@@ -359,7 +359,7 @@
 			if (!model) {
 				continue;
 			}
-			if (model.custom_info?.params.vision_capable ?? true) {
+			if (model.custom_info?.meta.vision_capable ?? true) {
 				visionCapableCount++;
 			}
 		}

+ 2 - 2
src/lib/components/chat/ModelSelector/Selector.svelte

@@ -307,10 +307,10 @@
 									</div>
 								</Tooltip>
 							{/if}
-							{#if item.info?.custom_info?.params.description}
+							{#if item.info?.custom_info?.meta.description}
 								<Tooltip
 									content={`${sanitizeResponseContent(
-										item.info.custom_info?.params.description
+										item.info.custom_info?.meta.description
 									).replaceAll('\n', '<br>')}`}
 								>
 									<div class="">

+ 6 - 10
src/lib/components/chat/Settings/Models.svelte

@@ -80,8 +80,8 @@
 		const model = $models.find((m) => m.id === selectedModelId);
 		if (model) {
 			modelName = model.custom_info?.name ?? model.name;
-			modelDescription = model.custom_info?.params.description ?? '';
-			modelIsVisionCapable = model.custom_info?.params.vision_capable ?? false;
+			modelDescription = model.custom_info?.meta.description ?? '';
+			modelIsVisionCapable = model.custom_info?.meta.vision_capable ?? false;
 		}
 	};
 
@@ -518,18 +518,16 @@
 		if (!model) {
 			return;
 		}
-		const modelSource =
-			'details' in model ? 'ollama' : model.source === 'LiteLLM' ? 'litellm' : 'openai';
 		// Remove any existing config
 		modelConfig = modelConfig.filter(
-			(m) => !(m.id === selectedModelId && m.source === modelSource)
+			(m) => !(m.id === selectedModelId)
 		);
 		// Add new config
 		modelConfig.push({
 			id: selectedModelId,
 			name: modelName,
-			source: modelSource,
-			params: {
+			params: {},
+			meta: {
 				description: modelDescription,
 				vision_capable: modelIsVisionCapable
 			}
@@ -549,10 +547,8 @@
 		if (!model) {
 			return;
 		}
-		const modelSource =
-			'details' in model ? 'ollama' : model.source === 'LiteLLM' ? 'litellm' : 'openai';
 		modelConfig = modelConfig.filter(
-			(m) => !(m.id === selectedModelId && m.source === modelSource)
+			(m) => !(m.id === selectedModelId)
 		);
 		await updateModelConfig(localStorage.token, modelConfig);
 		toast.success(