ValidationBitStreamReader.java

package zserio.runtime.validation;

import java.io.IOException;

import zserio.runtime.io.ByteArrayBitStreamReader;

/**
 * The bit stream reader used in validate() method generated by Zserio.
 * <p>
 * The validation bit stream reader records
 * - a special 'AND' bit mask (invalidMaskBuffer) which contains '1' if corresponding bit in read buffer is not
 *   used for validation and
 * - a special 'OR' bit mask (nanMaskBuffer) which constains normalized NaN values (NaNs used by writer).</p>
 * <p>
 * Bits in reader do not have to be used for validation in the following situations:
 * 1. If the bits have been skipped (not read) which can happen using Zserio 'align' command for example.
 * 2. If NaN occurs. NaNs do not have unique binary representation and they must be normalized to NaNs used by
 *    writer.</p>
 */
public final class ValidationBitStreamReader extends ByteArrayBitStreamReader
{
    /**
     * Constructs a new ValidationBitStreamReader.
     *
     * @param bytes     Array of bytes to construct from.
     */
    public ValidationBitStreamReader(final byte[] bytes)
    {
        super(bytes);
        invalidMaskBuffer = new byte[bytes.length];
        invalidMaskBufferEraser = new MaskBufferEraser(invalidMaskBuffer);
        invalidMaskBufferSetter = new MaskBufferSetter(invalidMaskBuffer);

        nanMaskBuffer = new byte[bytes.length];
        nanMaskBufferEraser = new MaskBufferEraser(nanMaskBuffer);
    }

    /**
     * Returns the read byte array with unused bits set to zero and with 'normalized' NaNs.
     *
     * 'Normalized' NaNs means NaNs with the same binary format as it is used by Java BitStreamWriter.
     *
     * @return Read byte array with masked unused bits.
     */
    public byte[] toMaskedByteArray()
    {
        final byte[] maskedByteArray = new byte[invalidMaskBuffer.length];
        final int endBytePosition = getBytePosition();
        for (int i = 0; i < endBytePosition; ++i)
            maskedByteArray[i] = (byte)(getBuffer()[i] & ~invalidMaskBuffer[i] | nanMaskBuffer[i]);

        final int endBitOffset = getBitOffset();
        if (endBitOffset > 0)
        {
            // endBitOffset |  1   |  2   |  3   |  4   |  5   |  6   |  7   |
            // -------------|------|------|------|------|------|------|------|
            // mask         | 0x7F | 0x3F | 0x1F | 0x0F | 0x07 | 0x03 | 0x01 |
            final int mask = (1 << (8 - endBitOffset)) - 1;
            final int invalidMask = invalidMaskBuffer[endBytePosition] | mask;
            maskedByteArray[endBytePosition] =
                    (byte)(getBuffer()[endBytePosition] & ~invalidMask | nanMaskBuffer[endBytePosition]);
        }

        return maskedByteArray;
    }

    @Override
    public void setBitPosition(long bitPosition) throws IOException
    {
        final int startBytePosition = getBytePosition();
        final int startBitOffset = getBitOffset();
        final long startBitPosition = getBitPosition();
        super.setBitPosition(bitPosition);

        final int endBytePosition = getBytePosition();
        final int endBitOffset = getBitOffset();
        final int requieredBufferLength = endBytePosition + ((endBitOffset != 0) ? 1 : 0);
        if (requieredBufferLength > invalidMaskBuffer.length)
            throw new IOException("ValidationBitStreamReader: Unable to set bit position to " + bitPosition +
                    ". It is beyond end of stream.");

        final long endBitPosition = getBitPosition();
        // this checking is neccesary because both positions can be end of stream
        if (endBitPosition != startBitPosition)
        {
            if (endBitPosition > startBitPosition)
            {
                modifyMaskBuffer(startBytePosition, startBitOffset, endBytePosition, endBitOffset,
                        invalidMaskBufferSetter);
            }
            else
            {
                modifyMaskBuffer(endBytePosition, endBitOffset, startBytePosition, startBitOffset,
                        invalidMaskBufferEraser);
                modifyMaskBuffer(
                        endBytePosition, endBitOffset, startBytePosition, startBitOffset, nanMaskBufferEraser);
            }
        }
    }

    @Override
    public float readFloat16() throws IOException
    {
        final int startBytePosition = getBytePosition();
        final int startBitOffset = getBitOffset();
        final float readFloat = super.readFloat16();
        if (Float.isNaN(readFloat))
        {
            modifyMaskBuffer(startBytePosition, startBitOffset, getBytePosition(), getBitOffset(),
                    invalidMaskBufferSetter);
            setNanMaskBuffer(startBytePosition, startBitOffset);
        }

        return readFloat;
    }

    private void modifyMaskBuffer(int startBytePosition, int startBitOffset, int endBytePosition,
            int endBitOffset, MaskBufferAction action) throws IOException
    {
        if (endBytePosition == startBytePosition)
        {
            action.modifyBits(startBytePosition, startBitOffset, endBitOffset - startBitOffset);
        }
        else
        {
            action.modifyBits(startBytePosition, startBitOffset, 8 - startBitOffset);

            for (int bytePosition = startBytePosition + 1; bytePosition < endBytePosition; ++bytePosition)
                action.modifyByte(bytePosition);

            if (endBitOffset > 0)
                action.modifyBits(endBytePosition, 0, endBitOffset);
        }
    }

    private void setNanMaskBuffer(int startBytePosition, int startBitOffset)
    {
        // set NaN used by writer (0x7E00) to nanMaskBuffer
        final byte nanHighByte = (byte)0x7E;
        nanMaskBuffer[startBytePosition] |= (nanHighByte >> startBitOffset);
        if (startBitOffset != 0)
            nanMaskBuffer[startBytePosition + 1] |= (nanHighByte << (8 - startBitOffset));
    }

    private interface MaskBufferAction
    {
        public void modifyBits(int bytePosition, int bitOffset, int numBits);
        public void modifyByte(int bytePosition);
    }

    private static class MaskBufferSetter implements MaskBufferAction
    {
        public MaskBufferSetter(byte[] maskBuffer)
        {
            this.maskBuffer = maskBuffer;
        }

        @Override
        public void modifyBits(int bytePosition, int bitOffset, int numBits)
        {
            // numBits |  0   |  1   |  2   |  3   |  4   |  5   |  6   |  7   |  8   |
            // --------|------|------|------|------|------|------|------|------|------|
            // mask    | 0x00 | 0x01 | 0x03 | 0x07 | 0x0F | 0x1F | 0x3F | 0x7F | 0xFF |
            final int mask = (1 << numBits) - 1;
            maskBuffer[bytePosition] |= mask << (8 - bitOffset - numBits);
        }

        @Override
        public void modifyByte(int bytePosition)
        {
            maskBuffer[bytePosition] = (byte)0xFF;
        }

        private final byte[] maskBuffer;
    }

    private static class MaskBufferEraser implements MaskBufferAction
    {
        public MaskBufferEraser(byte[] maskBuffer)
        {
            this.maskBuffer = maskBuffer;
        }

        @Override
        public void modifyBits(int bytePosition, int bitOffset, int numBits)
        {
            // numBits |  0   |  1   |  2   |  3   |  4   |  5   |  6   |  7   |  8   |
            // --------|------|------|------|------|------|------|------|------|------|
            // mask    | 0xFF | 0xFE | 0xFC | 0xF8 | 0xF0 | 0xE0 | 0xC0 | 0x80 | 0x00 |
            final int mask = ~((1 << numBits) - 1);
            maskBuffer[bytePosition] &= mask << (8 - bitOffset - numBits);
        }

        @Override
        public void modifyByte(int bytePosition)
        {
            maskBuffer[bytePosition] = (byte)0x00;
        }

        private final byte[] maskBuffer;
    }

    private final byte[] invalidMaskBuffer;
    private final MaskBufferEraser invalidMaskBufferEraser;
    private final MaskBufferSetter invalidMaskBufferSetter;

    private final byte[] nanMaskBuffer;
    private final MaskBufferEraser nanMaskBufferEraser;
}