/*
 * Decompiled with CFR 0.152.
 */
package com.xforceplus.ultraman.oqsengine.sdk.query.transformer.optimizer.planner;

import com.google.common.collect.Lists;
import com.xforceplus.ultraman.oqsengine.sdk.query.transformer.optimizer.planner.LeftRight;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;

public class ExtractCommonToAndRules
extends RelOptRule {
    public ExtractCommonToAndRules() {
        super(ExtractCommonToAndRules.operand(LogicalFilter.class, (RelOptRuleOperandChildren)ExtractCommonToAndRules.any()), RelFactories.LOGICAL_BUILDER, "OrToAnd");
    }

    public void onMatch(RelOptRuleCall call) {
        System.out.println(call);
        Filter filter = (Filter)call.rel(0);
        RexNode condition = filter.getCondition();
        if (condition instanceof RexCall) {
            System.out.println("RexCall");
            System.out.println(condition);
            RexCall rexNode = (RexCall)condition;
            RexBuilder rexBuilder = filter.getCluster().getRexBuilder();
            condition = this.optimizeOrToIn(condition, rexBuilder);
            if (condition.isA(SqlKind.OR)) {
                condition = this.transform((RexNode)rexNode, rexBuilder);
            }
            RelNode build = call.builder().push(filter.getInput()).filter(new RexNode[]{condition}).build();
            call.transformTo(build);
        } else {
            System.out.println("not RexCall");
            System.out.println(condition);
        }
    }

    private RexNode transformDown(RexNode rootNode, RexBuilder rexBuilder) {
        if (rootNode.isA(SqlKind.AND) || rootNode.isA(SqlKind.OR)) {
            List rexNodes = ((RexCall)rootNode).getOperands();
            if (rexNodes.size() > 2) {
                List headList = rexNodes.subList(0, rexNodes.size() - 1);
                RexNode tail = (RexNode)rexNodes.get(rexNodes.size() - 1);
                if (rootNode.isA(SqlKind.AND)) {
                    return rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, new RexNode[]{rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, headList), tail});
                }
                return this.optimizeOrToIn(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.OR, new RexNode[]{rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.OR, headList), tail}), rexBuilder);
            }
            return rootNode;
        }
        return rootNode;
    }

    private LeftRight getLeftRight(RexNode rootNode) {
        if (rootNode.isA(SqlKind.AND) || rootNode.isA(SqlKind.OR)) {
            List rexNodes = ((RexCall)rootNode).getOperands();
            return new LeftRight((RexNode)rexNodes.get(0), (RexNode)rexNodes.get(1));
        }
        return null;
    }

    private RexNode and(RexNode left, RexNode right, RexBuilder rexBuilder) {
        return rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, new RexNode[]{left, right});
    }

    private RexNode optimizeOrToIn(RexNode orNode, RexBuilder rexBuilder) {
        if (orNode.isA(SqlKind.OR)) {
            List operands = ((RexCall)orNode).getOperands();
            Map<RexNode, List<RexNode>> groupBy = operands.stream().filter(x -> x.isA(SqlKind.EQUALS)).collect(Collectors.groupingBy(x -> {
                Optional<RexNode> first = ((RexCall)x).getOperands().stream().filter(subx -> subx.isA(SqlKind.INPUT_REF)).findFirst();
                return first.orElseGet(() -> (RexNode)((RexCall)x).getOperands().get(0));
            }));
            List nonEquals = operands.stream().filter(x -> !x.isA(SqlKind.EQUALS)).collect(Collectors.toList());
            List collect = groupBy.entrySet().stream().flatMap(e -> {
                if (((RexNode)e.getKey()).isA(SqlKind.INPUT_REF)) {
                    List values = ((List)e.getValue()).stream().flatMap(s -> ((RexCall)s).getOperands().stream()).distinct().collect(Collectors.toList());
                    return Stream.of(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IN, values));
                }
                return ((List)e.getValue()).stream();
            }).collect(Collectors.toList());
            LinkedList rexNodes = Lists.newLinkedList(collect);
            rexNodes.addAll(nonEquals);
            if (rexNodes.size() > 1) {
                return rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.OR, (List)rexNodes);
            }
            return (RexNode)rexNodes.get(0);
        }
        return orNode;
    }

    private RexNode or(RexNode left, RexNode right, RexBuilder rexBuilder) {
        return this.optimizeOrToIn(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.OR, new RexNode[]{left, right}), rexBuilder);
    }

    private RexNode transform(RexNode node, RexBuilder rexBuilder) {
        if (node.isA(SqlKind.OR)) {
            RexNode binaryNode = this.transformDown(node, rexBuilder);
            List rexNodes = ((RexCall)binaryNode).getOperands();
            RexNode originLeft = (RexNode)rexNodes.get(0);
            RexNode originRight = (RexNode)rexNodes.get(1);
            RexNode left = this.transform(originLeft, rexBuilder);
            RexNode right = this.transform(originRight, rexBuilder);
            if (left.isA(SqlKind.AND) && right.isA(SqlKind.AND)) {
                LeftRight andLeft = this.getLeftRight(this.transformDown(left, rexBuilder));
                LeftRight andRight = this.getLeftRight(this.transformDown(right, rexBuilder));
                RexNode llOr = this.or(andLeft.getLeft(), andRight.getLeft(), rexBuilder);
                RexNode lrOr = this.or(andLeft.getLeft(), andRight.getRight(), rexBuilder);
                RexNode rlOr = this.or(andLeft.getRight(), andRight.getLeft(), rexBuilder);
                RexNode rrOr = this.or(andLeft.getRight(), andRight.getRight(), rexBuilder);
                return this.and(this.and(llOr, rrOr, rexBuilder), this.and(lrOr, rlOr, rexBuilder), rexBuilder);
            }
            if (left.isA(SqlKind.AND)) {
                LeftRight andLeft = this.getLeftRight(this.transformDown(left, rexBuilder));
                RexNode lor = this.or(andLeft.getLeft(), right, rexBuilder);
                RexNode ror = this.or(andLeft.getRight(), right, rexBuilder);
                return this.and(lor, ror, rexBuilder);
            }
            if (right.isA(SqlKind.AND)) {
                LeftRight andLeft = this.getLeftRight(this.transformDown(right, rexBuilder));
                RexNode lor = this.or(left, andLeft.getLeft(), rexBuilder);
                RexNode ror = this.or(left, andLeft.getRight(), rexBuilder);
                return this.and(lor, ror, rexBuilder);
            }
        }
        return node;
    }
}

