package com.xforceplus.xlog.mybatis;

import com.alibaba.fastjson.JSON;
import com.xforceplus.xlog.core.model.LogContext;
import com.xforceplus.xlog.core.model.impl.MyBatisLogEvent;
import com.xforceplus.xlog.core.utils.ExceptionUtil;
import com.xforceplus.xlog.logsender.model.LogSender;
import com.xforceplus.xlog.mybatis.sqlpretty.SqlPrettyUtil;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.nio.charset.StandardCharsets;
import java.util.Properties;

/**
 * MyBatis的SQL日志插件
 * <p>
 * 参考文档：https://mybatis.org/mybatis-3/configuration.html#plugins
 *
 * @author gulei
 * @date 2023/01/19
 */
@Intercepts({
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})
})
public class XlogMyBatisExecutorInterceptor implements Interceptor {
    private final String storeName;
    private final LogSender logSender;
    private final boolean sqlEnabled;
    private final boolean sqlResultEnabled;

    /**
     * 构造函数
     *
     * @param storeName        存储库的名称
     * @param logSender        日志发送器
     * @param sqlEnabled       是否记录SQL语句
     * @param sqlResultEnabled 是否记录SQL执行结果
     */
    public XlogMyBatisExecutorInterceptor(
            final String storeName,
            final LogSender logSender,
            final boolean sqlEnabled,
            final boolean sqlResultEnabled
    ) {
        this.storeName = storeName;
        this.logSender = logSender;
        this.sqlEnabled = sqlEnabled;
        this.sqlResultEnabled = sqlResultEnabled;
    }

    /**
     * 拦截方法调用
     *
     * @param invocation 调用实例
     * @return 调用结果
     * @throws Throwable 异常
     */
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        final MyBatisLogEvent event = new MyBatisLogEvent();
        event.setStoreName(this.storeName);
        event.setTraceId(LogContext.getTraceId());
        event.setParentTraceId(LogContext.getParentTraceId());

        // 埋点数据收集阶段(前)
        this.beforeExecute(event, invocation);

        final Object result;
        try {
            result = invocation.proceed();
        } catch (Throwable throwable) {
            event.setThrowable(throwable);

            logSender.send(event);

            throw throwable;
        }

        // 埋点数据收集阶段(后)
        this.afterExecute(event, invocation, result);

        // 发送埋点日志
        this.logSender.send(event);

        return result;
    }

    /**
     * 设置插件对象
     *
     * @param o 对象
     * @return 代理
     */
    @Override
    public Object plugin(Object o) {
        if (o instanceof Executor) {
            return Plugin.wrap(o, this);
        } else {
            return o;
        }
    }

    /**
     * 设置插件属性
     *
     * @param properties 属性
     */
    @Override
    public void setProperties(Properties properties) {

    }

    /**
     * 埋点数据收集阶段(前)
     */
    private void beforeExecute(final MyBatisLogEvent event, final Invocation invocation) {
        try {
            final Object[] args = invocation.getArgs();
            final MappedStatement mappedStatement = (MappedStatement) args[0];
            final Object parameterObject = args[1];

            event.setName(mappedStatement.getId());
            event.setResourceFile(mappedStatement.getResource());
            event.setSqlCommandType(mappedStatement.getSqlCommandType().name());

            if (this.sqlEnabled) {
                event.setSql(SqlPrettyUtil.prettify(mappedStatement, parameterObject));
                event.setSqlSize(event.getSql().getBytes(StandardCharsets.UTF_8).length);
            }
        } catch (Throwable throwable) {
            event.setMessage("(前)MyBatisExecutor埋点数据收集发生异常: " + ExceptionUtil.toDesc(throwable));
        }
    }

    /**
     * 埋点数据收集阶段(后)
     */
    private void afterExecute(MyBatisLogEvent event, Invocation invocation, Object result) {
        try {
            if (this.sqlResultEnabled) {
                event.setSqlResult(JSON.toJSONString(result));
                event.setSqlResultSize(event.getSqlResult().getBytes(StandardCharsets.UTF_8).length);
            }
        } catch (Throwable throwable) {
            event.setMessage("(后)MyBatisExecutor埋点数据收集发生异常: " + ExceptionUtil.toDesc(throwable));
        }
    }
}
