package com.xforceplus.security.strategy.filter.impl;

import com.xforceplus.business.reponse.code.Rep;
import com.xforceplus.security.login.exception.AuthenticationException;
import com.xforceplus.security.login.request.LoginCaptchaRequest;
import com.xforceplus.security.login.request.LoginRequest;
import com.xforceplus.security.strategy.filter.*;
import com.xforceplus.security.strategy.model.AccountLoginFailStrategy;
import com.xforceplus.security.strategy.model.CaptchaStrategy;
import com.xforceplus.security.login.context.LoginContext;
import lombok.experimental.SuperBuilder;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.ValueOperations;

import java.util.concurrent.TimeUnit;

/**
 * CAPTCHA 校验规则安全策略处理器
 * @author geewit
 */
@Slf4j
@SuperBuilder
public class FailTimesPredicationStrategyFilter extends AbstractStrategyFilter<AccountLoginFailStrategy>
        implements PostLoadPredicationFilter<AccountLoginFailStrategy>,
        PostLoginSuccessFilter<AccountLoginFailStrategy>,
        PostLoginFailFilter<AccountLoginFailStrategy> {

    @Override
    public AccountLoginFailStrategy defaultStrategy() {
        return new AccountLoginFailStrategy();
    }

    @Override
    public boolean support(LoginContext<? extends LoginRequest> loginContext) {
        log.debug("execute {}Filter.support", this.getStrategyClass().getSimpleName());
        if (!PostLoadPredicationFilter.super.support(loginContext)) {
            return false;
        }
        LoginRequest loginRequest = loginContext.getLoginRequest();
        if (loginRequest == null) {
            log.debug("{}Filter.loginContext.loginRequest = null, do nothing", this.getStrategyClass().getSimpleName());
            return false;
        }
        if (!(loginRequest instanceof LoginCaptchaRequest)) {
            log.debug("{}Filter.loginContext.loginRequest not instanceof LoginCaptchaRequest, do nothing", this.strategyClass().getSimpleName());
            return false;
        }
        if (loginContext.getLoginName() == null) {
            log.debug("{}Filter.loginContext.loginName = null, do nothing", this.getStrategyClass().getSimpleName());
            return false;
        }
        return true;
    }
    
    /**
     * @param loginContext 待校验的账号id
     */
    @Override
    public void executePostLoadPredicate(LoginContext<? extends LoginRequest> loginContext) {
        log.debug("{}Filter.executePostLoadPredicate", this.strategyClass().getSimpleName());
        LoginCaptchaRequest loginRequest = (LoginCaptchaRequest)loginContext.getLoginRequest();
        String loginName = loginContext.getLoginName();
        String redisKey = AccountLoginFailStrategy.ACCOUNT_LOGIN_FAILS_PREFIX + loginName;
        StringRedisTemplate redisTemplate = applicationContext.getBean(StringRedisTemplate.class);
        ValueOperations<String, String> valueOperations = redisTemplate.opsForValue();
        String redisValue = valueOperations.get(redisKey);
        int times;
        if (redisValue == null) {
            times = 0;
        } else {
            try {
                times = Integer.parseInt(redisValue);
            } catch (NumberFormatException e) {
                times = 0;
            }
        }
        log.debug("times = {}", times);
        AccountLoginFailStrategy loginFailStrategy = this.loadCurrentStrategy(loginContext);
        if (loginFailStrategy == null) {
            loginFailStrategy = new AccountLoginFailStrategy();
        }
        int lockThreshold;
        if (loginFailStrategy.isEnabled()) {
            lockThreshold = loginFailStrategy.getThreshold() != null && loginFailStrategy.getThreshold() > 0 ? loginFailStrategy.getThreshold() : AccountLoginFailStrategy.DEFAULT_THRESHOLD;
        } else {
            log.debug("{}Filter.AccountLoginFailStrategy disabled, lockThreshold = 0", this.strategyClass().getSimpleName());
            lockThreshold = 0;
        }
        log.debug("lockThreshold = {}", lockThreshold);
        if (times >= lockThreshold && lockThreshold > 0) {
            Long expireMinutes = valueOperations.getOperations().getExpire(redisKey, TimeUnit.MINUTES);
            log.debug("expireMinutes = {}", expireMinutes);
            if (expireMinutes != null && expireMinutes > 0) {
                String message = AccountLoginFailStrategy.lockMessage(expireMinutes.intValue());
                log.info(message);
                throw new AuthenticationException(Rep.AccountCode.FAIL, message);
            }
        }
        CaptchaStrategy strategy = this.loadCurrentStrategy(loginContext, CaptchaStrategy.class);
        if (strategy == null) {
            strategy = new CaptchaStrategy();
        }
        int captchaThreshold;
        if (strategy.isEnabled()) {
            captchaThreshold = strategy.getThreshold() != null && strategy.getThreshold() > 0 ? strategy.getThreshold() : CaptchaStrategy.DEFAULT_THRESHOLD;
        } else {
            log.debug("{}Filter.CaptchaStrategy disabled, captchaThreshold = 0", this.strategyClass().getSimpleName());
            captchaThreshold = 0;
        }
        log.debug("captchaThreshold = {}", captchaThreshold);
        String captcha = loginRequest.getCaptcha();
        if (times >= captchaThreshold && captchaThreshold > 0) {
            if (StringUtils.isBlank(captcha)) {
                throw new AuthenticationException(Rep.AccountCode.NEED_CAPTCHA, "请输入验证码");
            } else {
                String captchaKey = CaptchaStrategy.CAPTCHA_PREFIX + captcha.toUpperCase();
                boolean exist = redisTemplate.hasKey(captchaKey);
                if (exist) {
                    redisTemplate.delete(captchaKey);
                } else {
                    String message = "验证码错误！";
                    log.info(message);
                    throw new AuthenticationException(Rep.AccountCode.NEED_CAPTCHA, message);
                }
            }
        }
    }

    /**
     * @param loginContext 待校验的账号id
     */
    @Override
    public void executePostLoginFail(LoginContext<? extends LoginRequest> loginContext) {
        log.debug("execute PostLoginStrategyFilter.executePostLoginFail");
        String loginName = loginContext.getLoginName();
        AccountLoginFailStrategy loginFailStrategy = this.loadCurrentStrategy(loginContext);
        if (loginFailStrategy == null) {
            loginFailStrategy = new AccountLoginFailStrategy();
        }
        int timeout;
        int lockThreshold;
        if (loginFailStrategy.isEnabled()) {
            lockThreshold = loginFailStrategy.getThreshold() != null && loginFailStrategy.getThreshold() > 0 ? loginFailStrategy.getThreshold() : AccountLoginFailStrategy.DEFAULT_THRESHOLD;
            timeout = loginFailStrategy.getTimeout() != null && loginFailStrategy.getTimeout() > 0 ? loginFailStrategy.getTimeout() : AccountLoginFailStrategy.DEFAULT_TIMEOUT;
        } else {
            log.debug("{}Filter.AccountLoginFailStrategy disabled, lockThreshold = 0", this.strategyClass().getSimpleName());
            lockThreshold = 0;
            timeout = AccountLoginFailStrategy.DEFAULT_TIMEOUT;
        }
        log.debug("lockThreshold = {}", lockThreshold);
        CaptchaStrategy captchaStrategy = this.loadCurrentStrategy(loginContext, CaptchaStrategy.class);
        if (captchaStrategy == null) {
            captchaStrategy = new CaptchaStrategy();
        }
        int captchaThreshold;
        if (captchaStrategy.isEnabled()) {
            captchaThreshold = captchaStrategy.getThreshold() != null && captchaStrategy.getThreshold() > 0 ? captchaStrategy.getThreshold() : CaptchaStrategy.DEFAULT_THRESHOLD;
        } else {
            captchaThreshold = 0;
        }
        log.debug("captchaThreshold = {}", captchaThreshold);
        log.debug("lockThreshold = {}", lockThreshold);
        log.debug("timeout = {}", timeout);
        this.tryLockAccount(loginName, captchaThreshold, lockThreshold, timeout);
    }

    @Override
    public void executePostLoginSuccess(LoginContext<? extends LoginRequest> loginContext) {
        log.debug("execute PostLoginStrategyFilter.executePostLoginSuccess");
        try {
            StringRedisTemplate redisTemplate = applicationContext.getBean(StringRedisTemplate.class);
            String loginName = loginContext.getLoginName();
            String key = AccountLoginFailStrategy.ACCOUNT_LOGIN_FAILS_PREFIX + loginName;
            redisTemplate.delete(key);
        } catch (Exception e) {
            log.warn(e.getMessage());
        }
    }

    private void tryLockAccount(String loginName, int captchaThreshold, int lockThreshold, int timeout) {
        log.debug("execute PostLoginStrategyFilter.tryLockAccount");
        String redisKey = AccountLoginFailStrategy.ACCOUNT_LOGIN_FAILS_PREFIX + loginName;
        StringRedisTemplate redisTemplate = applicationContext.getBean(StringRedisTemplate.class);
        String redisValue = redisTemplate.opsForValue().get(redisKey);
        log.debug("redisValue = {}", redisValue);
        int times;
        if (redisValue == null) {
            times = 0;
        } else {
            try {
                times = Integer.parseInt(redisValue);
            } catch (NumberFormatException e) {
                times = 0;
            }
        }
        times++;
        try {
            int code;
            if (captchaThreshold > 0 && times >= captchaThreshold) {
                code = Rep.AccountCode.NEED_CAPTCHA;
            } else {
                code = Rep.AccountCode.FAIL;
            }
            if (lockThreshold > 0) {
                if (times < lockThreshold) {
                    throw new AuthenticationException(code, "用户名/密码错误, 还有" + (lockThreshold - times) + "次机会");
                } else {
                    String message = AccountLoginFailStrategy.lockMessage(timeout);
                    throw new AuthenticationException(Rep.AccountCode.FAIL, message);
                }
            } else {
                throw new AuthenticationException(code, "用户名/密码错误");
            }
        } finally {
            log.debug("times = {}", times);
            redisTemplate.opsForValue().set(redisKey, String.valueOf(times), timeout, TimeUnit.MINUTES);
        }
    }
}
