/*
 * Decompiled with CFR 0.152.
 */
package bt.net.crypto;

import bt.metainfo.TorrentId;
import bt.net.BigIntegers;
import bt.net.ByteChannelReader;
import bt.net.Peer;
import bt.net.crypto.EncryptedChannel;
import bt.net.crypto.MSEKeyPairGenerator;
import bt.protocol.DecodingContext;
import bt.protocol.Handshake;
import bt.protocol.Message;
import bt.protocol.Protocols;
import bt.protocol.crypto.EncryptionPolicy;
import bt.protocol.crypto.MSECipher;
import bt.protocol.handler.MessageHandler;
import bt.runtime.Config;
import bt.torrent.TorrentDescriptor;
import bt.torrent.TorrentRegistry;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.ReadableByteChannel;
import java.security.KeyPair;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.util.Arrays;
import java.util.Optional;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MSEHandshakeProcessor {
    private static final Logger LOGGER = LoggerFactory.getLogger(MSEHandshakeProcessor.class);
    private static final Duration receiveTimeout = Duration.ofSeconds(10L);
    private static final Duration waitBetweenReads = Duration.ofSeconds(1L);
    private static final int paddingMaxLength = 512;
    private static final byte[] VC_RAW_BYTES = new byte[8];
    private final MSEKeyPairGenerator keyGenerator;
    private final TorrentRegistry torrentRegistry;
    private final MessageHandler<Message> protocol;
    private final EncryptionPolicy localEncryptionPolicy;
    private final boolean mseDisabled;

    public MSEHandshakeProcessor(TorrentRegistry torrentRegistry, MessageHandler<Message> protocol, Config config) {
        boolean mseDisabled;
        this.localEncryptionPolicy = config.getEncryptionPolicy();
        int msePrivateKeySize = config.getMsePrivateKeySize();
        boolean bl = mseDisabled = !MSECipher.isKeySizeSupported(msePrivateKeySize);
        if (mseDisabled) {
            String message = String.format("Current Bt runtime is configured to use private key size of %d bytes for Message Stream Encryption (MSE), and the preferred encryption policy is %s. The aforementioned key size is not allowed in the current JDK configuration. Hence, MSE encryption negotiation procedure will NOT be used", msePrivateKeySize, this.localEncryptionPolicy.name());
            String postfix = " To fix this problem, please do one of the following: (a) update your JDK or Java runtime environment settings for unlimited cryptography support; (b) specify a different private key size (not recommended)";
            switch (this.localEncryptionPolicy) {
                case REQUIRE_PLAINTEXT: 
                case PREFER_PLAINTEXT: 
                case PREFER_ENCRYPTED: {
                    message = message + ", and all peer connections will be established in plaintext by using the standard BitTorrent handshake. This may negatively affect the number of peers, which can be connected to." + postfix;
                    LOGGER.warn(message);
                    break;
                }
                case REQUIRE_ENCRYPTED: {
                    message = message + ", and considering the requirement for mandatory encryption, this effectively means, that no peer connections will ever be established." + postfix + "; (c) choose a more permissive encryption policy";
                    throw new IllegalStateException(message);
                }
                default: {
                    throw new IllegalStateException("Unknown encryption policy: " + this.localEncryptionPolicy.name());
                }
            }
        }
        this.mseDisabled = mseDisabled;
        this.keyGenerator = new MSEKeyPairGenerator(msePrivateKeySize);
        this.torrentRegistry = torrentRegistry;
        this.protocol = protocol;
    }

    public Optional<MSECipher> negotiateOutgoing(Peer peer, ByteChannel channel, TorrentId torrentId, ByteBuffer in, ByteBuffer out) throws IOException {
        int theirPadding;
        int missing;
        byte[] encryptedVC;
        if (this.mseDisabled) {
            return Optional.empty();
        }
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace("Negotiating encryption for outgoing connection: {}", (Object)peer);
        }
        ByteChannelReader reader = this.reader(channel);
        KeyPair keys = this.keyGenerator.generateKeyPair();
        out.put(keys.getPublic().getEncoded());
        out.put(this.getPadding(512));
        out.flip();
        channel.write(out);
        out.clear();
        int phase1Min = this.keyGenerator.getPublicKeySize();
        int phase1Limit = phase1Min + 512;
        int phase1Read = reader.readBetween(phase1Min, phase1Limit).read(in);
        in.flip();
        BigInteger peerPublicKey = BigIntegers.decodeUnsigned(in, phase1Min);
        in.clear();
        BigInteger S = this.keyGenerator.calculateSharedSecret(peerPublicKey, keys.getPrivate());
        MessageDigest digest = this.getDigest("SHA-1");
        digest.update("req1".getBytes("ASCII"));
        digest.update(BigIntegers.encodeUnsigned(S, this.keyGenerator.getPublicKeySize()));
        out.put(digest.digest());
        digest.update("req2".getBytes("ASCII"));
        digest.update(torrentId.getBytes());
        byte[] b1 = digest.digest();
        digest.update("req3".getBytes("ASCII"));
        digest.update(BigIntegers.encodeUnsigned(S, this.keyGenerator.getPublicKeySize()));
        byte[] b2 = digest.digest();
        out.put(this.xor(b1, b2));
        out.flip();
        channel.write(out);
        out.clear();
        byte[] Sbytes = BigIntegers.encodeUnsigned(S, 96);
        MSECipher cipher = MSECipher.forInitiator(Sbytes, torrentId);
        EncryptedChannel encryptedChannel = new EncryptedChannel(channel, cipher.getDecryptionCipher(), cipher.getEncryptionCipher());
        out.put(VC_RAW_BYTES);
        out.put(this.getCryptoProvideBitfield(this.localEncryptionPolicy));
        byte[] padding = this.getZeroPadding(512);
        out.put(Protocols.getShortBytes(padding.length));
        out.put(padding);
        out.putShort((short)0);
        out.flip();
        encryptedChannel.write(out);
        out.clear();
        MSECipher throwawayCipher = MSECipher.forInitiator(Sbytes, torrentId);
        try {
            encryptedVC = throwawayCipher.getDecryptionCipher().doFinal(VC_RAW_BYTES);
        }
        catch (Exception e) {
            throw new RuntimeException("Failed to encrypt VC", e);
        }
        int phase2Min = encryptedVC.length + 4 + 2;
        int phase2Limit = phase1Limit - phase1Read + phase2Min + 512;
        int initpos = in.position();
        int phase2Read = reader.readBetween(phase2Min, phase2Limit).sync(in, encryptedVC);
        int matchpos = in.position();
        ByteChannelReader encryptedReader = this.reader(encryptedChannel);
        in.limit(initpos + phase2Read);
        cipher.getDecryptionCipher().update(new byte[VC_RAW_BYTES.length]);
        byte[] encryptedData = new byte[in.remaining()];
        in.get(encryptedData);
        in.position(matchpos);
        byte[] decryptedData = cipher.getDecryptionCipher().update(encryptedData);
        in.put(decryptedData);
        in.position(matchpos);
        if (in.remaining() < phase2Min - encryptedVC.length) {
            int lim = in.limit();
            in.limit(in.capacity());
            int read = encryptedReader.readAtLeast(phase2Min - encryptedVC.length).readNoMoreThan(phase2Min - encryptedVC.length + 512).read(in);
            in.position(matchpos);
            in.limit(lim + read);
        }
        byte[] crypto_select = new byte[4];
        in.get(crypto_select);
        EncryptionPolicy negotiatedEncryptionPolicy = this.selectPolicy(crypto_select, this.localEncryptionPolicy);
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace("Negotiated encryption policy: {}, peer: {}", (Object)negotiatedEncryptionPolicy, (Object)peer);
        }
        if ((missing = (theirPadding = in.getShort() & 0xFFFF) - in.remaining()) > 0) {
            int pos = in.position();
            in.limit(in.capacity());
            encryptedReader.readAtLeast(missing).read(in);
            in.flip();
            in.position(pos);
        }
        in.position(in.position() + theirPadding);
        in.compact();
        out.clear();
        switch (negotiatedEncryptionPolicy) {
            case REQUIRE_PLAINTEXT: 
            case PREFER_PLAINTEXT: {
                return Optional.empty();
            }
            case PREFER_ENCRYPTED: 
            case REQUIRE_ENCRYPTED: {
                return Optional.of(cipher);
            }
        }
        throw new IllegalStateException("Unknown encryption policy: " + negotiatedEncryptionPolicy.name());
    }

    public Optional<MSECipher> negotiateIncoming(Peer peer, ByteChannel channel, ByteBuffer in, ByteBuffer out) throws IOException {
        int theirPadding;
        if (this.mseDisabled) {
            return Optional.empty();
        }
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace("Negotiating encryption for incoming connection: {}", (Object)peer);
        }
        ByteChannelReader reader = this.reader(channel);
        int phase0Min = 20;
        int phase0Read = reader.readAtLeast(phase0Min).read(in);
        in.flip();
        DecodingContext context = new DecodingContext(peer);
        int consumed = 0;
        try {
            consumed = this.protocol.decode(context, in);
        }
        catch (Exception exception) {
            // empty catch block
        }
        if (consumed > 0 && context.getMessage() instanceof Handshake) {
            this.assertPolicyIsCompatible(EncryptionPolicy.REQUIRE_PLAINTEXT);
            return Optional.empty();
        }
        int phase1Min = this.keyGenerator.getPublicKeySize();
        int phase1Limit = phase1Min + 512;
        in.limit(in.capacity());
        in.position(phase0Read);
        int phase1Read = phase0Read < phase1Min ? reader.readAtLeast(phase1Min - phase0Read).readNoMoreThan(phase1Limit - phase0Read).read(in) : 0;
        in.flip();
        BigInteger peerPublicKey = BigIntegers.decodeUnsigned(in, this.keyGenerator.getPublicKeySize());
        in.clear();
        KeyPair keys = this.keyGenerator.generateKeyPair();
        out.put(keys.getPublic().getEncoded());
        out.put(this.getPadding(512));
        out.flip();
        channel.write(out);
        out.clear();
        BigInteger S = this.keyGenerator.calculateSharedSecret(peerPublicKey, keys.getPrivate());
        int phase2Min = 56;
        int phase2Limit = 568;
        MessageDigest digest = this.getDigest("SHA-1");
        byte[] bytes = new byte[20];
        digest.update("req1".getBytes("ASCII"));
        digest.update(BigIntegers.encodeUnsigned(S, this.keyGenerator.getPublicKeySize()));
        byte[] req1hash = digest.digest();
        int phase2Read = reader.readAtLeast(phase2Min).readNoMoreThan(phase2Limit + (phase1Limit - (phase0Read + phase1Read))).sync(in, req1hash);
        in.limit(phase1Read + phase2Read);
        in.get(bytes);
        TorrentId requestedTorrent = null;
        digest.update("req3".getBytes("ASCII"));
        digest.update(BigIntegers.encodeUnsigned(S, this.keyGenerator.getPublicKeySize()));
        byte[] b2 = digest.digest();
        for (TorrentId torrentId : this.torrentRegistry.getTorrentIds()) {
            digest.update("req2".getBytes("ASCII"));
            digest.update(torrentId.getBytes());
            byte[] b1 = digest.digest();
            if (!Arrays.equals(this.xor(b1, b2), bytes)) continue;
            requestedTorrent = torrentId;
            break;
        }
        if (requestedTorrent == null) {
            throw new IllegalStateException("Unsupported torrent requested");
        }
        Optional<TorrentDescriptor> descriptor = this.torrentRegistry.getDescriptor(requestedTorrent);
        if (descriptor.isPresent() && !descriptor.get().isActive()) {
            throw new IllegalStateException("Inactive torrent requested: " + requestedTorrent);
        }
        byte[] Sbytes = BigIntegers.encodeUnsigned(S, 96);
        MSECipher cipher = MSECipher.forReceiver(Sbytes, requestedTorrent);
        EncryptedChannel encryptedChannel = new EncryptedChannel(channel, cipher.getDecryptionCipher(), cipher.getEncryptionCipher());
        ByteChannelReader encryptedReader = this.reader(encryptedChannel);
        int pos = in.position();
        byte[] leftovers = new byte[in.remaining()];
        in.get(leftovers);
        in.position(pos);
        try {
            in.put(cipher.getDecryptionCipher().update(leftovers));
        }
        catch (Exception e) {
            throw new RuntimeException("Failed to decrypt leftover bytes: " + leftovers.length);
        }
        in.position(pos);
        byte[] theirVC = new byte[8];
        in.get(theirVC);
        if (!Arrays.equals(VC_RAW_BYTES, theirVC)) {
            throw new IllegalStateException("Invalid VC: " + Arrays.toString(theirVC));
        }
        byte[] crypto_provide = new byte[4];
        in.get(crypto_provide);
        EncryptionPolicy negotiatedEncryptionPolicy = this.selectPolicy(crypto_provide, this.localEncryptionPolicy);
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace("Negotiated encryption policy: {}, peer: {}", (Object)negotiatedEncryptionPolicy, (Object)peer);
        }
        if ((theirPadding = in.getShort() & 0xFFFF) > 512) {
            throw new IllegalStateException("Padding is too long: " + theirPadding);
        }
        int position = in.position();
        if (in.remaining() < theirPadding) {
            in.limit(in.capacity());
            in.position(phase1Read + phase2Read);
            encryptedReader.readAtLeast(theirPadding - in.remaining() + 2).read(in);
            in.flip();
            in.position(position);
        }
        in.position(position + theirPadding);
        int initialPayloadLength = in.getShort() & 0xFFFF;
        in.compact();
        out.put(VC_RAW_BYTES);
        out.put(this.getCryptoProvideBitfield(negotiatedEncryptionPolicy));
        byte[] padding = this.getZeroPadding(512);
        out.putShort((short)padding.length);
        out.put(padding);
        out.flip();
        encryptedChannel.write(out);
        out.clear();
        switch (negotiatedEncryptionPolicy) {
            case REQUIRE_PLAINTEXT: 
            case PREFER_PLAINTEXT: {
                return Optional.empty();
            }
            case PREFER_ENCRYPTED: 
            case REQUIRE_ENCRYPTED: {
                return Optional.of(cipher);
            }
        }
        throw new IllegalStateException("Unknown encryption policy: " + negotiatedEncryptionPolicy.name());
    }

    private ByteChannelReader reader(ReadableByteChannel channel) {
        return ByteChannelReader.forChannel(channel).withTimeout(receiveTimeout).waitBetweenReads(waitBetweenReads);
    }

    private void assertPolicyIsCompatible(EncryptionPolicy peerEncryptionPolicy) {
        if (!this.localEncryptionPolicy.isCompatible(peerEncryptionPolicy)) {
            throw new RuntimeException("Encryption policies are incompatible: peer's (" + peerEncryptionPolicy.name() + "), local (" + this.localEncryptionPolicy.name() + ")");
        }
    }

    private byte[] getPadding(int length) {
        Random r = new Random();
        byte[] padding = new byte[r.nextInt(length + 1)];
        for (int i = 0; i < padding.length; ++i) {
            padding[i] = (byte)r.nextInt(256);
        }
        return padding;
    }

    private byte[] getZeroPadding(int length) {
        Random r = new Random();
        return new byte[r.nextInt(length + 1)];
    }

    private MessageDigest getDigest(String algorithm) {
        try {
            return MessageDigest.getInstance(algorithm);
        }
        catch (NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
    }

    private byte[] xor(byte[] b1, byte[] b2) {
        if (b1.length != b2.length) {
            throw new IllegalStateException("Lengths do not match: " + b1.length + ", " + b2.length);
        }
        byte[] result = new byte[b1.length];
        for (int i = 0; i < b1.length; ++i) {
            result[i] = (byte)(b1[i] ^ b2[i]);
        }
        return result;
    }

    private byte[] getCryptoProvideBitfield(EncryptionPolicy encryptionPolicy) {
        byte[] crypto_provide = new byte[4];
        switch (encryptionPolicy) {
            case REQUIRE_PLAINTEXT: {
                crypto_provide[3] = 1;
                break;
            }
            case PREFER_PLAINTEXT: 
            case PREFER_ENCRYPTED: {
                crypto_provide[3] = 3;
                break;
            }
            case REQUIRE_ENCRYPTED: {
                crypto_provide[3] = 2;
                break;
            }
        }
        return crypto_provide;
    }

    private EncryptionPolicy selectPolicy(byte[] crypto_provide, EncryptionPolicy localEncryptionPolicy) {
        boolean plaintextProvided = (crypto_provide[3] & 1) == 1;
        boolean encryptionProvided = (crypto_provide[3] & 2) == 2;
        EncryptionPolicy selected = null;
        if (plaintextProvided || encryptionProvided) {
            switch (localEncryptionPolicy) {
                case REQUIRE_PLAINTEXT: {
                    if (!plaintextProvided) break;
                    selected = EncryptionPolicy.REQUIRE_PLAINTEXT;
                    break;
                }
                case PREFER_PLAINTEXT: {
                    selected = plaintextProvided ? EncryptionPolicy.REQUIRE_PLAINTEXT : EncryptionPolicy.REQUIRE_ENCRYPTED;
                    break;
                }
                case PREFER_ENCRYPTED: {
                    selected = encryptionProvided ? EncryptionPolicy.REQUIRE_ENCRYPTED : EncryptionPolicy.REQUIRE_PLAINTEXT;
                    break;
                }
                case REQUIRE_ENCRYPTED: {
                    if (!encryptionProvided) break;
                    selected = EncryptionPolicy.REQUIRE_ENCRYPTED;
                    break;
                }
                default: {
                    throw new IllegalStateException("Unknown encryption policy: " + localEncryptionPolicy.name());
                }
            }
        }
        if (selected == null) {
            throw new IllegalStateException("Failed to negotiate the encryption policy: local policy (" + localEncryptionPolicy.name() + "), peer's policy (" + Arrays.toString(crypto_provide) + ")");
        }
        return selected;
    }
}

