#include "optimizer/hexalyoptimizer.h"
#include <cmath>
#include <iostream>
#include <random>
#include <vector>

using namespace hexaly;

class StochasticPacking {
private:
    // Number of items
    int nbItems;
    // Number of bins
    int nbBins;
    // Number of scenarios
    int nbScenarios;
    // For each scenario, the weight of each item
    std::vector<std::vector<int>> scenarioItemWeights;

    // Hexaly Optimizer
    HexalyOptimizer optimizer;
    // Decision variable for the assignment of items
    std::vector<HxExpression> bins;
    // For each scenario, the corresponding maximum weight
    std::vector<HxExpression> scenarioMaxWeight;
    // Objective = minimize the 9th decile of all possible max weights
    HxExpression stochasticMaxWeight;

    void generateScenarios(unsigned int rngSeed) {
        std::mt19937 rng(rngSeed);
        std::uniform_int_distribution<int> distMin(10, 100);
        std::uniform_int_distribution<int> distDelta(0, 50);

        // Pick random parameters for each item distribution
        std::vector<std::uniform_int_distribution<int>> itemsDists;
        for (int i = 0; i < nbItems; ++i) {
            int min = distMin(rng);
            int max = min + distDelta(rng);
            itemsDists.emplace_back(min, max);
        }

        // Sample the distributions to generate the scenarios
        for (int i = 0; i < nbScenarios; ++i) {
            for (int j = 0; j < nbItems; ++j) {
                scenarioItemWeights[i][j] = itemsDists[j](rng);
            }
        }
    }

public:
    StochasticPacking(int nbItems, int nbBins, int nbScenarios, unsigned int seed)
        : nbItems(nbItems), nbBins(nbBins), nbScenarios(nbScenarios),
          scenarioItemWeights(nbScenarios, std::vector<int>(nbItems)), optimizer() {
        generateScenarios(seed);
    }

    void solve(int timeLimit) {
        // Declare the optimization model
        HxModel model = optimizer.getModel();

        bins.resize(nbBins);
        scenarioMaxWeight.resize(nbScenarios);

        // Set decisions: bins[k] represents the items in bin k
        for (int k = 0; k < nbBins; ++k) {
            bins[k] = model.setVar(nbItems);
        }

        // Each item must be in one bin and one bin only
        model.constraint(model.partition(bins.begin(), bins.end()));

        // Compute max weight for each scenario
        for (int m = 0; m < nbScenarios; ++m) {
            HxExpression scenario = model.array(scenarioItemWeights[m].begin(), scenarioItemWeights[m].end());
            HxExpression weightLambda = model.createLambdaFunction([&](HxExpression i) { return scenario[i]; });
            std::vector<HxExpression> binWeights(nbBins);

            for (int k = 0; k < nbBins; ++k) {
                binWeights[k] = model.sum(bins[k], weightLambda);
            }
            scenarioMaxWeight[m] = model.max(binWeights.begin(), binWeights.end());
        }

        // Compute the 9th decile of scenario max weights
        HxExpression scenarioMaxWeightArray = model.array(scenarioMaxWeight.begin(), scenarioMaxWeight.end());
        HxExpression sortedScenarioMaxWeight = model.sort(scenarioMaxWeightArray);
        stochasticMaxWeight = sortedScenarioMaxWeight[(int)std::ceil(0.9 * (nbScenarios - 1))];

        model.minimize(stochasticMaxWeight);
        model.close();

        // Parametrize the optimizer
        optimizer.getParam().setTimeLimit(timeLimit);

        optimizer.solve();
    }

    /* Write the solution */
    void writeSolution(std::ostream& os) const {
        os << "\nScenario item weights:\n";
        for (int i = 0; i < nbScenarios; ++i) {
            os << i << ": [";
            for (int j = 0; j < scenarioItemWeights[i].size(); ++j) {
                os << scenarioItemWeights[i][j] << (j == scenarioItemWeights[i].size() - 1 ? "" : ", ");
            }
            os << "]\n";
        }

        os << "\nBins:\n";
        for (int m = 0; m < nbBins; ++m) {
            os << m << ": { ";
            HxCollection items = bins[m].getCollectionValue();
            for (int i = 0; i < items.count(); ++i) {
                os << items[i] << (i == items.count() - 1 ? " " : ", ");
            }
            os << "}\n";
        }
    }
};

int main(int argc, char** argv) {
    int nbItems = 10;
    int nbBins = 2;
    int nbScenarios = 3;
    int rngSeed = 42;
    int timeLimit = 2;

    try {
        StochasticPacking model(nbItems, nbBins, nbScenarios, rngSeed);
        model.solve(timeLimit);
        model.writeSolution(std::cout);
        return 0;
    } catch (const std::exception& e) {
        std::cerr << "An error occurred: " << e.what() << std::endl;
        return 1;
    }
}