package com.xforceplus.ultraman.oqsengine.plus.master.mysql.calcite;

import com.google.common.collect.BoundType;
import com.google.common.collect.Range;
import com.xforceplus.metadata.schema.typed.BoIndex;
import com.xforceplus.ultraman.metadata.engine.EntityClassGroup;
import com.xforceplus.ultraman.metadata.entity.legacy.impl.ColumnField;
import com.xforceplus.ultraman.metadata.service.DictService;
import com.xforceplus.ultraman.oqsengine.plus.master.mysql.MysqlSqlDialectEx;
import com.xforceplus.ultraman.oqsengine.plus.master.mysql.query.CopyVisitor;
import com.xforceplus.ultraman.sdk.core.config.ExecutionConfig;
import com.xforceplus.ultraman.sdk.core.rel.legacy.ExpValue;
import org.apache.calcite.adapter.jdbc.JdbcImplementor;
import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttleImpl;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.rel2sql.RelToSqlConverter;
import org.apache.calcite.rel.rel2sql.SqlImplementor;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.*;
import org.apache.calcite.sql.*;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.util.SqlString;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.NlsString;
import org.apache.calcite.util.Sarg;
import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.xforceplus.ultraman.oqsengine.plus.meta.pojo.dto.table.SystemColumn.INVISIBLE_SYSTEM_WORDS;
import static com.xforceplus.ultraman.oqsengine.plus.meta.pojo.dto.table.SystemColumn.SYSTEM_WORDS;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.*;


/**
 * convert node to master sql only on sort , filter, tableScan
 */
public class ConditionalSqlShuttle extends RelShuttleImpl {

    private RelBuilder builder;

    private String targetCode;

    private String conditionSql;

    private EntityClassGroup targetGroup;

    private int limit;

    private ExecutionConfig executionConfig;

    private DictService dictService;

    public ConditionalSqlShuttle(EntityClassGroup targetGroup
            , RelBuilder builder, String targetCode
            , DictService dictService, ExecutionConfig executionConfig) {
        this.builder = builder;
        this.targetCode = targetCode;
        this.targetGroup = targetGroup;
        this.executionConfig = executionConfig;
        this.dictService = dictService;
    }

    public int getLimit() {
        return limit;
    }

    public String getConditionSql() {
        return conditionSql;
    }

    @Override
    public RelNode visit(LogicalSort sort) {
        RexNode fetch = sort.fetch;
        if (fetch != null) {
            RexLiteral literal = (RexLiteral) fetch;
            limit = literal.getValueAs(Integer.class);
        }
        return super.visit(sort);
    }

    @Override
    public RelNode visit(TableScan scan) {
        /**
         * assume the table is second;
         */
        builder = builder.scan("oqs", targetGroup.getEntityClass().masterQueryTable().toLowerCase());
        return super.visit(scan);
    }

    @Override
    public RelNode visit(LogicalValues values) {
        if (values.getTuples().isEmpty()) {
            conditionSql = " 1 = 0 ";
        }
        return null;
    }

    @Override
    public RelNode visit(LogicalFilter filter) {
        RelNode visit = super.visit(filter);
        RexCall condition = (RexCall) filter.getCondition();
        CopyVisitor copyVisitor = new CopyVisitor(builder, Collections.singletonList(targetGroup)
                , filter, executionConfig.isUseStrictEnum(), dictService, executionConfig.getInClauseRewrite());
        RexNode finalNode = condition.accept(copyVisitor);
        if (finalNode != null) {
            SqlImplementor.SimpleContext simpleContext = new InSupportedSqlImplementor(MysqlSqlDialectEx.DEFAULT, index -> {
                RelDataTypeField relDataTypeField = builder.peek().getRowType().getFieldList().get(index);
                String name = relDataTypeField.getName();
                return new SqlIdentifier(Arrays.asList(targetCode, name), SqlParserPos.QUOTED_ZERO);
            }, new JdbcImplementor(MysqlSqlDialectEx.DEFAULT, new JavaTypeFactoryImpl(builder.getTypeFactory().getTypeSystem())));
            SqlNode sqlNode = simpleContext.toSql(null, finalNode);
            if (sqlNode != null) {
                SqlString sqlString = sqlNode.toSqlString(MysqlSqlDialectEx.DEFAULT);
                conditionSql = sqlString.getSql();
            }
        }

        return visit;
    }
    
    class MyConverter extends RexToSqlNodeConverterImpl {

        private RelNode relNode;

        private String targetCode;

        public MyConverter(String targetCode, RelNode relNode) {
            super(new RexSqlStandardConvertletTable());
            this.relNode = relNode;
            this.targetCode = targetCode;
        }


        @Override
        public @Nullable SqlNode convertCall(RexCall call) {
            
            if(call.getOperator() == OR || call.getOperator() == AND) {
                if(call.getOperands().size() > 2) {
                    //split to
                    Stack<RexNode> rexNodeStack = new Stack<>();
                    for (RexNode operand : call.getOperands()) {
                        rexNodeStack.push(operand);
                        if(rexNodeStack.size() == 2) {
                            RexNode left = rexNodeStack.pop();
                            RexNode right = rexNodeStack.pop();
                            rexNodeStack.push(call.clone(call.getType(), Arrays.asList(left, right)));
                        }
                    }
                    
                    if(!rexNodeStack.isEmpty()) {
                        return super.convertNode(rexNodeStack.pop());
                    }
                }
            }
            
            if (call.getOperator() == SEARCH) {
                //transformer to in or < >
                List<RexNode> operands = call.getOperands();
                SqlNode key = null;
                SqlNode sqlListValue = null;
                SqlOperator sqlOperator = null;
                SqlNode low = null;
                boolean includingLower = false;
                SqlNode high = null;
                boolean includingHigh = false;
                for (RexNode operand : operands) {
                    if (operand instanceof RexInputRef) {
                        key = this.convertNode(operand);
                    } else if (operand instanceof RexLiteral) {
                        if (((RexLiteral) operand).getTypeName() == SqlTypeName.SARG) {
                            Comparable value = ((RexLiteral) operand).getValue();
                            if (value instanceof Sarg) {
                                if (((Sarg<?>) value).isPoints()) {
                                    Set<Range> ranges = ((Sarg) value).rangeSet.asRanges();
                                    List<SqlNode> nodes = ranges.stream().map(x -> x.lowerEndpoint())
                                            .map(x -> {
                                                if (x instanceof Number) {
                                                    return SqlLiteral.createExactNumeric(x.toString(), SqlParserPos.ZERO);
                                                } else if(x instanceof NlsString){
                                                    return SqlLiteral.createCharString(((NlsString) x).getValue(), SqlParserPos.ZERO);
                                                } else {
                                                    return SqlLiteral.createCharString(x.toString(), SqlParserPos.ZERO);
                                                }
                                            }).collect(Collectors.toList());
                                    sqlListValue = SqlNodeList.of(SqlParserPos.ZERO, nodes);
                                    sqlOperator = SqlStdOperatorTable.IN;
                                } else {
                                    Optional<Range> first = ((Sarg) value).rangeSet.asRanges().stream().findFirst();
                                    Range range = first.get();
                                    if (range.hasLowerBound()) {
                                        includingLower = range.lowerBoundType() == BoundType.CLOSED;
                                        Comparable target = range.lowerEndpoint();
                                        if (target instanceof Number) {
                                            low = SqlLiteral.createExactNumeric(target.toString(), SqlParserPos.ZERO);
                                        } else if(target instanceof NlsString){
                                            return SqlLiteral.createCharString(((NlsString) target).getValue(), SqlParserPos.ZERO);
                                        } else {
                                            low = SqlLiteral.createCharString(target.toString(), SqlParserPos.ZERO);
                                        }
                                    }

                                    if (range.hasUpperBound()) {
                                        includingHigh = range.upperBoundType() == BoundType.CLOSED;
                                        Comparable target = range.lowerEndpoint();
                                        if (target instanceof Number) {
                                            high = SqlLiteral.createExactNumeric(target.toString(), SqlParserPos.ZERO);
                                        } else {
                                            high = SqlLiteral.createCharString(target.toString(), SqlParserPos.ZERO);
                                        }
                                    }
                                }
                            }
                        }
                    }
                }

                if (low != null || high != null) {
                    List<SqlNode> clause = new ArrayList<>();
                    if(low != null) {
                        if(includingLower) {
                            clause.add(new SqlBasicCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, Arrays.asList(key,low)
                                    , SqlParserPos.ZERO));
                        } else {
                            clause.add(new SqlBasicCall(SqlStdOperatorTable.GREATER_THAN, Arrays.asList(key,low)
                                    , SqlParserPos.ZERO));
                        }
                    }
                    
                    if(high != null) {
                        if(includingHigh) {
                            clause.add(new SqlBasicCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, Arrays.asList(key,low)
                                    , SqlParserPos.ZERO));
                        } else {
                            clause.add(new SqlBasicCall(SqlStdOperatorTable.LESS_THAN, Arrays.asList(key,low)
                                    , SqlParserPos.ZERO));
                        }
                    }
                    
                    if(clause.size() > 1) {
                        new SqlBasicCall(AND, clause, SqlParserPos.ZERO);
                    } else {
                        return clause.get(0);
                    }
                } else {
                    //in
                    return new SqlBasicCall(SqlStdOperatorTable.IN, Arrays.asList(key, sqlListValue), SqlParserPos.ZERO);
                }
            }
            return super.convertCall(call);
        }

        @Override
        public @Nullable SqlNode convertInputRef(RexInputRef ref) {
            int index = ref.getIndex();
            RelDataTypeField relDataTypeField = builder.peek().getRowType().getFieldList().get(index);
            String name = relDataTypeField.getName();
            return new SqlIdentifier(Arrays.asList(targetCode, name), SqlParserPos.QUOTED_ZERO);
        }
    }
}
