123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- import json
- import logging
- from typing import Optional
- import peewee as pw
- from playhouse.shortcuts import model_to_dict
- from pydantic import BaseModel
- from apps.web.internal.db import DB
- from config import SRC_LOG_LEVELS
- log = logging.getLogger(__name__)
- log.setLevel(SRC_LOG_LEVELS["MODELS"])
- ####################
- # Models DB Schema
- ####################
- # ModelParams is a model for the data stored in the params field of the Model table
- # It isn't currently used in the backend, but it's here as a reference
- class ModelParams(BaseModel):
- description: str
- """
- User-facing description of the model.
- """
- vision_capable: bool
- """
- A flag indicating if the model is capable of vision and thus image inputs
- """
- class Model(pw.Model):
- id = pw.TextField()
- """
- The model's id as used in the API. If set to an existing model, it will override the model.
- """
- source = pw.TextField()
- """
- The source of the model, e.g., ollama, openai, or litellm.
- """
- base_model = pw.TextField(null=True)
- """
- An optional pointer to the actual model that should be used when proxying requests.
- Currently unused - but will be used to support Modelfile like behaviour in the future
- """
- name = pw.TextField()
- """
- The human-readable display name of the model.
- """
- params = pw.TextField()
- """
- Holds a JSON encoded blob of parameters, see `ModelParams`.
- """
- class Meta:
- database = DB
- indexes = (
- # Create a unique index on the id, source columns
- (("id", "source"), True),
- )
- class ModelModel(BaseModel):
- id: str
- source: str
- base_model: Optional[str] = None
- name: str
- params: str
- def to_form(self) -> "ModelForm":
- return ModelForm(**{**self.model_dump(), "params": json.loads(self.params)})
- ####################
- # Forms
- ####################
- class ModelForm(BaseModel):
- id: str
- source: str
- base_model: Optional[str] = None
- name: str
- params: dict
- def to_db_model(self) -> ModelModel:
- return ModelModel(**{**self.model_dump(), "params": json.dumps(self.params)})
- class ModelsTable:
- def __init__(
- self,
- db: pw.SqliteDatabase | pw.PostgresqlDatabase,
- ):
- self.db = db
- self.db.create_tables([Model])
- def get_all_models(self) -> list[ModelModel]:
- return [ModelModel(**model_to_dict(model)) for model in Model.select()]
- def get_all_models_by_source(self, source: str) -> list[ModelModel]:
- return [
- ModelModel(**model_to_dict(model))
- for model in Model.select().where(Model.source == source)
- ]
- def update_all_models(self, models: list[ModelForm]) -> bool:
- try:
- with self.db.atomic():
- # Fetch current models from the database
- current_models = self.get_all_models()
- current_model_dict = {
- (model.id, model.source): model for model in current_models
- }
- # Create a set of model IDs and sources from the current models and the new models
- current_model_keys = set(current_model_dict.keys())
- new_model_keys = set((model.id, model.source) for model in models)
- # Determine which models need to be created, updated, or deleted
- models_to_create = [
- model
- for model in models
- if (model.id, model.source) not in current_model_keys
- ]
- models_to_update = [
- model
- for model in models
- if (model.id, model.source) in current_model_keys
- ]
- models_to_delete = current_model_keys - new_model_keys
- # Perform the necessary database operations
- for model in models_to_create:
- Model.create(**model.to_db_model().model_dump())
- for model in models_to_update:
- Model.update(**model.to_db_model().model_dump()).where(
- (Model.id == model.id) & (Model.source == model.source)
- ).execute()
- for model_id, model_source in models_to_delete:
- Model.delete().where(
- (Model.id == model_id) & (Model.source == model_source)
- ).execute()
- return True
- except Exception as e:
- log.exception(e)
- return False
- Models = ModelsTable(DB)
|