package io.gitlab.rujal_sh.db.config;

import com.zaxxer.hikari.HikariDataSource;
import io.gitlab.rujal_sh.annotation.StrategyType;
import io.gitlab.rujal_sh.annotation.TenantHolder;
import io.gitlab.rujal_sh.annotation.internal.DataSourceConfiguration;
import io.gitlab.rujal_sh.commons.CustomEntityManagerFactoryBean;
import io.gitlab.rujal_sh.commons.JpaHibernateProperties;
import io.gitlab.rujal_sh.commons.PersistenceJPAConfig;
import io.gitlab.rujal_sh.commons.utils.Constants;
import io.gitlab.rujal_sh.db.config.domain.DataSourceConfig;
import lombok.RequiredArgsConstructor;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.hibernate.cfg.Environment;
import org.hibernate.context.spi.CurrentTenantIdentifierResolver;
import org.hibernate.engine.jdbc.connections.spi.AbstractDataSourceBasedMultiTenantConnectionProviderImpl;
import org.hibernate.engine.jdbc.connections.spi.MultiTenantConnectionProvider;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.boot.jdbc.DataSourceBuilder;
import org.springframework.data.repository.CrudRepository;
import org.springframework.orm.jpa.LocalContainerEntityManagerFactoryBean;
import org.springframework.orm.jpa.vendor.HibernateJpaVendorAdapter;
import org.springframework.web.context.support.GenericWebApplicationContext;

import javax.sql.DataSource;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.*;
import java.util.stream.Collectors;


@RequiredArgsConstructor
//@Component
@Aspect
public class DBTenantHelper extends SqlExecutions{

    private final DataSource dataSource;
    private final BeanFactory beanFactory;
    private final GenericWebApplicationContext context;
    private final JpaHibernateProperties jpaHibernateProperties;
    private final CurrentTenantIdentifierResolver currentTenantIdentifierResolver;
    private final Optional<TenantDataSource> optionalTenantDataSource;
    private final MultiTenantConnectionProvider multiTenantConnectionProvider;

    private PersistenceJPAConfig persistenceJPAConfig;
    private DataSourceConfig dataSourceConfig;

    @Around("execution(* org.springframework.data.repository.CrudRepository+.save(..)) && args(entity)")
    public Object test(ProceedingJoinPoint proceedingJoinPoint, Object entity) throws Throwable {
        String userName = null;

        List<Annotation> annotationList = Arrays.asList(entity.getClass().getAnnotations()).stream().filter(annotation -> annotation instanceof TenantHolder).collect(Collectors.toList());
        if (!annotationList.isEmpty()) {
            persistenceJPAConfig = new PersistenceJPAConfig(jpaHibernateProperties, dataSource, multiTenantConnectionProvider, currentTenantIdentifierResolver, context);

            DataSourceConfig dbConfig = getDataSourceConfig(entity);

            try {
                Method methodUserName = entity.getClass().getDeclaredMethod("getName");
                userName = String.valueOf(methodUserName.invoke(entity));
            } catch (NoSuchMethodException ne) {
                Method methodUserName = entity.getClass().getDeclaredMethod("getUsername");
                userName = String.valueOf(methodUserName.invoke(entity));
            }

            CrudRepository crudRepository = (CrudRepository) proceedingJoinPoint.getThis();

            List<Object> result = new ArrayList<>();
            crudRepository.findAll().forEach(result::add);

            DataSourceConfig finalDbConfig = dbConfig;
            result = result.stream().filter(entityFromRepo -> checkDbConfiguration(entityFromRepo, finalDbConfig)).collect(Collectors.toList());

            if (result.isEmpty()) {
                performDatabaseAction(entity, dbConfig, userName, Constants.strategy);
            }
            proceedingJoinPoint.proceed();
            createTables(userName, getValue(entity, "dataSourceConfig", DataSourceConfig.class));
        } else {
            proceedingJoinPoint.proceed();
        }

        return proceedingJoinPoint;
    }

    private void performDatabaseAction(Object entity, DataSourceConfig dbConfig, String userName, StrategyType strategy) throws SQLException {
        if (!dbConfig.getGenerate()) {
            testDbConnection(dbConfig);
        } else {

            dataSourceConfig = dbConfig = getTenantWiseDataSourceConfig(userName);
            switch (strategy){
                case SCHEMA:
                    createTenantSchema(userName);
                    break;
                case DATABASE:
                    createDatabase(userName);
                    optionalTenantDataSource.get().dataSources.put(dbConfig.getName(), createDataSource(dataSourceConfig));
                    break;
            }

            setValue(entity, "dataSourceConfig", dbConfig, DataSourceConfig.class);
        }
    }

    private void createTenantSchema(String tenantName) throws SQLException {
        Connection connection = dataSource.getConnection();
        connection.createStatement()
                .execute(String.format("CREATE SCHEMA IF NOT EXISTS %s", tenantName));
        connection.setSchema(tenantName);
    }

    private DataSourceConfig getTenantWiseDataSourceConfig(String tenantId) {
        HikariDataSource dataSourceH = ((HikariDataSource) dataSource);
        DataSourceConfig dataSourceConfig = DataSourceConfig.builder()
                .driverClassName(dataSourceH.getDriverClassName())
                .name(tenantId)
                .username(dataSourceH.getUsername())
                .generate(true)
                .password(dataSourceH.getPassword())
                .initialize(true)
                .url(generateUrl(tenantId, dataSourceH.getJdbcUrl()))
                .build();
        return dataSourceConfig;
    }

    private String generateUrl(String tenantId, String jdbcUrl) {
        List<String> strings = new ArrayList<>(Arrays.asList(jdbcUrl.split("/")));
        strings.add(strings.size(), tenantId);
        strings.remove(strings.size() - 2);
        return String.join("/", strings);
    }

    private void testDbConnection(DataSourceConfig dbConfig) {
        try {
            Class.forName(dbConfig.getDriverClassName());
            Connection connection = DriverManager.getConnection(dbConfig.getUrl(), dbConfig.getUsername(), dbConfig.getPassword());
            if (connection == null) {
                throw new RuntimeException(String.format("Driver connection to URL %s failed", dbConfig.getUrl()));
            }
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        } catch (SQLException throwables) {
            throwables.printStackTrace();
        }
    }


    private boolean checkDbConfiguration(Object entityRepo, DataSourceConfig dbConfig) {
        DataSourceConfig dataSourceConfig = getDataSourceConfig(dbConfig);
        if (dataSourceConfig != null) {
            DataSourceConfig dataSource = (DataSourceConfig) entityRepo;
            if (dataSource.getUsername().equalsIgnoreCase(dbConfig.getUsername())) {
                return true;
            }
        }
        return false;
    }

    private DataSourceConfig getDataSourceConfig(Object entity) {
        try {
            String fieldName = getFieldNameWithAnnotationDataSourceConfiguration(entity);
            Method methodId = entity.getClass().getSuperclass().getDeclaredMethod("get" + fieldName);
            Object dbConfig = methodId.invoke(entity);
            if (dbConfig != null) {
                return (DataSourceConfig) dbConfig;
            }
        } catch (Exception e) {
            throw new RuntimeException(e.getMessage());
        }
        return null;
    }

    private void createDatabase(String tenantName) throws SQLException {
        DataSourceBuilder dataSourceBuilder = DataSourceBuilder
                .create().driverClassName(((HikariDataSource) dataSource).getDriverClassName())
                .username(((HikariDataSource) dataSource).getUsername())
                .password(((HikariDataSource) dataSource).getPassword())
                .url(((HikariDataSource) dataSource).getJdbcUrl());
        DataSource ds = dataSourceBuilder.build();
        Connection connection = ds.getConnection();
        Statement statement = connection.createStatement();
        if (!tenantName.equalsIgnoreCase("public")) {
            try {
                statement.execute(String.format("create database %s", tenantName));
            } catch (Exception e) {
                //already exists
            }
        }
    }

    private String getFieldNameWithAnnotationDataSourceConfiguration(Object entity) {
        List<Field> fields = Arrays.asList(entity.getClass().getSuperclass().getDeclaredFields());
        if (fields.size() != 0) {
            int count = 0;
            Field field = null;
            for (Field f : fields) {
                Optional<Annotation> schemaFields = Arrays.asList(f.getAnnotations()).stream().filter(annotation -> annotation instanceof DataSourceConfiguration).findFirst();
                if (schemaFields.isPresent()) {
                    field = f;
                    count++;
                }
            }
            if (count != 1) {
                throw new RuntimeException("Fields cannot be annotated with '@DataSourceConfiguration' more than once");
            } else {
                String fieldName = new StringBuilder().append(field.getName().substring(0, 1).toUpperCase()).append(field.getName().substring(1)).toString();
                return fieldName;
            }
        }else{
            throw new RuntimeException("Consider extending class 'DataSourceComponent' in class annotated with @TenantHolder");
        }
    }

    private String createBean(String tenantId, DataSourceConfig dbConfig) {
        String beanName = "entityManagerFactory" + tenantId;
        context.registerBean(beanName, CustomEntityManagerFactoryBean.class, () -> customEntityManagerFactoryBean(dbConfig, tenantId));
        return beanName;
    }

    public void createTables(String userName, DataSourceConfig dbConfig) {
        if (!userName.equalsIgnoreCase("public")) {
            String beanName = createBean(userName, dbConfig);
            CustomEntityManagerFactoryBean l = beanFactory.getBean(CustomEntityManagerFactoryBean.class);
            context.removeBeanDefinition(beanName);
//        }
        }
    }


    CustomEntityManagerFactoryBean customEntityManagerFactoryBean(DataSourceConfig dbConfig, String tenantId) {
        DataSourceBuilder factory = DataSourceBuilder
                .create().driverClassName(dbConfig.getDriverClassName())
                .username(dbConfig.getUsername())
                .password(dbConfig.getPassword())
                .url(dbConfig.getUrl());
        DataSource dataSource_tenant = factory.build();
        MultiTenantConnectionProvider multiTenantConnectionProvider = new AbstractDataSourceBasedMultiTenantConnectionProviderImpl() {
            @Override
            protected DataSource selectAnyDataSource() {
                return selectDataSource(tenantId);
            }

            @Override
            protected DataSource selectDataSource(String tenantIdentifier) {
                return dataSource_tenant;
            }
        };
        LocalContainerEntityManagerFactoryBean localContainerEntityManagerFactoryBean = persistenceJPAConfig.entityManagerFactory();

        CustomEntityManagerFactoryBean customEntityManagerFactoryBean = new CustomEntityManagerFactoryBean();
        Map<String, Object> jpaPropertyMap = localContainerEntityManagerFactoryBean.getJpaPropertyMap();

        if (Constants.strategy.equals(StrategyType.SCHEMA)){
            jpaPropertyMap.put("hibernate.default_schema", tenantId);
            customEntityManagerFactoryBean.setDataSource(dataSource);
        }else{
            if (!dbConfig.getGenerate()) {
                jpaPropertyMap.put(Environment.USER, dbConfig.getUsername());
                jpaPropertyMap.put(Environment.PASS, dbConfig.getPassword());
                jpaPropertyMap.put(Environment.URL, dbConfig.getUrl());
                jpaPropertyMap.put(Environment.DRIVER, dbConfig.getDriverClassName());
            } else {
                try {
                    jpaPropertyMap.put(Environment.URL, dbConfig.getUrl());
                } catch (ClassCastException classCastException) {
                    //do nothing
                }
            }
            jpaPropertyMap.put(Environment.MULTI_TENANT_CONNECTION_PROVIDER, multiTenantConnectionProvider);
            customEntityManagerFactoryBean.setDataSource(dataSource_tenant);
        }
        customEntityManagerFactoryBean.setJpaPropertyMap(jpaPropertyMap);

        customEntityManagerFactoryBean.setJpaVendorAdapter(new HibernateJpaVendorAdapter());
        customEntityManagerFactoryBean.setPackagesToScan(Constants.basePackages);
        return customEntityManagerFactoryBean;
    }
    // pass tenenta id and get datasource

    private <T> T getValue(Object entity, String fieldName, Class<T> classType) {
        try {
            fieldName = fieldName.substring(0,1).toUpperCase() + fieldName.substring(1);
            Method methodId = entity.getClass().getSuperclass().getDeclaredMethod("get" + fieldName);
            T value = (T) methodId.invoke(entity);
            return value;
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException(e.getMessage());
        }
    }

    private <T> void setValue(Object entity, String fieldName, Object value, Class<T> classType) {
        try {
            Field field = entity.getClass().getSuperclass().getDeclaredField(fieldName);
            field.setAccessible(true);
            field.set(entity, value);
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException(e.getMessage());
        }
    }
}
