modelfiles.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from pydantic import BaseModel
  2. from peewee import *
  3. from playhouse.shortcuts import model_to_dict
  4. from typing import List, Union, Optional
  5. import time
  6. from utils.utils import decode_token
  7. from utils.misc import get_gravatar_url
  8. from apps.web.internal.db import DB
  9. import json
  10. ####################
  11. # User DB Schema
  12. ####################
  13. class Modelfile(Model):
  14. tag_name = CharField(unique=True)
  15. user_id = CharField()
  16. modelfile = TextField()
  17. timestamp = DateField()
  18. class Meta:
  19. database = DB
  20. class ModelfileModel(BaseModel):
  21. tag_name: str
  22. user_id: str
  23. modelfile: str
  24. timestamp: int # timestamp in epoch
  25. ####################
  26. # Forms
  27. ####################
  28. class ModelfileForm(BaseModel):
  29. modelfile: dict
  30. class ModelfileResponse(BaseModel):
  31. tag_name: str
  32. user_id: str
  33. modelfile: dict
  34. timestamp: int # timestamp in epoch
  35. class ModelfilesTable:
  36. def __init__(self, db):
  37. self.db = db
  38. self.db.create_tables([Modelfile])
  39. def insert_new_modelfile(
  40. self, user_id: str, form_data: ModelfileForm
  41. ) -> Optional[ModelfileModel]:
  42. if "title" in form_data.modelfile:
  43. modelfile = ModelfileModel(
  44. **{
  45. "user_id": user_id,
  46. "tag_name": form_data.modelfile["title"],
  47. "modelfile": json.dumps(form_data.modelfile),
  48. "timestamp": int(time.time()),
  49. }
  50. )
  51. result = Modelfile.create(**modelfile.model_dump())
  52. if result:
  53. return modelfile
  54. else:
  55. return None
  56. else:
  57. return None
  58. def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]:
  59. try:
  60. modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
  61. return ModelfileModel(**model_to_dict(modelfile))
  62. except:
  63. return None
  64. def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]:
  65. return [
  66. ModelfileResponse(
  67. **{
  68. **model_to_dict(modelfile),
  69. "modelfile": json.loads(modelfile.modelfile),
  70. }
  71. )
  72. for modelfile in Modelfile.select()
  73. # .limit(limit).offset(skip)
  74. ]
  75. def update_modelfile_by_tag_name(
  76. self, tag_name: str, modelfile: dict
  77. ) -> Optional[ModelfileModel]:
  78. try:
  79. query = Modelfile.update(
  80. modelfile=json.dumps(modelfile),
  81. timestamp=int(time.time()),
  82. ).where(Modelfile.tag_name == tag_name)
  83. query.execute()
  84. modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
  85. return ModelfileModel(**model_to_dict(modelfile))
  86. except:
  87. return None
  88. def delete_modelfile_by_tag_name(self, tag_name: str) -> bool:
  89. try:
  90. query = Modelfile.delete().where((Modelfile.tag_name == tag_name))
  91. query.execute() # Remove the rows, return number of rows removed.
  92. return True
  93. except:
  94. return False
  95. Modelfiles = ModelfilesTable(DB)