diff --git a/.gitignore b/.gitignore index 2b66062..d459a5f 100644 --- a/.gitignore +++ b/.gitignore @@ -163,5 +163,8 @@ cython_debug/ # Logger /pilot/logger/debug.log +#sqlite +/pilot/gpt-pilot + # workspace workspace \ No newline at end of file diff --git a/pilot/database/config.py b/pilot/database/config.py new file mode 100644 index 0000000..9e4fec9 --- /dev/null +++ b/pilot/database/config.py @@ -0,0 +1,8 @@ +import os + +DATABASE_TYPE = os.getenv("DATABASE_TYPE", "sqlite") +DB_NAME = os.getenv("DB_NAME") +DB_HOST = os.getenv("DB_HOST") +DB_PORT = os.getenv("DB_PORT") +DB_USER = os.getenv("DB_USER") +DB_PASSWORD = os.getenv("DB_PASSWORD") diff --git a/pilot/database/connection/postgres.py b/pilot/database/connection/postgres.py new file mode 100644 index 0000000..1d1635a --- /dev/null +++ b/pilot/database/connection/postgres.py @@ -0,0 +1,22 @@ +import psycopg2 +from peewee import PostgresqlDatabase +from psycopg2.extensions import quote_ident +from database.config import DB_NAME, DB_HOST, DB_PORT, DB_USER, DB_PASSWORD + +def get_postgres_database(): + return PostgresqlDatabase(DB_NAME, user=DB_USER, password=DB_PASSWORD, host=DB_HOST, port=DB_PORT) + +def create_postgres_database(): + conn = psycopg2.connect( + dbname='postgres', + user=DB_USER, + password=DB_PASSWORD, + host=DB_HOST, + port=DB_PORT + ) + conn.autocommit = True + cursor = conn.cursor() + safe_db_name = quote_ident(DB_NAME, conn) + cursor.execute(f"CREATE DATABASE {safe_db_name}") + cursor.close() + conn.close() diff --git a/pilot/database/connection/sqlite.py b/pilot/database/connection/sqlite.py new file mode 100644 index 0000000..b0d2cfd --- /dev/null +++ b/pilot/database/connection/sqlite.py @@ -0,0 +1,5 @@ +from peewee import SqliteDatabase +from database.config import DB_NAME + +def get_sqlite_database(): + return SqliteDatabase(DB_NAME) diff --git a/pilot/database/database.py b/pilot/database/database.py index 60bb18d..52d9711 100644 --- a/pilot/database/database.py +++ b/pilot/database/database.py @@ -4,12 +4,12 @@ from termcolor import colored from functools import reduce import operator import psycopg2 -import os from const.common import PROMPT_DATA_TO_IGNORE from logger.logger import logger from psycopg2.extensions import quote_ident from utils.utils import hash_data +from database.config import DB_NAME, DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DATABASE_TYPE from database.models.components.base_models import database from database.models.user import User from database.models.app import App @@ -26,12 +26,6 @@ from database.models.command_runs import CommandRuns from database.models.user_inputs import UserInputs from database.models.files import File -DB_NAME = os.getenv("DB_NAME") -DB_HOST = os.getenv("DB_HOST") -DB_PORT = os.getenv("DB_PORT") -DB_USER = os.getenv("DB_USER") -DB_PASSWORD = os.getenv("DB_PASSWORD") - def save_user(user_id, email, password): try: @@ -383,7 +377,14 @@ def drop_tables(): UserInputs, File, ]: - database.execute_sql(f'DROP TABLE IF EXISTS "{table._meta.table_name}" CASCADE') + if DATABASE_TYPE == "postgresql": + sql = f'DROP TABLE IF EXISTS "{table._meta.table_name}" CASCADE' + elif DATABASE_TYPE == "sqlite": + sql = f'DROP TABLE IF EXISTS "{table._meta.table_name}"' + else: + raise ValueError(f"Unsupported DATABASE_TYPE: {DATABASE_TYPE}") + + database.execute_sql(sql) def database_exists(): @@ -396,35 +397,42 @@ def database_exists(): def create_database(): - # Connect to the default 'postgres' database to create a new database - conn = psycopg2.connect( - dbname='postgres', - user=DB_USER, - password=DB_PASSWORD, - host=DB_HOST, - port=DB_PORT - ) - conn.autocommit = True - cursor = conn.cursor() + if DATABASE_TYPE == "postgres": + # Connect to the default 'postgres' database to create a new database + conn = psycopg2.connect( + dbname='postgres', + user=DB_USER, + password=DB_PASSWORD, + host=DB_HOST, + port=DB_PORT + ) + conn.autocommit = True + cursor = conn.cursor() - # Safely quote the database name - safe_db_name = quote_ident(DB_NAME, conn) + # Safely quote the database name + safe_db_name = quote_ident(DB_NAME, conn) - # Use the safely quoted database name in the SQL query - cursor.execute(f"CREATE DATABASE {safe_db_name}") + # Use the safely quoted database name in the SQL query + cursor.execute(f"CREATE DATABASE {safe_db_name}") - cursor.close() - conn.close() + cursor.close() + conn.close() + else: + pass def tables_exist(): tables = [User, App, ProjectDescription, UserStories, UserTasks, Architecture, DevelopmentPlanning, DevelopmentSteps, EnvironmentSetup, Development, FileSnapshot, CommandRuns, UserInputs, File] - for table in tables: - try: - database.get_tables().index(table._meta.table_name) - except ValueError: - return False + + if DATABASE_TYPE == "postgres": + for table in tables: + try: + database.get_tables().index(table._meta.table_name) + except ValueError: + return False + else: + pass return True diff --git a/pilot/database/models/architecture.py b/pilot/database/models/architecture.py index d530dba..261acb9 100644 --- a/pilot/database/models/architecture.py +++ b/pilot/database/models/architecture.py @@ -1,10 +1,15 @@ from peewee import * - +from database.config import DATABASE_TYPE from database.models.components.progress_step import ProgressStep +from database.models.components.sqlite_middlewares import JSONField from playhouse.postgres_ext import BinaryJSONField class Architecture(ProgressStep): - architecture = BinaryJSONField() + if DATABASE_TYPE == 'postgres': + architecture = BinaryJSONField() + else: + architecture = JSONField() # Custom JSON field for SQLite + class Meta: db_table = 'architecture' diff --git a/pilot/database/models/components/base_models.py b/pilot/database/models/components/base_models.py index 22f01f3..74d4a63 100644 --- a/pilot/database/models/components/base_models.py +++ b/pilot/database/models/components/base_models.py @@ -1,23 +1,17 @@ -import os from peewee import * from datetime import datetime from uuid import uuid4 -DB_NAME = os.getenv("DB_NAME") -DB_HOST = os.getenv("DB_HOST") -DB_PORT = os.getenv("DB_PORT") -DB_USER = os.getenv("DB_USER") -DB_PASSWORD = os.getenv("DB_PASSWORD") +from database.config import DATABASE_TYPE +from database.connection.postgres import get_postgres_database +from database.connection.sqlite import get_sqlite_database # Establish connection to the database -database = PostgresqlDatabase( - DB_NAME, - user=DB_USER, - password=DB_PASSWORD, - host=DB_HOST, - port=DB_PORT -) +if DATABASE_TYPE == "postgres": + database = get_postgres_database() +else: + database = get_sqlite_database() class BaseModel(Model): diff --git a/pilot/database/models/components/progress_step.py b/pilot/database/models/components/progress_step.py index 281659e..7b7e068 100644 --- a/pilot/database/models/components/progress_step.py +++ b/pilot/database/models/components/progress_step.py @@ -1,16 +1,23 @@ from peewee import * - -from playhouse.postgres_ext import BinaryJSONField - +from database.config import DATABASE_TYPE from database.models.components.base_models import BaseModel from database.models.app import App +from database.models.components.sqlite_middlewares import JSONField +from playhouse.postgres_ext import BinaryJSONField class ProgressStep(BaseModel): app = ForeignKeyField(App, primary_key=True, on_delete='CASCADE') step = CharField() - data = BinaryJSONField(null=True) - messages = BinaryJSONField(null=True) - app_data = BinaryJSONField() + + if DATABASE_TYPE == 'postgres': + app_data = BinaryJSONField() + data = BinaryJSONField(null=True) + messages = BinaryJSONField(null=True) + else: + app_data = JSONField() + data = JSONField(null=True) + messages = JSONField(null=True) + completed = BooleanField(default=False) completed_at = DateTimeField(null=True) diff --git a/pilot/database/models/components/sqlite_middlewares.py b/pilot/database/models/components/sqlite_middlewares.py new file mode 100644 index 0000000..69c188c --- /dev/null +++ b/pilot/database/models/components/sqlite_middlewares.py @@ -0,0 +1,14 @@ +import json +from peewee import TextField + + +class JSONField(TextField): + def python_value(self, value): + if value is not None: + return json.loads(value) + return value + + def db_value(self, value): + if value is not None: + return json.dumps(value) + return value diff --git a/pilot/database/models/development_planning.py b/pilot/database/models/development_planning.py index da8d6f9..8fe7a55 100644 --- a/pilot/database/models/development_planning.py +++ b/pilot/database/models/development_planning.py @@ -1,11 +1,15 @@ from peewee import * - +from database.config import DATABASE_TYPE from database.models.components.progress_step import ProgressStep +from database.models.components.sqlite_middlewares import JSONField from playhouse.postgres_ext import BinaryJSONField class DevelopmentPlanning(ProgressStep): - development_plan = BinaryJSONField() + if DATABASE_TYPE == 'postgres': + development_plan = BinaryJSONField() + else: + development_plan = JSONField() # Custom JSON field for SQLite class Meta: db_table = 'development_planning' diff --git a/pilot/database/models/development_steps.py b/pilot/database/models/development_steps.py index b0414bd..6492a4d 100644 --- a/pilot/database/models/development_steps.py +++ b/pilot/database/models/development_steps.py @@ -1,21 +1,26 @@ from peewee import * - -from playhouse.postgres_ext import BinaryJSONField - +from database.config import DATABASE_TYPE from database.models.components.base_models import BaseModel from database.models.app import App - +from database.models.components.sqlite_middlewares import JSONField +from playhouse.postgres_ext import BinaryJSONField class DevelopmentSteps(BaseModel): id = AutoField() # This will serve as the primary key app = ForeignKeyField(App, on_delete='CASCADE') hash_id = CharField(null=False) - messages = BinaryJSONField(null=True) - llm_response = BinaryJSONField(null=False) + + if DATABASE_TYPE == 'postgres': + messages = BinaryJSONField(null=True) + llm_response = BinaryJSONField(null=False) + else: + messages = JSONField(null=True) # Custom JSON field for SQLite + llm_response = JSONField(null=False) # Custom JSON field for SQLite + previous_step = ForeignKeyField('self', null=True, column_name='previous_step') class Meta: db_table = 'development_steps' indexes = ( (('app', 'hash_id'), True), - ) \ No newline at end of file + ) diff --git a/pilot/database/models/project_description.py b/pilot/database/models/project_description.py index 21a5bc1..462c1a2 100644 --- a/pilot/database/models/project_description.py +++ b/pilot/database/models/project_description.py @@ -1,7 +1,4 @@ from peewee import * - -from playhouse.postgres_ext import BinaryJSONField - from database.models.components.progress_step import ProgressStep diff --git a/pilot/database/models/user_stories.py b/pilot/database/models/user_stories.py index 24ee142..025e255 100644 --- a/pilot/database/models/user_stories.py +++ b/pilot/database/models/user_stories.py @@ -1,10 +1,14 @@ from peewee import * - +from database.config import DATABASE_TYPE from database.models.components.progress_step import ProgressStep +from database.models.components.sqlite_middlewares import JSONField from playhouse.postgres_ext import BinaryJSONField class UserStories(ProgressStep): - user_stories = BinaryJSONField() + if DATABASE_TYPE == 'postgres': + user_stories = BinaryJSONField() + else: + user_stories = JSONField() # Custom JSON field for SQLite class Meta: db_table = 'user_stories' diff --git a/pilot/database/models/user_tasks.py b/pilot/database/models/user_tasks.py index 124f86b..533340a 100644 --- a/pilot/database/models/user_tasks.py +++ b/pilot/database/models/user_tasks.py @@ -1,10 +1,15 @@ from peewee import * - +from database.config import DATABASE_TYPE from database.models.components.progress_step import ProgressStep +from database.models.components.sqlite_middlewares import JSONField from playhouse.postgres_ext import BinaryJSONField class UserTasks(ProgressStep): - user_tasks = BinaryJSONField() + if DATABASE_TYPE == 'postgres': + user_tasks = BinaryJSONField() + else: + user_tasks = JSONField() # Custom JSON field for SQLite + class Meta: db_table = 'user_tasks'