/*
 * Decompiled with CFR 0.152.
 */
package com.idrsolutions.image.scale;

import com.idrsolutions.image.BitReader;
import com.idrsolutions.image.scale.ChannelImage;
import com.idrsolutions.image.scale.Holder;
import com.idrsolutions.image.scale.ImagePlane;
import com.idrsolutions.image.scale.Tables;
import java.awt.Graphics;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.awt.image.DataBufferInt;
import java.awt.image.IndexColorModel;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;

public class SuperResolution {
    private static final boolean debug = true;
    private static final int BLOCK_SIZE = 128;
    private static final int OVERLAP = 14;

    private static ImagePlane normalize(ChannelImage image) {
        int width = image.width;
        int height = image.height;
        ImagePlane imagePlane = new ImagePlane(width, height);
        for (int i = 0; i < image.buffer.length; ++i) {
            imagePlane.setValueIndexed(i, image.buffer[i] / 255.0f);
        }
        return imagePlane;
    }

    private static ChannelImage denormalize(ImagePlane imagePlane) {
        ChannelImage image = new ChannelImage(imagePlane.width, imagePlane.height);
        for (int i = 0; i < imagePlane.buffer.length; ++i) {
            image.buffer[i] = imagePlane.buffer[i] * 255.0f;
        }
        return image;
    }

    private static ChannelImage[] channelDecompose(BufferedImage image) {
        int width = image.getWidth();
        int height = image.getHeight();
        ChannelImage imageR = new ChannelImage(width, height);
        ChannelImage imageG = new ChannelImage(width, height);
        ChannelImage imageB = new ChannelImage(width, height);
        ChannelImage imageA = null;
        switch (image.getType()) {
            case 12: {
                SuperResolution.decomposeBinary(image, imageR, imageG, imageB);
                break;
            }
            case 10: {
                SuperResolution.decomposeGRAY(image, imageR, imageG, imageB);
                break;
            }
            case 13: {
                SuperResolution.decomposeBYTE_INDEXED(image, imageR, imageG, imageB);
                break;
            }
            case 5: {
                SuperResolution.decomposeBYTE_BGR(image, imageR, imageG, imageB);
                break;
            }
            case 6: 
            case 7: {
                imageA = new ChannelImage(width, height);
                SuperResolution.decomposeBYTE_ABGR(image, imageR, imageG, imageB, imageA);
                break;
            }
            case 4: {
                SuperResolution.decomposeINT_BGR(image, imageR, imageG, imageB);
                break;
            }
            case 1: {
                SuperResolution.decomposeINT_RGB(image, imageR, imageG, imageB);
                break;
            }
            case 2: 
            case 3: {
                imageA = new ChannelImage(width, height);
                SuperResolution.decomposeINT_ARGB(image, imageR, imageG, imageB, imageA);
                break;
            }
            default: {
                BufferedImage image2 = new BufferedImage(image.getWidth(), image.getHeight(), 5);
                Graphics g2 = image2.getGraphics();
                g2.drawImage(image, 0, 0, null);
                g2.dispose();
                SuperResolution.decomposeBYTE_BGR(image2, imageR, imageG, imageB);
            }
        }
        return new ChannelImage[]{imageR, imageG, imageB, imageA};
    }

    private static void decomposeBinary(BufferedImage image, ChannelImage imageR, ChannelImage imageG, ChannelImage imageB) {
        int width = imageR.width;
        int height = imageR.height;
        byte[] data = ((DataBufferByte)image.getRaster().getDataBuffer()).getData();
        int p = 0;
        int bps = image.getColorModel().getPixelSize();
        int iw8 = width * bps % 8;
        BitReader br = new BitReader(data);
        IndexColorModel model = (IndexColorModel)image.getColorModel();
        byte[] rr = new byte[1 << bps];
        byte[] gg = new byte[1 << bps];
        byte[] bb = new byte[1 << bps];
        model.getReds(rr);
        model.getGreens(gg);
        model.getBlues(bb);
        for (int h = 0; h < height; ++h) {
            for (int w = 0; w < width; ++w) {
                int v = br.readBits(bps);
                imageR.buffer[p] = rr[v] & 0xFF;
                imageG.buffer[p] = gg[v] & 0xFF;
                imageB.buffer[p] = bb[v] & 0xFF;
                ++p;
            }
            if (iw8 == 0) continue;
            br.readBits(8 - iw8);
        }
    }

    private static void decomposeGRAY(BufferedImage image, ChannelImage imageR, ChannelImage imageG, ChannelImage imageB) {
        int dim = imageR.width * imageR.height;
        byte[] data = ((DataBufferByte)image.getRaster().getDataBuffer()).getData();
        for (int d = 0; d < dim; ++d) {
            float f;
            imageR.buffer[d] = f = (float)(data[d] & 0xFF);
            imageG.buffer[d] = f;
            imageB.buffer[d] = f;
        }
    }

    private static void decomposeBYTE_INDEXED(BufferedImage image, ChannelImage imageR, ChannelImage imageG, ChannelImage imageB) {
        int dim = imageR.width * imageR.height;
        byte[] data = ((DataBufferByte)image.getRaster().getDataBuffer()).getData();
        IndexColorModel model = (IndexColorModel)image.getColorModel();
        byte[] rr = new byte[256];
        byte[] gg = new byte[256];
        byte[] bb = new byte[256];
        model.getReds(rr);
        model.getGreens(gg);
        model.getBlues(bb);
        for (int d = 0; d < dim; ++d) {
            int v = data[d] & 0xFF;
            imageR.buffer[d] = rr[v] & 0xFF;
            imageG.buffer[d] = gg[v] & 0xFF;
            imageB.buffer[d] = bb[v] & 0xFF;
        }
    }

    private static void decomposeBYTE_BGR(BufferedImage image, ChannelImage imageR, ChannelImage imageG, ChannelImage imageB) {
        int p = 0;
        int c = 0;
        int dim = imageR.width * imageR.height;
        byte[] data = ((DataBufferByte)image.getRaster().getDataBuffer()).getData();
        for (int d = 0; d < dim; ++d) {
            imageB.buffer[c] = data[p++] & 0xFF;
            imageG.buffer[c] = data[p++] & 0xFF;
            imageR.buffer[c] = data[p++] & 0xFF;
            ++c;
        }
    }

    private static void decomposeBYTE_ABGR(BufferedImage image, ChannelImage imageR, ChannelImage imageG, ChannelImage imageB, ChannelImage imageA) {
        int p = 0;
        int dim = imageR.width * imageR.height;
        byte[] data = ((DataBufferByte)image.getRaster().getDataBuffer()).getData();
        for (int d = 0; d < dim; ++d) {
            imageA.buffer[d] = data[p++] & 0xFF;
            imageB.buffer[d] = data[p++] & 0xFF;
            imageG.buffer[d] = data[p++] & 0xFF;
            imageR.buffer[d] = data[p++] & 0xFF;
        }
    }

    private static void decomposeINT_BGR(BufferedImage image, ChannelImage imageR, ChannelImage imageG, ChannelImage imageB) {
        int dim = imageR.width * imageR.height;
        int[] data = ((DataBufferInt)image.getRaster().getDataBuffer()).getData();
        for (int d = 0; d < dim; ++d) {
            int v = data[d];
            imageR.buffer[d] = v & 0xFF;
            imageG.buffer[d] = v >> 8 & 0xFF;
            imageB.buffer[d] = v >> 16 & 0xFF;
        }
    }

    private static void decomposeINT_RGB(BufferedImage image, ChannelImage imageR, ChannelImage imageG, ChannelImage imageB) {
        int dim = imageR.width * imageR.height;
        int[] data = ((DataBufferInt)image.getRaster().getDataBuffer()).getData();
        for (int d = 0; d < dim; ++d) {
            int v = data[d];
            imageR.buffer[d] = v >> 16 & 0xFF;
            imageG.buffer[d] = v >> 8 & 0xFF;
            imageB.buffer[d] = v & 0xFF;
        }
    }

    private static void decomposeINT_ARGB(BufferedImage image, ChannelImage imageR, ChannelImage imageG, ChannelImage imageB, ChannelImage imageA) {
        int dim = imageR.width * imageR.height;
        int[] data = ((DataBufferInt)image.getRaster().getDataBuffer()).getData();
        for (int d = 0; d < dim; ++d) {
            int v = data[d];
            imageA.buffer[d] = v >> 24 & 0xFF;
            imageR.buffer[d] = v >> 16 & 0xFF;
            imageG.buffer[d] = v >> 8 & 0xFF;
            imageB.buffer[d] = v & 0xFF;
        }
    }

    private static BufferedImage channelCompose(ChannelImage[] channels, int imageType) {
        BufferedImage image;
        ChannelImage imageR = channels[0];
        int p = 0;
        int c = 0;
        int dim = imageR.width * imageR.height;
        switch (imageType) {
            case 10: {
                image = new BufferedImage(imageR.width, imageR.height, imageType);
                byte[] data = ((DataBufferByte)image.getRaster().getDataBuffer()).getData();
                for (int d = 0; d < dim; ++d) {
                    float r = imageR.buffer[d];
                    data[d] = (byte)(r < 0.0f ? 0.0f : (r > 255.0f ? 255.0f : r));
                }
                break;
            }
            default: {
                ChannelImage imageG = channels[1];
                ChannelImage imageB = channels[2];
                if (channels[3] != null) {
                    ChannelImage imageA = channels[3];
                    image = new BufferedImage(imageR.width, imageR.height, 6);
                    byte[] data = ((DataBufferByte)image.getRaster().getDataBuffer()).getData();
                    int hh = imageR.height;
                    for (int h = 0; h < hh; ++h) {
                        int ww = imageR.width;
                        for (int w = 0; w < ww; ++w) {
                            float a = imageA.buffer[c];
                            float b = imageB.buffer[c];
                            float g = imageG.buffer[c];
                            float r = imageR.buffer[c];
                            data[p++] = (byte)(a < 0.0f ? 0 : (byte)(a > 255.0f ? -1 : (byte)a));
                            data[p++] = (byte)(b < 0.0f ? 0 : (byte)(b > 255.0f ? -1 : (byte)b));
                            data[p++] = (byte)(g < 0.0f ? 0 : (byte)(g > 255.0f ? -1 : (byte)g));
                            data[p++] = (byte)(r < 0.0f ? 0 : (byte)(r > 255.0f ? -1 : (byte)r));
                            ++c;
                        }
                    }
                } else {
                    image = new BufferedImage(imageR.width, imageR.height, 5);
                    byte[] data = ((DataBufferByte)image.getRaster().getDataBuffer()).getData();
                    int hh = imageR.height;
                    for (int h = 0; h < hh; ++h) {
                        int ww = imageR.width;
                        for (int w = 0; w < ww; ++w) {
                            float b = imageB.buffer[c];
                            float g = imageG.buffer[c];
                            float r = imageR.buffer[c];
                            data[p++] = (byte)(b < 0.0f ? 0 : (byte)(b > 255.0f ? -1 : (byte)b));
                            data[p++] = (byte)(g < 0.0f ? 0 : (byte)(g > 255.0f ? -1 : (byte)g));
                            data[p++] = (byte)(r < 0.0f ? 0 : (byte)(r > 255.0f ? -1 : (byte)r));
                            ++c;
                        }
                    }
                }
                break;
            }
        }
        return image;
    }

    private static ChannelImage extrapolation(ChannelImage ci) {
        int px = 7;
        int height = ci.height;
        int width = ci.width;
        ChannelImage result = new ChannelImage(width + 14, height + 14);
        int hh = height + 14;
        for (int h = 0; h < hh; ++h) {
            int ww = width + 14;
            for (int w = 0; w < ww; ++w) {
                int index = w + h * ww;
                if (w < 7) {
                    if (h < 7) {
                        result.buffer[index] = ci.buffer[0];
                        continue;
                    }
                    if (7 + height <= h) {
                        result.buffer[index] = ci.buffer[(height - 1) * width];
                        continue;
                    }
                    result.buffer[index] = ci.buffer[(h - 7) * width];
                    continue;
                }
                if (7 + width <= w) {
                    if (h < 7) {
                        result.buffer[index] = ci.buffer[width - 1];
                        continue;
                    }
                    if (7 + height <= h) {
                        result.buffer[index] = ci.buffer[width - 1 + (height - 1) * width];
                        continue;
                    }
                    result.buffer[index] = ci.buffer[width - 1 + (h - 7) * width];
                    continue;
                }
                result.buffer[index] = h < 7 ? ci.buffer[w - 7] : (7 + height <= h ? ci.buffer[w - 7 + (height - 1) * width] : ci.buffer[w - 7 + (h - 7) * width]);
            }
        }
        return result;
    }

    private static ChannelImage resize2x(ChannelImage ci) {
        int dw = ci.width * 2;
        int dh = ci.height * 2;
        ChannelImage scaled_image = new ChannelImage(dw, dh);
        for (int h = 0; h < dh; ++h) {
            for (int w = 0; w < dw; ++w) {
                int scaled_index = w + h * dw;
                int w_orig = Math.round((float)(w + 1) / 2.0f) - 1;
                int h_orig = Math.round((float)(h + 1) / 2.0f) - 1;
                int index_orig = w_orig + h_orig * ci.width;
                scaled_image.buffer[scaled_index] = ci.buffer[index_orig];
            }
        }
        return scaled_image;
    }

    private static Holder blocking(ImagePlane[] initialPlanes) {
        int widthInput = initialPlanes[0].width;
        int heightInput = initialPlanes[0].height;
        int blocksW = (int)Math.ceil((1.0 * (double)widthInput - 14.0) / 114.0);
        int blocksH = (int)Math.ceil((1.0 * (double)heightInput - 14.0) / 114.0);
        int blocks = blocksW * blocksH;
        ImagePlane[][] inputBlocks = new ImagePlane[blocks][];
        for (int b = 0; b < blocks; ++b) {
            int n;
            int blockIndexW = b % blocksW;
            int blockIndexH = b / blocksW;
            int blockWidth = blockIndexW == blocksW - 1 ? widthInput - 114 * blockIndexW : 128;
            int blockHeight = blockIndexH == blocksH - 1 ? heightInput - 114 * blockIndexH : 128;
            ImagePlane[] channels = new ImagePlane[initialPlanes.length];
            for (n = 0; n < initialPlanes.length; ++n) {
                channels[n] = new ImagePlane(blockWidth, blockHeight);
            }
            for (int w = 0; w < blockWidth; ++w) {
                for (int h = 0; h < blockHeight; ++h) {
                    for (n = 0; n < initialPlanes.length; ++n) {
                        int targetIndexW = blockIndexW * 114 + w;
                        int targetIndexH = blockIndexH * 114 + h;
                        ImagePlane channel = initialPlanes[n];
                        float v = channel.getValue(targetIndexW, targetIndexH);
                        channels[n].setValue(w, h, v);
                    }
                }
            }
            inputBlocks[b] = channels;
        }
        Holder hold = new Holder();
        hold.inputBlocks = inputBlocks;
        hold.blockW = blocksW;
        hold.blockH = blocksH;
        return hold;
    }

    private static ImagePlane[] deblocking(ImagePlane[][] outputBlocks, int blocksW, int blocksH, int nComp) {
        int block_size = outputBlocks[0][0].width;
        int width = SuperResolution.getWidth(outputBlocks, blocksW);
        int height = SuperResolution.getHeight(outputBlocks, blocksW, blocksH);
        ImagePlane[] outputPlanes = new ImagePlane[nComp];
        for (int b = 0; b < outputBlocks.length; ++b) {
            ImagePlane[] block = outputBlocks[b];
            int blockIndexW = b % blocksW;
            int blockIndexH = b / blocksW;
            for (int n = 0; n < block.length; ++n) {
                if (outputPlanes[n] == null) {
                    outputPlanes[n] = new ImagePlane(width, height);
                }
                ImagePlane channelBlock = block[n];
                int p = 0;
                for (int h = 0; h < channelBlock.height; ++h) {
                    for (int w = 0; w < channelBlock.width; ++w) {
                        int targetIndexW = blockIndexW * block_size + w;
                        int targetIndexH = blockIndexH * block_size + h;
                        int targetIndex = targetIndexH * width + targetIndexW;
                        outputPlanes[n].setValueIndexed(targetIndex, channelBlock.buffer[p++]);
                    }
                }
            }
        }
        return outputPlanes;
    }

    private static int getHeight(ImagePlane[][] outputBlocks, int blocksW, int blocksH) {
        int height = 0;
        for (int b = 0; b < blocksW * blocksH; b += blocksW) {
            height += outputBlocks[b][0].height;
        }
        return height;
    }

    private static int getWidth(ImagePlane[][] outputBlocks, int blocksW) {
        int width = 0;
        for (int b = 0; b < blocksW; ++b) {
            width += outputBlocks[b][0].width;
        }
        return width;
    }

    private static ImagePlane[] convolution(ImagePlane[] inputPlanes, float[] W, int nOutputPlane, float[] bias, float[] sums) {
        int pp;
        float[] iPlane;
        int width = inputPlanes[0].width;
        int height = inputPlanes[0].height;
        ImagePlane[] outputPlanes = new ImagePlane[nOutputPlane];
        for (int o = 0; o < nOutputPlane; ++o) {
            outputPlanes[o] = new ImagePlane(width - 2, height - 2);
        }
        float[][] widths = new float[inputPlanes.length][];
        for (int j = 0; j < inputPlanes.length; ++j) {
            iPlane = new float[nOutputPlane * 9];
            int aa = 0;
            for (int k = 0; k < nOutputPlane; ++k) {
                pp = (k * inputPlanes.length + j) * 9;
                for (int l = 0; l < 9; ++l) {
                    iPlane[aa++] = W[pp++];
                }
            }
            widths[j] = iPlane;
        }
        int ii = inputPlanes.length;
        int hh = height - 2;
        for (int h = 0; h < hh; ++h) {
            int hw = h * width;
            int h1w = (h + 1) * width;
            int h2w = (h + 2) * width;
            int ww = width - 2;
            for (int w = 0; w < ww; ++w) {
                float v;
                int n;
                System.arraycopy(bias, 0, sums, 0, nOutputPlane);
                int a1 = w + hw;
                int a2 = a1 + 1;
                int a3 = a2 + 1;
                int a4 = w + h1w;
                int a5 = a4 + 1;
                int a6 = a5 + 1;
                int a7 = w + h2w;
                int a8 = a7 + 1;
                int a9 = a8 + 1;
                for (int i = 0; i < ii; ++i) {
                    iPlane = inputPlanes[i].buffer;
                    float i00 = iPlane[a1];
                    float i10 = iPlane[a2];
                    float i20 = iPlane[a3];
                    float i01 = iPlane[a4];
                    float i11 = iPlane[a5];
                    float i21 = iPlane[a6];
                    float i02 = iPlane[a7];
                    float i12 = iPlane[a8];
                    float i22 = iPlane[a9];
                    float[] xx = widths[i];
                    pp = 0;
                    n = 0;
                    while (n < nOutputPlane) {
                        v = i00 * xx[pp++] + i10 * xx[pp++] + i20 * xx[pp++] + i01 * xx[pp++] + i11 * xx[pp++] + i21 * xx[pp++] + i02 * xx[pp++] + i12 * xx[pp++] + i22 * xx[pp++];
                        int n2 = n++;
                        sums[n2] = sums[n2] + v;
                    }
                }
                for (n = 0; n < nOutputPlane; ++n) {
                    v = sums[n];
                    if (v < 0.0f) {
                        v *= 0.1f;
                    }
                    outputPlanes[n].setValue(w, h, v);
                }
            }
        }
        return outputPlanes;
    }

    private static List<float[]> typeW() {
        ArrayList<float[]> W = new ArrayList<float[]>();
        for (int i = 0; i < 7; ++i) {
            float[] result = new float[]{};
            switch (i + 1) {
                case 1: {
                    result = Tables.WEIGHT1;
                    break;
                }
                case 7: {
                    result = Tables.WEIGHT7;
                    break;
                }
                default: {
                    try {
                        String str = SuperResolution.readWidths(String.valueOf(i + 1));
                        String[] values = str.split(",");
                        result = new float[values.length];
                        for (int j = 0; j < values.length; ++j) {
                            double res = Double.parseDouble(values[j]);
                            result[j] = (float)res;
                        }
                        break;
                    }
                    catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            }
            W.add(result);
        }
        return W;
    }

    private static String readWidths(String path) throws IOException {
        SuperResolution sr = new SuperResolution();
        InputStream is = sr.getClass().getResourceAsStream("/com/idrsolutions/image/res/w" + path + ".txt");
        try (ByteArrayOutputStream buffer = new ByteArrayOutputStream();){
            int nRead;
            byte[] data = new byte[4096];
            while ((nRead = is.read(data, 0, data.length)) != -1) {
                buffer.write(data, 0, nRead);
            }
            buffer.flush();
            String string = new String(buffer.toByteArray());
            return string;
        }
    }

    public static BufferedImage scale2x(BufferedImage image) {
        List<float[]> W = SuperResolution.typeW();
        int iw = image.getWidth();
        int ih = image.getHeight();
        int nComp = 3;
        ChannelImage[] rgba = SuperResolution.channelDecompose(image);
        ImagePlane[] inputPlanes = new ImagePlane[nComp];
        for (int i = 0; i < nComp; ++i) {
            rgba[i] = SuperResolution.resize2x(rgba[i]);
            rgba[i] = SuperResolution.extrapolation(rgba[i]);
            inputPlanes[i] = SuperResolution.normalize(rgba[i]);
        }
        if (rgba[3] != null) {
            rgba[3] = SuperResolution.resizeAlpha(rgba[3], iw, ih);
        }
        Holder holder = SuperResolution.blocking(inputPlanes);
        ImagePlane[][] outputBlocks = new ImagePlane[holder.inputBlocks.length][];
        float[] sum1 = new float[3];
        float[] sum2 = new float[32];
        float[] sum3 = new float[64];
        float[] sum4 = new float[128];
        int bb = holder.inputBlocks.length;
        for (int b = 0; b < bb; ++b) {
            ImagePlane[] inputBlock = holder.inputBlocks[b];
            ImagePlane[] outputBlock = null;
            for (int l = 0; l < 7; ++l) {
                float[] bias;
                float[] sums;
                int nOutputPlane = Tables.NO[l];
                switch (l + 1) {
                    case 1: {
                        sums = sum2;
                        bias = Tables.BIAS1;
                        break;
                    }
                    case 2: {
                        sums = sum2;
                        bias = Tables.BIAS2;
                        break;
                    }
                    case 3: {
                        sums = sum3;
                        bias = Tables.BIAS3;
                        break;
                    }
                    case 4: {
                        sums = sum3;
                        bias = Tables.BIAS4;
                        break;
                    }
                    case 5: {
                        sums = sum4;
                        bias = Tables.BIAS5;
                        break;
                    }
                    case 6: {
                        sums = sum4;
                        bias = Tables.BIAS6;
                        break;
                    }
                    default: {
                        sums = sum1;
                        bias = Tables.BIAS7;
                    }
                }
                outputBlock = SuperResolution.convolution(inputBlock, W.get(l), nOutputPlane, bias, sums);
                inputBlock = outputBlock;
                holder.inputBlocks[b] = null;
            }
            outputBlocks[b] = outputBlock;
            System.out.println("Processed " + (int)(100.0 * (double)(b + 1) / (double)bb) + '%');
        }
        ImagePlane[] outputPlanes = SuperResolution.deblocking(outputBlocks, holder.blockW, holder.blockH, nComp);
        for (int i = 0; i < nComp; ++i) {
            rgba[i] = SuperResolution.denormalize(outputPlanes[i]);
        }
        return SuperResolution.channelCompose(rgba, image.getType());
    }

    private static ChannelImage resizeAlpha(ChannelImage ch, int iw, int ih) {
        BufferedImage img = new BufferedImage(iw, ih, 10);
        byte[] pix = ((DataBufferByte)img.getRaster().getDataBuffer()).getData();
        int ii = pix.length;
        for (int i = 0; i < ii; ++i) {
            pix[i] = (byte)ch.buffer[i];
        }
        img = SuperResolution.scale2x(img);
        pix = ((DataBufferByte)img.getRaster().getDataBuffer()).getData();
        ChannelImage res = new ChannelImage(iw * 2, ih * 2);
        int ii2 = pix.length;
        for (int i = 0; i < ii2; ++i) {
            res.buffer[i] = pix[i] & 0xFF;
        }
        return res;
    }
}

