models.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import json
  2. import logging
  3. from typing import Optional
  4. import peewee as pw
  5. from playhouse.shortcuts import model_to_dict
  6. from pydantic import BaseModel, ConfigDict
  7. from apps.web.internal.db import DB, JSONField
  8. from config import SRC_LOG_LEVELS
  9. log = logging.getLogger(__name__)
  10. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  11. ####################
  12. # Models DB Schema
  13. ####################
  14. # ModelParams is a model for the data stored in the params field of the Model table
  15. # It isn't currently used in the backend, but it's here as a reference
  16. class ModelParams(BaseModel):
  17. model_config = ConfigDict(extra="allow")
  18. pass
  19. # ModelMeta is a model for the data stored in the meta field of the Model table
  20. # It isn't currently used in the backend, but it's here as a reference
  21. class ModelMeta(BaseModel):
  22. description: Optional[str] = None
  23. """
  24. User-facing description of the model.
  25. """
  26. vision_capable: Optional[bool] = None
  27. """
  28. A flag indicating if the model is capable of vision and thus image inputs
  29. """
  30. model_config = ConfigDict(extra="allow")
  31. pass
  32. class Model(pw.Model):
  33. id = pw.TextField(unique=True)
  34. """
  35. The model's id as used in the API. If set to an existing model, it will override the model.
  36. """
  37. user_id = pw.TextField()
  38. base_model_id = pw.TextField(null=True)
  39. """
  40. An optional pointer to the actual model that should be used when proxying requests.
  41. Currently unused - but will be used to support Modelfile like behaviour in the future
  42. """
  43. name = pw.TextField()
  44. """
  45. The human-readable display name of the model.
  46. """
  47. params = JSONField()
  48. """
  49. Holds a JSON encoded blob of parameters, see `ModelParams`.
  50. """
  51. meta = JSONField()
  52. """
  53. Holds a JSON encoded blob of metadata, see `ModelMeta`.
  54. """
  55. updated_at: int # timestamp in epoch
  56. created_at: int # timestamp in epoch
  57. class Meta:
  58. database = DB
  59. class ModelModel(BaseModel):
  60. id: str
  61. base_model_id: Optional[str] = None
  62. name: str
  63. params: ModelParams
  64. meta: ModelMeta
  65. ####################
  66. # Forms
  67. ####################
  68. class ModelsTable:
  69. def __init__(
  70. self,
  71. db: pw.SqliteDatabase | pw.PostgresqlDatabase,
  72. ):
  73. self.db = db
  74. self.db.create_tables([Model])
  75. def get_all_models(self) -> list[ModelModel]:
  76. return [ModelModel(**model_to_dict(model)) for model in Model.select()]
  77. def update_all_models(self, models: list[ModelModel]) -> bool:
  78. try:
  79. with self.db.atomic():
  80. # Fetch current models from the database
  81. current_models = self.get_all_models()
  82. current_model_dict = {model.id: model for model in current_models}
  83. # Create a set of model IDs from the current models and the new models
  84. current_model_keys = set(current_model_dict.keys())
  85. new_model_keys = set(model.id for model in models)
  86. # Determine which models need to be created, updated, or deleted
  87. models_to_create = [
  88. model for model in models if model.id not in current_model_keys
  89. ]
  90. models_to_update = [
  91. model for model in models if model.id in current_model_keys
  92. ]
  93. models_to_delete = current_model_keys - new_model_keys
  94. # Perform the necessary database operations
  95. for model in models_to_create:
  96. Model.create(**model.model_dump())
  97. for model in models_to_update:
  98. Model.update(**model.model_dump()).where(
  99. Model.id == model.id
  100. ).execute()
  101. for model_id, model_source in models_to_delete:
  102. Model.delete().where(Model.id == model_id).execute()
  103. return True
  104. except Exception as e:
  105. log.exception(e)
  106. return False
  107. Models = ModelsTable(DB)