prompts.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from pydantic import BaseModel, ConfigDict
  2. from typing import List, Optional
  3. import time
  4. from sqlalchemy import String, Column, BigInteger
  5. from sqlalchemy.orm import Session
  6. from apps.webui.internal.db import Base, get_session
  7. import json
  8. ####################
  9. # Prompts DB Schema
  10. ####################
  11. class Prompt(Base):
  12. __tablename__ = "prompt"
  13. command = Column(String, primary_key=True)
  14. user_id = Column(String)
  15. title = Column(String)
  16. content = Column(String)
  17. timestamp = Column(BigInteger)
  18. class PromptModel(BaseModel):
  19. command: str
  20. user_id: str
  21. title: str
  22. content: str
  23. timestamp: int # timestamp in epoch
  24. model_config = ConfigDict(from_attributes=True)
  25. ####################
  26. # Forms
  27. ####################
  28. class PromptForm(BaseModel):
  29. command: str
  30. title: str
  31. content: str
  32. class PromptsTable:
  33. def insert_new_prompt(
  34. self, user_id: str, form_data: PromptForm
  35. ) -> Optional[PromptModel]:
  36. with get_session() as db:
  37. prompt = PromptModel(
  38. **{
  39. "user_id": user_id,
  40. "command": form_data.command,
  41. "title": form_data.title,
  42. "content": form_data.content,
  43. "timestamp": int(time.time()),
  44. }
  45. )
  46. try:
  47. result = Prompt(**prompt.dict())
  48. db.add(result)
  49. db.commit()
  50. db.refresh(result)
  51. if result:
  52. return PromptModel.model_validate(result)
  53. else:
  54. return None
  55. except Exception as e:
  56. return None
  57. def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
  58. with get_session() as db:
  59. try:
  60. prompt = db.query(Prompt).filter_by(command=command).first()
  61. return PromptModel.model_validate(prompt)
  62. except:
  63. return None
  64. def get_prompts(self) -> List[PromptModel]:
  65. with get_session() as db:
  66. return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()]
  67. def update_prompt_by_command(
  68. self, command: str, form_data: PromptForm
  69. ) -> Optional[PromptModel]:
  70. with get_session() as db:
  71. try:
  72. prompt = db.query(Prompt).filter_by(command=command).first()
  73. prompt.title = form_data.title
  74. prompt.content = form_data.content
  75. prompt.timestamp = int(time.time())
  76. db.commit()
  77. return prompt
  78. # return self.get_prompt_by_command(command)
  79. except:
  80. return None
  81. def delete_prompt_by_command(self, command: str) -> bool:
  82. with get_session() as db:
  83. try:
  84. db.query(Prompt).filter_by(command=command).delete()
  85. return True
  86. except:
  87. return False
  88. Prompts = PromptsTable()