package com.xforceplus.xlog.mybatis.sqlpretty;

import com.alibaba.fastjson.JSON;
import com.xforceplus.xlog.core.exception.XlogException;
import com.xforceplus.xlog.core.utils.ExceptionUtil;
import com.xforceplus.xlog.mybatis.model.SqlPrettyResult;
import lombok.SneakyThrows;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.StatementVisitorAdapter;
import net.sf.jsqlparser.statement.Statements;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import net.sf.jsqlparser.util.TablesNamesFinder;
import org.apache.commons.lang3.time.DateFormatUtils;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.ognl.DefaultClassResolver;
import org.apache.ibatis.ognl.DefaultTypeConverter;
import org.apache.ibatis.ognl.Ognl;
import org.apache.ibatis.ognl.OgnlContext;
import org.springframework.util.CollectionUtils;

import java.math.BigDecimal;
import java.util.*;
import java.util.stream.Collectors;

/**
 * SQL美化工具
 *
 * 美化指的是把 PreStatement 中的 ? 用传入参数准确的替换掉，非格式上的优化
 *
 * @author gulei
 * @date 2023/01/19
 */
public class SqlPrettyUtil {
    private SqlPrettyUtil() {
        throw new IllegalStateException("Utility class");
    }

    /**
     * 美化SQL
     *
     * @param mappedStatement 映射后的语句
     * @param parameterObject 参数对象
     * @return 美化后的SQL字符串
     */
    @SneakyThrows
    public static SqlPrettyResult prettify(final MappedStatement mappedStatement, final Object parameterObject) {
        final SqlPrettyResult result = new SqlPrettyResult();

        final BoundSql boundSql = mappedStatement.getBoundSql(parameterObject);
        final List<ParameterMapping> parameterMappingList = boundSql.getParameterMappings();
        final Class<?> parameterClazz = parameterObject.getClass();
        String sql = boundSql.getSql().replaceAll("\\s+", " ");

        result.setSql(sql);

        // 语句分析
        final Statements statements = CCJSqlParserUtil.parseStatements(sql);
        if (statements != null && !CollectionUtils.isEmpty(statements.getStatements())) {

            // 分析使用了哪些表
            final Set<String> tableNameSet = statements.getStatements().stream().map(t -> new TablesNamesFinder().getTableList(t))
                    .flatMap(Collection::stream)
                    .collect(Collectors.toSet());
            result.setTables(tableNameSet);

            // 分析where语句使用了哪些字段 - 取第一条语句，因为多语句时，往往都是一样的
            result.setColumns(calcWhereColumns(mappedStatement.getSqlCommandType(), statements.getStatements().get(0)));
        }

        // 对项目中用到的批处理插件做适配
        if (isCoderBeeBatchPlugin(parameterClazz.getName())) {
            result.setParameters(JSON.toJSONString(parameterObject));
            result.getPlugins().add("net.coderbee.mybatis.batch");
            return result;
        }

        // 替换SQL中的?为具体的参数值
        for (ParameterMapping parameterMapping : parameterMappingList) {
            final Class<?> javaType = parameterMapping.getJavaType();
            final String property = parameterMapping.getProperty();

            final int index = findFirstRealQuestionMark(sql);
            if (index == -1) {
                throw XlogException.create("findFirstRealQuestionMark 索引值为-1");
            }

            try {
                if (parameterClazz == javaType) {
                    sql = replaceAtIndex(sql, index, handleType(parameterObject));
                } else if (boundSql.hasAdditionalParameter(property)) {
                    sql = replaceAtIndex(sql, index, handleType(boundSql.getAdditionalParameter(property)));
                } else {
                    final Map context = new OgnlContext(new DefaultClassResolver(), new DefaultTypeConverter(), new DefaultMemberAccess(true));
                    sql = replaceAtIndex(sql, index, handleType(Ognl.getValue(property, context, parameterObject)));
                }
            } catch (Throwable e) {
                throw new RuntimeException("PrettySQL Exception: " + ExceptionUtil.toDesc(e));
            }
        }

        result.setSql(sql);

        return result;
    }

    public static Set<String> calcWhereColumns(final SqlCommandType sqlCommandType, final Statement statement) {
        final Set<String> result = new HashSet<>();

        if (SqlCommandType.SELECT != sqlCommandType && SqlCommandType.UPDATE != sqlCommandType && SqlCommandType.DELETE != sqlCommandType) {
            return result;
        }

        final ExpressionVisitorAdapter expressionVisitorAdapter = new ExpressionVisitorAdapter() {
            public void visit(Column column) {
                result.add(column.getColumnName());
            }
        };

        final SelectVisitorAdapter selectVisitorAdapter = new SelectVisitorAdapter() {
            @Override
            public void visit(PlainSelect plainSelect) {
                if (plainSelect.getWhere() != null) {
                    plainSelect.getWhere().accept(expressionVisitorAdapter);
                }

                final SelectVisitorAdapter self = this;

                if (plainSelect.getFromItem() != null) {
                    plainSelect.getFromItem().accept(new FromItemVisitorAdapter() {

                        @Override
                        public void visit(SubSelect subSelect) {
                            subSelect.getSelectBody().accept(self);
                        }

                        @Override
                        public void visit(ParenthesisFromItem parenthesisFromItem) {
                            parenthesisFromItem.accept(this);
                        }
                    });
                }
            }
        };

        statement.accept(new StatementVisitorAdapter() {
            public void visit(Delete delete) {
                if (delete.getWhere() != null) {
                    delete.getWhere().accept(expressionVisitorAdapter);
                }
            }

            public void visit(Update update) {
                if (update.getWhere() != null) {
                    update.getWhere().accept(expressionVisitorAdapter);
                }
            }

            public void visit(Select select) {
                select.getSelectBody().accept(selectVisitorAdapter);
            }
        });

        return result;
    }

    private static boolean isCoderBeeBatchPlugin(final String className) {
        return "net.coderbee.mybatis.batch.BatchParameter".equals(className);
    }

    private static int findFirstRealQuestionMark(final String sql) {
        boolean quoteFlag = false;

        for (int i = 0; i < sql.length(); i++) {
            final char ch = sql.charAt(i);

            if (ch == '\'') {
                quoteFlag = !quoteFlag;
            }

            if (ch == '?' && !quoteFlag) {
                return i;
            }
        }

        return -1;
    }

    private static String replaceAtIndex(final String text, final int index, final String replacement) {
        return text.substring(0, index) + replacement + text.substring(index + 1);
    }

    private static String handleType(final Object data) {
        if (data == null) {
            return "NULL";
        }

        final Class<?> clazz = data.getClass();

        if (clazz == String.class) {
            return String.format("'%s'", data);
        } else if (clazz == Date.class) {
            return String.format("'%s'", DateFormatUtils.format((Date) data, "yyyy-MM-dd HH:mm:ss"));
        } else if (clazz == BigDecimal.class) {
            return String.format("%s", ((BigDecimal) data).toPlainString());
        }

        return data + "";
    }
}
