prompts.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from pydantic import BaseModel, ConfigDict
  2. from typing import List, Optional
  3. import time
  4. from sqlalchemy import String, Column, BigInteger, Text
  5. from apps.webui.internal.db import Base, get_db
  6. import json
  7. ####################
  8. # Prompts DB Schema
  9. ####################
  10. class Prompt(Base):
  11. __tablename__ = "prompt"
  12. command = Column(String, primary_key=True)
  13. user_id = Column(String)
  14. title = Column(Text)
  15. content = Column(Text)
  16. timestamp = Column(BigInteger)
  17. class PromptModel(BaseModel):
  18. command: str
  19. user_id: str
  20. title: str
  21. content: str
  22. timestamp: int # timestamp in epoch
  23. model_config = ConfigDict(from_attributes=True)
  24. ####################
  25. # Forms
  26. ####################
  27. class PromptForm(BaseModel):
  28. command: str
  29. title: str
  30. content: str
  31. class PromptsTable:
  32. def insert_new_prompt(
  33. self, user_id: str, form_data: PromptForm
  34. ) -> Optional[PromptModel]:
  35. prompt = PromptModel(
  36. **{
  37. "user_id": user_id,
  38. "command": form_data.command,
  39. "title": form_data.title,
  40. "content": form_data.content,
  41. "timestamp": int(time.time()),
  42. }
  43. )
  44. try:
  45. with get_db() as db:
  46. result = Prompt(**prompt.dict())
  47. db.add(result)
  48. db.commit()
  49. db.refresh(result)
  50. if result:
  51. return PromptModel.model_validate(result)
  52. else:
  53. return None
  54. except Exception as e:
  55. return None
  56. def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
  57. try:
  58. with get_db() as db:
  59. prompt = db.query(Prompt).filter_by(command=command).first()
  60. return PromptModel.model_validate(prompt)
  61. except:
  62. return None
  63. def get_prompts(self) -> List[PromptModel]:
  64. with get_db() as db:
  65. return [
  66. PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
  67. ]
  68. def update_prompt_by_command(
  69. self, command: str, form_data: PromptForm
  70. ) -> Optional[PromptModel]:
  71. try:
  72. with get_db() as db:
  73. prompt = db.query(Prompt).filter_by(command=command).first()
  74. prompt.title = form_data.title
  75. prompt.content = form_data.content
  76. prompt.timestamp = int(time.time())
  77. db.commit()
  78. return PromptModel.model_validate(prompt)
  79. except:
  80. return None
  81. def delete_prompt_by_command(self, command: str) -> bool:
  82. try:
  83. with get_db() as db:
  84. db.query(Prompt).filter_by(command=command).delete()
  85. db.commit()
  86. return True
  87. except:
  88. return False
  89. Prompts = PromptsTable()