# The MIT License
# 
# Copyright (c) 2008 Randall Smith
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# 
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

"""Provides interface for extracting database schema information.

"""

import sqlalchemy as sa
from sqlalchemy.databases import information_schema as ischema

decorator_with_args = lambda decorator: lambda *args, **kwargs: lambda func: decorator(func, *args, **kwargs)

@decorator_with_args
def executable(meth, single_field=False):
    def new(self, *args, **kwargs):
        """Execute the statement if not returning it.

        """
        return_stmt = False
        if 'return_statement' in kwargs:
            return_stmt = kwargs['return_statement']
            del kwargs['return_statement']
        if self._db_con is None:
            return_stmt = True
        stmt = meth(self, *args, **kwargs)
        if return_stmt:
            return stmt
        else:
            rp = self._execute(stmt)
            if single_field:
                return [r[0] for r in rp]
            return rp
    return new

key_column_usage = sa.Table("key_column_usage", ischema.ischema, 
    sa.Column("table_schema", sa.String), 
    sa.Column("table_name", sa.String),
    sa.Column("table_catalog", sa.String),
    sa.Column("column_name", sa.String),
    sa.Column("constraint_name", sa.String),
    sa.Column("constraint_catalog", sa.String),
    sa.Column("ordinal_position", sa.Integer),
    useexisting=True,
    schema="information_schema")

table_constraints = sa.Table("table_constraints", ischema.ischema,
    sa.Column("table_schema", sa.String),
    sa.Column("table_name", sa.String),
    sa.Column("table_catalog", sa.String),
    sa.Column("constraint_name", sa.String),
    sa.Column("constraint_type", sa.String),
    sa.Column("constraint_schema", sa.String),
    sa.Column("constraint_catalog", sa.String),
    sa.Column("is_deferrable", sa.String),
    sa.Column("initially_deferred", sa.String),
    useexisting=True,
    schema="information_schema")

constraint_column_usage = sa.Table("constraint_column_usage", ischema.ischema,
    sa.Column("table_schema", sa.String),
    sa.Column("table_name", sa.String),
    sa.Column("column_name", sa.String),
    sa.Column("constraint_name", sa.String),
    sa.Column("constraint_schema", sa.String),
    sa.Column("constraint_catalog", sa.String),
    useexisting=True,
    schema="information_schema")

class BaseInfoFetcher(object):
    """Base class for InfoFetchers that uses the information_schema standard.

    """

    def __init__(self, db_con=None, db_version=None, catalog_name=None):
        self._db_con = db_con
        self._db_version = db_version
        self._catalog_name = catalog_name

    def _execute(self, stmt):
        """Execute stmt with self._db_con if it is not None."""
        if self._db_con is not None:
            return self._db_con.execute(stmt)
        else:
            raise Exception('db_con is not set.')

    def _setCatalogName(self, name):
        self._catalog_name = name

    def _getCatalogName(self):
        if self._catalog_name is None:
            tbl = ischema.schemata
            s = sa.select([tbl.c.catalog_name],
                distinct=True
            )
            # Assuming only one exists.
            result = self._execute(s).fetchall()
            name = result[0].catalog_name
            return name
        return self._catalog_name

    catalog_name = property(_getCatalogName, _setCatalogName)

    @executable(single_field=True)
    def getSchemaNames(self):
        tbl = ischema.schemata
        s = sa.select([tbl.c.schema_name, ],
            tbl.c.catalog_name == self.catalog_name
        )
        return s

    @executable(single_field=True)
    def getTableNames(self, schema_name=None):
        tbl = ischema.tables
        s = sa.select([tbl.c.table_name, ],
            sa.and_(
                tbl.c.table_schema == schema_name,
                tbl.c.table_type == 'BASE TABLE',
                tbl.c.table_catalog == self.catalog_name,
            )
        )
        return s

    @executable()
    def getViewNames(self, schema_name=None):
        tbl = ischema.tables
        s = sa.select([tbl.c.table_name, ],
            sa.and_(
                tbl.c.table_schema == schema_name,
                tbl.c.table_type == 'VIEW',
                tbl.c.table_catalog == self.catalog_name,
            )
        )
        return s

    @executable()
    def getColumnInfo(self, table_name, schema_name=None):
        tbl = ischema.columns
        fields = [tbl.c.column_name, tbl.c.is_nullable, tbl.c.data_type,
                  tbl.c.data_type, tbl.c.ordinal_position,
                  tbl.c.character_maximum_length, tbl.c.numeric_precision,
                  tbl.c.numeric_scale, tbl.c.column_default]
        s = sa.select(fields,
            sa.and_(
                tbl.c.table_schema == schema_name,
                tbl.c.table_name == table_name,
            )
        ).order_by(tbl.c.ordinal_position)
        return s

    @executable()
    def getConstraintInfo(self, table_name, schema_name=None,
                          constraint_type=None):
        # define the tables
        referential_constraints = ischema.ref_constraints
        kcu = key_column_usage
        tc = table_constraints
        rc = referential_constraints
        ccu = constraint_column_usage
        # build the join
        join = kcu.outerjoin(tc,
            sa.and_(
                kcu.c.table_name == tc.c.table_name,
                kcu.c.table_schema == tc.c.table_schema,
                kcu.c.table_catalog == tc.c.table_catalog,
                kcu.c.constraint_name == tc.c.constraint_name,
                kcu.c.constraint_catalog == tc.c.constraint_catalog,
            )
        )
        join = join.outerjoin(rc,
            sa.and_(
                rc.c.constraint_schema == tc.c.constraint_schema,
                rc.c.constraint_name == tc.c.constraint_name,
                tc.c.constraint_catalog == tc.c.constraint_catalog,
            )
        )
        join = join.outerjoin(ccu,
            sa.and_(
                rc.c.unique_constraint_schema == ccu.c.constraint_schema,
                rc.c.unique_constraint_name == ccu.c.constraint_name,
                tc.c.constraint_catalog == ccu.c.constraint_catalog,
            )
        )
        # build the statement
        fields = [tc.c.constraint_name,
                  kcu.c.column_name, tc.c.constraint_type,
                  tc.c.is_deferrable, tc.c.initially_deferred,
                  rc.c.match_option, rc.c.update_rule, rc.c.delete_rule,
                  ccu.c.constraint_name,
                  ccu.c.table_name.label('f_table_name'),
                  ccu.c.column_name.label('f_column_name'),
                  kcu.c.ordinal_position
                  ]
        s = sa.select(fields, 
            sa.and_(
                kcu.c.table_name == table_name,
                kcu.c.constraint_catalog == self.catalog_name,
            ),
            from_obj=join
        ).order_by(kcu.c.constraint_name, kcu.c.ordinal_position)
        if constraint_type is not None:
            s = s.where(tc.c.constraint_type == constraint_type)
        return s

    @executable()
    def getForeignKeyInfo(self, table_name, schema_name=None):
        return self.getConstraintInfo(table_name, schema_name, 
                                      constraint_type='FOREIGN KEY',
                                      return_statement=True)
        

class PostgresInfoFetcher(BaseInfoFetcher):

    def __init__(self, *args, **kwargs):
        BaseInfoFetcher.__init__(self, *args, **kwargs)


class MSSQLInfoFetcher(BaseInfoFetcher):

    def __init__(self, *args, **kwargs):
        BaseInfoFetcher.__init__(self, *args, **kwargs)


def createInfoFetcher(db_con, db_type=None, db_version=None):
    """Return the proper InfoFetcher instance."""
    if db_type is None and db_version is None:
        # Determine from db_con.
        if not isinstance(db_con, sa.engine.Connectable):
            raise ValueError, 'db_con must be type Connectable'
    # Todo
    return BaseInfoFetcher(db_con, db_version, catalog_name=None)
