package com.xforceplus.ultraman.oqsengine.sdk.query.transformer.optimizer.planner;

import com.google.common.collect.Lists;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;

import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * optimizer to turn and to or
 */
public class ExtractCommonToAndRules extends RelOptRule {


    public ExtractCommonToAndRules() {
        super(operand(LogicalFilter.class, any()), RelFactories.LOGICAL_BUILDER, "OrToAnd");
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        System.out.println(call);

        final Filter filter = call.rel(0);

        RexNode condition = filter.getCondition();


        if (condition instanceof RexCall) {
            System.out.println("RexCall");
            System.out.println(condition);

            RexCall rexNode = ((RexCall) condition);
            RexBuilder rexBuilder = filter.getCluster().getRexBuilder();

            condition = optimizeOrToIn(condition, rexBuilder);

            if (condition.isA(SqlKind.OR)) {
                condition = transform(rexNode, rexBuilder);
            }

            RelNode build = call.builder()
                    .push(filter.getInput())
                    .filter(condition)
                    .build();
            call.transformTo(build);
        } else {
            //TODO exception
            System.out.println("not RexCall");
            System.out.println(condition);
        }
    }

    //push down
    private RexNode transformDown(RexNode rootNode, RexBuilder rexBuilder) {

        if (rootNode.isA(SqlKind.AND) || rootNode.isA(SqlKind.OR)) {
            List<RexNode> rexNodes = ((RexCall) rootNode).getOperands();
            if (rexNodes.size() > 2) {
                List<RexNode> headList = rexNodes.subList(0, rexNodes.size() - 1);
                RexNode tail = rexNodes.get(rexNodes.size() - 1);

                if (rootNode.isA(SqlKind.AND)) {
                    return rexBuilder.makeCall(SqlStdOperatorTable.AND, rexBuilder.makeCall(SqlStdOperatorTable.AND, headList), tail);
                } else {
                    return optimizeOrToIn(rexBuilder.makeCall(SqlStdOperatorTable.OR, rexBuilder.makeCall(SqlStdOperatorTable.OR, headList), tail), rexBuilder);
                }
            } else {
                return rootNode;
            }
        } else {
            return rootNode;
        }
    }

    private LeftRight getLeftRight(RexNode rootNode) {
        if (rootNode.isA(SqlKind.AND) || rootNode.isA(SqlKind.OR)) {
            List<RexNode> rexNodes = ((RexCall) rootNode).getOperands();
            return new LeftRight(rexNodes.get(0), rexNodes.get(1));
        }

        return null;
    }

    private RexNode and(RexNode left, RexNode right, RexBuilder rexBuilder) {
        return rexBuilder.makeCall(SqlStdOperatorTable.AND, left, right);
    }

    /**
     * optimize to Or
     *
     * @param orNode
     * @param rexBuilder
     * @return
     */
    private RexNode optimizeOrToIn(RexNode orNode, RexBuilder rexBuilder) {
        if (orNode.isA(SqlKind.OR)) {
            List<RexNode> operands = ((RexCall) orNode).getOperands();
            //extract all equals
            Map<RexNode, List<RexNode>> groupBy = operands.stream().filter(x -> x.isA(SqlKind.EQUALS)).collect(Collectors.groupingBy(x -> {
                Optional<RexNode> first = ((RexCall) x).getOperands().stream().filter(subx -> subx.isA(SqlKind.INPUT_REF)).findFirst();
                return first.orElseGet(() -> ((RexCall) x).getOperands().get(0));
            }));

            List<RexNode> nonEquals = operands.stream().filter(x -> !x.isA(SqlKind.EQUALS)).collect(Collectors.toList());

            //calculate mapping
            List<RexNode> collect = groupBy.entrySet().stream().flatMap(e -> {
                if (e.getKey().isA(SqlKind.INPUT_REF)) {
                    List<RexNode> values = e.getValue().stream().flatMap(s -> ((RexCall) s).getOperands().stream())
                            .distinct()
                            .collect(Collectors.toList());
                    return Stream.of(rexBuilder.makeCall(SqlStdOperatorTable.IN, values));
                } else {
                    return e.getValue().stream();
                }
            }).collect(Collectors.toList());

            LinkedList<RexNode> rexNodes = Lists.newLinkedList(collect);

            rexNodes.addAll(nonEquals);

            if (rexNodes.size() > 1) {
                return rexBuilder.makeCall(SqlStdOperatorTable.OR, rexNodes);
            } else {
                return rexNodes.get(0);
            }
        }
        return orNode;
    }

    private RexNode or(RexNode left, RexNode right, RexBuilder rexBuilder) {
        return optimizeOrToIn(rexBuilder.makeCall(SqlStdOperatorTable.OR, left, right), rexBuilder);
    }

    //nodes is expression
    //rexBuilder is a transformer
    //and call
    private RexNode transform(RexNode node, RexBuilder rexBuilder) {
//        //case OR
        if (node.isA(SqlKind.OR)) {
            RexNode binaryNode = transformDown(node, rexBuilder);
            List<RexNode> rexNodes = ((RexCall) binaryNode).getOperands();
            RexNode originLeft = rexNodes.get(0);
            RexNode originRight = rexNodes.get(1);

            RexNode left = transform(originLeft, rexBuilder);
            RexNode right = transform(originRight, rexBuilder);

            if (left.isA(SqlKind.AND) && right.isA(SqlKind.AND)) {
                LeftRight andLeft = getLeftRight(transformDown(left, rexBuilder));
                LeftRight andRight = getLeftRight(transformDown(right, rexBuilder));

                //Or(And(left, right), And(left, right))

                RexNode llOr = or(andLeft.getLeft(), andRight.getLeft(), rexBuilder);
                RexNode lrOr = or(andLeft.getLeft(), andRight.getRight(), rexBuilder);
                RexNode rlOr = or(andLeft.getRight(), andRight.getLeft(), rexBuilder);
                RexNode rrOr = or(andLeft.getRight(), andRight.getRight(), rexBuilder);

                return and(and(llOr, rrOr, rexBuilder), and(lrOr, rlOr, rexBuilder), rexBuilder);
            } else if (left.isA(SqlKind.AND)) {
                LeftRight andLeft = getLeftRight(transformDown(left, rexBuilder));

                RexNode lor = or(andLeft.getLeft(), right, rexBuilder);
                RexNode ror = or(andLeft.getRight(), right, rexBuilder);

                return and(lor, ror, rexBuilder);
            } else if (right.isA(SqlKind.AND)) {
                LeftRight andLeft = getLeftRight(transformDown(right, rexBuilder));
                RexNode lor = or(left, andLeft.getLeft(), rexBuilder);
                RexNode ror = or(left, andLeft.getRight(), rexBuilder);

                return and(lor, ror, rexBuilder);
            }
        }
        return node;
    }
}
