package com.xforceplus.ultraman.oqsengine.plus.master.mysql.query;

import com.xforceplus.ultraman.metadata.engine.EntityClassGroup;
import com.xforceplus.ultraman.metadata.entity.IEntityField;
import com.xforceplus.ultraman.oqsengine.plus.common.StringUtils;
import com.xforceplus.ultraman.oqsengine.plus.master.mysql.utils.RexNodeHelper;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.rel.*;
import org.apache.calcite.rel.core.*;
import org.apache.calcite.rel.logical.*;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.*;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;

import java.util.*;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import static com.xforceplus.ultraman.oqsengine.plus.meta.pojo.dto.table.SystemColumn.DYNAMIC_FIELD;
import static com.xforceplus.ultraman.oqsengine.plus.meta.pojo.dto.table.SystemColumn.SYS_DELETED;

public class CopyCustomShuttle implements RelShuttle {

  private RelBuilder builder;

  /**
   * a stack to find out current part
   */
  private Stack<String> termStack = new Stack<>();

  private EntityClassGroup entityClass;

  private Map<String, String> projectMapping;

  private EntityClassGroup currentEntityClass;

  private Stack<EntityClassGroup> involvedEntityClasses = new Stack<>();

  private List<EntityClassGroup> allRelatedEntityClasses = new ArrayList<>();

  private List<RexDynamicParam> params = new ArrayList<>();

  public CopyCustomShuttle(RelBuilder builder, EntityClassGroup entityClassGroup, Map<String, String> projectMapping) {
    this.builder = builder;
    this.projectMapping = projectMapping;
    this.entityClass = entityClassGroup;
  }

  @Override
  public RelNode visit(TableScan scan) {
    List<String> qualifiedName = scan.getTable().getQualifiedName();

    /**
     * assume the table is second;
     */
    String schema = qualifiedName.get(1);
    EntityClassGroup group = entityClass.relatedEntityClassWithRawName(schema);
    EntityClassGroup entityClassGroup = group.relatedEntityClassWithRawName(schema.toLowerCase());

    if (entityClassGroup == null) {
      throw new RuntimeException("No related EntityClass:" + schema);
    }

    /**
     * record the involved entityClass
     */
    involvedEntityClasses.push(entityClassGroup);
    allRelatedEntityClasses.add(entityClassGroup);
    currentEntityClass = entityClassGroup;
    builder = builder.scan("oqs", entityClassGroup.getEntityClass().masterQueryTable().toLowerCase());

    if (currentStackHas("OqsengineJoin")) {
      //join
      if (!currentStackHas("OqsengineFilter", "OqsengineJoin")) {
        //no filter before join
        //add deleted filter
        RexNode rexNode = builder.getRexBuilder()
            .makeCall(SqlStdOperatorTable.EQUALS
                , builder.field(SYS_DELETED), builder.literal(false));
        builder.filter(rexNode);

        if (currentStackHas("OqsengineAggregate", "OqsengineJoin")) {
          if (!currentStackHas("OqsengineProject", "OqsengineAggregate")) {
            //no project before aggregate then expand project;
            expandProject(scan, true);
          }
        } else {
          if (!currentStackHas("OqsengineProject")) {
            expandProject(scan, true);
          }
        }
      }
    } else {
      //no join
      if (currentStackHas("OqsengineAggregate")) {
        if (!currentStackHas("OqsengineProject", "OqsengineAggregate")) {
          //no project before aggregate then expand project;
          expandProject(scan, true);
        }
      } else {
        if (!currentStackHas("OqsengineProject")) {
          expandProject(scan, true);
        }
      }
    }

    return scan;
  }

  @Override
  public RelNode visit(TableFunctionScan scan) {
    return null;
  }

  @Override
  public RelNode visit(LogicalValues values) {
    return null;
  }

  @Override
  public RelNode visit(LogicalFilter filter) {
    return null;
  }

  @Override
  public RelNode visit(LogicalCalc calc) {
    return null;
  }

  @Override
  public RelNode visit(LogicalProject project) {
    return null;
  }

  @Override
  public RelNode visit(LogicalJoin join) {
    return null;
  }

  @Override
  public RelNode visit(LogicalCorrelate correlate) {
    return null;
  }

  @Override
  public RelNode visit(LogicalUnion union) {
    return null;
  }

  @Override
  public RelNode visit(LogicalIntersect intersect) {
    return null;
  }

  @Override
  public RelNode visit(LogicalMinus minus) {
    return null;
  }

  @Override
  public RelNode visit(LogicalAggregate aggregate) {
    return null;
  }

  @Override
  public RelNode visit(LogicalMatch match) {
    return null;
  }

  @Override
  public RelNode visit(LogicalSort sort) {
    sort.getInput(0).accept(this);
    builder.sort(builder.fields());
    return null;
  }

  @Override
  public RelNode visit(LogicalExchange exchange) {
    return null;
  }

  @Override
  public RelNode visit(LogicalTableModify modify) {
    return null;
  }

  private boolean currentStackHas(String term) {
    int termSearch = termStack.search(term);
    if (termSearch > 0) {
      return true;
    } else {
      return false;
    }
  }

  private boolean currentStackHas(String term, String before) {
    int termSearch = termStack.search(term);
    int beforeSearch = termStack.search(before);
    if (termSearch > 0) {
      if (beforeSearch > 0) {
        return termSearch < beforeSearch;
      } else {
        return true;
      }
    } else {
      return false;
    }
  }

  @Override
  public RelNode visit(RelNode other) {
    /**
     * add new term
     */
    termStack.push(other.getRelTypeName());

    other.getInputs().forEach(x -> x.accept(this));

    if (other instanceof Aggregate) {
      termStack.pop();
      //get father type
      RelDataType rowType = other.getInput(0).getRowType();
//            List<String> fields = fieldNames(rowType);
      ImmutableBitSet groupSets = ((Aggregate) other).getGroupSet();
      List<String> keys = new ArrayList<>();
      for (int group : groupSets) {
        String rawKey = rowType.getFieldList().get(group).getKey();
        keys.add(nameConvertString(rowType, rawKey, entityClass));
      }

      List<RelBuilder.AggCall> newAgg = ((Aggregate) other).getNamedAggCalls().stream().map(x ->
          aggregateConvert(x.getKey(), builder, rowType, builder.peek().getRowType()
              , currentEntityClass, x.getValue())).collect(Collectors.toList());
      builder = builder.aggregate(builder.groupKey(keys.toArray(new String[]{})), newAgg);
    } else if (other instanceof Project) {
      termStack.pop();
      List<Pair<RexNode, String>> namedProjects = ((Project) other).getNamedProjects();
      List<String> nameList = new ArrayList<>();
      List<RexNode> refs = new ArrayList<>();
      RelNode input = ((Project) other).getInput();
      if (((Project) other).getInput() instanceof Join) {
        // join filter
        // join two tableScan
        // join two project
        RelNode left = input.getInput(0);
        RelNode right = input.getInput(1);
        EntityClassGroup leftEntityClass = findRelatedEntityClass(left);
        EntityClassGroup rightEntityClass = findRelatedEntityClass(right);
        refs = namedProjects.stream().map(x -> {
          nameList.add(x.getValue());
          return transformProject(x.getKey(), other, Arrays.asList(leftEntityClass, rightEntityClass), builder);
        }).filter(Objects::nonNull).collect(Collectors.toList());
      } else {
        refs = namedProjects.stream().map(x -> {
          nameList.add(x.getValue());
          return transformProject(x.getKey(), other, Collections.singletonList(currentEntityClass), builder);
        }).filter(Objects::nonNull).collect(Collectors.toList());
      }
      builder = builder.project(refs, nameList, true);
    } else if (other instanceof Sort) {
      termStack.pop();
//            if (!currentStackHas("OqsengineProject")) {
//                expandProject(other);
//            }

      Sort sort = (Sort) other;
      RelDataType rowType = ((Sort) other).getInput().getRowType();
      RelCollation collation = ((Sort) other).getCollation();
      for (RelFieldCollation fieldCollation : collation.getFieldCollations()) {
        int fieldIndex = fieldCollation.getFieldIndex();
        RexNode rexNode = RexNodeHelper.convert(builder, fieldIndex, allRelatedEntityClasses, sort, true, 1);
        if (fieldCollation.getDirection() == RelFieldCollation.Direction.DESCENDING) {
          builder = builder.sort(builder.desc(rexNode));
        } else {
          builder = builder.sort(rexNode);
        }
      }

      RexNode offset = sort.offset;
      RexNode fetch = sort.fetch;
      int offsetInt = 0;
      int size = 0;
      if (offset != null) {
        offsetInt = ((RexLiteral) offset).getValueAs(Integer.class);
      }

      if (fetch != null) {
        size = ((RexLiteral) fetch).getValueAs(Integer.class);
      }

      if (offset != null || fetch != null) {
        //add limit
        builder.limit(offsetInt, size);
      }

    } else if (other instanceof Filter) {
      termStack.pop();
      CopyVisitor copyVisitor;
      copyVisitor = new CopyVisitor(builder, allRelatedEntityClasses, other);
      RexNode transformed = ((Filter) other).getCondition().accept(copyVisitor);

      params.addAll(copyVisitor.getDynamic());

      /**
       * TODO add custom filter
       */
      if (((Filter) other).getInput() instanceof TableScan) {
        RexNode rexNode = builder.getRexBuilder()
            .makeCall(SqlStdOperatorTable.EQUALS, builder.field(SYS_DELETED), builder.literal(false));
        builder.filter(builder.and(transformed, rexNode));
      } else {
        builder.filter(transformed);
      }

      if (!currentStackHas("OqsengineProject")) {
        expandProject(other, false);
      }
    } else if (other instanceof Join) {
      termStack.pop();
      EntityClassGroup rightEntityGroup = involvedEntityClasses.pop();
      EntityClassGroup leftEntityGroup = involvedEntityClasses.pop();

      /**
       * re stack
       */
      involvedEntityClasses.push(leftEntityGroup);
      CopyVisitor copyVisitor = new CopyVisitor(builder, Arrays.asList(leftEntityGroup, rightEntityGroup), other);
      RexNode transformed = ((Join) other).getCondition().accept(copyVisitor);
      builder.join(((Join) other).getJoinType(), transformed);
    }
    return null;
  }

  public List<RexDynamicParam> getParams() {
    return params;
  }

  /**
   * find related entityClass with current relNode
   *
   * @param relNode
   * @return
   */
  private EntityClassGroup findRelatedEntityClass(RelNode relNode) {
    Stack<RelNode> stack = new Stack<>();
    stack.push(relNode);
    while (!stack.isEmpty()) {
      RelNode next = stack.pop();
      if (next instanceof TableScan) {
        RelOptTable table = ((TableScan) next).getTable();
        String entityCode = table.getQualifiedName().get(1);
        Optional<EntityClassGroup> first = allRelatedEntityClasses.stream()
            .filter(x -> x.getEntityClass().code().equalsIgnoreCase(entityCode)).findFirst();
        if (first.isPresent()) {
          return first.get();
        }
      } else {
        //TODO current we treat the combined relNode as single always get the left one
        //but we may get following case  BiRel or SingleRel
        if (next instanceof BiRel) {
          RelNode input = next.getInput(0);
          stack.push(input);
        } else {
          RelNode input = next.getInput(0);
          stack.push(input);
        }
      }
    }

    throw new RuntimeException("No Related EntityClass");
  }

  /**
   * TODO join
   *
   * @param currentType
   * @param originName
   * @param entityClass
   * @return
   */
  private String nameConvertString(RelDataType currentType, String originName, EntityClassGroup entityClass) {
    String targetName = originName.toLowerCase();
    Optional<IEntityField> fieldOp = entityClass.field(targetName);
    if (fieldOp.isPresent()) {
      IEntityField field = fieldOp.get();
      String name = field.name();
      return currentType.getFieldNames().stream().filter(x -> {
        return x.equalsIgnoreCase(name);
      }).findFirst().orElse(name);
    } else {
      return originName;
    }
  }

  /**
   * TODO
   * expand project
   *
   * @param current
   */
  private void expandProject(RelNode current, boolean withSystem) {
    //expand raw project TODO

    //find the source RowType
    if (current instanceof Filter) {
      //TODO current we need to find out the base bi
      current = current.getInput(0);
    }

    List<RelDataTypeField> fieldList;
    List<RelDataTypeField> rightList = new ArrayList<>();
    if (current instanceof BiRel) {
      List<RelDataTypeField> leftFields = current.getInput(0).getRowType().getFieldList();
      List<RelDataTypeField> rightFields = current.getInput(1).getRowType().getFieldList();

      fieldList = new ArrayList<>(leftFields);
      rightList.addAll(rightFields);
    } else {
      fieldList = current.getRowType().getFieldList();
    }

    List<String> nameList = new ArrayList<>();
    //TODO join must something wrong
    //do mapping current
    RelNode finalCurrent = current;
    List<RexNode> targetRefs = new ArrayList<>();
    fieldList.stream().map(x -> {
      nameList.add(x.getName());
      String originField = x.getName();
      projectMapping.put(originField, originField);
      return RexNodeHelper.convert(builder, x.getIndex(), allRelatedEntityClasses, finalCurrent, false, 1);
    }).filter(Objects::nonNull).forEach(x -> {
      targetRefs.add(x);
    });

    rightList.stream().map(x -> {
      nameList.add(x.getName());
      String originField = x.getName();
      projectMapping.put(originField, originField);
      return RexNodeHelper.convert(builder, x.getIndex() + fieldList.size(), allRelatedEntityClasses, finalCurrent, false, 1);
    }).filter(Objects::nonNull).forEach(x -> {
      targetRefs.add(x);
    });

    if (withSystem) {
      List<RexNode> withSystemList = new ArrayList<>(targetRefs);
      withSystemList.add(builder.field(SYS_DELETED));
      //TODO
      builder = builder.project(withSystemList, nameList);
    } else {
      builder = builder.project(targetRefs, nameList);
    }
  }

  /**
   * arg may always RexInputRef this is important
   *
   * @param aggCall
   * @param relBuilder
   * @param older
   * @param newer
   * @param entityClass
   * @param alise
   * @return
   */
  private RelBuilder.AggCall aggregateConvert(AggregateCall aggCall, RelBuilder relBuilder, RelDataType older, RelDataType newer,
      EntityClassGroup entityClass, String alise) {

    if (null == newer.getFieldNames() || newer.getFieldNames().isEmpty()) {
      throw new RuntimeException("agg convert failed, table relDataType is invalid.");
    }

    //  遍历获取新的位置.
    List<Integer> newArgList = new ArrayList<>();
    Consumer<Integer> integerConsumer = arg -> {
      RelDataTypeField relDataTypeField = older.getFieldList().get(arg);
      RexNode rexNode = RexNodeHelper.simpleNameConvert(relBuilder, entityClass
          , relDataTypeField.getName(), relDataTypeField.getName().toLowerCase()
          , 1, 1, false);
      newArgList.add(((RexInputRef) rexNode).getIndex());
    };
    aggCall.getArgList().forEach(integerConsumer);

    String name = aggCall.getName();
    if (!StringUtils.isEmpty(alise)) {
      name = alise;
    }

    RelBuilder.AggCall aggCall1 = relBuilder.aggregateCall(AggregateCall.create(aggCall.getAggregation()
        , aggCall.isDistinct(), aggCall.isApproximate(), aggCall.ignoreNulls()
        , newArgList, aggCall.filterArg
        , aggCall.distinctKeys, aggCall.getCollation(), newer, name));
    //RelBuilder.AggCall aggCall1 = relBuilder.aggregateCall(aggCall.getAggregation(), nameList.stream().map(x -> relBuilder.field(x)).collect(Collectors.toList()));
    return aggCall1;
  }

  static List<String> fieldNames(final RelDataType rowType) {
    return SqlValidatorUtil.uniquify(
        new AbstractList<String>() {
          @Override
          public String get(int index) {
            final String name = rowType.getFieldList().get(index).getName();
            return name.startsWith("$") ? "_" + name.substring(2) : name;
          }

          @Override
          public int size() {
            return rowType.getFieldCount();
          }
        },
        SqlValidatorUtil.EXPR_SUGGESTER, true);
  }

  private static boolean hasDynamic(RelDataType dataType) {
    return dataType.getFieldNames().contains(DYNAMIC_FIELD);
  }

//    public static RexNode nameConvert(RelBuilder relBuilder, String originName, EntityClassGroup entityClass) {
//        return nameConvert(relBuilder, originName, entityClass, 1, 1, false);
//    }

//    public static RexNode nameConvert(RelBuilder relBuilder, String originName, EntityClassGroup entityClass, int inputIndex, int totalNum, boolean isRight) {
//        String targetName = originName.toLowerCase();
//        Optional<IEntityField> fieldOp = entityClass.field(targetName);
//        if (fieldOp.isPresent()) {
//            IEntityField field = fieldOp.get();
//
//            //  表示查询了一个dynamic字段 && !field.isIndex()
//            if (field.isDynamic() && hasDynamic(relBuilder.peek().getRowType())) {
//                /**
//                 * TODO check join in dynamic
//                 */
//                RexNode call = relBuilder.call(new SqlJsonValueFunction("JSON_EXTRACT"),
//                        relBuilder.field(totalNum, totalNum > 1 ? 2 - inputIndex : 0, DYNAMIC_FIELD), relBuilder.literal("$.".concat(originName.toLowerCase())));
//                return relBuilder.call(new JsonUnquote(), call);
//            } else {
//                /**
//                 * only when dataType has a _related code we should do the match
//                 * if in a project already done this we return the raw name
//                 */
//                String finalTargetName = targetName;
//
//                /**
//                 * filter join
//                 */
//
//                RelNode peeked = relBuilder.peek(inputIndex - 1);
//                Optional<RelDataTypeField> first = null;
//                if (peeked instanceof Join) {
//                    if (isRight) {
//                        first = ((Join) peeked).getRight().getRowType()
//                                .getFieldList().stream().filter(x -> {
//                                            return isFitName(x.getName(), finalTargetName);
//                                        }
//                                ).findFirst();
//                    } else {
//                        first = ((Join) peeked).getLeft().getRowType()
//                                .getFieldList().stream().filter(x -> {
//                                            return isFitName(x.getName(), finalTargetName);
//                                        }
//                                ).findFirst();
//                    }
//                } else {
//                    first = relBuilder.peek(inputIndex - 1).getRowType()
//                            .getFieldList().stream().filter(x -> {
//                                        return isFitName(x.getName(), finalTargetName);
//                                    }
//                            ).findFirst();
//                }
//
//                /**
//                 * ignore case to find the field
//                 */
//                if (first.isPresent()) {
//                    //build origin
//                    /**
//                     * join
//                     */
//                    //find field position
//                    RelDataTypeField relDataTypeField = first.get();
//                    int index = relDataTypeField.getIndex();
//
//                    if (!isRight) {
//                        //in left
//                        return relBuilder.field(totalNum, totalNum > 1 ? 2 - inputIndex : 0, index);
//                    } else {
//                        if(peeked instanceof Join) {
//                            return relBuilder.field(totalNum, totalNum > 1 ? 2 - inputIndex : 0, index + ((Join) peeked).getLeft().getRowType().getFieldCount());
//                        } else {
//                            return relBuilder.field(totalNum, totalNum > 1 ? 2 - inputIndex : 0, index);
//                        }
//                    }
//                } else {
//                    //throw exception
//                    return relBuilder.field(totalNum, totalNum > 1 ? 2 - inputIndex : 0, targetName);
//                }
//            }
//        }
//
//        //in case duplicate we should find the mapping from NAME to RefIndex
//        List<String> fieldNames = relBuilder.peek().getRowType().getFieldNames();
//        int i = fieldNames.indexOf(originName);
//        if(i > 0) {
//            return relBuilder.field(totalNum, totalNum > 1 ? 2 - inputIndex : 0, i);
//        } else {
//            return relBuilder.field(totalNum, totalNum > 1 ? 2 - inputIndex : 0, originName);
//        }
//    }

  private String genTableName(EntityClassGroup group) {
    String code = group.getEntityClass().code();
    String tableName = "oqs_".concat(group.getEntityClass().ref().getAppCode().concat("_").concat(code));
    return tableName;
  }

  /**
   * 1 one the case 2 the comma
   *
   * @param name
   * @param finalTargetName
   * @return
   */
  private static boolean isFitName(String name, String finalTargetName) {
    if (name.equalsIgnoreCase(finalTargetName)) {
      return true;
    } else {
      String transformed = finalTargetName.replaceAll("\\.", "_");
      if (name.equalsIgnoreCase(transformed)) {
        return true;
      }
    }

    return false;
  }

  private RexNode transformProject(RexNode rexNode, RelNode relNode, List<EntityClassGroup> involvedEntityClass, RelBuilder relBuilder) {
    return rexNode.accept(new ProjectVisitor(relNode, relBuilder, involvedEntityClass));
  }

  class ProjectVisitor extends RexVisitorImpl<RexNode> {

    private List<EntityClassGroup> entityClass;
    private RelNode currentNode;
    private RelBuilder builder;

    public ProjectVisitor(
        RelNode currentNode
        , RelBuilder builder
        , List<EntityClassGroup> entityClass) {
      super(true);
      this.entityClass = entityClass;
      this.builder = builder;
      this.currentNode = currentNode;
    }

    @Override
    public RexNode visitInputRef(RexInputRef inputRef) {
//            /**
//             * change current rexInputRef to builder ref
//             */
//
//            RelDataTypeField relDataTypeField = null;
//            EntityClassGroup ptr = null;
//            int input = 1;
//
//            RelNode nodePtr = currentNode.getInput(0);
//
//            boolean isRight = false;
//            if (nodePtr instanceof Join) {
//                //if no project join real field should do in join
//                //find real field for filter
//                int leftSize = nodePtr.getInput(0).getRowType().getFieldList().size();
//                int index = inputRef.getIndex();
//                if (index >= leftSize) {
//                    //find in 1
//                    relDataTypeField = nodePtr.getInput(1).getRowType().getFieldList().get(index - leftSize);
//                    ptr = findRelatedEntityClass(nodePtr);
//                    isRight = true;
//                } else {
//                    //find in 0
//                    relDataTypeField = nodePtr.getInput(0).getRowType().getFieldList().get(index);
//                    ptr = findRelatedEntityClass(nodePtr);
//                }
//            }
//
//            /**
//             * other case
//             */
//            if (relDataTypeField == null) {
//                relDataTypeField = nodePtr.getRowType().getFieldList().get(inputRef.getIndex());
//                ptr = findRelatedEntityClass(nodePtr);
//            }
//
//            /**
//             * first find the real target field
//             */
//            if (relDataTypeField != null) {
//                String name = relDataTypeField.getName();
//                return nameConvert(builder, name, ptr, input, 1, isRight);
//            }
      return RexNodeHelper.convert(builder, inputRef.getIndex(), allRelatedEntityClasses, currentNode, true, 1);
    }

    @Override
    public RexNode visitLiteral(RexLiteral literal) {
      return literal;
    }

    @Override
    public RexNode visitCall(RexCall call) {
      List<RexNode> newList = new ArrayList<>();
      for (RexNode operand : call.operands) {
        RexNode accept = operand.accept(this);
        newList.add(accept);
      }
      return call.clone(call.type, newList);
    }
  }
}
