models.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import json
  2. import logging
  3. from typing import Optional
  4. from pydantic import BaseModel, ConfigDict
  5. from sqlalchemy import String, Column, BigInteger, Text
  6. from apps.webui.internal.db import Base, JSONField, get_db
  7. from typing import List, Union, Optional
  8. from config import SRC_LOG_LEVELS
  9. import time
  10. log = logging.getLogger(__name__)
  11. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  12. ####################
  13. # Models DB Schema
  14. ####################
  15. # ModelParams is a model for the data stored in the params field of the Model table
  16. class ModelParams(BaseModel):
  17. model_config = ConfigDict(extra="allow")
  18. pass
  19. # ModelMeta is a model for the data stored in the meta field of the Model table
  20. class ModelMeta(BaseModel):
  21. profile_image_url: Optional[str] = "/static/favicon.png"
  22. description: Optional[str] = None
  23. """
  24. User-facing description of the model.
  25. """
  26. capabilities: Optional[dict] = None
  27. model_config = ConfigDict(extra="allow")
  28. pass
  29. class Model(Base):
  30. __tablename__ = "model"
  31. id = Column(Text, primary_key=True)
  32. """
  33. The model's id as used in the API. If set to an existing model, it will override the model.
  34. """
  35. user_id = Column(Text)
  36. base_model_id = Column(Text, nullable=True)
  37. """
  38. An optional pointer to the actual model that should be used when proxying requests.
  39. """
  40. name = Column(Text)
  41. """
  42. The human-readable display name of the model.
  43. """
  44. params = Column(JSONField)
  45. """
  46. Holds a JSON encoded blob of parameters, see `ModelParams`.
  47. """
  48. meta = Column(JSONField)
  49. """
  50. Holds a JSON encoded blob of metadata, see `ModelMeta`.
  51. """
  52. updated_at = Column(BigInteger)
  53. created_at = Column(BigInteger)
  54. class ModelModel(BaseModel):
  55. id: str
  56. user_id: str
  57. base_model_id: Optional[str] = None
  58. name: str
  59. params: ModelParams
  60. meta: ModelMeta
  61. updated_at: int # timestamp in epoch
  62. created_at: int # timestamp in epoch
  63. model_config = ConfigDict(from_attributes=True)
  64. ####################
  65. # Forms
  66. ####################
  67. class ModelResponse(BaseModel):
  68. id: str
  69. name: str
  70. meta: ModelMeta
  71. updated_at: int # timestamp in epoch
  72. created_at: int # timestamp in epoch
  73. class ModelForm(BaseModel):
  74. id: str
  75. base_model_id: Optional[str] = None
  76. name: str
  77. meta: ModelMeta
  78. params: ModelParams
  79. class ModelsTable:
  80. def insert_new_model(
  81. self, form_data: ModelForm, user_id: str
  82. ) -> Optional[ModelModel]:
  83. model = ModelModel(
  84. **{
  85. **form_data.model_dump(),
  86. "user_id": user_id,
  87. "created_at": int(time.time()),
  88. "updated_at": int(time.time()),
  89. }
  90. )
  91. try:
  92. with get_db() as db:
  93. result = Model(**model.model_dump())
  94. db.add(result)
  95. db.commit()
  96. db.refresh(result)
  97. if result:
  98. return ModelModel.model_validate(result)
  99. else:
  100. return None
  101. except Exception as e:
  102. print(e)
  103. return None
  104. def get_all_models(self) -> List[ModelModel]:
  105. with get_db() as db:
  106. return [ModelModel.model_validate(model) for model in db.query(Model).all()]
  107. def get_model_by_id(self, id: str) -> Optional[ModelModel]:
  108. try:
  109. with get_db() as db:
  110. model = db.get(Model, id)
  111. return ModelModel.model_validate(model)
  112. except:
  113. return None
  114. def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
  115. try:
  116. with get_db() as db:
  117. # update only the fields that are present in the model
  118. result = (
  119. db.query(Model)
  120. .filter_by(id=id)
  121. .update(model.model_dump(exclude={"id"}, exclude_none=True))
  122. )
  123. db.commit()
  124. model = db.get(Model, id)
  125. db.refresh(model)
  126. return ModelModel.model_validate(model)
  127. except Exception as e:
  128. print(e)
  129. return None
  130. def delete_model_by_id(self, id: str) -> bool:
  131. try:
  132. with get_db() as db:
  133. db.query(Model).filter_by(id=id).delete()
  134. db.commit()
  135. return True
  136. except:
  137. return False
  138. Models = ModelsTable()