package com.xforceplus.ultraman.oqsengine.plus.master.mysql.query;

import com.google.common.collect.Range;
import com.google.common.collect.RangeSet;
import com.google.common.collect.TreeRangeSet;
import com.xforceplus.ultraman.metadata.domain.vo.dto.DictItem;
import com.xforceplus.ultraman.metadata.engine.EntityClassGroup;
import com.xforceplus.ultraman.metadata.entity.FieldType;
import com.xforceplus.ultraman.metadata.entity.IEntityClass;
import com.xforceplus.ultraman.metadata.entity.IEntityField;
import com.xforceplus.ultraman.metadata.entity.legacy.impl.ColumnField;
import com.xforceplus.ultraman.metadata.service.DictService;
import com.xforceplus.ultraman.oqsengine.plus.master.mysql.utils.RexNodeHelper;
import io.vavr.Tuple;
import io.vavr.Tuple2;
import io.vavr.Tuple3;
import io.vavr.control.Either;
import org.apache.calcite.rel.BiRel;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.core.*;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.*;
import org.apache.calcite.sql.SqlBinaryOperator;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlJsonValueFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.NlsString;
import org.apache.calcite.util.Sarg;
import org.apache.tinkerpop.gremlin.process.traversal.P;

import java.util.*;
import java.util.stream.Collectors;

import static org.apache.calcite.sql.type.OperandTypes.*;

/**
 * TODO add custom filter
 */
public class CopyVisitor extends RexVisitorImpl<RexNode> {

    private RelBuilder builder;
    private List<EntityClassGroup> involvedEntityClasses;
    private RelNode currentNode;

    private List<RexDynamicParam> dynamic = new ArrayList<>();

    private boolean useStrictEnum;

    private DictService dictService;

    private Map<String, List<String>> rewriteMapping;

    private final static String JSON_CONTAINS = "JSON_CONTAINS";

    private final static String JSON_OVERLAPS = "JSON_OVERLAPS";

    private static final SqlBinaryOperator BIT_AND =
            new SqlBinaryOperator(
                    "&",
                    SqlKind.OTHER,
                    60,
                    true,
                    ReturnTypes.QUOTIENT_NULLABLE,
                    InferTypes.FIRST_KNOWN,
                    NUMERIC_NUMERIC
                            .or(INTERVAL_NUMERIC)
                            .or(NUMERIC_INTERVAL));

    public CopyVisitor(
            RelBuilder builder
            , List<EntityClassGroup> involvedEntityClasses
            , RelNode currentNode
            , boolean useStrictEnum
            , DictService dictService
            , Map<String, List<String>> rewriteMapping
    ) {

        super(true);
        this.builder = builder;
        this.involvedEntityClasses = involvedEntityClasses;

        /**
         * current node is in old type
         * if is a join will get left and right
         */
        this.currentNode = currentNode;
        this.useStrictEnum = useStrictEnum;
        this.dictService = dictService;
        this.rewriteMapping = rewriteMapping;
    }

    @Override
    public RexNode visitDynamicParam(RexDynamicParam param) {
        dynamic.add(param);
        return param;
    }

    public List<RexDynamicParam> getDynamic() {
        return dynamic;
    }

    /**
     * @param inputRef
     * @return
     */
    @Override
    public RexNode visitInputRef(RexInputRef inputRef) {
        //turn origin to current type
//        RelDataTypeField relDataTypeField = null;
//        EntityClassGroup ptr = null;
//        int input = 1;
//
//        RelNode nodePtr = currentNode;
//
//        /**
//         * case the filter - join has no project .will cause the filter fields is not right
//         * for example
//         *    join A, B , C, D
//         *    filter will see A, B, C,D
//         *    but if use builder.field("C") may fail
//         */
////        if(nodePtr instanceof Filter) {
////
////        }
//        int inputCount = 1;
//        boolean isRight = false;
//        /**
//         * TODO multi field on ??
//         * if join the case is on condition
//         * we need do two field on
//         */
//        if (nodePtr instanceof Join) {
//            inputCount = 2;
//            int leftSize = nodePtr.getInput(0).getRowType().getFieldList().size();
//            int index = inputRef.getIndex();
//            if (index >= leftSize) {
//                //find in 1
//                relDataTypeField = nodePtr.getInput(1).getRowType().getFieldList().get(index - leftSize);
//                ptr = involvedEntityClasses.get(1);
//                input = 1;
//                isRight = true;
//            } else {
//                //find in 0
//                relDataTypeField = nodePtr.getInput(0).getRowType().getFieldList().get(index);
//                ptr = involvedEntityClasses.get(0);
//                input = 2;
//            }
//        } else if (nodePtr instanceof Filter || nodePtr instanceof Project) {
//            if (((SingleRel) nodePtr).getInput() instanceof Join) {
//                //if no project join real field should do in join
//                nodePtr = ((Filter) nodePtr).getInput();
//                //find real field for filter
//                int leftSize = nodePtr.getInput(0).getRowType().getFieldList().size();
//                int index = inputRef.getIndex();
//                if (index >= leftSize) {
//                    //find in 1
//                    relDataTypeField = nodePtr.getInput(1).getRowType().getFieldList().get(index - leftSize);
//                    ptr = involvedEntityClasses.get(1);
//                    isRight = true;
//                } else {
//                    //find in 0
//                    relDataTypeField = nodePtr.getInput(0).getRowType().getFieldList().get(index);
//                    ptr = involvedEntityClasses.get(1);
//                }
//            }
//        }
//
//        /**
//         * other case
//         */
//        if (relDataTypeField == null) {
//            ptr = involvedEntityClasses.get(0);
//            relDataTypeField = nodePtr.getRowType().getFieldList().get(inputRef.getIndex());
//        }
//
//        /**
//         * first find the real target field
//         */
//        if (relDataTypeField != null) {
//            String name = relDataTypeField.getName();
//            return nameConvert(builder, name, ptr, input, inputCount, isRight);
//        }
//
//        throw new RuntimeException("No Related Column");

        return RexNodeHelper.convert(builder, inputRef.getIndex(), involvedEntityClasses, currentNode, true, currentNode instanceof BiRel ? 2 : 1);
    }

    @Override
    public RexNode visitLiteral(RexLiteral literal) {
        RexNode copy = builder.getRexBuilder().copy(literal);
        return copy;
    }

    @Override
    public RexNode visitCall(RexCall call) {

        Optional<RexNode> first = call.getOperands().stream().filter(x -> x instanceof RexInputRef).findFirst();
        if (first.isPresent()) {

            Either<String, Tuple2<String, String>> isMultiValue = checkIfMultiValue((RexInputRef) first.get(), currentNode);
            if (isMultiValue.isRight()) {

                List<String> rewriteFields = rewriteMapping.get(isMultiValue.get()._1);
                boolean ifRewrite = false;
                if (isMultiValue.get()._2.equalsIgnoreCase("org_tree")) {
                    ifRewrite = true;
                } else if (rewriteFields != null) {
                    ifRewrite = rewriteFields.stream().anyMatch(x -> x.equalsIgnoreCase(isMultiValue.get()._2));
                }

                SqlOperator operator = call.getOperator();
                List<RexNode> operands = customVisitList(currentNode, builder, call.getOperands());
                Optional<RexNode> values = operands.stream().filter(x -> x instanceof RexLiteral).findFirst();
                boolean isIn = true;
                if (values.isPresent()) {
                    RexLiteral literal = (RexLiteral) values.get();
                    Object value2 = literal.getValue2();
                    RexLiteral newValues = null;
                    if (value2 instanceof Sarg) {
                        Sarg points = (Sarg) value2;
                        //((NlsString)((Range)((Sarg) ((RexLiteral) operands.get(1)).value)
                        // .rangeSet.asRanges().iterator().next()).upperEndpoint()).getValue()
                        RangeSet rangeSet;
                        if (points.isComplementedPoints()) {
                            isIn = false;
                            rangeSet = TreeRangeSet.create(points.rangeSet).complement();
                        } else {
                            rangeSet = points.rangeSet;
                        }
                        String collect = (String) rangeSet.asRanges().stream()
                                .map(x -> ((Range) x).upperEndpoint())
                                .map(x -> ((NlsString) x).getValue())
                                .map(x -> "\"".concat(x.toString()).concat("\""))
                                .collect(Collectors.joining(","));
                        newValues = builder.literal("[".concat(collect).concat("]"));
                    } else if (value2 instanceof String) {
                        newValues = builder.literal("[\"".concat((String) value2).concat("\"]"));
                    } else if (value2 instanceof Number) {
                        newValues = builder.literal("[\"".concat(value2.toString()).concat("\"]"));
                    }

                    if (newValues != null) {
//                        RelDataTypeField relDataTypeField = currentNode.getRowType().getFieldList().get(((RexInputRef) first.get()).getIndex());
                        List<RexNode> rexNodes = Arrays.asList(first.get().accept(this), newValues);
                        if (operator == SqlStdOperatorTable.IN || operator == SqlStdOperatorTable.SEARCH || operator == SqlStdOperatorTable.EQUALS) {
                            if (isIn) {
                                return builder.getRexBuilder().makeCall(call.getType(), new SqlJsonValueFunction(ifRewrite ? JSON_OVERLAPS : JSON_CONTAINS), rexNodes);
                            } else {
                                return builder.getRexBuilder().makeCall(SqlStdOperatorTable.NOT, builder.getRexBuilder().makeCall(call.getType(), new SqlJsonValueFunction(ifRewrite ? JSON_OVERLAPS : JSON_CONTAINS), rexNodes));
                            }
                        } else if (operator == SqlStdOperatorTable.NOT_IN || operator == SqlStdOperatorTable.NOT_EQUALS) {
                            return builder.getRexBuilder().makeCall(SqlStdOperatorTable.NOT, builder.getRexBuilder().makeCall(call.getType(), new SqlJsonValueFunction(ifRewrite ? JSON_OVERLAPS : JSON_CONTAINS), rexNodes));
                        }
                    }
                }
            }

            Tuple2<Boolean, String> isEnumsIndexed = checkIfEnumsIndexed((RexInputRef) first.get(), currentNode);

            //indexed enums
            if (isEnumsIndexed._1 && dictService != null) {
                SqlOperator operator = call.getOperator();
                List<RexNode> operands = customVisitList(currentNode, builder, call.getOperands());
                Optional<RexNode> values = operands.stream().filter(x -> x instanceof RexLiteral).findFirst();
                boolean isIn = true;
                if (values.isPresent()) {
                    RexLiteral literal = (RexLiteral) values.get();
                    Object value2 = literal.getValue2();
                    RexLiteral newValues = null;
                    if (value2 instanceof Sarg) {
                        Sarg points = (Sarg) value2;
                        //((NlsString)((Range)((Sarg) ((RexLiteral) operands.get(1)).value)
                        // .rangeSet.asRanges().iterator().next()).upperEndpoint()).getValue()
                        RangeSet rangeSet;
                        if (points.isComplementedPoints()) {
                            isIn = false;
                            rangeSet = TreeRangeSet.create(points.rangeSet).complement();
                        } else {
                            rangeSet = points.rangeSet;
                        }
                        List<String> collect = (List<String>) rangeSet.asRanges().stream()
                                .map(x -> ((Range) x).upperEndpoint())
                                .map(x -> ((NlsString) x).getValue())
                                .collect(Collectors.toList());

                        int enumBit = getEnumBit(isEnumsIndexed._2, collect);

                        newValues = builder.literal(enumBit);
                    } else if (value2 instanceof String) {
                        int enumBit = getEnumBit(isEnumsIndexed._2, Arrays.asList((String) value2));
                        newValues = builder.literal(enumBit);
                    } else if (value2 instanceof Number) {
                        int enumBit = getEnumBit(isEnumsIndexed._2, Arrays.asList(value2.toString()));
                        newValues = builder.literal(enumBit);
                    }

                    if (newValues != null) {
//                        RelDataTypeField relDataTypeField = currentNode.getRowType().getFieldList().get(((RexInputRef) first.get()).getIndex());
                        List<RexNode> rexNodes = Arrays.asList(first.get().accept(this), newValues);
                        if (operator == SqlStdOperatorTable.IN || operator == SqlStdOperatorTable.SEARCH || operator == SqlStdOperatorTable.EQUALS) {
                            RexNode rexNode = builder.getRexBuilder().makeCall(call.getType(), BIT_AND, rexNodes);
                            if (isIn) {
                                return builder.getRexBuilder().makeCall(call.getType(), SqlStdOperatorTable.EQUALS, Arrays.asList(rexNode, newValues));
                            } else {
                                return builder.getRexBuilder().makeCall(call.getType(), SqlStdOperatorTable.NOT_EQUALS, Arrays.asList(rexNode, newValues));
                            }
                        } else if (operator == SqlStdOperatorTable.NOT_IN || operator == SqlStdOperatorTable.NOT_EQUALS) {
                            RexNode rexNode = builder.getRexBuilder().makeCall(call.getType(), BIT_AND, rexNodes);
                            return builder.getRexBuilder().makeCall(call.getType(), SqlStdOperatorTable.NOT_EQUALS, Arrays.asList(rexNode, newValues));
                        }
                    }
                }
            }
        }

        List<RexNode> operands = customVisitList(currentNode, builder, call.getOperands());
        return builder.getRexBuilder().makeCall(call.getType(), call.getOperator(), operands);
    }

//    private RexInputRef newRexInputRef(String name) {
//        List<String> fieldNames = builder.peek().getRowType().getFieldNames();
//        if(fieldNames.contains(name)) {
//            return builder.field(name);
//        } else {
//            Optional<String> first = fieldNames.stream().filter(x -> x.equalsIgnoreCase(name)).findFirst();
//            if(first.isPresent()) {
//                String relatedName = first.get();
//                return builder.field(relatedName);
//            } else {
//                throw new RuntimeException("not field found " + name);
//            }
//        }
//    }

    private Boolean isBoolean(RelNode node, RexInputRef inputRef) {
        return node.getRowType().getFieldList().get(inputRef.getIndex()).getType().getSqlTypeName() == SqlTypeName.BOOLEAN;
    }

    private List<RexNode> customVisitList(RelNode node, RelBuilder relBuilder, List<RexNode> inputNodes) {
        List<RexNode> result = new ArrayList<>();
        for (RexNode source : inputNodes) {
            if (source instanceof RexInputRef && isBoolean(node, (RexInputRef) source)) {
                result.add(builder.getRexBuilder().makeCall(SqlStdOperatorTable.EQUALS, source.accept(this), builder.literal(true)));
            } else {
                result.add(source.accept(this));
            }
        }

        return result;
    }

    private int getEnumBit(String enumId, List<String> values) {
        //TODO profile
        List<DictItem> dictItems = dictService.findDictItems(enumId, null, Collections.emptyMap());
        Optional<Integer> reduce = values.stream().map(x -> dictService.findEnumIndex(dictItems, x)).reduce((a, b) -> a | b);
        return reduce.orElse(0);
    }

    private Either<String, Tuple2<String, String>> checkIfMultiValue(RexInputRef rexInputRef, RelNode currentNode) {
        Queue<Boolean> footprint = new LinkedList<>();
        Tuple3<EntityClassGroup, RelDataTypeField, RelNode> source = RexNodeHelper.findSourceWithNode(rexInputRef.getIndex(), involvedEntityClasses, currentNode, false, footprint);
        EntityClassGroup entityClassGroup = source._1;

        /**
         * values
         */
        if (entityClassGroup == null) {
            return Either.left("");
        }

        RelDataTypeField relDataTypeField = source._2;
        String name = relDataTypeField.getName();


        //TODO optimize this
        Optional<ColumnField> field = entityClassGroup.column(name);
        RelNode cNode = source._3;

        while (!field.isPresent() && !(cNode instanceof TableScan)) {
            if (cNode instanceof Project) {
                int target = ((Project) cNode).getMapping().getSourceOpt(relDataTypeField.getIndex());
                relDataTypeField = ((Project) cNode).getInput().getRowType().getFieldList().get(target);
                name = relDataTypeField.getName();
                field = entityClassGroup.column(name);
            }
//            else if(cNode instanceof Aggregate) {
//                //TODO
//               
//            }
            cNode = cNode.getInput(0);
        }


        if (field.isPresent()) {
            boolean isMultiValue = field.get().type() == FieldType.STRINGS || (field.get().type() == FieldType.ENUMS && !useStrictEnum);
            if (isMultiValue) {
                return Either.right(Tuple.of(entityClassGroup.getEntityClass().code(), field.get().name()));
            } else {
                return Either.left("");
            }
        } else {
            return Either.left("");
        }
    }

    private Tuple2<Boolean, String> checkIfEnumsIndexed(RexInputRef rexInputRef, RelNode currentNode) {
        Queue<Boolean> footprint = new LinkedList<>();
        Tuple3<EntityClassGroup, RelDataTypeField, RelNode> source = RexNodeHelper.findSourceWithNode(rexInputRef.getIndex()
                , involvedEntityClasses, currentNode, false, footprint);
        EntityClassGroup entityClassGroup = source._1;

        if (entityClassGroup == null) {
            return Tuple.of(false, "");
        }

        RelDataTypeField relDataTypeField = source._2;
        String name = relDataTypeField.getName();
        Optional<ColumnField> field = entityClassGroup.column(name);
        RelNode cNode = source._3;

        while (!field.isPresent() && !(cNode instanceof TableScan)) {
            if (cNode instanceof Project) {
                int target = ((Project) cNode).getMapping().getSourceOpt(relDataTypeField.getIndex());
                relDataTypeField = ((Project) cNode).getInput().getRowType().getFieldList().get(target);
                name = relDataTypeField.getName();
                field = entityClassGroup.column(name);
            }
//            else if(cNode instanceof Aggregate) {
//                //TODO
//               
//            }
            cNode = cNode.getInput(0);
        }
        if (field.isPresent()) {
            if (field.get().type() == FieldType.ENUMS && useStrictEnum) {
                return Tuple.of(true, field.get().dictId());
            } else {
                return Tuple.of(false, "");
            }
        } else {
            return Tuple.of(false, "");
        }
    }

//    private String nameConvert(String originName, IEntityClass entityClass) {
//        originName = originName.toLowerCase();
//        Optional<IEntityField> fieldOp = entityClass.field(originName);
//        if (fieldOp.isPresent()) {
//            IEntityField field = fieldOp.get();
//
//            //  表示查询了一个dynamic字段 && !field.isIndex()
//            if (field.isDynamic()) {
//                return SystemColumn.DYNAMIC_FIELD + ".$" + originName;
//            }
//        }
//        return originName;
//    }
}
