package com.xforceplus.ultraman.test.utils;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URL;
import java.sql.CallableStatement;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import javax.sql.DataSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Created by justin.xu on 10/2021.
 *
 * @since 1.8
 */
public class SqlInitUtils {
    private static final Logger LOGGER = LoggerFactory.getLogger(SqlInitUtils.class);

    private static final String DELIMITER_COMMAND = "DELIMITER";
    private static final String PROCEDURE_COMMAND = "CALL";

    static {
        try {
            Class.forName("com.mysql.cj.jdbc.Driver");
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e.getMessage(), e);
        }
    }

    private static List<String> readSqls(String resource) throws IOException {
        URL url = SqlInitUtils.class.getResource(resource);
        if (url == null) {
            return Collections.emptyList();
        }

        File path = new File(url.getPath());
        String[] sqlFiles = path.list((dir, name) -> {
            String[] names = name.split("\\.");
            if (names.length == 2 && names[1].equals("sql")) {
                return true;
            }
            return false;
        });

        List<String> sqls = new ArrayList();
        for (String file : sqlFiles) {
            String fullPath = String.format("%s%s%s", path.getAbsolutePath(), File.separator, file);
            LOGGER.info("Reader sql file: {}", fullPath);
            String lineEnd = ";";
            String annotation = "--";
            try (BufferedReader in = new BufferedReader(
                new InputStreamReader(new FileInputStream(fullPath), "utf8"))) {
                String line;
                StringBuilder buff = new StringBuilder();
                while ((line = in.readLine()) != null) {
                    if (!line.isEmpty() && !line.startsWith(annotation)) {
                        if (isDelimter(line)) {
                            lineEnd = parseEndString(line);
                            continue;
                        } else {
                            buff.append(line);
                            if (buff.toString().endsWith(lineEnd)) {
                                buff.delete(buff.length() - lineEnd.length(), buff.length());
                                sqls.add(buff.toString());

                                LOGGER.info(buff.toString());

                                buff = new StringBuilder();
                            }
                        }
                    }
                }
            }
        }

        return sqls;
    }

    /**
     * 执行资源中的SQL.
     *
     * @param resource   资源.
     * @param dataSource 操作的数据源.
     */
    public static void execute(String resource, DataSource dataSource) throws Exception {
        try (Connection conn = dataSource.getConnection()) {
            doExecute(resource, conn);
        }
    }

    /**
     * 执行目标资源中的SQL.
     *
     * @param resource     资源.
     * @param propertyName 属性.
     */
    public static void execute(String resource, String propertyName) throws Exception {
        try (Connection conn = DriverManager.getConnection(System.getProperty(propertyName))) {
            doExecute(resource, conn);
        }
    }

    private static void doExecute(String resource, Connection conn) throws Exception {
        List<String> sqls = readSqls(resource);

        try (Statement st = conn.createStatement()) {
            for (String sql : sqls) {
                if (isCall(sql)) {
                    try (CallableStatement callSt = conn.prepareCall(sql)) {
                        callSt.executeQuery();
                    }
                } else {
                    st.execute(sql);
                }
            }
        }
    }

    // 判断是否储存过程调用.
    private static boolean isCall(String sql) {
        return sql.startsWith(PROCEDURE_COMMAND) || sql.startsWith(PROCEDURE_COMMAND.toLowerCase());
    }

    private static boolean isDelimter(String command) {
        return command.toUpperCase(Locale.ROOT).startsWith(DELIMITER_COMMAND);
    }

    private static String parseEndString(String delimterCommand) {
        StringBuilder buff = new StringBuilder();
        buff.append(delimterCommand);
        buff.delete(0, DELIMITER_COMMAND.length());
        return buff.toString().trim();
    }

    /**
     * 执行清理,会清空所有的表.
     *
     * @param ds 数据源.
     * @throws SQLException 异常.
     */
    public static void clean(DataSource ds) throws SQLException {
        try (Connection conn = ds.getConnection()) {
            doClean(conn);
        }
    }

    /**
     * 执行清理.会清空所有的表.
     *
     * @throws SQLException 异常.
     */
    public static void clean(String propertyName) throws SQLException {
        try (Connection conn = DriverManager.getConnection(System.getProperty(propertyName))) {
            doClean(conn);
        }
    }

    private static void doClean(Connection connection) throws SQLException {
        try (Statement st = connection.createStatement()) {
            List<String> tables = new ArrayList<>();
            try (ResultSet rs = st.executeQuery("SHOW TABLES")) {
                while (rs.next()) {
                    tables.add(rs.getString(1));
                }
            }

            for (String table : tables) {
                st.executeUpdate(String.format("TRUNCATE TABLE `%s`", table));
            }
        }
    }
}
