#include "handle_image.h"

#include <iostream>
#include <vector>
#include <string>
#include <fstream>
#include <algorithm>
#include <cmath>

#include <jpeglib.h>
#include <png.h>
#include <zlib.h>

#ifdef _WIN32
#include <windows.h>

#include "specific_os.h"
#endif

#include "calculator.h"


void HandleImage::handleImage(
    const std::filesystem::path& imagePath,
    ImageData& imageData,
    const std::array<int, 2>& paperSize,
    const int& outputWidth,
    const int& outputQuality
    ) {
    std::string fileExtension = imagePath.extension().string();

    std::ranges::transform(
        fileExtension, fileExtension.begin(),
        [](unsigned char c) { return std::tolower(c); }
    );

    if (fileExtension == ".jpg" || fileExtension == ".jpeg") {
        readJPEG(imagePath, imageData);
    } else if (fileExtension == ".png") {
        readPNG(imagePath, imageData);
        readRGBFromPNG(imageData);
    }

    if (imageData.type == "jpeg") {
        if (outputWidth > 0 || outputQuality > 0) {
            decodeJPEG(imageData);
        }
    }

    if (outputWidth > 0) {
        int newWidth = Calculator::convertMMToPixels(outputWidth, imageData.xDpi);
        resizeImage(imageData, newWidth);
    }

    if (outputWidth > 0 || outputQuality > 0 || imageData.type == "png") {
        if (outputQuality == 0) {
            int imageQuality = 95;
            compressJPEG(imageData, imageQuality);
            imageData.type = "jpeg";
        } else if (outputQuality == 100) {
            compressPNG(imageData);
            imageData.type = "png";
        } else {
            int imageQuality = outputQuality;
            compressJPEG(imageData, imageQuality);
            imageData.type = "jpeg";
        }
    }
}

void HandleImage::rotateImage(ImageData& imageData, const int& angle) {
    if (angle != 90 && angle != 180 && angle != 270) {
        return;
    }

    auto inputWidth = imageData.width;
    auto inputHeight = imageData.height;
    auto channels = imageData.channels;

    int outputWidth = inputWidth;
    int outputHeight = inputHeight;
    if (angle == 90 || angle == 270) {
        outputWidth = inputHeight;
        outputHeight = inputWidth;
    }

    std::vector<unsigned char> outputData(outputWidth * outputHeight * channels);

    auto index = [channels](int x, int y, int w) {
        return (y * w + x) * channels;
    };

    for (int y = 0; y < inputHeight; ++y) {
        for (int x = 0; x < inputWidth; ++x) {
            int outputX = 0;
            int outputY = 0;

            if (angle == 90) {
                outputX = y;
                outputY = outputHeight - 1 - x;
            } else if (angle == 270) {
                outputX = outputWidth  - 1 - y;
                outputY = x;
            } else {
                outputX = outputWidth  - 1 - x;
                outputY = outputHeight - 1 - y;
            }

            const unsigned char* inputPixels = imageData.data.data() + index(x, y, inputWidth);
            unsigned char* outputPixels = outputData.data() + index(outputX, outputY, outputWidth);

            for (int c = 0; c < channels; ++c)
                outputPixels[c] = inputPixels[c];
        }
    }

    imageData.data = std::move(outputData);
    imageData.width = outputWidth;
    imageData.height = outputHeight;
}

void HandleImage::resizeImage(ImageData& imageData, const int& newWidth) {
    auto cubicWeight = [](double x) {
        constexpr double a = -0.5;
        x = std::fabs(x);

        if (x <= 1.0) {
            return (a + 2) * x * x * x - (a + 3) * x * x + 1;
        }

        if (x < 2.0) {
            return  a * x * x * x - 5 * a * x * x + 8 * a * x - 4 * a;
        }

        return 0.0;
    };

    auto clamp = [](int v, int lo, int hi) {
        return std::max(lo, std::min(v, hi));
    };

    const auto& inputData = imageData.data;

    const int originalWidth = imageData.width;
    const int originalHeight = imageData.height;

    const double ratioHeightToWeight = static_cast<double>(originalHeight) / static_cast<double>(originalWidth);
    const int newHeight = static_cast<int>(std::round(newWidth * ratioHeightToWeight));

    std::vector<unsigned char> outputData(newWidth * newHeight * 3);

    const double scaleX = static_cast<double>(originalWidth) / static_cast<double>(newWidth);
    const double scaleY = static_cast<double>(originalHeight) / static_cast<double>(newHeight);

    struct Weights {
        int idx[4];
        double w[4];
    };
    std::vector<Weights> xTab(newWidth);

    for (int dx = 0; dx < newWidth; ++dx) {
        double fx = (dx + 0.5) * scaleX - 0.5;
        int ix = static_cast<int>(std::floor(fx));
        double tx = fx - ix;

        for (int k = -1; k <= 2; ++k)
        {
            xTab[dx].idx[k + 1] = clamp(ix + k, 0, originalWidth - 1);
            xTab[dx].w[k + 1] = cubicWeight(k - tx);
        }
    }

    for (int dy = 0; dy < newHeight; ++dy)
    {
        double fy = (dy + 0.5) * scaleY - 0.5;
        int iy = static_cast<int>(std::floor(fy));
        double ty = fy - iy;

        double wy[4];
        int yid[4];
        for (int k = -1; k <= 2; ++k) {
            yid[k + 1] = clamp(iy + k, 0, originalHeight - 1);
            wy [k + 1] = cubicWeight(k - ty);
        }

        for (int dx = 0; dx < newWidth; ++dx)
        {
            double r = 0.0;
            double g = 0.0;
            double b = 0.0;

            for (int m = 0; m < 4; ++m) {
                const unsigned char* inputDataRow = &inputData[(yid[m] * originalWidth) * 3];

                double w_y = wy[m];
                for (int n = 0; n < 4; ++n)
                {
                    int sx = xTab[dx].idx[n];
                    double w = w_y * xTab[dx].w[n];

                    const unsigned char* p = &inputDataRow[sx * 3];

                    r += w * p[0];
                    g += w * p[1];
                    b += w * p[2];
                }
            }

            unsigned char* q = &outputData[(dy * newWidth + dx) * 3];
            q[0] = static_cast<unsigned char>(std::clamp(r, 0.0, 255.0));
            q[1] = static_cast<unsigned char>(std::clamp(g, 0.0, 255.0));
            q[2] = static_cast<unsigned char>(std::clamp(b, 0.0, 255.0));
        }
    }

    imageData.data = std::move(outputData);
    imageData.width = newWidth;
    imageData.height = newHeight;
}

void HandleImage::readJPEG(const std::filesystem::path& imagePath, ImageData& imageData) {
#ifdef _WIN32
    auto imagePathWcharString = SpecificOS::convertUTF8ToWchar(imagePath.string());
    FILE* file = _wfopen(imagePathWcharString.c_str(), L"rb");
#else
    FILE* file = fopen(imagePath.string().c_str(), "rb");
#endif

    fseek(file, 0, SEEK_END);
    long file_size = ftell(file);
    fseek(file, 0, SEEK_SET);

    imageData.data.resize(file_size);
    fread(imageData.data.data(), 1, file_size, file);
    fseek(file, 0, SEEK_SET);

    jpeg_decompress_struct cinfo{};
    jpeg_error_mgr jerr{};

    cinfo.err = jpeg_std_error(&jerr);
    jpeg_create_decompress(&cinfo);
    jpeg_stdio_src(&cinfo, file);
    jpeg_read_header(&cinfo, TRUE);

    imageData.width = static_cast<int>(cinfo.image_width);
    imageData.height = static_cast<int>(cinfo.image_height);
    imageData.channels = cinfo.num_components;
    imageData.type = "jpeg";

    int xDpi = 96;
    int yDpi = 96;

    int densityUnit = cinfo.density_unit;
    if (densityUnit == 1) {
        xDpi = cinfo.X_density;
        yDpi = cinfo.Y_density;
    } else if (densityUnit == 2) {
        /* The density value is pixels per centimeter. 1 inch = 2.54 cm. */
        xDpi = static_cast<int>(std::round(cinfo.X_density * 2.54));
        yDpi = static_cast<int>(std::round(cinfo.Y_density * 2.54));
    }

    imageData.xDpi = xDpi;
    imageData.yDpi = yDpi;

    jpeg_destroy_decompress(&cinfo);
    fclose(file);
}

void HandleImage::decodeJPEG(ImageData& imageData)
{
    jpeg_decompress_struct cinfo{};
    jpeg_error_mgr jerr{};
    cinfo.err = jpeg_std_error(&jerr);
    jpeg_create_decompress(&cinfo);

    jpeg_mem_src(&cinfo, imageData.data.data(), imageData.data.size());
    jpeg_read_header(&cinfo, TRUE);

    cinfo.out_color_space = JCS_RGB;

    jpeg_start_decompress(&cinfo);

    auto width = cinfo.output_width;
    auto height = cinfo.output_height;
    auto channels = cinfo.output_components;
    auto rowStride = width * channels;

    imageData.width = static_cast<int>(width);
    imageData.height = static_cast<int>(height);
    imageData.channels = channels;

    std::vector<unsigned char> jpegData;
    jpegData.resize(static_cast<size_t>(width) * height * channels);

    while (cinfo.output_scanline < height) {
        unsigned char* rowPtr = jpegData.data() + cinfo.output_scanline * rowStride;
        jpeg_read_scanlines(&cinfo, &rowPtr, 1);
    }

    imageData.data = std::move(jpegData);

    jpeg_finish_decompress(&cinfo);
    jpeg_destroy_decompress(&cinfo);
}

void HandleImage::compressJPEG(ImageData& imageData, const int& imageQuality) {
    jpeg_compress_struct cinfo{};
    jpeg_error_mgr jerr{};
    cinfo.err = jpeg_std_error(&jerr);
    jpeg_create_compress(&cinfo);

    unsigned char* outBuf  = nullptr;
    unsigned long  outSize = 0;
    jpeg_mem_dest(&cinfo, &outBuf, &outSize);
    cinfo.image_width = imageData.width;
    cinfo.image_height = imageData.height;
    cinfo.input_components = imageData.channels;

    if (imageData.channels == 3) {
        cinfo.in_color_space = JCS_RGB;
    } else if (imageData.channels == 1) {
        cinfo.in_color_space = JCS_GRAYSCALE;
    }

    jpeg_set_defaults(&cinfo);
    jpeg_set_quality(&cinfo, imageQuality, true);

    cinfo.density_unit = 1;
    cinfo.X_density = imageData.xDpi;
    cinfo.Y_density = imageData.yDpi;

    jpeg_start_compress(&cinfo, true);

    JSAMPROW row_pointer[1];
    int row_stride = imageData.width * imageData.channels;

    unsigned char* buff = imageData.data.data();

    while (cinfo.next_scanline < imageData.height) {
        row_pointer[0] = &buff[cinfo.next_scanline * row_stride];
        jpeg_write_scanlines(&cinfo, row_pointer, 1);
    }

    jpeg_finish_compress(&cinfo);

    std::vector<unsigned char> jpegData(outBuf, outBuf + outSize);
    imageData.data = std::move(jpegData);

    jpeg_destroy_compress(&cinfo);
    free(outBuf);
}

void HandleImage::readPNG(const std::filesystem::path& imagePath, ImageData& imageData) {
#ifdef _WIN32
    auto imagePathWcharString = SpecificOS::convertUTF8ToWchar(imagePath.string());
    FILE* file = _wfopen(imagePathWcharString.c_str(), L"rb");
#else
    FILE* file = fopen(imagePath.string().c_str(), "rb");
#endif

    png_structp png = png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
    png_infop info = png_create_info_struct(png);

    png_init_io(png, file);
    png_read_info(png, info);

    png_uint_32 pixelsPerMeterX;
    png_uint_32 pixelsPerMeterY;
    int densityUnit;
    int xDpi = 96;
    int yDpi = 96;

    if (png_get_pHYs(png, info, &pixelsPerMeterX, &pixelsPerMeterY, &densityUnit)) {
        if (densityUnit == PNG_RESOLUTION_METER) {
            /* The density value is pixels per meter. 1 inch = 0.0254 meter. */
            xDpi = static_cast<int>(std::round(pixelsPerMeterX * 0.0254));
            yDpi = static_cast<int>(std::round(pixelsPerMeterY * 0.0254));
        }
    }

    imageData.xDpi = xDpi;
    imageData.yDpi = yDpi;

    imageData.width = static_cast<int>(png_get_image_width(png, info));
    imageData.height = static_cast<int>(png_get_image_height(png, info));
    png_byte color_type = png_get_color_type(png, info);
    png_byte bit_depth = png_get_bit_depth(png, info);

    if (bit_depth == 16)
        png_set_strip_16(png);

    if (color_type == PNG_COLOR_TYPE_PALETTE)
        png_set_palette_to_rgb(png);

    if (color_type == PNG_COLOR_TYPE_GRAY && bit_depth < 8)
        png_set_expand_gray_1_2_4_to_8(png);

    if (png_get_valid(png, info, PNG_INFO_tRNS))
        png_set_tRNS_to_alpha(png);

    if (color_type == PNG_COLOR_TYPE_RGB ||
        color_type == PNG_COLOR_TYPE_GRAY ||
        color_type == PNG_COLOR_TYPE_PALETTE)
        png_set_filler(png, 0xFF, PNG_FILLER_AFTER);

    if (color_type == PNG_COLOR_TYPE_GRAY ||
        color_type == PNG_COLOR_TYPE_GRAY_ALPHA)
        png_set_gray_to_rgb(png);

    png_read_update_info(png, info);

    imageData.channels = 4;
    imageData.data.resize(imageData.width * imageData.height * 4);

    std::vector<png_bytep> row_pointers(imageData.height);
    for (int y = 0; y < imageData.height; ++y) {
        row_pointers[y] = imageData.data.data() + y * imageData.width * 4;
    }

    png_read_image(png, row_pointers.data());
    imageData.type = "png";

    png_destroy_read_struct(&png, &info, nullptr);
    fclose(file);
}

void HandleImage::readRGBFromPNG(ImageData& imageData) {
    std::vector<unsigned char> rgbData;
    rgbData.reserve(imageData.width * imageData.height * 3);

    for (size_t i = 0; i < imageData.data.size(); i += 4) {
        rgbData.push_back(imageData.data[i]);
        rgbData.push_back(imageData.data[i + 1]);
        rgbData.push_back(imageData.data[i + 2]);
    }

    imageData.data = std::move(rgbData);
    imageData.channels = 3;
}

void HandleImage::compressPNG(ImageData& imageData) {
    auto inputDataSize = imageData.data.size();
    auto compressedSize = compressBound(inputDataSize);
    std::vector<unsigned char> compressed(compressedSize);

    int compressionLevel = 9;

    int result = compress2(
        compressed.data(),
        &compressedSize,
        imageData.data.data(),
        inputDataSize,
        compressionLevel
    );

    if (result != Z_OK) {
        throw std::runtime_error("Compression failed");
    }

    compressed.resize(compressedSize);

    imageData.data = std::move(compressed);
}
