瀏覽代碼

feat: add support for using postgres for the backend DB

Jun Siang Cheah 1 年之前
父節點
當前提交
e91a49c455

+ 12 - 1
backend/apps/litellm/main.py

@@ -21,6 +21,8 @@ from utils.utils import get_verified_user, get_current_user, get_admin_user
 from config import SRC_LOG_LEVELS, ENV
 from constants import MESSAGES
 
+import os
+
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["LITELLM"])
 
@@ -62,6 +64,13 @@ app.state.CONFIG = litellm_config
 # Global variable to store the subprocess reference
 background_process = None
 
+CONFLICT_ENV_VARS = [
+    # Uvicorn uses PORT, so LiteLLM might use it as well
+    "PORT",
+    # LiteLLM uses DATABASE_URL for Prisma connections
+    "DATABASE_URL",
+]
+
 
 async def run_background_process(command):
     global background_process
@@ -70,9 +79,11 @@ async def run_background_process(command):
     try:
         # Log the command to be executed
         log.info(f"Executing command: {command}")
+        # Filter environment variables known to conflict with litellm
+        env = {k: v for k, v in os.environ.items() if k not in CONFLICT_ENV_VARS}
         # Execute the command and create a subprocess
         process = await asyncio.create_subprocess_exec(
-            *command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
+            *command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
         )
         background_process = process
         log.info("Subprocess started successfully.")

+ 5 - 4
backend/apps/web/internal/db.py

@@ -1,6 +1,7 @@
 from peewee import *
 from peewee_migrate import Router
-from config import SRC_LOG_LEVELS, DATA_DIR
+from playhouse.db_url import connect
+from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL
 import os
 import logging
 
@@ -11,12 +12,12 @@ log.setLevel(SRC_LOG_LEVELS["DB"])
 if os.path.exists(f"{DATA_DIR}/ollama.db"):
     # Rename the file
     os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
-    log.info("File renamed successfully.")
+    log.info("Database migrated from Ollama-WebUI successfully.")
 else:
     pass
 
-
-DB = SqliteDatabase(f"{DATA_DIR}/webui.db")
+DB = connect(DATABASE_URL)
+log.info(f"Connected to a {DB.__class__.__name__} database.")
 router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log)
 router.run()
 DB.connect(reuse_if_open=True)

+ 105 - 0
backend/apps/web/internal/migrations/001_initial_schema.py

@@ -37,6 +37,18 @@ with suppress(ImportError):
 def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
     """Write your migrations here."""
 
+    # We perform different migrations for SQLite and other databases
+    # This is because SQLite is very loose with enforcing its schema, and trying to migrate other databases like SQLite
+    # will require per-database SQL queries.
+    # Instead, we assume that because external DB support was added at a later date, it is safe to assume a newer base
+    # schema instead of trying to migrate from an older schema.
+    if isinstance(database, pw.SqliteDatabase):
+        migrate_sqlite(migrator, database, fake=fake)
+    else:
+        migrate_external(migrator, database, fake=fake)
+
+
+def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
     @migrator.create_model
     class Auth(pw.Model):
         id = pw.CharField(max_length=255, unique=True)
@@ -129,6 +141,99 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
             table_name = "user"
 
 
+def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
+    @migrator.create_model
+    class Auth(pw.Model):
+        id = pw.CharField(max_length=255, unique=True)
+        email = pw.CharField(max_length=255)
+        password = pw.TextField()
+        active = pw.BooleanField()
+
+        class Meta:
+            table_name = "auth"
+
+    @migrator.create_model
+    class Chat(pw.Model):
+        id = pw.CharField(max_length=255, unique=True)
+        user_id = pw.CharField(max_length=255)
+        title = pw.TextField()
+        chat = pw.TextField()
+        timestamp = pw.BigIntegerField()
+
+        class Meta:
+            table_name = "chat"
+
+    @migrator.create_model
+    class ChatIdTag(pw.Model):
+        id = pw.CharField(max_length=255, unique=True)
+        tag_name = pw.CharField(max_length=255)
+        chat_id = pw.CharField(max_length=255)
+        user_id = pw.CharField(max_length=255)
+        timestamp = pw.BigIntegerField()
+
+        class Meta:
+            table_name = "chatidtag"
+
+    @migrator.create_model
+    class Document(pw.Model):
+        id = pw.AutoField()
+        collection_name = pw.CharField(max_length=255, unique=True)
+        name = pw.CharField(max_length=255, unique=True)
+        title = pw.TextField()
+        filename = pw.TextField()
+        content = pw.TextField(null=True)
+        user_id = pw.CharField(max_length=255)
+        timestamp = pw.BigIntegerField()
+
+        class Meta:
+            table_name = "document"
+
+    @migrator.create_model
+    class Modelfile(pw.Model):
+        id = pw.AutoField()
+        tag_name = pw.CharField(max_length=255, unique=True)
+        user_id = pw.CharField(max_length=255)
+        modelfile = pw.TextField()
+        timestamp = pw.BigIntegerField()
+
+        class Meta:
+            table_name = "modelfile"
+
+    @migrator.create_model
+    class Prompt(pw.Model):
+        id = pw.AutoField()
+        command = pw.CharField(max_length=255, unique=True)
+        user_id = pw.CharField(max_length=255)
+        title = pw.TextField()
+        content = pw.TextField()
+        timestamp = pw.BigIntegerField()
+
+        class Meta:
+            table_name = "prompt"
+
+    @migrator.create_model
+    class Tag(pw.Model):
+        id = pw.CharField(max_length=255, unique=True)
+        name = pw.CharField(max_length=255)
+        user_id = pw.CharField(max_length=255)
+        data = pw.TextField(null=True)
+
+        class Meta:
+            table_name = "tag"
+
+    @migrator.create_model
+    class User(pw.Model):
+        id = pw.CharField(max_length=255, unique=True)
+        name = pw.CharField(max_length=255)
+        email = pw.CharField(max_length=255)
+        role = pw.CharField(max_length=255)
+        profile_image_url = pw.TextField()
+        timestamp = pw.BigIntegerField()
+
+        class Meta:
+            table_name = "user"
+
+
 def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
     """Write your rollback migrations here."""
 

+ 53 - 0
backend/apps/web/internal/migrations/005_add_updated_at.py

@@ -37,6 +37,13 @@ with suppress(ImportError):
 def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
     """Write your migrations here."""
 
+    if isinstance(database, pw.SqliteDatabase):
+        migrate_sqlite(migrator, database, fake=fake)
+    else:
+        migrate_external(migrator, database, fake=fake)
+
+
+def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
     # Adding fields created_at and updated_at to the 'chat' table
     migrator.add_fields(
         "chat",
@@ -60,9 +67,40 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
     )
 
 
+def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
+    # Adding fields created_at and updated_at to the 'chat' table
+    migrator.add_fields(
+        "chat",
+        created_at=pw.BigIntegerField(null=True),  # Allow null for transition
+        updated_at=pw.BigIntegerField(null=True),  # Allow null for transition
+    )
+
+    # Populate the new fields from an existing 'timestamp' field
+    migrator.sql(
+        "UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
+    )
+
+    # Now that the data has been copied, remove the original 'timestamp' field
+    migrator.remove_fields("chat", "timestamp")
+
+    # Update the fields to be not null now that they are populated
+    migrator.change_fields(
+        "chat",
+        created_at=pw.BigIntegerField(null=False),
+        updated_at=pw.BigIntegerField(null=False),
+    )
+
+
 def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
     """Write your rollback migrations here."""
 
+    if isinstance(database, pw.SqliteDatabase):
+        rollback_sqlite(migrator, database, fake=fake)
+    else:
+        rollback_external(migrator, database, fake=fake)
+
+
+def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
     # Recreate the timestamp field initially allowing null values for safe transition
     migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True))
 
@@ -75,3 +113,18 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
 
     # Finally, alter the timestamp field to not allow nulls if that was the original setting
     migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False))
+
+
+def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False):
+    # Recreate the timestamp field initially allowing null values for safe transition
+    migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True))
+
+    # Copy the earliest created_at date back into the new timestamp field
+    # This assumes created_at was originally a copy of timestamp
+    migrator.sql("UPDATE chat SET timestamp = created_at")
+
+    # Remove the created_at and updated_at fields
+    migrator.remove_fields("chat", "created_at", "updated_at")
+
+    # Finally, alter the timestamp field to not allow nulls if that was the original setting
+    migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False))

+ 130 - 0
backend/apps/web/internal/migrations/006_migrate_timestamps_and_charfields.py

@@ -0,0 +1,130 @@
+"""Peewee migrations -- 006_migrate_timestamps_and_charfields.py.
+
+Some examples (model - class or model name)::
+
+    > Model = migrator.orm['table_name']            # Return model in current state by name
+    > Model = migrator.ModelClass                   # Return model in current state by name
+
+    > migrator.sql(sql)                             # Run custom SQL
+    > migrator.run(func, *args, **kwargs)           # Run python function with the given args
+    > migrator.create_model(Model)                  # Create a model (could be used as decorator)
+    > migrator.remove_model(model, cascade=True)    # Remove a model
+    > migrator.add_fields(model, **fields)          # Add fields to a model
+    > migrator.change_fields(model, **fields)       # Change fields
+    > migrator.remove_fields(model, *field_names, cascade=True)
+    > migrator.rename_field(model, old_field_name, new_field_name)
+    > migrator.rename_table(model, new_table_name)
+    > migrator.add_index(model, *col_names, unique=False)
+    > migrator.add_not_null(model, *field_names)
+    > migrator.add_default(model, field_name, default)
+    > migrator.add_constraint(model, name, sql)
+    > migrator.drop_index(model, *col_names)
+    > migrator.drop_not_null(model, *field_names)
+    > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+    import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your migrations here."""
+
+    # Alter the tables with timestamps
+    migrator.change_fields(
+        "chatidtag",
+        timestamp=pw.BigIntegerField(),
+    )
+    migrator.change_fields(
+        "document",
+        timestamp=pw.BigIntegerField(),
+    )
+    migrator.change_fields(
+        "modelfile",
+        timestamp=pw.BigIntegerField(),
+    )
+    migrator.change_fields(
+        "prompt",
+        timestamp=pw.BigIntegerField(),
+    )
+    migrator.change_fields(
+        "user",
+        timestamp=pw.BigIntegerField(),
+    )
+    # Alter the tables with varchar to text where necessary
+    migrator.change_fields(
+        "auth",
+        password=pw.TextField(),
+    )
+    migrator.change_fields(
+        "chat",
+        title=pw.TextField(),
+    )
+    migrator.change_fields(
+        "document",
+        title=pw.TextField(),
+        filename=pw.TextField(),
+    )
+    migrator.change_fields(
+        "prompt",
+        title=pw.TextField(),
+    )
+    migrator.change_fields(
+        "user",
+        profile_image_url=pw.TextField(),
+    )
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your rollback migrations here."""
+
+    if isinstance(database, pw.SqliteDatabase):
+        # Alter the tables with timestamps
+        migrator.change_fields(
+            "chatidtag",
+            timestamp=pw.DateField(),
+        )
+        migrator.change_fields(
+            "document",
+            timestamp=pw.DateField(),
+        )
+        migrator.change_fields(
+            "modelfile",
+            timestamp=pw.DateField(),
+        )
+        migrator.change_fields(
+            "prompt",
+            timestamp=pw.DateField(),
+        )
+        migrator.change_fields(
+            "user",
+            timestamp=pw.DateField(),
+        )
+    migrator.change_fields(
+        "auth",
+        password=pw.CharField(max_length=255),
+    )
+    migrator.change_fields(
+        "chat",
+        title=pw.CharField(),
+    )
+    migrator.change_fields(
+        "document",
+        title=pw.CharField(),
+        filename=pw.CharField(),
+    )
+    migrator.change_fields(
+        "prompt",
+        title=pw.CharField(),
+    )
+    migrator.change_fields(
+        "user",
+        profile_image_url=pw.CharField(),
+    )

+ 1 - 1
backend/apps/web/models/auths.py

@@ -23,7 +23,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 class Auth(Model):
     id = CharField(unique=True)
     email = CharField()
-    password = CharField()
+    password = TextField()
     active = BooleanField()
 
     class Meta:

+ 3 - 3
backend/apps/web/models/chats.py

@@ -17,11 +17,11 @@ from apps.web.internal.db import DB
 class Chat(Model):
     id = CharField(unique=True)
     user_id = CharField()
-    title = CharField()
+    title = TextField()
     chat = TextField()  # Save Chat JSON as Text
 
-    created_at = DateTimeField()
-    updated_at = DateTimeField()
+    created_at = BigIntegerField()
+    updated_at = BigIntegerField()
 
     share_id = CharField(null=True, unique=True)
     archived = BooleanField(default=False)

+ 3 - 3
backend/apps/web/models/documents.py

@@ -25,11 +25,11 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 class Document(Model):
     collection_name = CharField(unique=True)
     name = CharField(unique=True)
-    title = CharField()
-    filename = CharField()
+    title = TextField()
+    filename = TextField()
     content = TextField(null=True)
     user_id = CharField()
-    timestamp = DateField()
+    timestamp = BigIntegerField()
 
     class Meta:
         database = DB

+ 1 - 1
backend/apps/web/models/modelfiles.py

@@ -20,7 +20,7 @@ class Modelfile(Model):
     tag_name = CharField(unique=True)
     user_id = CharField()
     modelfile = TextField()
-    timestamp = DateField()
+    timestamp = BigIntegerField()
 
     class Meta:
         database = DB

+ 2 - 2
backend/apps/web/models/prompts.py

@@ -19,9 +19,9 @@ import json
 class Prompt(Model):
     command = CharField(unique=True)
     user_id = CharField()
-    title = CharField()
+    title = TextField()
     content = TextField()
-    timestamp = DateField()
+    timestamp = BigIntegerField()
 
     class Meta:
         database = DB

+ 1 - 1
backend/apps/web/models/tags.py

@@ -35,7 +35,7 @@ class ChatIdTag(Model):
     tag_name = CharField()
     chat_id = CharField()
     user_id = CharField()
-    timestamp = DateField()
+    timestamp = BigIntegerField()
 
     class Meta:
         database = DB

+ 2 - 2
backend/apps/web/models/users.py

@@ -18,8 +18,8 @@ class User(Model):
     name = CharField()
     email = CharField()
     role = CharField()
-    profile_image_url = CharField()
-    timestamp = DateField()
+    profile_image_url = TextField()
+    timestamp = BigIntegerField()
     api_key = CharField(null=True, unique=True)
 
     class Meta:

+ 2 - 0
backend/apps/web/routers/auths.py

@@ -1,3 +1,5 @@
+import logging
+
 from fastapi import Request
 from fastapi import Depends, HTTPException, status
 

+ 7 - 0
backend/config.py

@@ -534,3 +534,10 @@ LITELLM_PROXY_PORT = int(os.getenv("LITELLM_PROXY_PORT", "14365"))
 if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535:
     raise ValueError("Invalid port number for LITELLM_PROXY_PORT")
 LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1")
+
+
+####################################
+# Database
+####################################
+
+DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")

+ 2 - 0
backend/requirements.txt

@@ -15,6 +15,8 @@ requests
 aiohttp
 peewee
 peewee-migrate
+psycopg2-binary
+pymysql
 bcrypt
 
 litellm==1.35.17