74 lines
1.8 KiB
Python
74 lines
1.8 KiB
Python
import os
|
|
import sqlite3
|
|
from pathlib import Path
|
|
|
|
try:
|
|
import psycopg
|
|
except ImportError:
|
|
psycopg = None
|
|
|
|
DB_TYPE = 'sqlite3'
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent
|
|
|
|
DATABASE_URL = os.getenv("DATABASE_URL")
|
|
SQLITE_PATH = BASE_DIR / "instance" / "dev.db"
|
|
|
|
|
|
def is_postgres():
|
|
return DATABASE_URL and DATABASE_URL.lower().startswith("postgres")
|
|
|
|
if is_postgres():
|
|
SCHEMA_PATH = BASE_DIR / "schema.sql"
|
|
else:
|
|
SCHEMA_PATH = BASE_DIR / "schema_sqlite.sql"
|
|
|
|
def get_connection():
|
|
if is_postgres():
|
|
if psycopg is None:
|
|
raise RuntimeError("psycopg is required for PostgreSQL but not installed.")
|
|
return psycopg.connect(DATABASE_URL, autocommit=True)
|
|
else:
|
|
SQLITE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
|
conn = sqlite3.connect(str(SQLITE_PATH))
|
|
conn.execute("PRAGMA foreign_keys = ON;")
|
|
return conn
|
|
|
|
|
|
def _iter_sql_statements(script: str):
|
|
buffer = []
|
|
for line in script.splitlines():
|
|
buffer.append(line)
|
|
if line.strip().endswith(";"):
|
|
stmt = "\n".join(buffer).strip()
|
|
buffer = []
|
|
if stmt:
|
|
yield stmt
|
|
tail = "\n".join(buffer).strip()
|
|
if tail:
|
|
yield tail
|
|
|
|
|
|
def db_init():
|
|
if not SCHEMA_PATH.is_file():
|
|
raise FileNotFoundError(f"Schema file not found: {SCHEMA_PATH}")
|
|
|
|
with open(SCHEMA_PATH, "r", encoding="utf-8") as f:
|
|
sql_script = f.read()
|
|
|
|
conn = get_connection()
|
|
try:
|
|
if is_postgres():
|
|
with conn.cursor() as cur:
|
|
for stmt in _iter_sql_statements(sql_script):
|
|
cur.execute(stmt)
|
|
else:
|
|
conn.executescript(sql_script)
|
|
conn.commit()
|
|
print("Database initialized successfully.")
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
db_init() |