import enum
import sqlite3

class DatabaseOrder(enum.Enum):
    DEFAULT = 0
    ASC     = 1
    DESC    = 2

class DatabaseTable():
    pass

class DatabaseTableCursor():
    __slots__ = 'cr', 'table',

    def __init__(self, table, cr):
        self.cr    = cr
        self.table = table

    def __getattr__(self, name):
        return getattr(self.cr, name)

    def __map__(self, row):
        fn = getattr(self.table, '__from_row__', None)

        if fn is not None:
            return fn(row)

        obj = self.table()

        for name in self.table.columns:
            try:
                setattr(obj, name, row[name])
            except IndexError:
                setattr(obj, name, None)

        return obj

    def fetchone(self):
        row = self.cr.fetchone()

        return self.__map__(row) if row is not None else None

    def fetchall(self):
        return map(self.__map__, self.cr.fetchall())

    def each(self):
        while True:
            obj = self.fetchone()

            if obj is None:
                break

            yield obj

class Database():
    __slots__ = 'db',

    def __init__(self, db):
        self.db = db

    def __getattr__(self, name):
        return getattr(self.db, name)

    @staticmethod
    def connect(path):
        db = sqlite3.connect(path)
        db.row_factory = sqlite3.Row

        return Database(db)

    def add(self, obj):
        table = type(obj)
        sql   = f"insert into {table.name} ("
        sql += ", ".join([c for c in table.columns if c != table.key])
        sql += ') values ('
        sql += ", ".join(['?' for c in table.columns if c != table.key])
        sql += ')'

        fn = getattr(obj, '__values__', None)

        if fn is not None:
            values = fn()
        else:
            values = list()

            for column in table.columns:
                if column != table.key:
                    values.append(getattr(obj, column))

        cr = self.db.execute(sql, values)

        setattr(obj, table.key, cr.lastrowid)

    def query_sql(self, table, sql, values=list()):
        cr = DatabaseTableCursor(table, self.db.cursor())
        cr.execute(sql, values)

        return cr

    def query(self, table, values=dict(), order_by=list()):
        sql = "select %s from %s" % (
            ', '.join(table.columns),
            table.name
        )

        if len(values) > 0:
            sql += " where "
            sql += " and ".join([f"{table.name}.{k} = ?" for k in values])

        if len(order_by) > 0:
            sql += " order by"

            first = True

            for column, order in order_by:
                if first:
                    first = False
                else:
                    sql += ", "

                if order is None or order is DatabaseOrder.DEFAULT:
                    sql += f" {column}"
                elif order is DatabaseOrder.ASC:
                    sql += f" {column} asc"
                elif order is DatabaseOrder.DESC:
                    sql += f" {column} desc"

        return self.query_sql(table, sql, list(values.values()))

    def get(self, table, values: dict=dict()):
        return self.query(table, values).fetchone()

    def _call(self, table, fn: str, column: str, values: dict=dict()) -> int:
        sql = f"select {fn}({column}) as ret from {table.name}"

        if len(values) > 0:
            sql += " where "
            sql += " and ".join([f"{k} = ?" for k in values])

        row = self.db.execute(sql, list(values.values())).fetchone()

        return row[0] if row is not None else None

    def min(self, table, column: str, values: dict=dict()) -> int:
        return self._call(table, 'min', column, values)

    def max(self, table, column: str, values: dict=dict()) -> int:
        return self._call(table, 'max', column, values)

    def count(self, table, values: dict=dict()) -> int:
        return self._call(table, 'count', table.key, values)