/*
 * Decompiled with CFR 0.152.
 */
package com.xforceplus.ultraman.adapter.elasticsearch.query.utils;

import com.google.common.collect.ImmutableList;
import com.xforceplus.ultraman.metadata.engine.EntityClassEngine;
import com.xforceplus.ultraman.metadata.engine.EntityClassGroup;
import com.xforceplus.ultraman.metadata.entity.IEntityClass;
import com.xforceplus.ultraman.oqsengine.plus.master.mysql.MysqlSqlDialectEx;
import com.xforceplus.ultraman.oqsengine.plus.storage.pojo.dto.select.SelectConfig;
import io.vavr.Tuple2;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.calcite.DataContext;
import org.apache.calcite.config.Lex;
import org.apache.calcite.jdbc.CalciteConnection;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.rel2sql.RelToSqlConverter;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.StructKind;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.sql.SqlAsOperator;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlCharStringLiteral;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlJoin;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOrderBy;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlTableRef;
import org.apache.calcite.sql.fun.SqlCountAggFunction;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.tools.FrameworkConfig;
import org.apache.calcite.tools.Frameworks;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParseSqlNodeUtils {
    private static final Logger log = LoggerFactory.getLogger(ParseSqlNodeUtils.class);
    private final String systemFields = "__sys_count";
    private Map<String, String> relationAsMapping = new HashMap<String, String>();
    private EntityClassGroup classGroup;
    private EntityClassEngine calssEngine;
    private Map<SqlNode, Boolean> hasModified = new ConcurrentHashMap<SqlNode, Boolean>();
    private List<SqlNode> conditions = new ArrayList<SqlNode>();

    public ParseSqlNodeUtils(EntityClassGroup classGroup, EntityClassEngine engine) {
        this.classGroup = classGroup;
        this.calssEngine = engine;
        this.initRelationAsMapping(classGroup.getEntityClass(), this.calssEngine);
    }

    private void initRelationAsMapping(IEntityClass entityClass, EntityClassEngine calssEngine) {
        entityClass.relations().forEach(iRelation -> {
            String key = "_".concat(iRelation.getName().toLowerCase());
            Optional loadEntityClass = calssEngine.load(String.valueOf(iRelation.getEntityClassId()), this.classGroup.profile());
            loadEntityClass.ifPresent(iEntityClass -> this.relationAsMapping.put(key, iEntityClass.code()));
        });
    }

    public void recursion(SqlNode sqlNode) {
        if (sqlNode instanceof SqlSelect) {
            SqlNodeList selectList = ((SqlSelect)sqlNode).getSelectList();
            SqlNode from = ((SqlSelect)sqlNode).getFrom();
            if (!(from instanceof SqlIdentifier) && !(from instanceof SqlTableRef)) {
                this.recursion(from);
            }
            SqlNode having = ((SqlSelect)sqlNode).getHaving();
            this.recursion(having);
            if (((SqlSelect)sqlNode).hasOrderBy()) {
                SqlNodeList orderBy = ((SqlSelect)sqlNode).getOrderList();
                orderBy.getList().forEach(order -> this.recursion((SqlNode)order));
            }
            if (((SqlSelect)sqlNode).hasWhere()) {
                SqlNode where = ((SqlSelect)sqlNode).getWhere();
                this.recursion(where);
            }
            SqlNodeList groupBy = ((SqlSelect)sqlNode).getGroup();
            selectList.getList().forEach(fields -> this.recursion((SqlNode)fields));
            if (groupBy != null) {
                groupBy.getList().forEach(group -> this.recursion((SqlNode)group));
            }
        } else if (sqlNode instanceof SqlOrderBy) {
            SqlNode query = ((SqlOrderBy)sqlNode).query;
            this.recursion(query);
        } else if (sqlNode instanceof SqlBasicCall) {
            for (SqlNode operand : ((SqlBasicCall)sqlNode).getOperandList()) {
                if (operand instanceof SqlIdentifier) {
                    this.modificationOperand(operand);
                    continue;
                }
                this.recursion(operand);
            }
        } else if (sqlNode instanceof SqlIdentifier) {
            this.modificationOperand(sqlNode);
        }
    }

    public static String relNodeToSqlStr(RelNode relNode) {
        HepProgram program = HepProgram.builder().build();
        HepPlanner planner = new HepPlanner(program);
        planner.setRoot(relNode);
        RelNode optimizedNode = planner.findBestExp();
        SqlNode sqlNode = new RelToSqlConverter(MysqlSqlDialectEx.DEFAULT).visitRoot(relNode).asStatement();
        String elasticSql = Util.toLinux((String)sqlNode.toSqlString(MysqlSqlDialectEx.DEFAULT).getSql()).replaceAll("\n", " ");
        return elasticSql;
    }

    public static SqlSelect buildCountSql(SqlSelect query) {
        query.setFetch(null);
        query.setOffset(null);
        query.setFetch(null);
        SqlNodeList selectList = query.getSelectList();
        selectList.clear();
        SqlAsOperator sqlAsOperator = new SqlAsOperator();
        SqlBasicCall countSqlBasic = new SqlBasicCall((SqlOperator)new SqlCountAggFunction("COUNT"), (List)selectList, SqlParserPos.ZERO);
        selectList.add((SqlNode)countSqlBasic);
        SqlCharStringLiteral aliasName = SqlLiteral.createCharString((String)"c", (SqlParserPos)SqlParserPos.ZERO);
        selectList.add((SqlNode)aliasName);
        SqlBasicCall sqlBasicCall = new SqlBasicCall((SqlOperator)sqlAsOperator, (List)selectList, SqlParserPos.ZERO);
        selectList.clear();
        selectList.add((SqlNode)sqlBasicCall);
        query.setSelectList(selectList);
        return query;
    }

    public static SqlNode getSqlNode(String sql) {
        SqlNode sqlNode = null;
        try {
            SqlParser parser = SqlParser.create((String)sql, (SqlParser.Config)SqlParser.config().withLex(Lex.MYSQL));
            sqlNode = parser.parseStmt();
        }
        catch (SqlParseException e) {
            log.error(e.getMessage());
        }
        return sqlNode;
    }

    public static SqlNode formatSqlNode(SqlNode sqlNode, String prefix) {
        ParseSqlNodeUtils.recursion(sqlNode, prefix);
        return sqlNode;
    }

    public static FrameworkConfig buildFrameworkConfig(Connection esConnection) throws SQLException {
        CalciteConnection calciteConn = esConnection.unwrap(CalciteConnection.class);
        SchemaPlus rootSchema = calciteConn.getRootSchema();
        Frameworks.ConfigBuilder configBuilder = Frameworks.newConfigBuilder();
        FrameworkConfig build = configBuilder.defaultSchema(rootSchema).parserConfig(SqlParser.config().withLex(Lex.MYSQL).withCaseSensitive(false)).build();
        return build;
    }

    private void modificationOperand(SqlNode operand) {
        if (operand instanceof SqlIdentifier) {
            ImmutableList identifierNames = ((SqlIdentifier)operand).names;
            if (identifierNames.size() == 2) {
                this.reNameFields(operand, (String)identifierNames.get(1));
            } else {
                this.reNameFields(operand, (String)identifierNames.get(0));
            }
        }
    }

    private void reNameFields(SqlNode operand, String name) {
        if (this.hasModified.containsKey(operand)) {
            return;
        }
        ArrayList<String> names = new ArrayList<String>();
        ArrayList<SqlParserPos> poses = new ArrayList<SqlParserPos>();
        if (name.startsWith("_") && name.contains(".")) {
            String[] split = name.split("\\.");
            String relationCode = split[0].toLowerCase(Locale.ROOT);
            if (this.relationAsMapping.get(relationCode) != null) {
                name = name.substring(1);
            }
        } else if (name.contains(".")) {
            name = name.replace(".", "_");
        }
        names.add(name);
        poses.add(operand.getParserPosition());
        ((SqlIdentifier)operand).setNames(names, poses);
        this.hasModified.put(operand, true);
    }

    private static void modificationOperand(SqlNode operand, String prefix) {
        ImmutableList identifierNames = ((SqlIdentifier)operand).names;
        ArrayList<String> names = new ArrayList<String>();
        ArrayList<SqlParserPos> poses = new ArrayList<SqlParserPos>();
        if (identifierNames.size() > 1) {
            String identifierName = (String)identifierNames.get(1);
            poses.add(operand.getParserPosition());
            names.add(prefix.concat(identifierName));
            ((SqlIdentifier)operand).setNames(names, poses);
        }
    }

    public List<SqlNode> getConditions() {
        return this.conditions;
    }

    public void findSqlJoinConditions(SqlNode sqlNode) {
        if (sqlNode instanceof SqlSelect) {
            SqlNode from = ((SqlSelect)sqlNode).getFrom();
            SqlNode where = ((SqlSelect)sqlNode).getWhere();
            if (where != null && where instanceof SqlBasicCall) {
                List operandList = ((SqlBasicCall)where).getOperandList();
                boolean existeSystemFields = false;
                for (SqlNode operandNode : operandList) {
                    String fieldsName;
                    if (!(operandNode instanceof SqlIdentifier) || !StringUtils.equalsIgnoreCase((String)(fieldsName = ((SqlIdentifier)operandNode).names.size() == 2 ? (String)((SqlIdentifier)operandNode).names.get(1) : ((SqlIdentifier)operandNode).getSimple()), (String)"__sys_count")) continue;
                    existeSystemFields = true;
                }
                if (!existeSystemFields) {
                    this.conditions.add(where);
                }
            }
            this.findSqlJoinConditions(from);
        }
        if (sqlNode instanceof SqlJoin) {
            SqlNode left = ((SqlJoin)sqlNode).getLeft();
            this.findSqlJoinConditions(left);
            SqlNode right = ((SqlJoin)sqlNode).getRight();
            this.findSqlJoinConditions(right);
        }
        if (sqlNode instanceof SqlBasicCall) {
            ((SqlBasicCall)sqlNode).getOperandList().forEach(baseCall -> this.findSqlJoinConditions((SqlNode)baseCall));
        }
    }

    private static void recursion(SqlNode sqlNode, String prefix) {
        if (sqlNode instanceof SqlSelect) {
            SqlNodeList selectList = ((SqlSelect)sqlNode).getSelectList();
            SqlNode from = ((SqlSelect)sqlNode).getFrom();
            ParseSqlNodeUtils.recursion(from, prefix);
            SqlNode having = ((SqlSelect)sqlNode).getHaving();
            ParseSqlNodeUtils.recursion(having, prefix);
            if (((SqlSelect)sqlNode).hasOrderBy()) {
                SqlNodeList orderBy = ((SqlSelect)sqlNode).getOrderList();
                orderBy.getList().forEach(order -> ParseSqlNodeUtils.recursion(order, prefix));
            }
            if (((SqlSelect)sqlNode).hasWhere()) {
                SqlNode where = ((SqlSelect)sqlNode).getWhere();
                ParseSqlNodeUtils.recursion(where, prefix);
            }
            SqlNodeList groupBy = ((SqlSelect)sqlNode).getGroup();
            selectList.getList().forEach(fields -> ParseSqlNodeUtils.recursion(fields, prefix));
            if (groupBy != null) {
                groupBy.getList().forEach(group -> ParseSqlNodeUtils.recursion(group, prefix));
            }
        } else if (sqlNode instanceof SqlOrderBy) {
            SqlNode query = ((SqlOrderBy)sqlNode).query;
            ParseSqlNodeUtils.recursion(query, prefix);
        } else if (sqlNode instanceof SqlBasicCall) {
            for (SqlNode operand : ((SqlBasicCall)sqlNode).getOperandList()) {
                if (operand instanceof SqlIdentifier) {
                    ParseSqlNodeUtils.modificationOperand(operand, prefix);
                    continue;
                }
                ParseSqlNodeUtils.recursion(operand, prefix);
            }
        } else if (sqlNode instanceof SqlIdentifier) {
            ParseSqlNodeUtils.modificationOperand(sqlNode, prefix);
        }
    }

    public static String converterSqlString(SqlNode sqlNode) {
        return Util.toLinux((String)sqlNode.toSqlString(MysqlSqlDialectEx.DEFAULT).getSql()).replaceAll("\n", " ");
    }

    public static SelectConfig getSelectConfig(String profile, RelDataType type, List<RexNode> ops, List<Map.Entry<String, Tuple2<StructKind, Class>>> fields, List<Pair<RexNode, String>> projects, List<Map.Entry<String, RelFieldCollation.Direction>> sort, Long offset, Long fetch, List<String> groupBy, List<AggregateCall> aggs, List<RelHint> hints, RelNode rawTree, DataContext dataContext) {
        HashMap context = new HashMap();
        SelectConfig selectConfig = new SelectConfig();
        selectConfig.setRexNodes(ops);
        selectConfig.setFields(fields);
        selectConfig.setRelDataType(type);
        selectConfig.setSorts(sort);
        selectConfig.setOffset(Optional.ofNullable(offset).map(Long::intValue).orElse(0).intValue());
        selectConfig.setFetch(Optional.ofNullable(fetch).map(Long::intValue).orElse(20).intValue());
        selectConfig.setProjects(projects);
        selectConfig.setAggs(aggs);
        selectConfig.setGroupBy(groupBy);
        selectConfig.setHints(hints);
        selectConfig.setContext(context);
        selectConfig.setRawTree(rawTree);
        selectConfig.setProfile(profile);
        selectConfig.setDataContext(dataContext);
        return selectConfig;
    }
}

