modelfiles.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. from pydantic import BaseModel
  2. from peewee import *
  3. from playhouse.shortcuts import model_to_dict
  4. from typing import List, Union, Optional
  5. import time
  6. from utils.utils import decode_token
  7. from utils.misc import get_gravatar_url
  8. from apps.web.internal.db import DB
  9. import json
  10. ####################
  11. # Modelfile DB Schema
  12. ####################
  13. class Modelfile(Model):
  14. tag_name = CharField(unique=True)
  15. user_id = CharField()
  16. modelfile = TextField()
  17. timestamp = DateField()
  18. class Meta:
  19. database = DB
  20. class ModelfileModel(BaseModel):
  21. tag_name: str
  22. user_id: str
  23. modelfile: str
  24. timestamp: int # timestamp in epoch
  25. ####################
  26. # Forms
  27. ####################
  28. class ModelfileForm(BaseModel):
  29. modelfile: dict
  30. class ModelfileTagNameForm(BaseModel):
  31. tag_name: str
  32. class ModelfileUpdateForm(ModelfileForm, ModelfileTagNameForm):
  33. pass
  34. class ModelfileResponse(BaseModel):
  35. tag_name: str
  36. user_id: str
  37. modelfile: dict
  38. timestamp: int # timestamp in epoch
  39. class ModelfilesTable:
  40. def __init__(self, db):
  41. self.db = db
  42. self.db.create_tables([Modelfile])
  43. def insert_new_modelfile(
  44. self, user_id: str,
  45. form_data: ModelfileForm) -> Optional[ModelfileModel]:
  46. if "tagName" in form_data.modelfile:
  47. modelfile = ModelfileModel(
  48. **{
  49. "user_id": user_id,
  50. "tag_name": form_data.modelfile["tagName"],
  51. "modelfile": json.dumps(form_data.modelfile),
  52. "timestamp": int(time.time()),
  53. })
  54. try:
  55. result = Modelfile.create(**modelfile.model_dump())
  56. if result:
  57. return modelfile
  58. else:
  59. return None
  60. except:
  61. return None
  62. else:
  63. return None
  64. def get_modelfile_by_tag_name(self,
  65. tag_name: str) -> Optional[ModelfileModel]:
  66. try:
  67. modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
  68. return ModelfileModel(**model_to_dict(modelfile))
  69. except:
  70. return None
  71. def get_modelfiles(self,
  72. skip: int = 0,
  73. limit: int = 50) -> List[ModelfileResponse]:
  74. return [
  75. ModelfileResponse(
  76. **{
  77. **model_to_dict(modelfile),
  78. "modelfile":
  79. json.loads(modelfile.modelfile),
  80. }) for modelfile in Modelfile.select()
  81. # .limit(limit).offset(skip)
  82. ]
  83. def update_modelfile_by_tag_name(
  84. self, tag_name: str, modelfile: dict) -> Optional[ModelfileModel]:
  85. try:
  86. query = Modelfile.update(
  87. modelfile=json.dumps(modelfile),
  88. timestamp=int(time.time()),
  89. ).where(Modelfile.tag_name == tag_name)
  90. query.execute()
  91. modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
  92. return ModelfileModel(**model_to_dict(modelfile))
  93. except:
  94. return None
  95. def delete_modelfile_by_tag_name(self, tag_name: str) -> bool:
  96. try:
  97. query = Modelfile.delete().where((Modelfile.tag_name == tag_name))
  98. query.execute() # Remove the rows, return number of rows removed.
  99. return True
  100. except:
  101. return False
  102. Modelfiles = ModelfilesTable(DB)