models.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import json
  2. import logging
  3. from typing import Optional
  4. import peewee as pw
  5. from playhouse.shortcuts import model_to_dict
  6. from pydantic import BaseModel
  7. from apps.web.internal.db import DB
  8. from config import SRC_LOG_LEVELS
  9. log = logging.getLogger(__name__)
  10. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  11. ####################
  12. # Models DB Schema
  13. ####################
  14. # ModelParams is a model for the data stored in the params field of the Model table
  15. # It isn't currently used in the backend, but it's here as a reference
  16. class ModelParams(BaseModel):
  17. description: str
  18. """
  19. User-facing description of the model.
  20. """
  21. vision_capable: bool
  22. """
  23. A flag indicating if the model is capable of vision and thus image inputs
  24. """
  25. class Model(pw.Model):
  26. id = pw.TextField()
  27. """
  28. The model's id as used in the API. If set to an existing model, it will override the model.
  29. """
  30. source = pw.TextField()
  31. """
  32. The source of the model, e.g., ollama, openai, or litellm.
  33. """
  34. base_model = pw.TextField(null=True)
  35. """
  36. An optional pointer to the actual model that should be used when proxying requests.
  37. Currently unused - but will be used to support Modelfile like behaviour in the future
  38. """
  39. name = pw.TextField()
  40. """
  41. The human-readable display name of the model.
  42. """
  43. params = pw.TextField()
  44. """
  45. Holds a JSON encoded blob of parameters, see `ModelParams`.
  46. """
  47. class Meta:
  48. database = DB
  49. indexes = (
  50. # Create a unique index on the id, source columns
  51. (("id", "source"), True),
  52. )
  53. class ModelModel(BaseModel):
  54. id: str
  55. source: str
  56. base_model: Optional[str] = None
  57. name: str
  58. params: str
  59. def to_form(self) -> "ModelForm":
  60. return ModelForm(**{**self.model_dump(), "params": json.loads(self.params)})
  61. ####################
  62. # Forms
  63. ####################
  64. class ModelForm(BaseModel):
  65. id: str
  66. source: str
  67. base_model: Optional[str] = None
  68. name: str
  69. params: dict
  70. def to_db_model(self) -> ModelModel:
  71. return ModelModel(**{**self.model_dump(), "params": json.dumps(self.params)})
  72. class ModelsTable:
  73. def __init__(
  74. self,
  75. db: pw.SqliteDatabase | pw.PostgresqlDatabase,
  76. ):
  77. self.db = db
  78. self.db.create_tables([Model])
  79. def get_all_models(self) -> list[ModelModel]:
  80. return [ModelModel(**model_to_dict(model)) for model in Model.select()]
  81. def get_all_models_by_source(self, source: str) -> list[ModelModel]:
  82. return [
  83. ModelModel(**model_to_dict(model))
  84. for model in Model.select().where(Model.source == source)
  85. ]
  86. def update_all_models(self, models: list[ModelForm]) -> bool:
  87. try:
  88. with self.db.atomic():
  89. # Fetch current models from the database
  90. current_models = self.get_all_models()
  91. current_model_dict = {
  92. (model.id, model.source): model for model in current_models
  93. }
  94. # Create a set of model IDs and sources from the current models and the new models
  95. current_model_keys = set(current_model_dict.keys())
  96. new_model_keys = set((model.id, model.source) for model in models)
  97. # Determine which models need to be created, updated, or deleted
  98. models_to_create = [
  99. model
  100. for model in models
  101. if (model.id, model.source) not in current_model_keys
  102. ]
  103. models_to_update = [
  104. model
  105. for model in models
  106. if (model.id, model.source) in current_model_keys
  107. ]
  108. models_to_delete = current_model_keys - new_model_keys
  109. # Perform the necessary database operations
  110. for model in models_to_create:
  111. Model.create(**model.to_db_model().model_dump())
  112. for model in models_to_update:
  113. Model.update(**model.to_db_model().model_dump()).where(
  114. (Model.id == model.id) & (Model.source == model.source)
  115. ).execute()
  116. for model_id, model_source in models_to_delete:
  117. Model.delete().where(
  118. (Model.id == model_id) & (Model.source == model_source)
  119. ).execute()
  120. return True
  121. except Exception as e:
  122. log.exception(e)
  123. return False
  124. Models = ModelsTable(DB)