models.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import json
  2. import logging
  3. from typing import Optional
  4. from pydantic import BaseModel, ConfigDict
  5. from sqlalchemy import String, Column, BigInteger
  6. from sqlalchemy.orm import Session
  7. from apps.webui.internal.db import Base, JSONField
  8. from typing import List, Union, Optional
  9. from config import SRC_LOG_LEVELS
  10. import time
  11. log = logging.getLogger(__name__)
  12. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  13. ####################
  14. # Models DB Schema
  15. ####################
  16. # ModelParams is a model for the data stored in the params field of the Model table
  17. class ModelParams(BaseModel):
  18. model_config = ConfigDict(extra="allow")
  19. pass
  20. # ModelMeta is a model for the data stored in the meta field of the Model table
  21. class ModelMeta(BaseModel):
  22. profile_image_url: Optional[str] = "/favicon.png"
  23. description: Optional[str] = None
  24. """
  25. User-facing description of the model.
  26. """
  27. capabilities: Optional[dict] = None
  28. model_config = ConfigDict(extra="allow")
  29. pass
  30. class Model(Base):
  31. __tablename__ = "model"
  32. id = Column(String, primary_key=True)
  33. """
  34. The model's id as used in the API. If set to an existing model, it will override the model.
  35. """
  36. user_id = Column(String)
  37. base_model_id = Column(String, nullable=True)
  38. """
  39. An optional pointer to the actual model that should be used when proxying requests.
  40. """
  41. name = Column(String)
  42. """
  43. The human-readable display name of the model.
  44. """
  45. params = Column(JSONField)
  46. """
  47. Holds a JSON encoded blob of parameters, see `ModelParams`.
  48. """
  49. meta = Column(JSONField)
  50. """
  51. Holds a JSON encoded blob of metadata, see `ModelMeta`.
  52. """
  53. updated_at = Column(BigInteger)
  54. created_at = Column(BigInteger)
  55. class ModelModel(BaseModel):
  56. model_config = ConfigDict(from_attributes=True)
  57. id: str
  58. user_id: str
  59. base_model_id: Optional[str] = None
  60. name: str
  61. params: ModelParams
  62. meta: ModelMeta
  63. updated_at: int # timestamp in epoch
  64. created_at: int # timestamp in epoch
  65. ####################
  66. # Forms
  67. ####################
  68. class ModelResponse(BaseModel):
  69. id: str
  70. name: str
  71. meta: ModelMeta
  72. updated_at: int # timestamp in epoch
  73. created_at: int # timestamp in epoch
  74. class ModelForm(BaseModel):
  75. id: str
  76. base_model_id: Optional[str] = None
  77. name: str
  78. meta: ModelMeta
  79. params: ModelParams
  80. class ModelsTable:
  81. def insert_new_model(
  82. self, db: Session, form_data: ModelForm, user_id: str
  83. ) -> Optional[ModelModel]:
  84. model = ModelModel(
  85. **{
  86. **form_data.model_dump(),
  87. "user_id": user_id,
  88. "created_at": int(time.time()),
  89. "updated_at": int(time.time()),
  90. }
  91. )
  92. try:
  93. result = Model(**model.dict())
  94. db.add(result)
  95. db.commit()
  96. db.refresh(result)
  97. if result:
  98. return ModelModel.model_validate(result)
  99. else:
  100. return None
  101. except Exception as e:
  102. print(e)
  103. return None
  104. def get_all_models(self, db: Session) -> List[ModelModel]:
  105. return [ModelModel.model_validate(model) for model in db.query(Model).all()]
  106. def get_model_by_id(self, db: Session, id: str) -> Optional[ModelModel]:
  107. try:
  108. model = db.get(Model, id)
  109. return ModelModel.model_validate(model)
  110. except:
  111. return None
  112. def update_model_by_id(
  113. self, db: Session, id: str, model: ModelForm
  114. ) -> Optional[ModelModel]:
  115. try:
  116. # update only the fields that are present in the model
  117. model = db.query(Model).get(id)
  118. model.update(**model.model_dump())
  119. db.commit()
  120. db.refresh(model)
  121. return ModelModel.model_validate(model)
  122. except Exception as e:
  123. print(e)
  124. return None
  125. def delete_model_by_id(self, db: Session, id: str) -> bool:
  126. try:
  127. db.query(Model).filter_by(id=id).delete()
  128. return True
  129. except:
  130. return False
  131. Models = ModelsTable()