prompts.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from pydantic import BaseModel, ConfigDict
  2. from typing import List, Optional
  3. import time
  4. from sqlalchemy import String, Column, BigInteger
  5. from sqlalchemy.orm import Session
  6. from apps.webui.internal.db import Base
  7. import json
  8. ####################
  9. # Prompts DB Schema
  10. ####################
  11. class Prompt(Base):
  12. __tablename__ = "prompt"
  13. command = Column(String, primary_key=True)
  14. user_id = Column(String)
  15. title = Column(String)
  16. content = Column(String)
  17. timestamp = Column(BigInteger)
  18. class PromptModel(BaseModel):
  19. command: str
  20. user_id: str
  21. title: str
  22. content: str
  23. timestamp: int # timestamp in epoch
  24. model_config = ConfigDict(from_attributes=True)
  25. ####################
  26. # Forms
  27. ####################
  28. class PromptForm(BaseModel):
  29. command: str
  30. title: str
  31. content: str
  32. class PromptsTable:
  33. def insert_new_prompt(
  34. self, db: Session, user_id: str, form_data: PromptForm
  35. ) -> Optional[PromptModel]:
  36. prompt = PromptModel(
  37. **{
  38. "user_id": user_id,
  39. "command": form_data.command,
  40. "title": form_data.title,
  41. "content": form_data.content,
  42. "timestamp": int(time.time()),
  43. }
  44. )
  45. try:
  46. result = Prompt(**prompt.dict())
  47. db.add(result)
  48. db.commit()
  49. db.refresh(result)
  50. if result:
  51. return PromptModel.model_validate(result)
  52. else:
  53. return None
  54. except Exception as e:
  55. return None
  56. def get_prompt_by_command(self, db: Session, command: str) -> Optional[PromptModel]:
  57. try:
  58. prompt = db.query(Prompt).filter_by(command=command).first()
  59. return PromptModel.model_validate(prompt)
  60. except:
  61. return None
  62. def get_prompts(self, db: Session) -> List[PromptModel]:
  63. return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()]
  64. def update_prompt_by_command(
  65. self, db: Session, command: str, form_data: PromptForm
  66. ) -> Optional[PromptModel]:
  67. try:
  68. db.query(Prompt).filter_by(command=command).update(
  69. {
  70. "title": form_data.title,
  71. "content": form_data.content,
  72. "timestamp": int(time.time()),
  73. }
  74. )
  75. return self.get_prompt_by_command(db, command)
  76. except:
  77. return None
  78. def delete_prompt_by_command(self, db: Session, command: str) -> bool:
  79. try:
  80. db.query(Prompt).filter_by(command=command).delete()
  81. return True
  82. except:
  83. return False
  84. Prompts = PromptsTable()