models.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import json
  2. import logging
  3. from typing import Optional
  4. from pydantic import BaseModel, ConfigDict
  5. from sqlalchemy import String, Column, BigInteger
  6. from sqlalchemy.orm import Session
  7. from apps.webui.internal.db import Base, JSONField, get_session
  8. from typing import List, Union, Optional
  9. from config import SRC_LOG_LEVELS
  10. import time
  11. log = logging.getLogger(__name__)
  12. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  13. ####################
  14. # Models DB Schema
  15. ####################
  16. # ModelParams is a model for the data stored in the params field of the Model table
  17. class ModelParams(BaseModel):
  18. model_config = ConfigDict(extra="allow")
  19. pass
  20. # ModelMeta is a model for the data stored in the meta field of the Model table
  21. class ModelMeta(BaseModel):
  22. profile_image_url: Optional[str] = "/favicon.png"
  23. description: Optional[str] = None
  24. """
  25. User-facing description of the model.
  26. """
  27. capabilities: Optional[dict] = None
  28. model_config = ConfigDict(extra="allow")
  29. pass
  30. class Model(Base):
  31. __tablename__ = "model"
  32. id = Column(String, primary_key=True)
  33. """
  34. The model's id as used in the API. If set to an existing model, it will override the model.
  35. """
  36. user_id = Column(String)
  37. base_model_id = Column(String, nullable=True)
  38. """
  39. An optional pointer to the actual model that should be used when proxying requests.
  40. """
  41. name = Column(String)
  42. """
  43. The human-readable display name of the model.
  44. """
  45. params = Column(JSONField)
  46. """
  47. Holds a JSON encoded blob of parameters, see `ModelParams`.
  48. """
  49. meta = Column(JSONField)
  50. """
  51. Holds a JSON encoded blob of metadata, see `ModelMeta`.
  52. """
  53. updated_at = Column(BigInteger)
  54. created_at = Column(BigInteger)
  55. class ModelModel(BaseModel):
  56. id: str
  57. user_id: str
  58. base_model_id: Optional[str] = None
  59. name: str
  60. params: ModelParams
  61. meta: ModelMeta
  62. updated_at: int # timestamp in epoch
  63. created_at: int # timestamp in epoch
  64. model_config = ConfigDict(from_attributes=True)
  65. ####################
  66. # Forms
  67. ####################
  68. class ModelResponse(BaseModel):
  69. id: str
  70. name: str
  71. meta: ModelMeta
  72. updated_at: int # timestamp in epoch
  73. created_at: int # timestamp in epoch
  74. class ModelForm(BaseModel):
  75. id: str
  76. base_model_id: Optional[str] = None
  77. name: str
  78. meta: ModelMeta
  79. params: ModelParams
  80. class ModelsTable:
  81. def insert_new_model(
  82. self, form_data: ModelForm, user_id: str
  83. ) -> Optional[ModelModel]:
  84. model = ModelModel(
  85. **{
  86. **form_data.model_dump(),
  87. "user_id": user_id,
  88. "created_at": int(time.time()),
  89. "updated_at": int(time.time()),
  90. }
  91. )
  92. try:
  93. with get_session() as db:
  94. result = Model(**model.model_dump())
  95. db.add(result)
  96. db.commit()
  97. db.refresh(result)
  98. if result:
  99. return ModelModel.model_validate(result)
  100. else:
  101. return None
  102. except Exception as e:
  103. print(e)
  104. return None
  105. def get_all_models(self) -> List[ModelModel]:
  106. with get_session() as db:
  107. return [ModelModel.model_validate(model) for model in db.query(Model).all()]
  108. def get_model_by_id(self, id: str) -> Optional[ModelModel]:
  109. try:
  110. with get_session() as db:
  111. model = db.get(Model, id)
  112. return ModelModel.model_validate(model)
  113. except:
  114. return None
  115. def update_model_by_id(
  116. self, id: str, model: ModelForm
  117. ) -> Optional[ModelModel]:
  118. try:
  119. # update only the fields that are present in the model
  120. with get_session() as db:
  121. model = db.query(Model).get(id)
  122. model.update(**model.model_dump())
  123. db.commit()
  124. db.refresh(model)
  125. return ModelModel.model_validate(model)
  126. except Exception as e:
  127. print(e)
  128. return None
  129. def delete_model_by_id(self, id: str) -> bool:
  130. try:
  131. with get_session() as db:
  132. db.query(Model).filter_by(id=id).delete()
  133. return True
  134. except:
  135. return False
  136. Models = ModelsTable()