package com.xforceplus.ultraman.sdk.infra.base.id.node;

import com.xforceplus.ultraman.sdk.infra.base.timerwheel.ITimerWheel;
import com.xforceplus.ultraman.sdk.infra.base.timerwheel.TimeoutNotification;
import com.xforceplus.ultraman.sdk.infra.base.timerwheel.TimerWheel;
import com.xforceplus.ultraman.sdk.infra.lifecycle.SimpleLifecycle;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.sql.DataSource;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.StringJoiner;
import java.util.concurrent.ThreadLocalRandom;

/**
 * 基于mysql的一个node结点的生成器.<br>
 * 其依赖于如下表定义.
 * <pre>
 *  CREATE TABLE nodeid (
 *   `id`       int   NOT NULL COMMENT '从0开始的编号',
 *   `heartbeat` bigint NOT NULL DEFAULT 0 COMMENT '最后心跳时间,毫秒值',
 *   PRIMARY KEY (`id`)
 * ) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='结点编号记录';
 * </pre>
 * 其实现会在初始化时加载所有的数据,并收集计算当前时间和 heartbeat 之间的差是否大于指定间隔的记录.<br>
 * 之后会使用"cas"的方式进行更新,如果更新成功表示占用成功.否则排除此结点,并继续探测直到成功.
 *
 * @author dongbin
 * @version 0.1 2022/12/9 17:37
 * @since 1.8
 */
public class MysqlNodeIdGenerator implements NodeIdGenerator, SimpleLifecycle {

    private final Logger logger = LoggerFactory.getLogger(MysqlNodeIdGenerator.class);

    /*
    默认心跳超间隔,毫秒.
     */
    private static final long DEFAULT_HEARTBEAT_INTERVAL_MS = 1000 * 60 * 60;

    /*
    默认表名.
     */
    private static final String DEFAULT_TABLE_NAME = "nodeid";

    /*
    最大的ID号.
     */
    private static final int DEFAULT_MAX_ID = 1023;

    private static final String ID_FIELD = "id";
    private static final String HEARTBEAT_FIELD = "heartbeat";

    private DataSource ds;

    private ITimerWheel<NodeId> heartBearWheel;

    private String tableName;

    // 心跳间隔,默认1小时.
    private long heartbeatIntervalMs;

    private int maxId = DEFAULT_MAX_ID;

    private NodeId nodeId;

    public DataSource getDs() {
        return ds;
    }

    public String getTableName() {
        return tableName;
    }

    public long getHeartbeatIntervalMs() {
        return heartbeatIntervalMs;
    }

    public int getMaxId() {
        return maxId;
    }

    @Override
    public void init() throws Exception {
        NodeId[] nodeIds = initNodeInfo();
        NodeId nodeId = findNodeInfo(nodeIds);

        if (nodeId == null) {
            throw new IllegalStateException("A valid node number could not be found.");
        }

        this.nodeId = nodeId;

        logger.info("Initialization successful, current node number is {}.", nodeId.getId());

        heartBearWheel = new TimerWheel(new HeartbeatTimeoutNotification());
        heartBearWheel.add(this.nodeId, this.heartbeatIntervalMs);

    }

    @Override
    public void destroy() throws Exception {
        heartBearWheel.destroy();
    }

    @Override
    public Integer next() {
        if (this.nodeId == null) {
            throw new IllegalStateException("The node number generator has not yet generated the number.");
        }

        /*
        因为储存中最小值为1,但是我们的id范围需要从0开始.
         */
        return this.nodeId.getId() - 1;
    }

    // 加载结点编号信息.
    private NodeId[] initNodeInfo() throws SQLException {
        int size = maxId + 1;
        List<NodeId> nodeIds = new ArrayList<>(size);
        try (Connection conn = ds.getConnection()) {
            try (Statement st = conn.createStatement()) {
                try (ResultSet rs = st.executeQuery(
                    String.format("SELECT %s, %s FROM %s LIMIT 0, %d", ID_FIELD, HEARTBEAT_FIELD, tableName, size))) {

                    while (rs.next()) {
                        nodeIds.add(new NodeId(rs.getInt(ID_FIELD), rs.getLong(HEARTBEAT_FIELD)));
                    }
                }
            }
        }

        if (nodeIds.size() != size) {
            throw new IllegalStateException(
                String.format("The expected number of available ids was %d, but it is now %d.", size, nodeIds.size()));
        }

        return nodeIds.stream().toArray(NodeId[]::new);
    }

    /*
    不断尝试,不断排除.直到成功获取一个结点信息.
     */
    private NodeId findNodeInfo(NodeId[] nodeIds) throws SQLException {
        ThreadLocalRandom random = ThreadLocalRandom.current();
        NodeId[] useNodeIds = nodeIds;
        while (true) {
            if (useNodeIds.length == 0) {
                return null;
            }

            int selectNumber = random.nextInt(0, useNodeIds.length);
            NodeId selectNodeId = nodeIds[selectNumber];
            if (takeUp(selectNodeId, false)) {

                return selectNodeId;

            } else {

                useNodeIds = eliminate(nodeIds, selectNodeId);

                if (logger.isInfoEnabled()) {
                    logger.info(
                        "The use of the node information {} was abandoned because the occupation failed.", selectNodeId);
                }

            }
        }
    }

    // 试图占用一个结点.true成功,false失败.focus表示无条件占用.
    private boolean takeUp(NodeId selectNodeId, boolean focus) throws SQLException {
        // 结点没有超时,无法占用.
        if (!focus && !isTimeout(selectNodeId)) {
            return false;
        }

        boolean ok = false;
        long newHeartbeatTimeMs = System.currentTimeMillis();
        try (Connection conn = ds.getConnection()) {
            // update table set heartbeat=? WHERE id=? and heartbeat=?
            try (PreparedStatement ps =
                     conn.prepareStatement(
                         String.format(
                             "UPDATE %s SET %s=? WHERE %s=? AND %s=?",
                             tableName, HEARTBEAT_FIELD, ID_FIELD, HEARTBEAT_FIELD))) {
                ps.setLong(1, newHeartbeatTimeMs);
                ps.setInt(2, selectNodeId.getId());
                ps.setLong(3, selectNodeId.getHeartbeat());

                // 成功只有一种可能,更新条数为1.
                final int success = 1;
                ok = ps.executeUpdate() == success;
            }
        }

        selectNodeId.setHeartbeat(newHeartbeatTimeMs);

        return ok;
    }

    // 判断是否已经超时,心跳超过指定时间没有更新.true表示超时,false表示没有超时.
    private boolean isTimeout(NodeId selectNodeId) {
        long dur = System.currentTimeMillis() - selectNodeId.getHeartbeat();

        // 当前时间小于结点时间,有可能人为的修改了结点时间,认为失败.
        if (dur <= 0) {
            return false;
        }

        return dur > heartbeatIntervalMs;
    }

    // 排除一个不合式的结点,返回排除扣的新结点列表.
    private NodeId[] eliminate(NodeId[] nodeIds, NodeId targetInfo) {
        List<NodeId> surviveNodeIds = new ArrayList<>(nodeIds.length - 1);

        for (NodeId nodeId : nodeIds) {
            if (!nodeId.equals(targetInfo)) {
                surviveNodeIds.add(nodeId);
            }
        }

        return surviveNodeIds.stream().toArray(NodeId[]::new);
    }

    // 结点信息
    private static class NodeId {
        private int id;
        private long heartbeat;

        public NodeId(int id, long heartbeat) {
            this.id = id;
            this.heartbeat = heartbeat;
        }

        public int getId() {
            return id;
        }

        public long getHeartbeat() {
            return heartbeat;
        }

        public void setHeartbeat(long heartbeat) {
            this.heartbeat = heartbeat;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (!(o instanceof NodeId)) {
                return false;
            }
            NodeId nodeId = (NodeId) o;
            return getId() == nodeId.getId();
        }

        @Override
        public int hashCode() {
            return Objects.hash(getId());
        }

        @Override
        public String toString() {
            return new StringJoiner(", ", NodeId.class.getSimpleName() + "[", "]")
                .add("heartbeat=" + heartbeat)
                .add("id=" + id)
                .toString();
        }


    }

    /**
     * 实例构造器.
     */
    public static final class Builder {
        private DataSource ds;
        private String tableName = DEFAULT_TABLE_NAME;
        private long heartbeatIntervalMs = DEFAULT_HEARTBEAT_INTERVAL_MS;
        private int maxId = DEFAULT_MAX_ID;

        private Builder() {}

        public static Builder anMysqlNodeIdGenerator() {
            return new Builder();
        }

        public Builder withDataSource(DataSource ds) {
            this.ds = ds;
            return this;
        }

        public Builder withTableName(String tableName) {
            this.tableName = tableName;
            return this;
        }

        public Builder withHeartbeatIntervalMs(long heartbeatIntervalMs) {
            this.heartbeatIntervalMs = heartbeatIntervalMs;
            return this;
        }

        public Builder withMaxId(int maxId) {
            this.maxId = maxId;
            return this;
        }

        /**
         * 构造实例.
         */
        public MysqlNodeIdGenerator build() throws Exception {
            MysqlNodeIdGenerator mysqlNodeIdGenerator = new MysqlNodeIdGenerator();
            mysqlNodeIdGenerator.heartbeatIntervalMs = this.heartbeatIntervalMs;
            mysqlNodeIdGenerator.tableName = this.tableName;
            mysqlNodeIdGenerator.ds = this.ds;
            mysqlNodeIdGenerator.maxId = this.maxId;
            mysqlNodeIdGenerator.init();
            return mysqlNodeIdGenerator;
        }
    }

    /*
    心跳处理器,当心跳到达时
     */
    private class HeartbeatTimeoutNotification implements TimeoutNotification<NodeId> {

        @Override
        public long notice(NodeId nodeId) {

            boolean result = false;
            try {
                result = takeUp(nodeId, true);
            } catch (SQLException e) {
                logger.error(e.getMessage(), e);
            }

            return heartbeatIntervalMs;
        }
    }


}
