package com.xforceplus.phoenix.split.service.dataflow.impl;

import com.google.common.base.Stopwatch;
import com.xforceplus.phoenix.split.domain.SplitGroupLimit;
import com.xforceplus.phoenix.split.model.BillItem;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Getter;
import org.apache.commons.collections.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.annotation.Scope;
import org.springframework.stereotype.Service;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.*;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

/**
 * 类描述：
 * @ClassName MinPackagePlugin
 * @Description 最小张数成票规则
 * @Author ZZW
 * @Date 2021/12/25 14:06
 */
@Service
@Scope("prototype")
public class MinPackagePlugin2 {
    private static final Logger logger = LoggerFactory.getLogger(MinPackagePlugin2.class);

    Map<Integer, List<Integer>> POSITION_MAP = new HashMap<>();

    {
        POSITION_MAP.put(3, Arrays.asList(1, 0));
        POSITION_MAP.put(2, Arrays.asList(0, 1));
        POSITION_MAP.put(1, Arrays.asList(0, 1));
        POSITION_MAP.put(0, Arrays.asList(1, 0));
    }

    //将所有的明细，按照不含税金额+税差 分成 4组
    // 0 不含税金额不超过中位数 税差为正
    // 1 不含税金额不超过中位数 税差为负数
    // 2 不含税金额超过中位数 税差为负数
    // 3 不含税金额超过中位数 税差为正数
    //组合优先级：最优 3+1/2+0  次优 3+0/2+1  1+0
    //算法逻辑：从 3 扫描，找 1中匹配，金额总和最靠近限额的，然后是 0 中匹配，然后结束 3 中其他明细独自成票
    //         从 2 扫描 找 0匹配 再找 1匹配  剩余独自成票
    //         从 1 中找 0 ，然后 剩余与自身匹配
    //         0 只能匹配0中
    public Map<Integer, Group> distributeGroup(List<BillItem> list, BigDecimal limitAmount) {

        Map<Integer, Group> map = new HashMap<>();
        map.put(0, new Group(new CopyOnWriteArrayList<>()));
        map.put(1, new Group(new CopyOnWriteArrayList<>()));
        map.put(2, new Group(new CopyOnWriteArrayList<>()));
        map.put(3, new Group(new CopyOnWriteArrayList<>()));

        final BigDecimal median = limitAmount.divide(BigDecimal.valueOf(2), RoundingMode.DOWN);
        list.forEach(item->{
            boolean overMedian = item.getAmountWithoutTax().subtract(item.getDiscountWithoutTax()).compareTo(median) > 0 ;
            BigDecimal diff = (item.getAmountWithoutTax().multiply(item.getTaxRate()).subtract(item.getTaxAmount())).add(item.getDiscountWithoutTax().multiply(item.getTaxRate()).subtract(item.getDiscountTax()));
            boolean taxDiffPosition = diff.compareTo(BigDecimal.ZERO) > 0;
            Integer index = getGroupIndex(overMedian, taxDiffPosition);

            BigDecimal amount = item.getAmountWithoutTax().subtract(item.getDiscountWithoutTax());
            BigDecimal amountWithTax = item.getAmountWithTax().subtract(item.getDiscountWithTax());
            Group group = map.get(index);
            group.elementList.add(new Element(item, amount, diff, amountWithTax));
        });
        map.values().stream().forEach(key -> {
            Collections.sort(key.elementList, Comparator.comparing(Element::getAmountWithoutTax));
        });
        return map;
    }

    public Map<Integer, Group> distributeGroupWithTax(List<BillItem> list, BigDecimal limitAmount) {

        Map<Integer, Group> map = new HashMap<>();
        map.put(0, new Group(new CopyOnWriteArrayList<>()));
        map.put(1, new Group(new CopyOnWriteArrayList<>()));
        map.put(2, new Group(new CopyOnWriteArrayList<>()));
        map.put(3, new Group(new CopyOnWriteArrayList<>()));

        final BigDecimal median = limitAmount.divide(BigDecimal.valueOf(2), RoundingMode.DOWN);
        list.forEach(item->{
            boolean overMedian = item.getAmountWithTax().subtract(item.getDiscountWithTax()).compareTo(median) > 0 ;
            BigDecimal diff = this.calTaxAmountByAmountWithTax(item.getAmountWithTax(), item.getTaxRate()).subtract(item.getTaxAmount())
                    .add(this.calTaxAmountByAmountWithTax(item.getDiscountWithTax(), item.getTaxRate()).subtract(item.getDiscountTax()));
            boolean taxDiffPosition = diff.compareTo(BigDecimal.ZERO) > 0;
            Integer index = getGroupIndex(overMedian, taxDiffPosition);

            BigDecimal amount = item.getAmountWithoutTax().subtract(item.getDiscountWithoutTax());
            BigDecimal amountWithTax = item.getAmountWithTax().subtract(item.getDiscountWithTax());
            Group group = map.get(index);
            group.elementList.add(new Element(item, amount, diff, amountWithTax));
        });
        map.values().forEach(key -> key.elementList.sort(Comparator.comparing(Element::getAmountWithTax)));
        return map;
    }

    private BigDecimal calTaxAmountByAmountWithTax(BigDecimal amountWithTax, BigDecimal taxRate) {
        return amountWithTax.multiply(taxRate).divide(BigDecimal.ONE.add(taxRate), 2, RoundingMode.HALF_UP);
    }

    public List<List<BillItem>> processData(List<BillItem> billItems, SplitGroupLimit splitGroupLimit) {
        List<ResGroup> RES = new ArrayList<>();

        Stopwatch stopwatch = Stopwatch.createStarted();
        logger.info("MinPackagePlugin2 处理 开始");

        Map<Integer, MinPackagePlugin2.Group> groupMap = splitGroupLimit.isLimitIsAmountWithTax()
                ? distributeGroupWithTax(billItems, splitGroupLimit.getLimitAmount())
                : distributeGroup(billItems, splitGroupLimit.getLimitAmount());
        MinPackagePlugin2.ResGroup resGroup = new MinPackagePlugin2.ResGroup(new ArrayList<>(), BigDecimal.ZERO, BigDecimal.ZERO, BigDecimal.ZERO);

        /**
         * 匹配3区间数据 循环 3 区间，再 0，1 区间中符合条件的 数据，匹配到的数据会从 相应区间中删除
         */
        resGroup = this.packageItem(groupMap, 3, splitGroupLimit, resGroup, RES);

        /**
         * 匹配2区间数据 循环 3 区间，再 1，0 区间中符合条件的 数据，匹配到的数据会从 相应区间中删除
         */
        resGroup = this.packageItem(groupMap, 2, splitGroupLimit, resGroup, RES);

        /**
         * 匹配2区间数据 循环 1 区间，再 0，1 区间中符合条件的 数据，匹配到的数据会从 相应区间中删除
         */
        resGroup = this.packageItem(groupMap, 1, splitGroupLimit, resGroup, RES);

        /**
         * 匹配2区间数据 循环 0 区间，再 1，0 区间中符合条件的 数据，匹配到的数据会从 相应区间中删除
         */
        resGroup = this.packageItem(groupMap, 0, splitGroupLimit, resGroup, RES);

        List<List<BillItem>> lastGroup = processLast(groupMap, splitGroupLimit.getLimitAmount(), splitGroupLimit.getInvoiceMaxErrorAmount());
        List<List<BillItem>> splitGroup = new ArrayList<>();
        splitGroup.addAll(lastGroup);
        RES.stream().filter(tmpGroup -> tmpGroup.getTotalAmount().compareTo(BigDecimal.ZERO) > 0).forEach(tmpGroup -> {
            splitGroup.add(tmpGroup.elementList.stream().map(element -> element.billItem).collect(Collectors.toList()));
        });

        logger.info("MinPackagePlugin2 处理 完成 耗时{}  条数{}", stopwatch.elapsed(TimeUnit.MILLISECONDS), splitGroup.size());
        return splitGroup;
    }

    //todo 尾数处理 临时解决方案 ，不应该存在此问题，bug
    private List<List<BillItem>> processLast(Map<Integer, MinPackagePlugin2.Group> groupMap, BigDecimal limit, BigDecimal taxDiffLimit) {
        List<List<BillItem>> splitGroup = new ArrayList<>();
        List<Element> splitItems = new CopyOnWriteArrayList<>();
        groupMap.values().stream().filter(tmpGroup -> CollectionUtils.isNotEmpty(tmpGroup.elementList)).forEach(group -> {
            for (Element element : group.getElementList()) {
                splitItems.add(element);
            }
        });
        BigDecimal totalAmount = BigDecimal.ZERO;
        BigDecimal totalDiff = BigDecimal.ZERO;
        Integer originalNum = splitItems.size();
        while (splitItems.size() > 0) {
            List<BillItem> tmpRes = new ArrayList<>();
            for (Element element : splitItems) {
                if (element.amountWithoutTax.compareTo(limit) > 0) {
                    tmpRes.add(element.billItem);
                    splitItems.remove(element);
                    continue;
                }
                totalAmount = totalAmount.add(element.amountWithoutTax);
                totalDiff = totalDiff.add(element.taxAmountDiff);
                if (totalDiff.abs().compareTo(taxDiffLimit) > 0) {
                    continue;
                }
                if (totalAmount.compareTo(limit) > 0) {
                    continue;
                }
                tmpRes.add(element.billItem);
                splitItems.remove(element);
            }
            if (originalNum == splitItems.size()) {
                throw new IllegalArgumentException("miniPackage data exception");
            }
            if (CollectionUtils.isNotEmpty(tmpRes)) {
                splitGroup.add(tmpRes);
                totalAmount = BigDecimal.ZERO;
                totalDiff = BigDecimal.ZERO;
            }
        }

        return splitGroup;
    }

    private Integer getGroupIndex(Boolean overMedian, Boolean taxDiffPosition) {
        if (overMedian) {
            if (taxDiffPosition) {
                return 2;
            } else {
                return 3;
            }
        } else {
            if (taxDiffPosition) {
                return 0;
            } else {
                return 1;
            }
        }
    }

    public ResGroup packageItem(Map<Integer, MinPackagePlugin2.Group> map, Integer mapPosition, SplitGroupLimit splitGroupLimit, ResGroup resGroup, List<ResGroup> RES) {
        MinPackagePlugin2.Group targetGroup = map.get(mapPosition);
        List<Integer> matchPositions = POSITION_MAP.get(mapPosition);
        if (CollectionUtils.isEmpty(targetGroup.elementList)) {
            return resGroup;
        }
        resGroup = this.doMatch(map, targetGroup, mapPosition, resGroup, splitGroupLimit, RES);
        Integer matchPosition = matchPositions.get(0);
        MinPackagePlugin2.Group matchGroup = map.get(matchPosition);
        resGroup = this.doMatch(map, matchGroup, 0, resGroup, splitGroupLimit, RES);
        matchGroup = map.get(matchPositions.get(1));
        this.doMatch(map, matchGroup, 1, resGroup, splitGroupLimit, RES);
        return resGroup;
    }

    Boolean isturn = false;

    private ResGroup doMatch(
            Map<Integer, MinPackagePlugin2.Group> map,
            MinPackagePlugin2.Group group, Integer matchPosition,
            ResGroup resGroup,
            SplitGroupLimit splitGroupLimit,
            List<ResGroup> RES
    ) {

        for (MinPackagePlugin2.Element element : group.getElementList()) {
            // 如果当前税差超限，继续执行
            if (resGroup.getTaxDiff().add(element.getTaxAmountDiff()).abs().compareTo(splitGroupLimit.getInvoiceMaxErrorAmount()) > 0) {
                matchPosition = matchPosition > 0 ? 0 : 1;
                group = map.get(matchPosition);
                if (CollectionUtils.isEmpty(group.elementList)) {
                    if (!RES.contains(resGroup)) {
                        RES.add(resGroup);
                    }
                    return doMatch(map, group, matchPosition, new ResGroup(new ArrayList<>(), BigDecimal.ZERO, BigDecimal.ZERO, BigDecimal.ZERO), splitGroupLimit, RES);
                }
                isturn = true;
                return doMatch(map, group, matchPosition, resGroup, splitGroupLimit, RES);
            }


            BigDecimal totalAmount;
            if (splitGroupLimit.isLimitIsAmountWithTax()) {
                totalAmount = resGroup.getTotalAmount().add(element.getAmountWithTax());
            } else {
                totalAmount = resGroup.getTotalAmount().add(element.getAmountWithoutTax());
            }
            BigDecimal totalQuantity = resGroup.getTotalQuantity().add(element.getQuantity());

            if (totalAmount.compareTo(splitGroupLimit.getLimitAmount()) > 0 || (Objects.nonNull(splitGroupLimit.getLimitQuantity()) && totalQuantity.compareTo(splitGroupLimit.getLimitQuantity()) > 0)) {
                // 如果当前限额超限，进行转换
                if (!isturn) {
                    matchPosition = matchPosition > 0 ? 0 : 1;
                    group = map.get(matchPosition);
                    if (CollectionUtils.isEmpty(group.elementList)) {
                        if (!RES.contains(resGroup)) {
                            RES.add(resGroup);
                        }
                        return doMatch(map, group, matchPosition, new ResGroup(new ArrayList<>(), BigDecimal.ZERO, BigDecimal.ZERO, BigDecimal.ZERO), splitGroupLimit, RES);
                    }
                    isturn = true;
                    return doMatch(map, group, matchPosition, resGroup, splitGroupLimit, RES);
                }
                // 如果转换后依然 超限，结束此次匹配，从当前
                if (isturn) {
                    isturn = false;
                    if (!RES.contains(resGroup)) {
                        RES.add(resGroup);
                    }
                    return doMatch(map, group, matchPosition, new ResGroup(new ArrayList<>(), BigDecimal.ZERO, BigDecimal.ZERO, BigDecimal.ZERO), splitGroupLimit, RES);
                }
            }
            resGroup.elementList.add(element);
            resGroup.setTotalAmount(totalAmount);
            resGroup.setTotalQuantity(totalQuantity);
            if (resGroup.getTaxDiff().add(element.getTaxAmountDiff()).abs().compareTo(splitGroupLimit.getInvoiceMaxErrorAmount()) >= 0) {
                System.out.println(resGroup);
            }
            resGroup.setTaxDiff(resGroup.getTaxDiff().add(element.getTaxAmountDiff()));

            group.getElementList().remove(element);
        }
        if (!RES.contains(resGroup)) {
            RES.add(resGroup);
            if (!isturn) {
                matchPosition = matchPosition > 0 ? 0 : 1;
                group = map.get(matchPosition);
                if (CollectionUtils.isEmpty(group.elementList)) {
                    if (!RES.contains(resGroup)) {
                        RES.add(resGroup);
                    }
                    return resGroup;
                }
                isturn = true;
                return doMatch(map, group, matchPosition, resGroup, splitGroupLimit, RES);
            }
        }
        return resGroup;
    }

    @Data
    @AllArgsConstructor
    static class ResGroup {
        List<MinPackagePlugin2.Element> elementList;
        BigDecimal totalAmount;
        BigDecimal taxDiff;
        BigDecimal totalQuantity;
    }

    class Group {
        public Group(List<Element> elementList) {
            this.elementList = elementList;
        }

        List<Element> elementList;

        public void setElementList(List<Element> elementList) {
            this.elementList = elementList;
        }

        public List<Element> getElementList() {
            return elementList;
        }
    }

    @Getter
    @AllArgsConstructor
    class Element implements Comparable<Element> {
        BillItem billItem;
        BigDecimal amountWithoutTax;
        BigDecimal taxAmountDiff;
        BigDecimal amountWithTax;

        public BigDecimal getQuantity() {
            return billItem.getQuantity();
        }

        @Override
        public int compareTo(Element o) {
            return this.getAmountWithoutTax().compareTo(o.getAmountWithoutTax());
        }

    }

}
