package com.xforceplus.ultraman.adapter.elasticsearch.query.utils;

import com.google.common.collect.ImmutableList;
import com.xforceplus.ultraman.metadata.engine.EntityClassEngine;
import com.xforceplus.ultraman.metadata.engine.EntityClassGroup;
import com.xforceplus.ultraman.metadata.entity.IEntityClass;
import com.xforceplus.ultraman.metadata.entity.IEntityField;
import com.xforceplus.ultraman.metadata.entity.IRelation;
import com.xforceplus.ultraman.oqsengine.plus.master.mysql.MysqlSqlDialectEx;
import com.xforceplus.ultraman.oqsengine.plus.storage.pojo.dto.select.SelectConfig;
import io.vavr.Tuple2;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.DataContext;
import org.apache.calcite.config.Lex;
import org.apache.calcite.jdbc.CalciteConnection;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.rel.RelFieldCollation.Direction;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.rel2sql.RelToSqlConverter;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.StructKind;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.sql.SqlAsOperator;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlCharStringLiteral;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlJoin;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOrderBy;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlTableRef;
import org.apache.calcite.sql.dialect.CalciteSqlDialect;
import org.apache.calcite.sql.fun.SqlCountAggFunction;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.tools.FrameworkConfig;
import org.apache.calcite.tools.Frameworks;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import org.apache.commons.lang.StringUtils;

import static com.xforceplus.ultraman.oqsengine.plus.master.mysql.MysqlSqlDialectEx.DEFAULT_CONTEXT;

/**
 * @ClassName ParseSqlNodeUtils
 * @description:
 * @author: WanYi
 * @create: 2023-07-14 13:51
 * @Version 1.0
 **/
@Slf4j
public class ParseSqlNodeUtils {

  private final String systemFields = "__sys_count";
  private Map<String, String> relationAsMapping = new HashMap<>();

  private EntityClassGroup classGroup;
  private EntityClassEngine calssEngine;
  
  private Map<SqlNode, Boolean> hasModified = new ConcurrentHashMap<>();

  public ParseSqlNodeUtils(EntityClassGroup classGroup, EntityClassEngine engine) {
    this.classGroup = classGroup;
    this.calssEngine = engine;
    initRelationAsMapping(classGroup.getEntityClass(), calssEngine);
  }

  private void initRelationAsMapping(IEntityClass entityClass, EntityClassEngine calssEngine) {
    entityClass.relations().forEach(iRelation -> {
      String key = "_".concat(iRelation.getName().toLowerCase());
      Optional<IEntityClass> loadEntityClass = calssEngine.load(String.valueOf(iRelation.getEntityClassId()), classGroup.profile());
      loadEntityClass.ifPresent(iEntityClass -> {
        relationAsMapping.put(key, iEntityClass.code());
      });
    });
  }

  /**
   * 转义字段名称 原SQL字段名字 . 替换成_
   *
   * @param sqlNode
   **/
  public void recursion(SqlNode sqlNode) {
    if (sqlNode instanceof SqlSelect) {
      SqlNodeList selectList = ((SqlSelect) sqlNode).getSelectList();
      SqlNode from = ((SqlSelect) sqlNode).getFrom();
      if (!(from instanceof SqlIdentifier) && !(from instanceof SqlTableRef)) {
        recursion(from);
      }
      SqlNode having = ((SqlSelect) sqlNode).getHaving();
      recursion(having);
      if (((SqlSelect) sqlNode).hasOrderBy()) {
        SqlNodeList orderBy = ((SqlSelect) sqlNode).getOrderList();
        orderBy.getList().forEach(order -> recursion(order));
      }
      if (((SqlSelect) sqlNode).hasWhere()) {
        SqlNode where = ((SqlSelect) sqlNode).getWhere();
        recursion(where);
      }
      SqlNodeList groupBy = ((SqlSelect) sqlNode).getGroup();
      selectList.getList().forEach(fields -> recursion(fields));
      if (groupBy != null) {
        groupBy.getList().forEach(group -> recursion(group));
      }
    } else if (sqlNode instanceof SqlOrderBy) {
      SqlNode query = ((SqlOrderBy) sqlNode).query;
      recursion(query);
    } else if (sqlNode instanceof SqlBasicCall) {
      for (SqlNode operand : ((SqlBasicCall) sqlNode).getOperandList()) {
        if (operand instanceof SqlIdentifier) {
          modificationOperand(operand);
        } else {
          recursion(operand);
        }
      }
    } else if (sqlNode instanceof SqlIdentifier) {
      modificationOperand(sqlNode);
    }
  }

  /**
   * 将elastic relNode节点转换成calcite sql
   *
   * @param relNode
   * @return
   **/
  public static String relNodeToSqlStr(RelNode relNode) {
    HepProgram program = HepProgram.builder().build();
    HepPlanner planner = new HepPlanner(program);
    planner.setRoot(relNode);
    RelNode optimizedNode = planner.findBestExp();
    SqlNode sqlNode = new RelToSqlConverter(MysqlSqlDialectEx.DEFAULT).visitRoot(relNode)
        .asStatement();
    String elasticSql = Util.toLinux(sqlNode.toSqlString(MysqlSqlDialectEx.DEFAULT).getSql())
        .replaceAll("\n", " ");
    return elasticSql;
  }

  /**
   * remove page size and limit
   *
   * @param query
   * @return
   */
  public static SqlSelect buildCountSql(SqlSelect query) {
    query.setFetch(null);
    query.setOffset(null);
    query.setFetch(null);
    SqlNodeList selectList = query.getSelectList();
    selectList.clear();
    SqlAsOperator sqlAsOperator = new SqlAsOperator();
    SqlBasicCall countSqlBasic = new SqlBasicCall(new SqlCountAggFunction("COUNT"), selectList,
        SqlParserPos.ZERO);
    selectList.add(countSqlBasic);
    SqlCharStringLiteral aliasName = SqlLiteral.createCharString("c", SqlParserPos.ZERO);
    selectList.add(aliasName);
    SqlBasicCall sqlBasicCall = new SqlBasicCall(sqlAsOperator, selectList, SqlParserPos.ZERO);
    selectList.clear();
    selectList.add(sqlBasicCall);
    query.setSelectList(selectList);
    return query;
  }

  public static SqlNode getSqlNode(String sql) {
    SqlNode sqlNode = null;
    try {
      SqlParser parser = SqlParser.create(sql, SqlParser.config().withLex(Lex.MYSQL));
      sqlNode = parser.parseStmt();
    } catch (SqlParseException e) {
      log.error(e.getMessage());
    }
    return sqlNode;
  }

  public static SqlNode formatSqlNode(SqlNode sqlNode, String prefix) {
    recursion(sqlNode, prefix);
    return sqlNode;
  }

  public static FrameworkConfig buildFrameworkConfig(Connection esConnection) throws SQLException {
    CalciteConnection calciteConn = esConnection.unwrap(CalciteConnection.class);
    SchemaPlus rootSchema = calciteConn.getRootSchema();
    Frameworks.ConfigBuilder configBuilder = Frameworks.newConfigBuilder();
    FrameworkConfig build = configBuilder.defaultSchema(rootSchema).parserConfig(SqlParser.config().withLex(Lex.MYSQL).withCaseSensitive(false))
        .build();
    return build;
  }

  /**
   * 修改表名与字段名 符合elastic查询引擎sql命名规范
   *
   * @param operand
   */
  private void modificationOperand(SqlNode operand) {
    if (operand instanceof SqlIdentifier) {
      ImmutableList<String> identifierNames = ((SqlIdentifier) operand).names;

      if (identifierNames.size() == 2) {
        reNameFields(operand, identifierNames.get(1));
      } else {
        reNameFields(operand, identifierNames.get(0));
      }
    }
  }

  /**
   * 重命名字段名称
   *
   * @param name
   * @param operand
   **/
  private void reNameFields(SqlNode operand, String name) {
      if(hasModified.containsKey(operand)) {
          return;
      } 
      
    List<String> names = new ArrayList<>();
    List<SqlParserPos> poses = new ArrayList<>();
    if (name.startsWith("_") && name.contains(".")) {
      String[] split = name.split("\\.");
      String relationCode = split[0].toLowerCase(Locale.ROOT);
      if (relationAsMapping.get(relationCode) != null) {
//        name = relationAsMapping.get(relationCode).concat(".").concat(split[1]);
          name = name.substring(1);
      }
    } else if (name.contains(".")) {
      name = name.replace(".", "_");
    }
    names.add(name);
    poses.add(operand.getParserPosition());
    ((SqlIdentifier) operand).setNames(names, poses);
      hasModified.put(operand, true);
  }

  /**
   * 修改表名与字段名 符合elastic查询引擎sql命名规范
   *
   * @param prefix
   * @param operand
   **/
  private static void modificationOperand(SqlNode operand, String prefix) {
    ImmutableList<String> identifierNames = ((SqlIdentifier) operand).names;
    List<String> names = new ArrayList<>();
    List<SqlParserPos> poses = new ArrayList<>();
    if (identifierNames.size() > 1) {
      String identifierName = identifierNames.get(1);
      poses.add(operand.getParserPosition());
      names.add(prefix.concat(identifierName));
      ((SqlIdentifier) operand).setNames(names, poses);
    }

  }


  private List<SqlNode> conditions = new ArrayList<>();

  public List<SqlNode> getConditions() {
    return conditions;
  }

  /**
   * 递归查找到join多表查询中的过滤条件
   *
   * @param sqlNode
   **/
  public void findSqlJoinConditions(SqlNode sqlNode) {
    if (sqlNode instanceof SqlSelect) {
      SqlNode from = ((SqlSelect) sqlNode).getFrom();
      SqlNode where = ((SqlSelect) sqlNode).getWhere();
      if (where != null) {
        if (where instanceof SqlBasicCall) {
          List<SqlNode> operandList = ((SqlBasicCall) where).getOperandList();
          boolean existeSystemFields = false;
          for (SqlNode operandNode : operandList) {
            if (operandNode instanceof SqlIdentifier) {
              String fieldsName = ((SqlIdentifier) operandNode).names.size() == 2 ? ((SqlIdentifier) operandNode).names.get(1)
                  : ((SqlIdentifier) operandNode).getSimple();
              if (StringUtils.equalsIgnoreCase(fieldsName, systemFields)) {
                existeSystemFields = true;
              }
            }
          }
          if (!existeSystemFields) {
            conditions.add(where);
          }
        }
      }
      findSqlJoinConditions(from);
    }
    if (sqlNode instanceof SqlJoin) {
      SqlNode left = ((SqlJoin) sqlNode).getLeft();
      findSqlJoinConditions(left);
      SqlNode right = ((SqlJoin) sqlNode).getRight();
      findSqlJoinConditions(right);
    }
    if (sqlNode instanceof SqlBasicCall) {
      ((SqlBasicCall) sqlNode).getOperandList().forEach(baseCall -> {
        findSqlJoinConditions(baseCall);
      });
    }
  }

  /**
   * 递归格式化sqlnode 字段名字与表名
   *
   * @param sqlNode
   * @param prefix
   **/
  private static void recursion(SqlNode sqlNode, String prefix) {
    if (sqlNode instanceof SqlSelect) {
      SqlNodeList selectList = ((SqlSelect) sqlNode).getSelectList();
      SqlNode from = ((SqlSelect) sqlNode).getFrom();
      recursion(from, prefix);
      SqlNode having = ((SqlSelect) sqlNode).getHaving();
      recursion(having, prefix);
      if (((SqlSelect) sqlNode).hasOrderBy()) {
        SqlNodeList orderBy = ((SqlSelect) sqlNode).getOrderList();
        orderBy.getList().forEach(order -> recursion(order, prefix));
      }
      if (((SqlSelect) sqlNode).hasWhere()) {
        SqlNode where = ((SqlSelect) sqlNode).getWhere();
        recursion(where, prefix);
      }
      SqlNodeList groupBy = ((SqlSelect) sqlNode).getGroup();
      selectList.getList().forEach(fields -> recursion(fields, prefix));
      if (groupBy != null) {
        groupBy.getList().forEach(group -> recursion(group, prefix));
      }
    } else if (sqlNode instanceof SqlOrderBy) {
      SqlNode query = ((SqlOrderBy) sqlNode).query;
      recursion(query, prefix);
    } else if (sqlNode instanceof SqlBasicCall) {
      for (SqlNode operand : ((SqlBasicCall) sqlNode).getOperandList()) {
        if (operand instanceof SqlIdentifier) {
          modificationOperand(operand, prefix);
        } else {
          recursion(operand, prefix);
        }
      }
    } else if (sqlNode instanceof SqlIdentifier) {
      modificationOperand(sqlNode, prefix);
    }
  }

  public static String converterSqlString(SqlNode sqlNode) {
    return Util.toLinux(sqlNode.toSqlString(MysqlSqlDialectEx.DEFAULT).getSql())
        .replaceAll("\n", " ");
  }

  /***
   * 封装SelectConfig对象
   * @param profile
   * @param aggs
   * @param dataContext
   * @param fetch
   * @param fields
   * @param groupBy
   * @param hints
   * @param offset
   * @param ops
   * @param projects
   * @param rawTree
   * @param sort
   * @param type
   * @return
   * **/
  public static SelectConfig getSelectConfig(String profile, RelDataType type, List<RexNode> ops,
      List<Entry<String, Tuple2<StructKind, Class>>> fields,
      List<Pair<RexNode, String>> projects, List<Entry<String, Direction>> sort, Long offset, Long fetch, List<String> groupBy,
      List<AggregateCall> aggs, List<RelHint> hints, RelNode rawTree, DataContext dataContext) {
    Map<String, Object> context = new HashMap<>();
    SelectConfig selectConfig = new SelectConfig();
    selectConfig.setRexNodes(ops);
    selectConfig.setFields(fields);
    selectConfig.setRelDataType(type);
    selectConfig.setSorts(sort);
    selectConfig.setOffset(Optional.ofNullable(offset).map(Long::intValue).orElse(0));
    selectConfig.setFetch(Optional.ofNullable(fetch).map(Long::intValue).orElse(20));
    selectConfig.setProjects(projects);
    selectConfig.setAggs(aggs);
    selectConfig.setGroupBy(groupBy);
    selectConfig.setHints(hints);
    selectConfig.setContext(context);
    selectConfig.setRawTree(rawTree);
    selectConfig.setProfile(profile);
    selectConfig.setDataContext(dataContext);
    return selectConfig;
  }
}
