package com.xforceplus.pscc.common.intercept;

import cn.hutool.core.collection.CollUtil;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.PostConstruct;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpHeaders;
import org.springframework.stereotype.Component;

/**
 * 生成traceId 打印请求返回参数（跳过部分奥特曼接口）
 *
 * @author nathan
 * @date 20220714
 */
@Component
@WebFilter(filterName = "logFilter", urlPatterns = "/*")
@Slf4j
public class LogFilter implements Filter {

    @Autowired
    private LogFilterConfig logFilterConfig;

    private static final Set<String> EXCLUDE_PATTERNS = CollUtil.newHashSet("health", "data-om", "bos");

    @PostConstruct
    public void beanPost() {
        if (CollectionUtils.isEmpty(logFilterConfig.getExcludePath())) {
            return;
        }
        EXCLUDE_PATTERNS.addAll(logFilterConfig.getExcludePath());
    }

    /**
     * request拦截的conten-type列表
     */
    private static final String CONTENT_TYPE = "application/json";

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        HttpServletRequest httpServletRequest = (HttpServletRequest) request;
        HttpServletResponse httpServletResponse = (HttpServletResponse) response;

        //根据exclude 跳过
        if (isExcludePattern(httpServletRequest.getRequestURI())) {
            TraceContext.putContextTraceId();
            chain.doFilter(httpServletRequest, httpServletResponse);
            TraceContext.removeContextTraceId();
            return;
        }
        //打印请求日志
        httpServletRequest = logRequest(httpServletRequest);
        //包装Response，重写getOutputStream()和getWriter()方法，并用自定义的OutputStream和Writer来拦截和保存ResponseBody
        ResponseWrapper responseWrapper = new ResponseWrapper(httpServletResponse);
        //请求开始时间
        Long dateStart = System.currentTimeMillis();
        //Spring通过DispatchServlet处理请求
        chain.doFilter(httpServletRequest, responseWrapper);
        //打印返回日志
        logResponse(dateStart, responseWrapper);
    }


    public HttpServletRequest logRequest(HttpServletRequest httpServletRequest) {
        try {
            //所有请求参数的Map
            String paramBody = null;
            //application/json才拦截
            if (StringUtils.contains(httpServletRequest.getContentType(), CONTENT_TYPE)) {
                //获取header参数
                /** 暂不打印请求头参数
                 Map<String, String> headerMap = getHeaders(httpServletRequest);
                 if (headerMap != null && !headerMap.isEmpty()){
                 paramMap.putAll(headerMap);
                 }**/

                httpServletRequest = new RequestWrapper(httpServletRequest);
                //获取所有queryString和requestBody
                paramBody = getRequestParamBody(httpServletRequest);

            }
            //请求路径
            String uri = httpServletRequest.getRequestURI();
            String method = httpServletRequest.getMethod();

            //请求的真实IP
            String requestedIp = getRealIP(httpServletRequest);
            log.info("请求报文-TraceId:{},Method:{},Uri:{},PathVariable:{},Ip:{},RequestStr:{}",
                TraceContext.putContextTraceId(),
                method,
                uri,
                httpServletRequest.getQueryString(),
                requestedIp,
                paramBody);
        } catch (Exception e) {
            log.error("输出请求报文异常-->", e);
        }
        return httpServletRequest;
    }


    public void logResponse(Long dateStart, ResponseWrapper responseWrapper) {
        try {
            if (!StringUtils.contains(responseWrapper.getContentType(), CONTENT_TYPE)) {
                log.info("response contentType:{}，无需打印参数", responseWrapper.getContentType());
            } else {
                //请求结束时间
                Long dateEnd = System.currentTimeMillis();
                String responseBody = null;
                if (responseWrapper.getMyOutputStream() == null) {
                    if (responseWrapper.getMyWriter() != null) {
                        responseBody = responseWrapper.getMyWriter().getContent();
                        //一定要flush，responseBody会被复用
                        responseWrapper.getMyWriter().myFlush();
                    }
                } else {
                    responseBody = responseWrapper.getMyOutputStream().getBuffer();
                    //一定要flush，responseBody会被复用
                    responseWrapper.getMyOutputStream().myFlush();
                }
                log.info("返回报文-TraceId:{},Time:{},ResponseStr:{}",
                    TraceContext.getContextTraceId(),
                    dateEnd - dateStart,
                    responseBody);
            }

            TraceContext.removeContextTraceId();
        } catch (Exception e) {
            log.error("输出返回报文异常-->", e);
        }
    }

    public Map<String, String> getHeaders(HttpServletRequest request) {
        Map<String, String> headerMap = new HashMap<>();
        List<String> headers = getCommonHeaders();
        headers.add("Postman-Token");
        headers.add("Proxy-Connection");
        headers.add("X-Lantern-Version");
        headers.add("Cookie");
        Enumeration<String> headerNames = request.getHeaderNames();
        while (headerNames.hasMoreElements()) {
            String headerName = headerNames.nextElement();
            if (headers.contains(headerName)) {
                continue;
            }
            headerMap.put(headerName, request.getHeader(headerName));
        }
        return headerMap;
    }

    private List<String> getCommonHeaders() {

        List<String> headers = new ArrayList<>();
        Class<HttpHeaders> clazz = HttpHeaders.class;
        Field[] fields = clazz.getFields();
        for (Field field : fields) {
            field.setAccessible(true);
            if (field.getType().toString().endsWith("java.lang.String") && Modifier.isStatic(field.getModifiers())) {
                try {
                    headers.add((String) field.get(HttpHeaders.class));
                } catch (IllegalAccessException e) {
                    log.error("反射获取属性值异常-->", e);
                }
            }
        }
        return headers;
    }

    /**
     * 获取请求的真实IP
     *
     * @param request 请求
     * @return ip
     */
    public static String getRealIP(HttpServletRequest request) {
        String ip = request.getHeader("X-Forwarded-For");
        if (StringUtils.isNotEmpty(ip) && !"unKnown".equalsIgnoreCase(ip)) {
            //多次反向代理后会有多个ip值，第一个ip才是真实ip
            int index = ip.indexOf(",");
            if (index != -1) {
                return ip.substring(0, index);
            } else {
                return ip;
            }
        }
        ip = request.getHeader("X-Real-IP");
        if (StringUtils.isNotEmpty(ip) && !"unKnown".equalsIgnoreCase(ip)) {
            return ip;
        }
        return request.getRemoteAddr();
    }

    /**
     * 从Request中获取所有的请求参数， 包括GET/POST/PATCH等请求，不包括路径参数
     *
     * @param request 请求
     * @return 请求参数
     */
    public static String getRequestParamBody(HttpServletRequest request) throws IOException {
        String requestBody = null;
        //获取Body中的参数，POST/PATCH等方式,application/json
        //当为POST请求且 application/json时，request被RequestFilter处理为wrapper类
        if (!(request instanceof RequestWrapper)) {
            return requestBody;
        }

        return ((RequestWrapper) request).getBody();
    }

    public static boolean isExcludePattern(String url) {
        if (StringUtils.isBlank(url)) {
            return false;
        }
        String path = url.trim();
        if (StringUtils.startsWith(path, "/")) {
            path = StringUtils.substring(path, 1);
        }
        boolean match = EXCLUDE_PATTERNS.stream().anyMatch(path::startsWith);
        if (match) {
            //log.info("excludePattern == {}", url);
        }
        return match;
    }

    @Override
    public void init(FilterConfig logFilterConfig) throws ServletException {
    }

    @Override
    public void destroy() {
    }

}
