package com.xforceplus.janus.framework.record.interceptor;

import com.xforceplus.janus.framework.record.domain.AccessContentDto;
import com.xforceplus.janus.framework.record.domain.AccessRecord;
import com.xforceplus.janus.framework.record.cache.AccessRecordCache;
import com.xforceplus.janus.framework.util.IPUtils;

import org.apache.commons.lang3.time.DateFormatUtils;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.HandlerInterceptor;

import java.io.BufferedReader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import lombok.extern.slf4j.Slf4j;

/**
 * 请求记录拦击记录履历
 *
 * @Author: xuchuanhou
 * @Date:2022/2/19下午3:46
 */
@Slf4j
@Component
@Order(1)
public class RequestInterceptor implements HandlerInterceptor {

    private static final String CONTENT_TYPE_MULTIPART = "multipart/form-data";

    private AccessRecordCache accessRecordCache;


    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        request.setAttribute("startTime", System.currentTimeMillis());
        return true;
    }

    private static Set<String> excludeheaderKey = new HashSet<String>() {{
        add("Cache-Control");
    }};

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
        long startTime = (long) request.getAttribute("startTime");
        long cost = System.currentTimeMillis() - startTime;
        AccessRecord.AccessRecordBuilder builder = AccessRecord.builder()
                .costTime(cost)
                .requestTime(DateFormatUtils.format(startTime, "yyyyMMddHHmmssSSS"))
                .action(request.getRequestURI())
                .requestMethod(request.getMethod()).status(response.getStatus())
                .sourceIp(IPUtils.getIpAddr(request));

        AccessContentDto accessContentDto = new AccessContentDto();

        if (response instanceof CustomHttpServletResponseWrapper) {
            String responseBody = new String(((CustomHttpServletResponseWrapper) response).getBytes());
            accessContentDto.setResponseBody(responseBody);
            builder.reqDataLen(responseBody.getBytes(StandardCharsets.UTF_8).length);
        }

        Map<String, String> headerMap = new HashMap<>();
        Enumeration<String> headerEnu = request.getHeaderNames();
        while (headerEnu.hasMoreElements()) {
            String key = headerEnu.nextElement();
            if (excludeheaderKey.contains(key)) {
                continue;
            }
            headerMap.put(key, request.getHeader(key));
        }

        Map<String, String> paramMap = new HashMap<>();

        if (request.getContentType() != null && request.getContentType().contains(CONTENT_TYPE_MULTIPART)) {
            Enumeration<String> paramEnu = request.getParameterNames();
            while (paramEnu.hasMoreElements()) {
                String key = paramEnu.nextElement();
                paramMap.put(key, request.getParameter(key));
            }
        }

        accessContentDto.setRequestHeader(headerMap);
        accessContentDto.setRequestParam(paramMap);

        if (request.getInputStream() != null && !request.getInputStream().isFinished()) {
            StringBuilder reqBody = getRequestBody(request);
            if (reqBody != null && reqBody.length() > 0) {
                String reqBodyStr = reqBody.toString();
                accessContentDto.setRequestBody(reqBodyStr);
                builder.reqDataLen(reqBodyStr.getBytes(StandardCharsets.UTF_8).length);
            }
        }

        AccessRecord accessRecord = builder.build();
        accessRecord.setAccessContent(accessContentDto);

        try {
//            log.info("{} cost:{},method:{} headers:{}", accessRecord.getAction(), accessRecord.getCostTime(), accessRecord.getRequestMethod());
            AccessRecordCache.pushRecord(accessRecord);
        } catch (Exception exception) {
            log.error("record error:{}", accessRecord.toString());
        }
    }

    private StringBuilder getRequestBody(HttpServletRequest request) {
        StringBuilder requestBodySB = new StringBuilder();
        try (BufferedReader br = request.getReader()) {
            String line = null;
            while ((line = br.readLine()) != null) {
                requestBodySB.append(line);
            }
            br.close();
            return requestBodySB;
        } catch (IOException ex) {
        }

        return requestBodySB;
    }

}
