package cn.bestwu.generator.database

import java.sql.DatabaseMetaData
import java.sql.ResultSet

/**
 * 数据库MetaData
 *
 * @author Peter Wu
 */
fun ResultSet.use(rs: ResultSet.() -> Unit) {
    while (next()) {
        rs(this)
    }
}

class DatabaseMetaData(private val metaData: DatabaseMetaData, private val catalog: String? = null, private val schema: String? = null) : DatabaseMetaData by metaData {

    /**
     * 所有数据表
     * @return 数据表名
     */
    fun tableNames(): List<String> {
        val tableNames = mutableListOf<String>()
        getTables(catalog, schema, null, null).use { tableNames.add(getString("TABLE_NAME")) }
        return tableNames
    }


    /**
     * 所有数据表
     * @param tableNames 表名
     * @return 数据表
     */
    fun tables(vararg tableNames: String): List<Table> {
        val tables = mutableListOf<Table>()
        println("查询：${if (tableNames.isEmpty()) "所有表" else tableNames.joinToString(",")} 数据结构")
        if (tableNames.isNotEmpty()) {
            tableNames.forEach {
                var curentSchema=schema
                var curentTableName=it
                if (it.contains('.')) {
                    val names = it.split('.')
                    curentSchema=names[0]
                    curentTableName=names[1]
                }
                getTables(catalog, curentSchema, curentTableName, null).use {
                    val tableName = getString("TABLE_NAME")
                    tables.add(Table(tableName, getString("TABLE_TYPE"), getString("REMARKS")
                            ?: "", primaryKeys(it), columns(it)))
                }
            }
        } else {
            getTables(catalog, schema, null, null).use {
                val tableName = getString("TABLE_NAME")
                tables.add(Table(tableName, getString("TABLE_TYPE"), getString("REMARKS")
                        ?: "", primaryKeys(tableName), columns(tableName)))
            }
        }

        if (tables.size == 0) {
            throw RuntimeException("未在数据库${databaseProductName}中找到${if (tableNames.isEmpty()) "所有表" else tableNames.joinToString(",")}")
        }
        return tables
    }

    /**
     * 数据字段
     * @param tableName 表名
     * @return 字段集
     */
    fun columns(tableName: String, vararg columnNames: String): List<Column> {
        val columns = mutableListOf<Column>()
        var curentSchema=schema
        var curentTableName=tableName
        if (tableName.contains('.')) {
            val names = tableName.split('.')
            curentSchema=names[0]
            curentTableName=names[1]
        }
        if (columnNames.isEmpty()) {
            getColumns(catalog, curentSchema, curentTableName, null).use {
                fillColumn(columns)
            }
        } else
            columnNames.forEach {
                getColumns(catalog, curentSchema, curentTableName, it).use {
                    fillColumn(columns)
                }
            }

        return columns
    }

    private fun ResultSet.fillColumn(fields: MutableList<Column>) {
        var supportsIsAutoIncrement = false
        var supportsIsGeneratedColumn = false

        val rsmd = metaData
        val colCount = rsmd.columnCount
        for (i in 1..colCount) {
            if ("IS_AUTOINCREMENT" == rsmd.getColumnName(i)) {
                supportsIsAutoIncrement = true
            }
            if ("IS_GENERATEDCOLUMN" == rsmd.getColumnName(i)) {
                supportsIsGeneratedColumn = true
            }
        }
        val columnName = getString("COLUMN_NAME")
        val typeName = getString("TYPE_NAME")
        val dataType = getInt("DATA_TYPE")
        val nullable = getInt("NULLABLE") == 1
        val scale = getInt("DECIMAL_DIGITS")
        val defaultVal = getString("COLUMN_DEF")?.trim('\'')?.trim()
        val length = getInt("COLUMN_SIZE")
        val comment = getString("REMARKS")?.replace("[\t\n\r]", "")?.replace("\"", "'")?.trim()
                ?: ""
        val tableCat = getString("TABLE_CAT")
        val tableSchem = getString("TABLE_SCHEM")
        val column = Column(tableCat, tableSchem, columnName, typeName, dataType, scale, length, comment, nullable, defaultVal)

        if (supportsIsAutoIncrement) {
            column.autoIncrement = "YES" == getString("IS_AUTOINCREMENT")
        }

        if (supportsIsGeneratedColumn) {
            column.generatedColumn = "YES" == getString("IS_GENERATEDCOLUMN")
        }
        fields.add(column)
    }

    /**
     * 获取表主键
     * @param tableName 表名
     * @return 主键字段名
     */
    fun primaryKeys(tableName: String): List<Column> {
        val primaryKeys = mutableListOf<Column>()
        var curentSchema=schema
        var curentTableName=tableName
        if (tableName.contains('.')) {
            val names = tableName.split('.')
            curentSchema=names[0]
            curentTableName=names[1]
        }
        getPrimaryKeys(catalog, curentSchema, curentTableName).use {
            primaryKeys.addAll(columns(tableName, getString("COLUMN_NAME")))
        }

        return primaryKeys
    }

}