package com.xforceplus.phoenix.split.service.dataflow;

import com.google.common.base.Objects;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import com.xforceplus.phoenix.split.constant.Tuple;
import com.xforceplus.phoenix.split.domain.ItemGroup;
import com.xforceplus.phoenix.split.domain.SplitGroupLimit;
import com.xforceplus.phoenix.split.exception.SplitBizException;
import com.xforceplus.phoenix.split.model.BillItem;
import org.springframework.util.CollectionUtils;

import java.math.BigDecimal;
import java.util.*;

/**
 * 最少发票张数回溯算法实现
 */
public class MinInvoiceService {

    /**
     * 额度限制
     */
    private BigDecimal limitAmount;

    /**
     * 额度限制是否是含税金额
     */
    private boolean limitIsAmountWithTax;


    /**
     * 行数限制
     */
    private int limitLine;


    private List<Integer> indexes;

    private int maxCount;

    private Multimap<Integer, InvoiceLimitState> state;

    protected MinInvoiceService(BigDecimal limitAmount,
                                boolean limitIsAmountWithTax,
                                int limitLine) {
        this.limitAmount = limitAmount;
        this.limitIsAmountWithTax = limitIsAmountWithTax;
        this.limitLine = limitLine;
    }


    public MinInvoiceService(SplitGroupLimit splitGroupLimit) {
        this(splitGroupLimit.getLimitAmount(),
                splitGroupLimit.isLimitIsAmountWithTax(),
                splitGroupLimit.getLimitLine());
    }


    /**
     * 按照最少发票张数合并发票
     *
     * @param invoices
     * @return
     */
    public List<List<BillItem>> minMergeInvoices(List<ItemGroup> invoices) {
        List<List<BillItem>> invoiceList = new ArrayList<>(invoices.size());
        for (ItemGroup itemGroup : invoices) {
            invoiceList.add(itemGroup.getBillItems());
        }
        List<List<BillItem>> result = new ArrayList<>(invoiceList.size());
        while (invoiceList.size() > 1) {
            Tuple<List<BillItem>, List<List<BillItem>>> tuple =
                    mergeInvoices(invoiceList);
            result.add(tuple.getFirst());
            invoiceList = tuple.getSecond();
        }
        if (invoiceList.size() == 1) {
            result.add(invoiceList.get(0));
        }
        return result;
    }

    private Tuple<List<BillItem>, List<List<BillItem>>> mergeInvoices(List<List<BillItem>> invoiceList) {
        maxCount = 0;
        indexes = new ArrayList<>();
        state = HashMultimap.create();
        InvoiceLimitState invoiceLimitState = new InvoiceLimitState(limitAmount, limitIsAmountWithTax, limitLine);
        List<Integer> recordIndexes = new LinkedList<>();
        exhaustiveInvoices(invoiceLimitState, 0, invoiceList, 0, recordIndexes);

        if (maxCount == 0) {
            throw new SplitBizException("最少张数发票创建失败,请联系中台排查");
        }
        Set<Integer> itemIndexesSet = new HashSet<>(indexes);
        List<BillItem> invoice = new LinkedList<>();
        List<List<BillItem>> leftInvoice = new ArrayList<>();

        for (int i = 0; i < invoiceList.size(); i++) {
            if (itemIndexesSet.contains(i)) {
                invoice.addAll(invoiceList.get(i));
            } else {
                leftInvoice.add(invoiceList.get(i));
            }
        }
        return new Tuple<>(invoice, leftInvoice);
    }

    private void exhaustiveInvoices(InvoiceLimitState invoiceLimitState, int index, List<List<BillItem>> invoiceList,
                                    int count, List<Integer> recordIndexes) {

        if (index == invoiceList.size() || !invoiceLimitState.canAdd()) {
            if (count > maxCount) {
                maxCount = count;
                this.indexes.clear();
                this.indexes.addAll(recordIndexes);
            }
            recordIndexes.clear();
            return;
        }
        if (state.get(index).contains(invoiceLimitState)) {
            return;
        }
        state.put(index, invoiceLimitState);


        if (invoiceLimitState.canAddInvoice(invoiceList.get(index))) {
            InvoiceLimitState newInvoiceLimitState = invoiceLimitState.copy();
            newInvoiceLimitState.addInvoice(invoiceList.get(index));
            List<Integer> newRecordIndex = new ArrayList<>(recordIndexes);
            newRecordIndex.add(index);
            exhaustiveInvoices(newInvoiceLimitState.copy(), index + 1, invoiceList,
                    count + 1, newRecordIndex);
        }

        exhaustiveInvoices(invoiceLimitState.copy(), index + 1, invoiceList,
                count, new ArrayList<>(recordIndexes));


    }


    public List<List<BillItem>> createMinInvoice(List<BillItem> toBeClassifiedBillItem) {
        if (CollectionUtils.isEmpty(toBeClassifiedBillItem)) {
            return Lists.newArrayList();
        }
        List<List<BillItem>> result = new LinkedList<>();
        //1.差额征税单独一组
        Iterator<BillItem> itemIterator = toBeClassifiedBillItem.iterator();
        while (itemIterator.hasNext()) {
            BillItem billItem = itemIterator.next();
            if (billItem.getDeductions().compareTo(BigDecimal.ZERO) > 0) {
                List<BillItem> invoiceItem = new ArrayList<>();
                invoiceItem.add(billItem);
                result.add(invoiceItem);
                itemIterator.remove();
            }
        }

        //2.遍历待分组明细
        while (toBeClassifiedBillItem.size() > 1) {
            //创建一张明细数最多的发票
            Tuple<List<BillItem>, List<BillItem>> tuple = createMaxItemInvoice(toBeClassifiedBillItem);
            result.add(tuple.getFirst());
            toBeClassifiedBillItem = tuple.getSecond();

        }


        if (toBeClassifiedBillItem.size() == 1) {
            result.add(Lists.newArrayList(toBeClassifiedBillItem.get(0)));
            toBeClassifiedBillItem.clear();
        }
        return result;
    }

    /**
     * 创建一张明细数最多的预制发票
     *
     * @param toBeClassifiedBillItem
     * @return
     */
    private Tuple<List<BillItem>, List<BillItem>> createMaxItemInvoice(List<BillItem> toBeClassifiedBillItem) {

        maxCount = 0;
        indexes = new ArrayList<>();
        state = HashMultimap.create();
        InvoiceLimitState invoiceLimitState = new InvoiceLimitState(limitAmount, limitIsAmountWithTax, limitLine);

        List<Integer> recordIndexes = new LinkedList<>();
        exhaustiveItems(invoiceLimitState, 0, toBeClassifiedBillItem, 0, recordIndexes);
        if (maxCount == 0) {
            throw new SplitBizException("最少张数发票创建失败,请联系中台排查");
        }

        //1.从待分类的明细中删除已经分类的明细
        Set<Integer> itemIndexesSet = new HashSet<>(indexes);
        List<BillItem> invoiceItems = new LinkedList<>();
        List<BillItem> leftItems = new ArrayList<>();
        for (int i = 0; i < toBeClassifiedBillItem.size(); i++) {
            if (itemIndexesSet.contains(i)) {
                invoiceItems.add(toBeClassifiedBillItem.get(i));
            } else {
                leftItems.add(toBeClassifiedBillItem.get(i));
            }
        }
        return new Tuple<>(invoiceItems, leftItems);
    }

    /**
     * 穷举所有符合条件的组合，并记录明细数最多的一种组合
     *
     * @param invoiceLimitState
     * @param index
     * @param toBeClassifiedBillItem
     * @param itemCount
     * @param recordIndexes
     */
    private void exhaustiveItems(InvoiceLimitState invoiceLimitState,
                                 int index,
                                 List<BillItem> toBeClassifiedBillItem,
                                 int itemCount,
                                 List<Integer> recordIndexes) {
        if (end(itemCount, index, invoiceLimitState, toBeClassifiedBillItem, recordIndexes)) {
            return;
        }

        if (state.get(index).contains(invoiceLimitState)) {
            return;
        }
        state.put(index, invoiceLimitState);


        //1.可以加入当前明细
        if (invoiceLimitState.canAddItem(toBeClassifiedBillItem.get(index))) {
            InvoiceLimitState newInvoiceLimitState = invoiceLimitState.copy();
            newInvoiceLimitState.addItem(toBeClassifiedBillItem.get(index));
            List<Integer> newRecordIndex = new ArrayList<>(recordIndexes);
            newRecordIndex.add(index);
            exhaustiveItems(newInvoiceLimitState.copy(), index + 1, toBeClassifiedBillItem,
                    itemCount + 1, newRecordIndex);
        }

        //2.不加入当前明细
        exhaustiveItems(invoiceLimitState.copy(), index + 1, toBeClassifiedBillItem,
                itemCount, new ArrayList<>(recordIndexes));
    }

    private <T> boolean end(int count, int index, InvoiceLimitState invoiceLimitState, List<T> data, List<Integer> recordIndexes) {
        if (index == data.size() || !invoiceLimitState.canAdd()) {
            if (count > maxCount) {
                maxCount = count;
                this.indexes.clear();
                this.indexes.addAll(recordIndexes);
            }
            recordIndexes.clear();
            return true;
        }
        return false;
    }


    /**
     * 发票状态
     */
    private class InvoiceLimitState {

        private final BigDecimal ERROR_AMOUNT = new BigDecimal("1.27");

        /**
         * 额度限制
         */
        private BigDecimal limitAmount;

        /**
         * 额度限制是否是含税金额
         */
        private boolean limitIsAmountWithTax;


        /**
         * 行数限制
         */
        private int limitLine;


        private BigDecimal currentTotalAmount;

        private BigDecimal currentTotalError;

        private int currentLine;


        InvoiceLimitState(BigDecimal limitAmount, boolean limitIsAmountWithTax, int limitLine) {
            this.limitAmount = limitAmount;
            this.limitIsAmountWithTax = limitIsAmountWithTax;
            this.limitLine = limitLine;

            this.currentTotalAmount = BigDecimal.ZERO;
            this.currentTotalError = BigDecimal.ZERO;
            this.currentLine = 0;
        }

        InvoiceLimitState() {
        }

        /**
         * 是否还能添加明细
         * 当前金额小于限额 &
         * 当前行数小于限制行数 &
         * 当前误差累计小于等于误差限制 则能加
         *
         * @return
         */
        boolean canAdd() {
            return currentTotalAmount.compareTo(limitAmount) <= 0 && currentLine < limitLine && currentTotalError.compareTo(ERROR_AMOUNT) <= 0;
        }

        boolean canAddItem(BillItem billItem) {
            BigDecimal amount = getBillItemAmount(billItem);
            BigDecimal errorAmount = getBillItemErrorAmount(billItem);
            int line = getBillItemLine(billItem);

            return currentTotalAmount.add(amount).compareTo(limitAmount) <= 0 &&
                    (currentLine + line) <= limitLine &&
                    currentTotalError.add(errorAmount).compareTo(this.ERROR_AMOUNT) <= 0;
        }

        void addItem(BillItem billItem) {
            BigDecimal amount = getBillItemAmount(billItem);
            BigDecimal errorAmount = getBillItemErrorAmount(billItem);
            int line = getBillItemLine(billItem);

            currentTotalAmount = currentTotalAmount.add(amount);
            currentLine += line;
            currentTotalError = currentTotalError.add(errorAmount);
        }


        /**
         * 是否能加多条明细
         */
        boolean canAddInvoice(List<BillItem> billItems) {


            BigDecimal totalAmount = BigDecimal.ZERO;
            BigDecimal totalErrorAmount = BigDecimal.ZERO;
            int totalLine = 0;
            for (BillItem billItem : billItems) {
                BigDecimal amount = getBillItemAmount(billItem);
                BigDecimal errorAmount = getBillItemErrorAmount(billItem);
                int line = getBillItemLine(billItem);
                totalAmount = totalAmount.add(amount);
                totalErrorAmount = totalErrorAmount.add(errorAmount);
                totalLine += line;

                if (!(currentTotalAmount.add(totalAmount).compareTo(limitAmount) <= 0 &&
                        (currentLine + totalLine) <= limitLine &&
                        currentTotalError.add(totalErrorAmount).compareTo(this.ERROR_AMOUNT) <= 0)) {
                    return false;
                }

            }

            return true;
        }

        /**
         * 加入多条明细
         */
        void addInvoice(List<BillItem> billItems) {
            billItems.forEach(this::addItem);
        }

        private int getBillItemLine(BillItem billItem) {
            if (billItem.getDiscountWithoutTax().compareTo(BigDecimal.ZERO) > 0) {
                return 2;
            }
            return 1;
        }

        private BigDecimal getBillItemErrorAmount(BillItem billItem) {
            return (billItem.getTaxAmount().subtract(billItem.getDiscountTax()).subtract(
                    (billItem.getAmountWithoutTax().subtract(billItem.getDiscountWithoutTax())).multiply(billItem.getTaxRate())
            ).abs());
        }

        private BigDecimal getBillItemAmount(BillItem billItem) {
            BigDecimal result;

            if (limitIsAmountWithTax) {
                result = billItem.getAmountWithTax().subtract(billItem.getDiscountWithTax());
            } else {
                result = billItem.getAmountWithoutTax().subtract(billItem.getDiscountWithoutTax());
            }

            return result;
        }


        InvoiceLimitState copy() {
            InvoiceLimitState copy = new InvoiceLimitState();
            copy.currentTotalAmount = new BigDecimal(this.currentTotalAmount.toString());
            copy.currentTotalError = new BigDecimal(this.currentTotalError.toString());
            copy.currentLine = this.currentLine;

            copy.limitLine = limitLine;
            copy.limitAmount = limitAmount;
            copy.limitIsAmountWithTax = limitIsAmountWithTax;

            return copy;

        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            InvoiceLimitState that = (InvoiceLimitState) o;
            return currentLine == that.currentLine &&
                    currentTotalError.compareTo(that.currentTotalError) == 0 &&
                    currentTotalAmount.compareTo(that.currentTotalAmount) == 0;
        }

        @Override
        public int hashCode() {
            return Objects.hashCode(currentTotalAmount, currentTotalError, currentLine);
        }


    }

}
