#include <opencv2/opencv.hpp>
#include <opencv2/core.hpp>
#include <opencv2/features2d.hpp>  // For SIFT, FAST, ORB
#include <opencv2/xfeatures2d.hpp> // For SURF, STAR, MSD, HLF, GFFT, BRIEF?
#include <opencv2/cudafeatures2d.hpp> // For CUDA accelerated: FAST, ORB

#include <iostream>
#include <functional>
#include <memory>    

#include "detector_wrappers.hpp"

using detector_config = std::variant<SIFT_CPU_config, SURF_CPU_config, ORB_CPU_config, AKAZE_CPU_config, ORB_GPU_config>;

// General detector wrapper interface class
std::unique_ptr<i_detector> i_detector::create(detector_type type, detector_config &&config) {
    switch (type) {
        // Should I add checks for each of the configuration variants?
        // CPU based feature detectors: 
        case SIFT_CPU:
            return std::make_unique<sift_cpu_detector>(std::get<SIFT_CPU_config>(config));
        case SURF_CPU:
            return std::make_unique<surf_cpu_detector>(std::get<SURF_CPU_config>(config));
        case ORB_CPU:
            return std::make_unique<orb_cpu_detector>(std::get<ORB_CPU_config>(config));
        case AKAZE_CPU:
            return std::make_unique<akaze_cpu_detector>(std::get<AKAZE_CPU_config>(config));
        // CUDA accelerated detectors: 
        case ORB_GPU:
            return std::make_unique<orb_cuda_detector>(std::get<ORB_GPU_config>(config)); 
        // TODO:
        // POPSIFT?
        // SuperGlue/SuperPoint?
        default:
            return nullptr; // What is the correct way to handle this?
    }
    return nullptr;
}

std::unique_ptr<i_detector> i_detector::create(json json_config) {
    std::unordered_map<std::string, std::function<std::unique_ptr<i_detector>()>> factory_map = {
        {"SURF_CPU",  [&]() { return std::make_unique<surf_cpu_detector>(json_config); }},
        {"SIFT_CPU",  [&]() { return std::make_unique<sift_cpu_detector>(json_config); }},
        {"ORB_CPU",   [&]() { return std::make_unique<orb_cpu_detector>(json_config); }},
        {"AKAZE_CPU", [&]() { return std::make_unique<akaze_cpu_detector>(json_config); }},
        {"ORB_GPU",   [&]() { return std::make_unique<orb_cuda_detector>(json_config); }}
    };
    
    std::string det_type = json_config["type"];
    auto it = factory_map.find(det_type);
    if (it == factory_map.end()) {
        throw std::invalid_argument("Invalid feature detector type: " + det_type);
    }
    return it->second();
}


void i_detector::show_features(){
    cv::Mat image_keypoints;
    cv::drawKeypoints(i_detector::image1, i_detector::keypoints1, image_keypoints, cv::Scalar(0, 255, 0));
    cv::imshow("Detected features", image_keypoints);
    cv::waitKey(0);
}

template<typename T>
T i_detector::get_parameter_with_default(const json &config, const std::string& parameter, const T& default_value) {
    if (config.contains(parameter) && !config.at(parameter).is_null()) {
        return config.at(parameter).get<T>();
    } else {
        std::cout << "Using default value (" << default_value << ") for parameter \"" << parameter << "\"" << std::endl;
        return default_value;
    }
}

template<typename T>
T i_detector::get_parameter_with_default(const json &config, const std::unordered_map<std::string, T> &string_enum_map, const std::string &parameter, const T& default_value) {
    if (config.contains(parameter) && !config.at(parameter).is_null()) {
        auto map_el = string_enum_map.find(config.at(parameter));
        if (map_el == string_enum_map.end()) {
            std::cout << "Using default value (" << default_value << ") for parameter \"" << parameter << "\"" << std::endl;
            return default_value;
        } else {
            return map_el->second;
        }
        
    } else {
        std::cout << "Using default value (" << default_value << ") for parameter \"" << parameter << "\"" << std::endl;
        return default_value;
    }
}

// ORB CPU detector
orb_cpu_detector::orb_cpu_detector(detector_config &&config) {
    type = ORB_CPU;
    detector = cv::ORB::create(
        std::get<ORB_CPU_config>(config).nfeatures,
        std::get<ORB_CPU_config>(config).scaleFactor,
        std::get<ORB_CPU_config>(config).nlevels,
        std::get<ORB_CPU_config>(config).edgeThreshold,
        std::get<ORB_CPU_config>(config).firstLevel,
        std::get<ORB_CPU_config>(config).WTA_K,
        std::get<ORB_CPU_config>(config).scoreType,
        std::get<ORB_CPU_config>(config).patchSize,
        std::get<ORB_CPU_config>(config).fastThreshold
    );

    matcher = cv::makePtr<cv::BFMatcher>(cv::NORM_HAMMING, true);
    /*
    // TODO: Look into using FLANN-based matchers for binary descriptors (as an alternative for BF-based matchers)
    cv::Ptr<cv::flann::IndexParams> indexParams = cv::makePtr<cv::flann::LshIndexParams>(12, 20, 2);
    cv::Ptr<cv::flann::SearchParams> searchParams = cv::makePtr<cv::flann::SearchParams>();
    matcher = cv::makePtr<cv::FlannBasedMatcher>(indexParams, searchParams); 
    */
}

orb_cpu_detector::orb_cpu_detector(json config) {
    type = ORB_CPU;

    static const std::unordered_map<std::string, cv::ORB::ScoreType> score_type_map = {
        {"HARRIS", cv::ORB::HARRIS_SCORE},
        {"FAST",   cv::ORB::FAST_SCORE}
    };
    
    detector = cv::ORB::create(
        get_parameter_with_default(config, "n_features", 1500),
        get_parameter_with_default(config, "scale_factor", 1.2f),
        get_parameter_with_default(config, "n_levels", 8),
        get_parameter_with_default(config, "edge_threshold", 31),
        get_parameter_with_default(config, "first_level", 0),
        get_parameter_with_default(config, "WTA_K", 2),
        get_parameter_with_default(config, score_type_map, "score_type", cv::ORB::HARRIS_SCORE),
        get_parameter_with_default(config, "patch_size", 31),
        get_parameter_with_default(config, "fast_threshold", 15)
    );

    static const std::unordered_map<std::string, cv::NormTypes> norm_map = {
        {"NORM_INF",      cv::NORM_INF},
        {"NORM_L1",       cv::NORM_L1},
        {"NORM_L2",       cv::NORM_L2},
        {"NORM_L2SQR",    cv::NORM_L2SQR},
        {"NORM_HAMMING",  cv::NORM_HAMMING},
        {"NORM_HAMMING2", cv::NORM_HAMMING2}
    };

    matcher = cv::BFMatcher::create(
        get_parameter_with_default(config, norm_map, "norm_type", cv::NORM_L2), 
        get_parameter_with_default(config, "cross_check", false)
    );
}

void orb_cpu_detector::detect_features_img1() {
    if (image1.empty()) {
        std::cerr << "Error: No image" << std::endl;
        return;
    }

    detector->detectAndCompute(image1, cv::noArray(), keypoints1, descriptors1); 
}

void orb_cpu_detector::detect_features_img2() {
    if (image2.empty()) {
        std::cerr << "Error: No image" << std::endl;
        return;
    }

    detector->detectAndCompute(image2, cv::noArray(), keypoints2, descriptors2); 
}

void orb_cpu_detector::match_features(){
    std::vector<std::vector<cv::DMatch>> flann_matches;
    matcher->knnMatch(descriptors1, descriptors2, flann_matches, 2);

    

    for (const auto& match_pair : flann_matches) {
        const cv::DMatch& m1 = match_pair[0];
        const cv::DMatch& m2 = match_pair[1];
        if (m1.distance < lowes_ratio_thresh  * m2.distance) {
            matches.push_back(m1);
        }
    }

    raw_matches_count = flann_matches.size();
    matches_count = matches.size();
}



// SIFT CPU detector
sift_cpu_detector::sift_cpu_detector(detector_config &&config) {
    type = SIFT_CPU;
    detector = cv::SIFT::create(
        std::get<SIFT_CPU_config>(config).nfeatures,
        std::get<SIFT_CPU_config>(config).nOctaveLayers,
        std::get<SIFT_CPU_config>(config).contrastThreshold,
        std::get<SIFT_CPU_config>(config).edgeThreshold,
        std::get<SIFT_CPU_config>(config).sigma
    );

    matcher = cv::FlannBasedMatcher::create();
}

sift_cpu_detector::sift_cpu_detector(json config) {
    type = SIFT_CPU;

    detector = cv::SIFT::create(
        get_parameter_with_default(config, "n_features", 0),
        get_parameter_with_default(config, "n_octaveLayers", 3),
        get_parameter_with_default(config, "contrast_threshold", 0.04),
        get_parameter_with_default(config, "edge_threshold", 10),
        get_parameter_with_default(config, "sigma", 1.6),
        get_parameter_with_default(config, "enable_precise_upscale", false)
    );

    matcher = cv::FlannBasedMatcher::create();
}

void sift_cpu_detector::detect_features_img1() {
    if (image1.empty()) {
        std::cerr << "Error: No image" << std::endl;
        return;
    }

    detector->detectAndCompute(image1, cv::noArray(), keypoints1, descriptors1); 
}

void sift_cpu_detector::detect_features_img2() {
    if (image2.empty()) {
        std::cerr << "Error: No image" << std::endl;
        return;
    }

    detector->detectAndCompute(image2, cv::noArray(), keypoints2, descriptors2); 
}

void sift_cpu_detector::match_features(){
    std::vector<std::vector<cv::DMatch>> flann_matches;
    matcher->knnMatch(descriptors1, descriptors2, flann_matches, 2);
    for (const auto& match_pair : flann_matches) {
        const cv::DMatch& m1 = match_pair[0];
        const cv::DMatch& m2 = match_pair[1];
        if (m1.distance < lowes_ratio_thresh  * m2.distance) {
            matches.push_back(m1);
        }
    }

    raw_matches_count = flann_matches.size();
    matches_count = matches.size();
}



// SURF CPU detector
surf_cpu_detector::surf_cpu_detector(detector_config &&config) {
    type = SURF_CPU;
    detector = cv::xfeatures2d::SURF::create(
        std::get<SURF_CPU_config>(config).hessianThreshold,
        std::get<SURF_CPU_config>(config).nOctaves,
        std::get<SURF_CPU_config>(config).nOctaveLayers,
        std::get<SURF_CPU_config>(config).extended,
        std::get<SURF_CPU_config>(config).upright
    );

    matcher = cv::FlannBasedMatcher::create();
}

surf_cpu_detector::surf_cpu_detector(json config) {
    type = SURF_CPU;

    detector = cv::xfeatures2d::SURF::create(
        get_parameter_with_default(config, "hessian_threshold", 100),
        get_parameter_with_default(config, "n_octaves", 4),
        get_parameter_with_default(config, "n_octave_layers", 3),
        get_parameter_with_default(config, "extended", false),
        get_parameter_with_default(config, "upright", false)
    );

    matcher = cv::FlannBasedMatcher::create();
}

void surf_cpu_detector::detect_features_img1() {
    if (image1.empty()) {
        std::cerr << "Error: No image" << std::endl;
        return;
    }

    detector->detectAndCompute(image1, cv::noArray(), keypoints1, descriptors1); 
}

void surf_cpu_detector::detect_features_img2() {
    if (image2.empty()) {
        std::cerr << "Error: No image" << std::endl;
        return;
    }

    detector->detectAndCompute(image2, cv::noArray(), keypoints2, descriptors2); 
}

void surf_cpu_detector::match_features(){
    std::vector<std::vector<cv::DMatch>> flann_matches;
    matcher->knnMatch(descriptors1, descriptors2, flann_matches, 2);

    for (const auto& match_pair : flann_matches) {
        const cv::DMatch& m1 = match_pair[0];
        const cv::DMatch& m2 = match_pair[1];
        if (m1.distance < lowes_ratio_thresh  * m2.distance) {
            matches.push_back(m1);
        }
    }

    raw_matches_count = flann_matches.size();
    matches_count = matches.size();
}



// AKAZE CPU detector
akaze_cpu_detector::akaze_cpu_detector(detector_config &&config) {
    type = AKAZE_CPU;
    detector = cv::AKAZE::create(
        std::get<AKAZE_CPU_config>(config).descriptor_type,
        std::get<AKAZE_CPU_config>(config).descriptor_size,
        std::get<AKAZE_CPU_config>(config).descriptor_channels,
        std::get<AKAZE_CPU_config>(config).threshold,
        std::get<AKAZE_CPU_config>(config).nOctaves,
        std::get<AKAZE_CPU_config>(config).nOctaveLayers,
        std::get<AKAZE_CPU_config>(config).diffusivity,
        std::get<AKAZE_CPU_config>(config).max_points
    );

    matcher = cv::makePtr<cv::BFMatcher>(cv::NORM_HAMMING, false);
}

akaze_cpu_detector::akaze_cpu_detector(json config) {
    type = AKAZE_CPU;

    static const std::unordered_map<std::string, cv::AKAZE::DescriptorType> desc_map = {
        {"DESCRIPTOR_KAZE_UPRIGHT", cv::AKAZE::DESCRIPTOR_KAZE_UPRIGHT},
        {"DESCRIPTOR_KAZE",         cv::AKAZE::DESCRIPTOR_KAZE},
        {"DESCRIPTOR_MLDB_UPRIGHT", cv::AKAZE::DESCRIPTOR_MLDB_UPRIGHT},
        {"DESCRIPTOR_MLDB",         cv::AKAZE::DESCRIPTOR_MLDB}
    };

    static const std::unordered_map<std::string, cv::KAZE::DiffusivityType> diff_map = {
        {"DIFF_PM_G1",       cv::KAZE::DIFF_PM_G1},
        {"DIFF_PM_G2",       cv::KAZE::DIFF_PM_G2},
        {"DIFF_WEICKERT",    cv::KAZE::DIFF_WEICKERT},
        {"DIFF_CHARBONNIER", cv::KAZE::DIFF_CHARBONNIER}
    };

    detector = cv::AKAZE::create(
        get_parameter_with_default(config, desc_map, "descriptor_type", cv::AKAZE::DESCRIPTOR_MLDB),
        get_parameter_with_default(config, "descriptor_size", 0),
        get_parameter_with_default(config, "descriptor_channels", 3),
        get_parameter_with_default(config, "threshold", 0.001f),
        get_parameter_with_default(config, "n_octaves", 4),
        get_parameter_with_default(config, "n_octave_layers", 4),
        get_parameter_with_default(config, diff_map, "diffusivity_type", cv::KAZE::DIFF_PM_G2),
        get_parameter_with_default(config, "max_points", -1)
    );

    static const std::unordered_map<std::string, cv::NormTypes> norm_map = {
        {"NORM_INF",      cv::NORM_INF},
        {"NORM_L1",       cv::NORM_L1},
        {"NORM_L2",       cv::NORM_L2},
        {"NORM_L2SQR",    cv::NORM_L2SQR},
        {"NORM_HAMMING",  cv::NORM_HAMMING},
        {"NORM_HAMMING2", cv::NORM_HAMMING2}
    };

    matcher = cv::BFMatcher::create(
        get_parameter_with_default(config, norm_map, "norm_type", cv::NORM_L2), 
        get_parameter_with_default(config, "cross_check", false)
    );
}

void akaze_cpu_detector::detect_features_img1() {
    if (image1.empty()) {
        std::cerr << "Error: No image" << std::endl;
        return;
    }

    detector->detectAndCompute(image1, cv::noArray(), keypoints1, descriptors1); 
}

void akaze_cpu_detector::detect_features_img2() {
    if (image2.empty()) {
        std::cerr << "Error: No image" << std::endl;
        return;
    }

    detector->detectAndCompute(image2, cv::noArray(), keypoints2, descriptors2); 
}

void akaze_cpu_detector::match_features(){
    std::vector<std::vector<cv::DMatch>> flann_matches;
    matcher->knnMatch(descriptors1, descriptors2, flann_matches, 2);

    for (const auto& match_pair : flann_matches) {
        const cv::DMatch& m1 = match_pair[0];
        const cv::DMatch& m2 = match_pair[1];
        if (m1.distance < lowes_ratio_thresh  * m2.distance) {
            matches.push_back(m1);
        }
    }

    raw_matches_count = flann_matches.size();
    matches_count = matches.size();
}


// ORB CUDA accelerated detector
orb_cuda_detector::orb_cuda_detector(detector_config &&config){
    type = ORB_CPU;
    detector = cv::cuda::ORB::create(
        std::get<ORB_GPU_config>(config).nfeatures,
        std::get<ORB_GPU_config>(config).scaleFactor,
        std::get<ORB_GPU_config>(config).nlevels,
        std::get<ORB_GPU_config>(config).edgeThreshold,
        std::get<ORB_GPU_config>(config).firstLevel,
        std::get<ORB_GPU_config>(config).WTA_K,
        std::get<ORB_GPU_config>(config).scoreType,
        std::get<ORB_GPU_config>(config).patchSize,
        std::get<ORB_GPU_config>(config).fastThreshold,
        std::get<ORB_GPU_config>(config).blurForDescriptor
    );

    matcher = cv::cuda::DescriptorMatcher::createBFMatcher(cv::NORM_HAMMING);
}

orb_cuda_detector::orb_cuda_detector(json config){
    type = ORB_CPU;

    static const std::unordered_map<std::string, cv::ORB::ScoreType> score_type_map = {
        {"HARRIS", cv::ORB::HARRIS_SCORE},
        {"FAST",   cv::ORB::FAST_SCORE}
    };

    detector = cv::cuda::ORB::create(
        get_parameter_with_default(config, "n_features", 1500),
        get_parameter_with_default(config, "scale_factor", 1.2f),
        get_parameter_with_default(config, "n_levels", 8),
        get_parameter_with_default(config, "edge_threshold", 31),
        get_parameter_with_default(config, "first_level", 0),
        get_parameter_with_default(config, "WTA_K", 2),
        get_parameter_with_default(config, score_type_map, "score_type", cv::ORB::HARRIS_SCORE),
        get_parameter_with_default(config, "patch_size", 31),
        get_parameter_with_default(config, "fast_threshold", 15),
        get_parameter_with_default(config, "blur_for_descriptor", false)
    );        

    static const std::unordered_map<std::string, cv::NormTypes> norm_map = {
        {"NORM_INF",      cv::NORM_INF},
        {"NORM_L1",       cv::NORM_L1},
        {"NORM_L2",       cv::NORM_L2},
        {"NORM_L2SQR",    cv::NORM_L2SQR},
        {"NORM_HAMMING",  cv::NORM_HAMMING},
        {"NORM_HAMMING2", cv::NORM_HAMMING2}
    };

    matcher = cv::cuda::DescriptorMatcher::createBFMatcher(
        get_parameter_with_default(config, norm_map, "norm_type", cv::NORM_HAMMING)
    );
}

void orb_cuda_detector::detect_features_img1() {
    if (image1.empty()) {
        std::cerr << "Error: No image" << std::endl;
        return;
    }

    d_image1.upload(image1);

    if (d_image1.empty()) {
        std::cerr << "Error: Upload to GPU failed" << std::endl;
        return;
    }

    detector->detectAndComputeAsync(d_image1, cv::noArray(), d_keypoints1, d_descriptors1);
    detector->convert(d_keypoints1, keypoints1);
    d_descriptors1.download(descriptors1);
}

void orb_cuda_detector::detect_features_img2() {
    if (image2.empty()) {
        std::cerr << "Error: No image" << std::endl;
        return;
    }
    d_image2.upload(image2);
    detector->detectAndComputeAsync(d_image2, cv::noArray(), d_keypoints2, d_descriptors2);
    detector->convert(d_keypoints2, keypoints2);
    d_descriptors2.download(descriptors2);
}

void orb_cuda_detector::match_features(){
    std::vector<std::vector<cv::DMatch>> flann_matches;
    matcher->knnMatch(d_descriptors1, d_descriptors2, flann_matches, 2);


    for (const auto& match_pair : flann_matches) {
        const cv::DMatch& m1 = match_pair[0];
        const cv::DMatch& m2 = match_pair[1];
        if (m1.distance < lowes_ratio_thresh  * m2.distance) {
            matches.push_back(m1);
        }
    }

    raw_matches_count = flann_matches.size();
    matches_count = matches.size();
}