package website.dachuan.migration.tool;

import lombok.SneakyThrows;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import net.sf.jsqlparser.statement.update.UpdateSet;

import java.io.StringReader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * @author yqb22
 */
public class SqlParserTenantUtil {
    public static String addTenantInfo(String sql, String tenantCol, String tenantVal) {
        if (sql == null || sql.length() == 0) {
            throw new RuntimeException("jsql parse sql 不能为空！");
        }
        Statement parser;
        try {
            parser = CCJSqlParserUtil.parse(new StringReader(sql));
        } catch (JSQLParserException e) {
            throw new RuntimeException(e.getMessage());
        }
        if (parser instanceof Insert) {
            return processInsertStatement((Insert) parser, tenantCol, tenantVal).toString();
        } else if (parser instanceof Update) {
            return processUpdateStatement((Update) parser, tenantCol, tenantVal).toString();
        } else if (parser instanceof Select) {
            return processSelectStatement((Select) parser, tenantCol, tenantVal).toString();
        } else if (parser instanceof Delete) {
            return processDeleteStatement((Delete) parser, tenantCol, tenantVal).toString();
        }
        return sql;
    }

    public static Delete processDeleteStatement(Delete delete, String tenantCol, String tenantVal) {
        delete.setWhere(processExpression(delete.getTable().getAlias(), delete.getWhere(), tenantCol, tenantVal));
        return delete;
    }

    @SuppressWarnings(value = "unchecked")
    public static Insert processInsertStatement(Insert insert, String tenantCol, String tenantVal) {
        // 场景1 insert into select 支持的场景 insert into i_table (c_1, c_2) select * from o_table(c_1, c_2); 不支持 列不明确的sql； INSERT INTO my_table select * from ss_dd // insert into i_table(c_1,c_2) WITH t AS ( SELECT * FROM user WHERE user.user_name = 'test' ) SELECT t.c_1,t.c2 FROM t
        // 场景2 insert into t_table(c_1, c_2) value(c_1, c_2);
        // 场景3 insert into t_table(c_1, c_2) values (c_1, c_2),(c_11, c_21),(c_21, c_22),(c_31, c_32);
        // 场景4 insert into t_table set c_1  =  'c_1', c_2 = 'c_2'; 解析不支持
        List<Column> columns = insert.getColumns();
        boolean includeTenantIdColumn = false;
        if (columns != null && columns.size() > 0) {
            includeTenantIdColumn = columns.stream().anyMatch(column -> column.getColumnName().contains(tenantCol));
            if (!includeTenantIdColumn) {
                columns.add(new Column(insert.getTable(), tenantCol));
            }
        }
        Select select = insert.getSelect();

        if (select != null) {
            if (select instanceof PlainSelect) {
                PlainSelect ps = (PlainSelect) select;
                // 进入场景1
                // 进行判断 sql 是否明确 插入列 和 查询列
                int iSum;
                if (columns == null || columns.size() == 0) {
                    throw new RuntimeException("insert select 语句需要明确插入列！");
                }
                iSum = columns.size();
                boolean unClear = ps.getSelectItems().stream().anyMatch(selectItem -> selectItem.toString().contains("*"));
                if (unClear) {
                    throw new RuntimeException("insert select 语句需要明确查询列！");
                }
                if (!includeTenantIdColumn) {
                    ps.getSelectItems().add(SelectItem.from(new Column(tenantCol), ps.getFromItem().getAlias()));
                }
                int oSum = ps.getSelectItems().size();
                if (iSum != oSum) {
                    throw new RuntimeException("insert select 语句需要明确列且查询列与插入列需要个数相同！");
                }
                // 处理查询语句
                processSelectStatement(select, tenantCol, tenantVal);
            } else if (select instanceof Values) {
                if (!includeTenantIdColumn) {
                    Values v = (Values) select;
                    Expression e = v.getExpressions();
                    if (e != null) {
                        //场景3 ()
                        ExpressionList<?> expressionList = v.getExpressions();
                        if (expressionList != null) {
                            if (expressionList instanceof ParenthesedExpressionList) {
                                ((ParenthesedExpressionList<Expression>) expressionList).addExpression(new StringValue(tenantVal));
                            } else {
                                for (Expression expression : expressionList) {
                                    if (expression instanceof ParenthesedExpressionList) {
                                        ((ParenthesedExpressionList<Expression>) expression).addExpression(new StringValue(tenantVal));
                                    }
                                }
                            }
                        }
                    }
                }
            } else {
                throw new RuntimeException("insert select 语句中的查询方式，当前系统不支持！");
            }
        } else {
            throw new RuntimeException("sql:" + insert + "，当前系统不支持解析！");
        }
        return insert;
    }

    /**
     * 1.1 单表update单字段
     * update stu t set t.NAME = 'mike' where t.ID = '1';
     * <p>
     * 1.2 单表update多字段
     * update stu t set t.NAME = 'mike', t.SEX = '1' where t.ID = '2';
     * <p>
     * 多表关联update的时候,记得要加exists()条件,否则不满足条件的记录被update称NULL：
     * 比如：stu表存在,但stu1表不存在的数据,对应的字段会被updat成NULL;
     * <p>
     * <p>
     * 2.1 多表关联update单字段
     * update stu t set t.NAME = (select t1.NAME from stu1 t1 where t1.ID = t.ID)
     * where exists(select 1 from stu1 t2 where t2.ID = t.ID);
     * <p>
     * 2.2 多表关联update多字段
     * update stu t set (t.NAME, t.SEX) = (select t1.NAME, t1.SEX from stu1 t1 where t1.ID = t.ID)
     * where exists(select 1 from stu1 t2 where t2.ID = t.ID);
     **/
    public static Update processUpdateStatement(Update update, String tenantCol, String tenantVal) {
        Table t = update.getTable();
        Alias alias = t.getAlias();
        List<UpdateSet> uss = update.getUpdateSets();
        // 更新 set select 中select添加租户信息
        for (UpdateSet updateSet : uss) {
            processUpdateSet(updateSet, tenantCol, tenantVal);
        }
        // where 条件中增加租户信息 目前支持 exist语句
        Expression w = update.getWhere();
        // 查询条件增加租户限制
        update.setWhere(processExpression(alias, w, tenantCol, tenantVal));
        return update;
    }


    public static Select processSelectStatement(Select select, String tenantCol, String tenantVal) {
        // 处理with
        if (select.getWithItemsList() != null && select.getWithItemsList().size() > 0) {
            List<WithItem> withItemList = select.getWithItemsList();
            for (WithItem w : withItemList) {
                Select withSelect = w.getSelect();
                if (withSelect instanceof ParenthesedSelect) {
                    Select s = ((ParenthesedSelect) withSelect).getSelect();
                    processSelectStatement(s, tenantCol, tenantVal);
                }
            }
        }
        if (select instanceof PlainSelect) {
            PlainSelect ps = (PlainSelect) select;
            if (ps.getJoins() != null) {
                for (Join join : ps.getJoins()) {
                    Collection<Expression> expressions = join.getOnExpressions();
                    if (expressions != null && expressions.size() > 1) {
                        throw new RuntimeException("异常！！！");
                    }
                    Expression e = null;
                    if (expressions != null && expressions.size() == 1) {
                        e = expressions.iterator().next();
                        expressions.clear();

                    }
                    if (expressions == null) {
                        expressions = new ArrayList<>(2);
                    }
                    e = processExpression(join.getFromItem().getAlias(), e, tenantCol, tenantVal);
                    expressions.add(e);
                }
            }
            ps.setWhere(processExpression(ps.getFromItem().getAlias(), ps.getWhere(), tenantCol, tenantVal));
        }
        return select;
    }

    private static void processUpdateSet(UpdateSet us, String tenantCol, String tenantVal) {
        ExpressionList<Column> columns = us.getColumns();
        ExpressionList<?> values = us.getValues();
        //
        for (Expression value : values) {
            if (value instanceof ParenthesedSelect) {
                ParenthesedSelect s = (ParenthesedSelect) value;
                Select select = s.getSelect();
                int ii = columns.size();
                int oi = select.getPlainSelect().getSelectItems().size();
                if (ii == oi) {
                    processSelectStatement(s.getSelect(), tenantCol, tenantVal);
                } else {
                    throw new RuntimeException("update select 需要保持set 与 select 字段数一致！");
                }
            }
        }
    }

    private static Expression processExpression(Alias alias, Expression expression, String tenantCol, String tenantVal) {
        Expression te = new EqualsTo(new Column(alias == null ? tenantCol : alias.getName() + "." + tenantCol), new StringValue(tenantVal));
        if (expression == null) {
            return te;
        } else {
            ExpressionVisitorAdapter expressionVisitorAdapter = new ExpressionVisitorAdapter() {
                @SneakyThrows
                @Override
                public void visit(ExistsExpression e) {
                    Expression r = e.getRightExpression();
                    if (r instanceof ParenthesedSelect) {
                        ParenthesedSelect s = (ParenthesedSelect) r;
                        processSelectStatement(s.getSelect(), tenantCol, tenantVal);
                    }
                }

                @SneakyThrows
                @Override
                public void visit(InExpression expr) {
                    Expression right = expr.getRightExpression();
                    if (right instanceof ParenthesedSelect) {
                        processSelectStatement(((ParenthesedSelect) right).getSelect(), tenantCol, tenantVal);
                    }
                }
            };
            expression.accept(expressionVisitorAdapter);
            if (!containTenantColByExpression(expression, tenantCol)) {
                Parenthesis parenthesis = new Parenthesis(expression);
                return new AndExpression(parenthesis, te);
            } else {
                return expression;
            }
        }
    }

    private static boolean containTenantColByExpression(Expression expression, String tenantCol) {
        if (expression == null) {
            return false;
        }
        AtomicBoolean exist = new AtomicBoolean(false);
        expression.accept(new ExpressionVisitorAdapter() {
            @Override
            public void visit(Column column) {
                if (column.getColumnName().toLowerCase().contains(tenantCol)) {
                    exist.set(true);
                }
            }
        });
        return exist.get();
    }
}
