package com.xforceplus.query;

import com.xforceplus.api.model.CompanyModel.Request.*;
import com.xforceplus.entity.*;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.data.jpa.domain.Specification;

import javax.persistence.Tuple;
import javax.persistence.criteria.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

@SuppressWarnings("all")
public class CompanyQueryHelper {
    public static Specification<Company> querySpecification(Query query) {
        Specification<Company> specification = (Specification<Company>) (root, criteriaQuery, builder) -> {
            return toPredicate(query, root, criteriaQuery, builder);
        };
        return specification;
    }

    public static Predicate queryTuplePredicate(Query query, Root<Company> root, CriteriaQuery<Tuple> criteriaQuery, CriteriaBuilder builder) {
        return toPredicate(query, root, criteriaQuery, builder);
    }

    public static Predicate queryCountPredicate(Query query, Root<Company> root, CriteriaQuery<Long> criteriaQuery, CriteriaBuilder builder) {
        return toPredicate(query, root, criteriaQuery, builder);
    }

    private static <T> Predicate toPredicate(Query query, Root<Company> root, CriteriaQuery<T> criteriaQuery, CriteriaBuilder builder) {
        List<Predicate> predicates = new ArrayList<>();
        Class<T> resultType = criteriaQuery.getResultType();
        boolean isCount = resultType.isAssignableFrom(Long.class);
        boolean joinTable = false;
        ListJoin<Company, OrgStruct> joinOrgs = null;
        if(query.getMultipleTenants() != null) {
            if (resultType.isAssignableFrom(Tuple.class)) {
                criteriaQuery = criteriaQuery.multiselect(root.alias("company"));
            }
            joinOrgs = root.joinList("orgs", JoinType.LEFT);
            Expression<Long> countExpression = builder.count(joinOrgs.<Long>get("orgId"));
            Predicate havingPredicate;
            if (query.getMultipleTenants()) {
                havingPredicate = builder.gt(countExpression, 1);
            } else {
                havingPredicate = builder.equal(countExpression, 1);
            }

            criteriaQuery.having(havingPredicate);
            joinTable = true;
        }
        if ((query.getTenantId() != null && query.getTenantId() > 0) || StringUtils.isNotBlank(query.getTenantName()) || query.getMultipleTenants() != null) {
            if(joinOrgs == null) {
                joinOrgs = root.joinList("orgs", JoinType.LEFT);
            }
            if (query.getTenantId() != null && query.getTenantId() > 0) {
                predicates.add(builder.equal(joinOrgs.<Long>get("tenantId"), query.getTenantId()));
            }
            if (StringUtils.isNotBlank(query.getTenantName())) {
                Join<OrgStruct, Tenant> joinTenant = joinOrgs.join("tenant", JoinType.LEFT);
                predicates.add(builder.like(joinTenant.get("tenantName"), query.getTenantName() + "%"));
                if (query.getStatus() != null && query.getStatus() == 1) {
                    predicates.add(builder.equal(joinTenant.get("status"), 1));
                }
            }
            if (query.getStatus() != null && query.getStatus() == 1) {
                predicates.add(builder.equal(joinOrgs.get("status"), 1));
            }
            joinTable = true;
        }

        if (StringUtils.isNotBlank(query.getCompanyCode())) {
            Set<String> companyCodes = Arrays.stream(StringUtils.split(query.getCompanyCode(), ","))
                    .filter(StringUtils::isNotBlank).collect(Collectors.toSet());
            if(companyCodes != null && !companyCodes.isEmpty()) {
                if(companyCodes.size() == 1) {
                    predicates.add(builder.equal(root.<String>get("companyCode"), companyCodes.stream().findFirst().get()));
                } else {
                    predicates.add(root.<String>get("companyCode").in(companyCodes));
                }
            }
        } else if (ArrayUtils.isNotEmpty(query.getCompanyCodes())) {
            Set<String> companyCodes = Arrays.stream(query.getCompanyCodes())
                    .filter(StringUtils::isNotBlank).collect(Collectors.toSet());
            if(companyCodes.size() == 1) {
                predicates.add(builder.equal(root.<String>get("companyCode"), companyCodes.stream().findFirst().get()));
            } else {
                predicates.add(root.<String>get("companyCode").in(companyCodes));
            }
        }

        if (query.getCompanyId() != null && query.getCompanyId() > 0) {
            predicates.add(builder.equal(root.<Long>get("companyId"), query.getCompanyId()));
        } else if (ArrayUtils.isNotEmpty(query.getCompanyIds())) {//添加CompanyIds 导出功能
            if(query.getCompanyIds().length == 1) {
                predicates.add(builder.equal(root.<Long>get("companyId"), Arrays.stream(query.getCompanyIds()).findFirst().get()));
            } else {
                predicates.add(root.<Long>get("companyId").in(Arrays.stream(query.getCompanyIds()).collect(Collectors.toSet())));
            }
        }

        if (StringUtils.isNotBlank(query.getTaxNum())) {
            Set<String> taxNums = Arrays.stream(StringUtils.split(query.getTaxNum(), ","))
                    .filter(StringUtils::isNotBlank).collect(Collectors.toSet());
            if(taxNums != null && !taxNums.isEmpty()) {
                if(taxNums.size() == 1) {
                    predicates.add(builder.equal(root.<String>get("taxNum"), taxNums.stream().findFirst().get()));
                } else {
                    predicates.add(root.<String>get("taxNum").in(taxNums));
                }
            }
        }
        if (StringUtils.isNotBlank(query.getCompanyName())) {
            predicates.add(builder.like(root.get("companyName"), query.getCompanyName() + "%"));
        }

        if (query.getInspectionServiceFlag() != null && query.getInspectionServiceFlag() > 0) {
            predicates.add(builder.equal(root.<Integer>get("inspectionServiceFlag"), query.getInspectionServiceFlag()));
        }
        if (query.getSpeedInspectionChannelFlag() != null && query.getSpeedInspectionChannelFlag() > 0) {
            predicates.add(builder.equal(root.<Integer>get("speedInspectionChannelFlag"), query.getSpeedInspectionChannelFlag()));
        }
        if (query.getTraditionAuthenFlag() != null && query.getTraditionAuthenFlag() > 0) {
            predicates.add(builder.equal(root.<Integer>get("traditionAuthenFlag"), query.getTraditionAuthenFlag()));
        }
        if (query.getStatus() != null) {
            predicates.add(builder.equal(root.<Integer>get("status"), query.getStatus()));
        }
        if (!predicates.isEmpty()) {
            criteriaQuery.where(predicates.stream().toArray(Predicate[]::new));
        }
        Predicate predicate;
        if (joinTable) {
            if (isCount) {
                criteriaQuery = criteriaQuery.distinct(true);
            } else {
                criteriaQuery = criteriaQuery.groupBy(root.<Long>get("companyId"));
            }
            predicate = criteriaQuery.getGroupRestriction();
        } else {
            predicate = criteriaQuery.getRestriction();
        }
        return predicate;
    }

    public static Specification<Company> queryOneSpecification(Query query) {
        Specification<Company> specification = (Specification<Company>) (root, criteriaQuery, builder) -> {
            List<Predicate> predicates = new ArrayList<>();
            if ((query.getTenantId() != null && query.getTenantId() > 0) || StringUtils.isNotBlank(query.getTenantName())) {
                ListJoin<Company, OrgStruct> joinOrgs = root.joinList("orgs", JoinType.LEFT);
                if (query.getTenantId() != null && query.getTenantId() > 0) {
                    predicates.add(builder.equal(joinOrgs.<Long>get("tenantId"), query.getTenantId()));
                }
                if (StringUtils.isNotBlank(query.getTenantName())) {
                    Join<OrgStruct, Tenant> joinTenant = joinOrgs.join("tenant", JoinType.LEFT);
                    predicates.add(builder.equal(joinTenant.<String>get("tenantName"), query.getTenantName()));
                }

                criteriaQuery.groupBy(root.<Long>get("companyId"));
            }
            if (query.getCompanyId() != null && query.getCompanyId() > 0) {
                predicates.add(builder.equal(root.<Long>get("companyId"), query.getCompanyId()));
            }
            if (StringUtils.isNotBlank(query.getCompanyCode()) || StringUtils.isNotBlank(query.getTaxNum()) || StringUtils.isNotBlank(query.getCompanyName())) {
                Predicate predicate = builder.disjunction();
                if (StringUtils.isNotBlank(query.getCompanyCode())) {
                    predicate = builder.or(predicate, builder.equal(root.<String>get("companyCode"), query.getCompanyCode()));
                }
                if (StringUtils.isNotBlank(query.getTaxNum())) {
                    predicate = builder.or(predicate, builder.equal(root.<String>get("taxNum"), query.getTaxNum()));
                }
                //region TODO delete it
                if (StringUtils.isNotBlank(query.getCompanyName())) {
                    predicate = builder.or(predicate, builder.equal(root.<String>get("companyName"), query.getCompanyName()));
                }
                //endregion
                if (predicate != null) {
                    predicates.add(predicate);
                }
            }
            if (query.getStatus() != null) {
                predicates.add(builder.equal(root.<Integer>get("status"), query.getStatus()));
            }
            if (predicates.isEmpty()) {
                throw new IllegalArgumentException("查询参数不合法");
            } else {
                criteriaQuery.where(predicates.stream().toArray(Predicate[]::new));
            }
            return criteriaQuery.getRestriction();
        };
        return specification;
    }
}
