/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.transport;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.handler.codec.MessageToMessageEncoder;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.EnumSet;
import java.util.List;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.metrics.ClientMessageSizeMetrics;
import org.apache.cassandra.transport.CBUtil;
import org.apache.cassandra.transport.Connection;
import org.apache.cassandra.transport.Message;
import org.apache.cassandra.transport.ProtocolException;
import org.apache.cassandra.transport.ProtocolVersion;
import org.apache.cassandra.transport.messages.ErrorMessage;
import org.apache.cassandra.utils.ByteBufferUtil;

public class Envelope {
    public static final byte PROTOCOL_VERSION_MASK = 127;
    public final Header header;
    public final ByteBuf body;

    public Envelope(Header header, ByteBuf body) {
        this.header = header;
        this.body = body;
    }

    public void retain() {
        this.body.retain();
    }

    public boolean release() {
        return this.body.release();
    }

    @VisibleForTesting
    public Envelope clone() {
        return new Envelope(this.header, Unpooled.wrappedBuffer(ByteBufferUtil.clone(this.body.nioBuffer())));
    }

    public static Envelope create(Message.Type type, int streamId, ProtocolVersion version, EnumSet<Header.Flag> flags, ByteBuf body) {
        Header header = new Header(version, flags, streamId, type, body.readableBytes());
        return new Envelope(header, body);
    }

    public ByteBuf encodeHeader() {
        ByteBuf buf = CBUtil.allocator.buffer(9);
        Message.Type type = this.header.type;
        buf.writeByte(type.direction.addToVersion(this.header.version.asInt()));
        buf.writeByte(Header.Flag.serialize(this.header.flags));
        if (this.header.version.isGreaterOrEqualTo(ProtocolVersion.V3)) {
            buf.writeShort(this.header.streamId);
        } else {
            buf.writeByte(this.header.streamId);
        }
        buf.writeByte(type.opcode);
        buf.writeInt(this.body.readableBytes());
        return buf;
    }

    public void encodeHeaderInto(ByteBuffer buf) {
        buf.put((byte)this.header.type.direction.addToVersion(this.header.version.asInt()));
        buf.put((byte)Header.Flag.serialize(this.header.flags));
        if (this.header.version.isGreaterOrEqualTo(ProtocolVersion.V3)) {
            buf.putShort((short)this.header.streamId);
        } else {
            buf.put((byte)this.header.streamId);
        }
        buf.put((byte)this.header.type.opcode);
        buf.putInt(this.body.readableBytes());
    }

    public void encodeInto(ByteBuffer buf) {
        this.encodeHeaderInto(buf);
        buf.put(this.body.nioBuffer());
    }

    public Envelope with(ByteBuf newBody) {
        return new Envelope(this.header, newBody);
    }

    private static long discard(ByteBuf buffer, long remainingToDiscard) {
        int availableToDiscard = (int)Math.min(remainingToDiscard, (long)buffer.readableBytes());
        buffer.skipBytes(availableToDiscard);
        return remainingToDiscard - (long)availableToDiscard;
    }

    @ChannelHandler.Sharable
    public static class Compressor
    extends MessageToMessageEncoder<Envelope> {
        public static Compressor instance = new Compressor();

        private Compressor() {
        }

        @Override
        public void encode(ChannelHandlerContext ctx, Envelope source, List<Object> results) throws IOException {
            Connection connection = ctx.channel().attr(Connection.attributeKey).get();
            if (source.header.type == Message.Type.STARTUP || connection == null) {
                results.add(source);
                return;
            }
            org.apache.cassandra.transport.Compressor compressor = connection.getCompressor();
            if (compressor == null) {
                results.add(source);
                return;
            }
            source.header.flags.add(Header.Flag.COMPRESSED);
            results.add(compressor.compress(source));
        }
    }

    @ChannelHandler.Sharable
    public static class Decompressor
    extends MessageToMessageDecoder<Envelope> {
        public static Decompressor instance = new Decompressor();

        private Decompressor() {
        }

        @Override
        public void decode(ChannelHandlerContext ctx, Envelope source, List<Object> results) throws IOException {
            Connection connection = ctx.channel().attr(Connection.attributeKey).get();
            if (!source.header.flags.contains((Object)Header.Flag.COMPRESSED) || connection == null) {
                results.add(source);
                return;
            }
            org.apache.cassandra.transport.Compressor compressor = connection.getCompressor();
            if (compressor == null) {
                results.add(source);
                return;
            }
            results.add(compressor.decompress(source));
        }
    }

    @ChannelHandler.Sharable
    public static class Encoder
    extends MessageToMessageEncoder<Envelope> {
        public static final Encoder instance = new Encoder();

        private Encoder() {
        }

        @Override
        public void encode(ChannelHandlerContext ctx, Envelope source, List<Object> results) {
            ByteBuf serializedHeader = source.encodeHeader();
            int messageSize = serializedHeader.readableBytes() + source.body.readableBytes();
            ClientMessageSizeMetrics.bytesSent.inc(messageSize);
            ClientMessageSizeMetrics.bytesSentPerResponse.update(messageSize);
            results.add(serializedHeader);
            results.add(source.body);
        }
    }

    public static class Decoder
    extends ByteToMessageDecoder {
        private static final int MAX_TOTAL_LENGTH = DatabaseDescriptor.getNativeTransportMaxFrameSize();
        private boolean discardingTooLongMessage;
        private long tooLongTotalLength;
        private long bytesToDiscard;
        private int tooLongStreamId;

        HeaderExtractionResult extractHeader(ByteBuffer buffer) {
            Preconditions.checkArgument(buffer.remaining() >= 9, "Undersized buffer supplied. Expected %s, actual %s", 9, buffer.remaining());
            int idx = buffer.position();
            byte firstByte = buffer.get(idx++);
            int versionNum = firstByte & 0x7F;
            byte flags = buffer.get(idx++);
            short streamId = buffer.getShort(idx);
            idx += 2;
            byte opcode = buffer.get(idx++);
            long bodyLength = buffer.getInt(idx);
            if (bodyLength < 0L) {
                return new HeaderExtractionResult.Error(new ProtocolException("Invalid value for envelope header body length field: " + bodyLength), (int)streamId, bodyLength);
            }
            Message.Direction direction = Message.Direction.extractFromVersion(firstByte);
            try {
                ProtocolVersion version = ProtocolVersion.decode(versionNum, DatabaseDescriptor.getNativeTransportAllowOlderProtocols());
                EnumSet<Header.Flag> decodedFlags = this.decodeFlags(version, flags);
                Message.Type type = Message.Type.fromOpcode(opcode, direction);
                return new HeaderExtractionResult.Success(new Header(version, decodedFlags, streamId, type, bodyLength));
            }
            catch (ProtocolException e) {
                return new HeaderExtractionResult.Error(e, (int)streamId, bodyLength);
            }
        }

        @VisibleForTesting
        Envelope decode(ByteBuf buffer) {
            Message.Type type;
            ProtocolVersion version;
            if (this.discardingTooLongMessage) {
                this.bytesToDiscard = Envelope.discard(buffer, this.bytesToDiscard);
                if (this.bytesToDiscard <= 0L) {
                    this.fail();
                }
                return null;
            }
            int readableBytes = buffer.readableBytes();
            if (readableBytes == 0) {
                return null;
            }
            int idx = buffer.readerIndex();
            byte firstByte = buffer.getByte(idx++);
            Message.Direction direction = Message.Direction.extractFromVersion(firstByte);
            int versionNum = firstByte & 0x7F;
            try {
                version = ProtocolVersion.decode(versionNum, DatabaseDescriptor.getNativeTransportAllowOlderProtocols());
            }
            catch (ProtocolException e) {
                buffer.skipBytes(readableBytes);
                throw e;
            }
            if (readableBytes < 9) {
                return null;
            }
            byte flags = buffer.getByte(idx++);
            EnumSet<Header.Flag> decodedFlags = this.decodeFlags(version, flags);
            short streamId = buffer.getShort(idx);
            idx += 2;
            try {
                type = Message.Type.fromOpcode(buffer.getByte(idx++), direction);
            }
            catch (ProtocolException e) {
                throw ErrorMessage.wrap(e, streamId);
            }
            long bodyLength = buffer.getUnsignedInt(idx);
            idx += 4;
            long totalLength = bodyLength + 9L;
            if (totalLength > (long)MAX_TOTAL_LENGTH) {
                this.discardingTooLongMessage = true;
                this.tooLongStreamId = streamId;
                this.tooLongTotalLength = totalLength;
                this.bytesToDiscard = Envelope.discard(buffer, totalLength);
                if (this.bytesToDiscard <= 0L) {
                    this.fail();
                }
                return null;
            }
            if ((long)buffer.readableBytes() < totalLength) {
                return null;
            }
            ClientMessageSizeMetrics.bytesReceived.inc(totalLength);
            ClientMessageSizeMetrics.bytesReceivedPerRequest.update(totalLength);
            ByteBuf body = buffer.slice(idx, (int)bodyLength);
            body.retain();
            idx = (int)((long)idx + bodyLength);
            buffer.readerIndex(idx);
            return new Envelope(new Header(version, decodedFlags, streamId, type, bodyLength), body);
        }

        private EnumSet<Header.Flag> decodeFlags(ProtocolVersion version, int flags) {
            EnumSet<Header.Flag> decodedFlags = Header.Flag.deserialize(flags);
            if (version.isBeta() && !decodedFlags.contains((Object)Header.Flag.USE_BETA)) {
                throw new ProtocolException(String.format("Beta version of the protocol used (%s), but USE_BETA flag is unset", version), version);
            }
            return decodedFlags;
        }

        @Override
        protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List<Object> results) {
            Envelope envelope = this.decode(buffer);
            if (envelope == null) {
                return;
            }
            results.add(envelope);
        }

        private void fail() {
            long tooLongTotalLength = this.tooLongTotalLength;
            this.tooLongTotalLength = 0L;
            this.discardingTooLongMessage = false;
            String msg = String.format("Request is too big: length %d exceeds maximum allowed length %d.", tooLongTotalLength, MAX_TOTAL_LENGTH);
            throw ErrorMessage.wrap(new InvalidRequestException(msg), this.tooLongStreamId);
        }

        public static abstract class HeaderExtractionResult {
            private final Outcome outcome;
            private final int streamId;
            private final long bodyLength;

            private HeaderExtractionResult(Outcome outcome, int streamId, long bodyLength) {
                this.outcome = outcome;
                this.streamId = streamId;
                this.bodyLength = bodyLength;
            }

            boolean isSuccess() {
                return this.outcome == Outcome.SUCCESS;
            }

            int streamId() {
                return this.streamId;
            }

            long bodyLength() {
                return this.bodyLength;
            }

            Header header() {
                throw new IllegalStateException(String.format("Unable to provide header from extraction result : %s", new Object[]{this.outcome}));
            }

            ProtocolException error() {
                throw new IllegalStateException(String.format("Unable to provide error from extraction result : %s", new Object[]{this.outcome}));
            }

            private static class Error
            extends HeaderExtractionResult {
                private final ProtocolException error;

                private Error(ProtocolException error, int streamId, long bodyLength) {
                    super(Outcome.ERROR, streamId, bodyLength);
                    this.error = error;
                }

                @Override
                ProtocolException error() {
                    return this.error;
                }
            }

            private static class Success
            extends HeaderExtractionResult {
                private final Header header;

                Success(Header header) {
                    super(Outcome.SUCCESS, header.streamId, header.bodySizeInBytes);
                    this.header = header;
                }

                @Override
                Header header() {
                    return this.header;
                }
            }

            static enum Outcome {
                SUCCESS,
                ERROR;

            }
        }
    }

    public static class Header {
        public static final int LENGTH = 9;
        public static final int BODY_LENGTH_SIZE = 4;
        public final ProtocolVersion version;
        public final EnumSet<Flag> flags;
        public final int streamId;
        public final Message.Type type;
        public final long bodySizeInBytes;

        private Header(ProtocolVersion version, EnumSet<Flag> flags, int streamId, Message.Type type, long bodySizeInBytes) {
            this.version = version;
            this.flags = flags;
            this.streamId = streamId;
            this.type = type;
            this.bodySizeInBytes = bodySizeInBytes;
        }

        public static enum Flag {
            COMPRESSED,
            TRACING,
            CUSTOM_PAYLOAD,
            WARNING,
            USE_BETA;

            private static final Flag[] ALL_VALUES;

            public static EnumSet<Flag> deserialize(int flags) {
                EnumSet<Flag> set = EnumSet.noneOf(Flag.class);
                for (int n = 0; n < ALL_VALUES.length; ++n) {
                    if ((flags & 1 << n) == 0) continue;
                    set.add(ALL_VALUES[n]);
                }
                return set;
            }

            public static int serialize(EnumSet<Flag> flags) {
                int i = 0;
                for (Flag flag : flags) {
                    i |= 1 << flag.ordinal();
                }
                return i;
            }

            static {
                ALL_VALUES = Flag.values();
            }
        }
    }
}

