/*
 * Decompiled with CFR 0.152.
 */
package org.apache.shardingsphere.proxy.frontend.opengauss.authentication;

import com.google.common.base.Strings;
import java.security.MessageDigest;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedList;
import java.util.Locale;
import javax.crypto.Mac;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;
import lombok.Generated;
import org.apache.shardingsphere.db.protocol.postgresql.constant.PostgreSQLErrorCode;
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLPasswordMessagePacket;
import org.apache.shardingsphere.infra.executor.check.SQLCheckEngine;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.user.Grantee;
import org.apache.shardingsphere.infra.metadata.user.ShardingSphereUser;
import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
import org.apache.shardingsphere.proxy.frontend.postgresql.authentication.PostgreSQLLoginResult;

public final class OpenGaussAuthenticationHandler {
    private static final String PBKDF2_WITH_HMAC_SHA1_ALGORITHM = "PBKDF2WithHmacSHA1";
    private static final String HMAC_SHA256_ALGORITHM = "HmacSHA256";
    private static final String SHA256_ALGORITHM = "SHA-256";
    private static final String CLIENT_KEY = "Client Key";

    public static PostgreSQLLoginResult loginWithSCRAMSha256Password(String username, String databaseName, String salt, String nonce, int serverIteration, PostgreSQLPasswordMessagePacket passwordMessagePacket) {
        String clientDigest = passwordMessagePacket.getDigest();
        Grantee grantee = new Grantee(username, "%");
        if (!Strings.isNullOrEmpty((String)databaseName) && !ProxyContext.getInstance().databaseExists(databaseName)) {
            return new PostgreSQLLoginResult(PostgreSQLErrorCode.INVALID_CATALOG_NAME, String.format("database \"%s\" does not exist", databaseName));
        }
        if (!SQLCheckEngine.check((Grantee)grantee, OpenGaussAuthenticationHandler.getRules(databaseName))) {
            return new PostgreSQLLoginResult(PostgreSQLErrorCode.INVALID_AUTHORIZATION_SPECIFICATION, String.format("unknown username: %s", username));
        }
        if (!SQLCheckEngine.check((Grantee)grantee, (a, b) -> OpenGaussAuthenticationHandler.isPasswordRight((ShardingSphereUser)a, (Object[])b), (Object)new Object[]{clientDigest, salt, nonce, serverIteration}, OpenGaussAuthenticationHandler.getRules(databaseName))) {
            return new PostgreSQLLoginResult(PostgreSQLErrorCode.INVALID_PASSWORD, String.format("password authentication failed for user \"%s\"", username));
        }
        return null == databaseName || SQLCheckEngine.check((String)databaseName, OpenGaussAuthenticationHandler.getRules(databaseName), (Grantee)grantee) ? new PostgreSQLLoginResult(PostgreSQLErrorCode.SUCCESSFUL_COMPLETION, null) : new PostgreSQLLoginResult(PostgreSQLErrorCode.PRIVILEGE_NOT_GRANTED, String.format("Access denied for user '%s' to database '%s'", username, databaseName));
    }

    private static Collection<ShardingSphereRule> getRules(String databaseName) {
        LinkedList<ShardingSphereRule> result = new LinkedList<ShardingSphereRule>();
        if (!Strings.isNullOrEmpty((String)databaseName) && ProxyContext.getInstance().databaseExists(databaseName)) {
            result.addAll(((ShardingSphereDatabase)ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getDatabases().get(databaseName)).getRuleMetaData().getRules());
        }
        result.addAll(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getRules());
        return result;
    }

    private static boolean isPasswordRight(ShardingSphereUser user, Object[] args) {
        String h3HexString = (String)args[0];
        String salt = (String)args[1];
        String nonce = (String)args[2];
        int serverIteration = (Integer)args[3];
        byte[] serverStoredKey = OpenGaussAuthenticationHandler.calculatedStoredKey(user.getPassword(), salt, serverIteration);
        byte[] h3 = OpenGaussAuthenticationHandler.hexStringToBytes(h3HexString);
        byte[] h2 = OpenGaussAuthenticationHandler.calculateH2(user.getPassword(), salt, nonce, serverIteration);
        byte[] clientCalculatedStoredKey = OpenGaussAuthenticationHandler.sha256(OpenGaussAuthenticationHandler.xor(h3, h2));
        return Arrays.equals(clientCalculatedStoredKey, serverStoredKey);
    }

    private static byte[] calculatedStoredKey(String password, String salt, int serverIteration) {
        byte[] k = OpenGaussAuthenticationHandler.generateKFromPBKDF2(password, salt, serverIteration);
        byte[] clientKey = OpenGaussAuthenticationHandler.getKeyFromHmac(k, CLIENT_KEY.getBytes());
        return OpenGaussAuthenticationHandler.sha256(clientKey);
    }

    private static byte[] calculateH2(String password, String salt, String nonce, int serverIteration) {
        byte[] k = OpenGaussAuthenticationHandler.generateKFromPBKDF2(password, salt, serverIteration);
        byte[] clientKey = OpenGaussAuthenticationHandler.getKeyFromHmac(k, CLIENT_KEY.getBytes());
        byte[] storedKey = OpenGaussAuthenticationHandler.sha256(clientKey);
        return OpenGaussAuthenticationHandler.getKeyFromHmac(storedKey, OpenGaussAuthenticationHandler.hexStringToBytes(nonce));
    }

    private static byte[] generateKFromPBKDF2(String password, String saltString, int serverIteration) {
        char[] chars = password.toCharArray();
        byte[] salt = OpenGaussAuthenticationHandler.hexStringToBytes(saltString);
        PBEKeySpec spec = new PBEKeySpec(chars, salt, serverIteration, 256);
        SecretKeyFactory skf = SecretKeyFactory.getInstance(PBKDF2_WITH_HMAC_SHA1_ALGORITHM);
        return skf.generateSecret(spec).getEncoded();
    }

    private static byte[] hexStringToBytes(String rawHexString) {
        if (null == rawHexString || rawHexString.isEmpty()) {
            return new byte[0];
        }
        String hexString = rawHexString.toUpperCase(Locale.ENGLISH);
        int length = hexString.length() / 2;
        char[] hexChars = hexString.toCharArray();
        byte[] result = new byte[length];
        for (int i = 0; i < length; ++i) {
            int pos = i * 2;
            result[i] = (byte)(OpenGaussAuthenticationHandler.charToByte(hexChars[pos]) << 4 | OpenGaussAuthenticationHandler.charToByte(hexChars[pos + 1]));
        }
        return result;
    }

    private static byte charToByte(char c) {
        return (byte)"0123456789ABCDEF".indexOf(c);
    }

    private static byte[] sha256(byte[] str) {
        MessageDigest md = MessageDigest.getInstance(SHA256_ALGORITHM);
        md.update(str);
        return md.digest();
    }

    private static byte[] getKeyFromHmac(byte[] key, byte[] data) {
        SecretKeySpec signingKey = new SecretKeySpec(key, HMAC_SHA256_ALGORITHM);
        Mac mac = Mac.getInstance(HMAC_SHA256_ALGORITHM);
        mac.init(signingKey);
        return mac.doFinal(data);
    }

    private static byte[] xor(byte[] password1, byte[] password2) {
        if (password1.length != password2.length) {
            throw new IllegalArgumentException("Xor values with different length");
        }
        int length = password1.length;
        byte[] result = new byte[length];
        for (int i = 0; i < length; ++i) {
            result[i] = (byte)(password1[i] ^ password2[i]);
        }
        return result;
    }

    @Generated
    private OpenGaussAuthenticationHandler() {
    }
}

