package com.xforceplus.ultraman.adapter.elasticsearch.query.utils;

import com.xforceplus.tech.base.core.context.ContextService;
import com.xforceplus.ultraman.adapter.elasticsearch.query.ElasticCustomShuttle;
import com.xforceplus.ultraman.metadata.engine.EntityClassEngine;
import com.xforceplus.ultraman.metadata.engine.EntityClassGroup;
import com.xforceplus.ultraman.metadata.entity.FieldType;
import com.xforceplus.ultraman.metadata.entity.IEntityClass;
import com.xforceplus.ultraman.metadata.entity.IEntityField;
import com.xforceplus.ultraman.metadata.entity.IRelation;
import com.xforceplus.ultraman.metadata.entity.legacy.impl.ColumnField;
import com.xforceplus.ultraman.oqsengine.plus.master.mysql.MysqlSqlDialectEx;
import com.xforceplus.ultraman.oqsengine.plus.storage.pojo.dto.select.SelectConfig;
import io.vavr.Tuple;
import io.vavr.Tuple2;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttle;
import org.apache.calcite.rel.RelShuttleImpl;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.TableFunctionScan;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.logical.*;
import org.apache.calcite.rel.rel2sql.RelToSqlConverter;
import org.apache.calcite.rel.type.StructKind;
import org.apache.calcite.sql.*;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.util.SqlBasicVisitor;
import org.apache.calcite.sql.util.SqlShuttle;
import org.apache.calcite.sql.util.SqlVisitor;
import org.apache.commons.lang3.StringUtils;
import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.*;
import java.util.stream.Collectors;


/**
 * @author wanyi
 * @desc将oqs relnode转换成elasticsearch 支持的sql
 */
public class ElasticSearchSqlConverter {

    private EntityClassEngine engine;
    private ContextService contextService;
    private ElasticCustomShuttle elasticCustomShuttle;

    public ElasticSearchSqlConverter(EntityClassEngine engine, ContextService contextService, ElasticCustomShuttle elasticCustomShuttle) {
        this.engine = engine;
        this.contextService = contextService;
        this.elasticCustomShuttle = elasticCustomShuttle;
    }

    /**
     * oqs RelNode 转换成elasticsearch child-parent模型识别的查询语句
     *
     * @param entityClass
     * @param selectConfig
     * @return
     **/
    public SqlNode oqsRelNodeConverterElasticSql(IEntityClass entityClass,
                                                 SelectConfig selectConfig) {
        EntityClassGroup group = engine.describe(entityClass, selectConfig.getProfile());
        ParseSqlNodeUtils parseSqlNodeUtils = new ParseSqlNodeUtils(group, engine);
        SqlNode sqlNode = new ElasticSearchRelToSqlConverter(MysqlSqlDialectEx.DEFAULT, group).visitRoot(selectConfig.getRawTree())
                .asStatement();
        if (elasticCustomShuttle.getJoinsCounter()) {
            if (sqlNode instanceof SqlSelect) {
                List<String> qualifiedName = new ArrayList<>();
                qualifiedName.add(entityClass.appCode());
                qualifiedName.add(entityClass.code());
                SqlIdentifier from = new SqlIdentifier(qualifiedName, SqlParserPos.ZERO);
                SqlNodeList selectList = ((SqlSelect) sqlNode).getSelectList();
                SqlNodeList orderList = ((SqlSelect) sqlNode).getOrderList();
                parseSqlNodeUtils.findSqlJoinConditions(sqlNode);
                List<SqlNode> conditions = parseSqlNodeUtils.getConditions();
                SqlBasicCall condition = null;
                if (parseSqlNodeUtils.getConditions().size() > 1) {
                    condition = new SqlBasicCall(
                            SqlStdOperatorTable.AND,
                            parseSqlNodeUtils.getConditions(),
                            SqlParserPos.ZERO);
                } else {
                    if (conditions.size() == 1) {
                        condition = (SqlBasicCall) conditions.get(0);
                    }
                }
                /**拼装where 过滤条件**/
                Integer fetch = selectConfig.getFetch();
                Integer offset = selectConfig.getOffset();
                SqlNumericLiteral fetchExactNumeric = SqlLiteral.createExactNumeric(String.valueOf(fetch), SqlParserPos.ZERO);
                SqlNumericLiteral offsetExactNumeric = null;
                if (offset != null) {
                    offsetExactNumeric = SqlLiteral.createExactNumeric(String.valueOf(offset), SqlParserPos.ZERO);
                }
                SqlSelect sqlSelect = new SqlSelect(SqlParserPos.ZERO, SqlNodeList.EMPTY,
                        selectList, from, condition, null, null, null, orderList, offsetExactNumeric, fetchExactNumeric, null);
                parseSqlNodeUtils.recursion(sqlSelect);
                sqlNode = sqlSelect;
            }
            contextService.getAll().put("join_query", true);
        } else {
            contextService.getAll().put("join_query", false);
            queryMasterCondition(sqlNode, selectConfig.getRawTree(), group);

        }
        //renameArrayTypeFileds(sqlNode, selectConfig, entityClass);
        return sqlNode;
    }

    /**
     * 过滤明细数据，只查parent 主数据
     *
     * @param sqlNode
     **/
    private void queryMasterCondition(SqlNode sqlNode, RelNode rawTree, EntityClassGroup group) {

        if (sqlNode instanceof SqlSelect) {

            AppendSql appendSql = new AppendSql();
            ProjectExpand projectExpand = new ProjectExpand();
            NormalizeFields normalizeFields = new NormalizeFields(group);

            //using visitor mode
            sqlNode.accept(appendSql);
            rawTree.accept(projectExpand);

            if (((SqlSelect) sqlNode).getOrderList() != null) {
                ((SqlSelect) sqlNode).setOrderBy((SqlNodeList) ((SqlSelect) sqlNode).getOrderList().accept(normalizeFields));
            }

            if (projectExpand.noProject) {
                List<SqlNode> collect = group.getEntityClass().fields().stream().map(x -> {
                    //TODO
                    SqlNode from = null;
                    String realName = x.name().replaceAll("\\.", "_");
                    if (x.type() == FieldType.STRINGS) {
                        List<SqlNode> as = new ArrayList<>();
                        as.add(new SqlIdentifier(realName.concat("@raw"), SqlParserPos.ZERO));
                        as.add(new SqlIdentifier(realName, SqlParserPos.ZERO));
                        from = new SqlBasicCall(new SqlAsOperator(), as, SqlParserPos.ZERO);
                    } else {
                        from = new SqlIdentifier(realName, SqlParserPos.ZERO);
                    }

                    return from;
                }).collect(Collectors.toList());
                ((SqlSelect) sqlNode).setSelectList(new SqlNodeList(collect, SqlParserPos.ZERO));
//                ((SqlSelect) sqlNode).getSelectList().addAll(collect);
            }

            Stack<Tuple2<SqlSelect, SqlNode>> node = appendSql.node;
            Tuple2<SqlSelect, SqlNode> where = null;
            SqlIdentifier create_time = new SqlIdentifier("create_time", SqlParserPos.ZERO);
            List<SqlNode> sqlNodes = new ArrayList<>();
            sqlNodes.add(create_time);
            SqlBasicCall condition = new SqlBasicCall(
                    SqlStdOperatorTable.IS_NOT_NULL,
                    sqlNodes,
                    SqlParserPos.ZERO);
            if (!node.isEmpty()) {
                where = node.get(0);
                if (where._2 instanceof SqlLiteral) {
                    //empty where 
                    where._1.setWhere(condition.accept(normalizeFields));
                } else {
                    List<SqlNode> conditions = new ArrayList<>();
                    conditions.add(where._2.accept(normalizeFields));
                    conditions.add(condition);
                    where._1.setWhere(new SqlBasicCall(
                            SqlStdOperatorTable.AND,
                            conditions,
                            SqlParserPos.ZERO));
                }
            }
        }
    }


    /**
     * 对数组类型字段进行重命名
     *
     * @param iEntityClass
     * @param queryNode
     **/
//    private void renameArrayTypeFileds(SqlNode queryNode, SelectConfig selectConfig,
//                                       IEntityClass iEntityClass) {
//        Map<String, IEntityField> rowType = getRowType(selectConfig.getProfile(), iEntityClass);
//        if (queryNode instanceof SqlOrderBy) {
//            queryNode = ((SqlOrderBy) queryNode).query;
//        }
//        List<@Nullable SqlNode> listFields = ((SqlSelect) queryNode).getSelectList().getList();
//        List<SqlNode> transfromArrayFields = new ArrayList<>();
//        if (listFields.size() == 1) {
//            SqlNode sqlIdentity = listFields.get(0);
//            if (sqlIdentity instanceof SqlIdentifier) {
//                if (StringUtils.isEmpty(((SqlIdentifier) sqlIdentity).getSimple())) {
//                    converterFieldsName(selectConfig.getFields(), rowType, transfromArrayFields);
//                }
//            }
//        } else {
//            converterFieldsName(selectConfig.getFields(), rowType, transfromArrayFields);
//        }
//        if (transfromArrayFields.size() >= 1) {
//            SqlNodeList sqlNodes = new SqlNodeList(transfromArrayFields, SqlParserPos.ZERO);
//            ((SqlSelect) queryNode).setSelectList(sqlNodes);
//        }
//    }


    /**
     * 遍及字段是否在数组类型的字段名称，有就进行转换
     *
     * @param fields
     * @param rowType
     * @param transfromArrayFields
     **/
//    private void converterFieldsName(List<Map.Entry<String, Tuple2<StructKind, Class>>> fields, Map<String, IEntityField> rowType,
//                                     List<SqlNode> transfromArrayFields) {
//        try {
//            fields.forEach(field -> {
//                String fieldsName = field.getKey().replace(".", "_").toLowerCase(Locale.ROOT);
//                IEntityField identityFields = rowType.get(fieldsName);
//                if (identityFields != null && identityFields.type() == FieldType.STRINGS) {
//                    List<SqlNode> as = new ArrayList<>();
//                    as.add(new SqlIdentifier(fieldsName.concat("@raw"), SqlParserPos.ZERO));
//                    as.add(new SqlIdentifier(fieldsName, SqlParserPos.ZERO));
//                    transfromArrayFields.add(new SqlBasicCall(new SqlAsOperator(), as, SqlParserPos.ZERO));
//                } else {
//                    transfromArrayFields.add(new SqlIdentifier(fieldsName, SqlParserPos.ZERO));
//                }
//            });
//        } catch (Exception e) {
//            throw e;
//        }
//    }

    /**
     * 获取表所有的字段
     *
     * @param iEntityClass
     * @param profile
     * @return
     **/
    private Map<String, IEntityField> getRowType(String profile, IEntityClass iEntityClass) {
        EntityClassGroup describe = engine.describe(iEntityClass, profile);
        Map<String, IEntityField> entityFieldSet = new HashMap<>();
        describe.getAllFields().stream().forEach(
                x -> entityFieldSet.put(x.name().toLowerCase(Locale.ROOT).replace(".", "_"), x));
        for (
                IRelation iRelation : iEntityClass.relations()) {
            Optional<IEntityClass> iRelationEntityClassGroup = engine.load(String.valueOf(iRelation.getEntityClassId()), profile);
            IEntityClass iRelationEntityClass = iRelationEntityClassGroup.get();
            engine.describe(iRelationEntityClass, profile).getAllFields().stream().forEach(
                    x -> entityFieldSet.put(x.name().toLowerCase(Locale.ROOT).replace(".", "_"), x));
        }
        return entityFieldSet;
    }

    class NormalizeFields extends SqlShuttle {

        private EntityClassGroup group;

        public NormalizeFields(EntityClassGroup group) {
            this.group = group;
        }

        @Override
        public SqlNode visit(SqlIdentifier identifier) {
            super.visit(identifier);
            String targetName = identifier.getSimple();
            String finalTargetName = targetName;
            Optional<ColumnField> first = group.columns().stream().filter(x -> x.name().equalsIgnoreCase(finalTargetName)).findFirst();
            if (first.isPresent()) {
                targetName = first.get().name();
            }

            String replace = targetName.replace(".", "_");
            return identifier.setName(0, replace);
        }
    }

    class ProjectExpand extends RelShuttleImpl {
        private boolean noProject = true;

        @Override
        public RelNode visit(RelNode other) {
            other.getInputs().forEach(x -> x.accept(this));
            if (other instanceof Project) {
                noProject = false;
            }
            return super.visit(other);
        }
    }

    class AppendSql extends SqlShuttle {
        Stack<Tuple2<SqlSelect, SqlNode>> node = new Stack<>();

//        boolean hasStar = false;

//        @Override
//        public @Nullable SqlNode visit(SqlIdentifier id) {
//            if (id == SqlIdentifier.STAR) {
//                hasStar = true;
//            }
//            return id;
//        }

        @Override
        public SqlNode visit(SqlCall call) {
            if (call instanceof SqlSelect) {
                call.getOperandList().stream().filter(Objects::nonNull).forEach(x -> x.accept(this));
                if (((SqlSelect) call).getWhere() != null) {
                    node.push(Tuple.of((SqlSelect) call, ((SqlSelect) call).getWhere()));
                } else {
                    node.push(Tuple.of((SqlSelect) call, SqlLiteral.createBoolean(true, SqlParserPos.ZERO)));
                }
            }

            //do
            return super.visit(call);
        }
    }
}
