#include "optimizer/hexalyoptimizer.h"
#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <vector>

using namespace hexaly;
using namespace std;

class Irp {
public:
    // Hexaly Optimizer
    HexalyOptimizer optimizer;

    // Number of customers
    int nbCustomers;

    // Horizon length
    int horizonLength;

    // Capacity
    int capacity;

    // Start level at the supplier
    hxint startLevelSupplier;

    // Production rate of the supplier
    int productionRateSupplier;

    // Holding costs of the supplier
    double holdingCostSupplier;

    // Start level of the customers
    vector<hxint> startLevel;

    // Max level of the customers
    vector<int> maxLevel;

    // Demand rate of the customers
    vector<int> demandRate;

    // Holding costs of the customers
    vector<double> holdingCost;

    // Distance matrix
    vector<vector<int>> distMatrixData;

    // Distance to depot
    vector<int> distSupplierData;

    // Decision variables
    vector<vector<HxExpression>> delivery;

    // Decision variables
    vector<HxExpression> route;

    // Are the customers receiving products
    vector<vector<HxExpression>> isDelivered;

    // Total inventory cost at the supplier
    HxExpression totalCostInventorySupplier;

    // Total inventory cost at customers
    HxExpression totalCostInventory;

    // Total transportation cost
    HxExpression totalCostRoute;

    // Objective
    HxExpression objective;

    Irp() {}

    /* Read instance data */
    void readInstance(const string& fileName) { readInputIrp(fileName); }

    void solve(int limit) {
        // Declare the optimization model
        HxModel model = optimizer.getModel();

        // Quantity of product delivered at each discrete time instant of
        // the planning time horizon to each customer
        delivery.resize(horizonLength);
        for (int t = 0; t < horizonLength; ++t) {
            delivery[t].resize(nbCustomers);
            for (int i = 0; i < nbCustomers; ++i) {
                delivery[t][i] = model.floatVar(0, capacity);
            }
        }

        // Sequence of customers visited at each discrete time instant of
        // the planning time horizon
        route.resize(horizonLength);
        for (int t = 0; t < horizonLength; ++t) {
            route[t] = model.listVar(nbCustomers);
        }

        // Create distances as arrays to be able to access them with an "at" operator
        HxExpression distMatrix = model.array();
        for (int i = 0; i < nbCustomers; ++i) {
            distMatrix.addOperand(model.array(distMatrixData[i].begin(), distMatrixData[i].end()));
        }
        HxExpression distSupplier = model.array(distSupplierData.begin(), distSupplierData.end());

        isDelivered.resize(horizonLength);
        vector<HxExpression> distRoutes(horizonLength);

        for (int t = 0; t < horizonLength; ++t) {
            HxExpression sequence = route[t];
            HxExpression c = model.count(sequence);
            isDelivered[t].resize(nbCustomers);

            // Customers receive products only if they are visited
            for (int i = 0; i < nbCustomers; ++i) {
                isDelivered[t][i] = model.contains(sequence, i);
            }

            // Distance traveled at instant t
            HxExpression distLambda = model.createLambdaFunction(
                [&](HxExpression i) { return model.at(distMatrix, sequence[i - 1], sequence[i]); });
            distRoutes[t] = model.iif(c > 0,
                                     distSupplier[sequence[0]] + model.sum(model.range(1, c), distLambda) +
                                         distSupplier[sequence[c - 1]],
                                     0);
        }

        // Stockout constraints at the supplier
        vector<HxExpression> inventorySupplier(horizonLength + 1);
        inventorySupplier[0] = model.createConstant(startLevelSupplier);
        for (int t = 0; t < horizonLength; ++t) {
            inventorySupplier[t + 1] = inventorySupplier[t] -
                                   model.sum(delivery[t].begin(), delivery[t].end()) + productionRateSupplier;
            model.constraint(inventorySupplier[t] >= model.sum(delivery[t].begin(), delivery[t].end()));
        }

        // Stockout constraints at the customers
        vector<vector<HxExpression>> inventory(nbCustomers);
        for (int i = 0; i < nbCustomers; ++i) {
            inventory[i].resize(horizonLength + 1);
            inventory[i][0] = model.createConstant(startLevel[i]);
            for (int t = 0; t < horizonLength; ++t) {
                inventory[i][t + 1] = inventory[i][t] + delivery[t][i] - demandRate[i];
                model.constraint(inventory[i][t + 1] >= 0);
            }
        }

        for (int t = 0; t < horizonLength; ++t) {
            // Capacity constraints
            model.constraint(model.sum(delivery[t].begin(), delivery[t].end()) <= capacity);

            // Maximum level constraints
            for (int i = 0; i < nbCustomers; ++i) {
                model.constraint(delivery[t][i] <= maxLevel[i] - inventory[i][t]);
                model.constraint(delivery[t][i] <= maxLevel[i] * isDelivered[t][i]);
            }
        }

        // Total inventory cost at the supplier
        totalCostInventorySupplier =
            holdingCostSupplier * model.sum(inventorySupplier.begin(), inventorySupplier.end());

        // Total inventory cost at customers
        vector<HxExpression> costInventoryCustomer(nbCustomers);
        for (int i = 0; i < nbCustomers; ++i) {
            costInventoryCustomer[i] = holdingCost[i] * model.sum(inventory[i].begin(), inventory[i].end());
        }
        totalCostInventory = model.sum(costInventoryCustomer.begin(), costInventoryCustomer.end());

        // Total transportation cost
        totalCostRoute = model.sum(distRoutes.begin(), distRoutes.end());

        // Objective: minimize the sum of all costs
        objective = totalCostInventorySupplier + totalCostInventory + totalCostRoute;
        model.minimize(objective);

        model.close();

        // Parametrize the optimizer
        optimizer.getParam().setTimeLimit(limit);

        optimizer.solve();
    }

    /* Write the solution in a file with the following format:
     * - total distance run by the vehicle
     * - the nodes visited at each time step (omitting the start/end at the supplier) */
    void writeSolution(const string& fileName) {
        ofstream outfile;
        outfile.exceptions(ofstream::failbit | ofstream::badbit);
        outfile.open(fileName.c_str());

        outfile << totalCostRoute.getValue() << endl;
        for (int t = 0; t < horizonLength; ++t) {
            HxCollection routeCollection = route[t].getCollectionValue();
            for (int i = 0; i < routeCollection.count(); ++i) {
                outfile << routeCollection[i] + 1 << " ";
            }
            outfile << endl;
        }
    }

private:
    // The input files follow the "Archetti" format
    void readInputIrp(const string& fileName) {
        ifstream infile(fileName.c_str());
        if (!infile.is_open()) {
            throw std::runtime_error("File cannot be opened.");
        }

        infile >> nbCustomers;
        nbCustomers -= 1;
        infile >> horizonLength;
        infile >> capacity;
        int idSupplier;
        double xCoordSupplier, yCoordSupplier;
        vector<int> id;
        vector<double> xCoord, yCoord;
        infile >> idSupplier;
        infile >> xCoordSupplier;
        infile >> yCoordSupplier;
        infile >> startLevelSupplier;
        infile >> productionRateSupplier;
        infile >> holdingCostSupplier;
        vector<int> minLevel;
        id.resize(nbCustomers);
        xCoord.resize(nbCustomers);
        yCoord.resize(nbCustomers);
        startLevel.resize(nbCustomers);
        maxLevel.resize(nbCustomers);
        minLevel.resize(nbCustomers);
        demandRate.resize(nbCustomers);
        holdingCost.resize(nbCustomers);
        for (int i = 0; i < nbCustomers; ++i) {
            infile >> id[i];
            infile >> xCoord[i];
            infile >> yCoord[i];
            infile >> startLevel[i];
            infile >> maxLevel[i];
            infile >> minLevel[i];
            infile >> demandRate[i];
            infile >> holdingCost[i];
        }

        printf("%lf", holdingCost[3]);

        computeDistanceMatrix(xCoordSupplier, yCoordSupplier, xCoord, yCoord);

        infile.close();
    }

    // Compute the distance matrix
    void computeDistanceMatrix(double xCoordSupplier, double yCoordSupplier, const vector<double>& xCoord,
                               const vector<double>& yCoord) {
        distMatrixData.resize(nbCustomers);
        for (int i = 0; i < nbCustomers; ++i) {
            distMatrixData[i].resize(nbCustomers);
        }
        for (int i = 0; i < nbCustomers; ++i) {
            distMatrixData[i][i] = 0;
            for (int j = i + 1; j < nbCustomers; ++j) {
                int distance = computeDist(xCoord[i], xCoord[j], yCoord[i], yCoord[j]);
                distMatrixData[i][j] = distance;
                distMatrixData[j][i] = distance;
            }
        }

        distSupplierData.resize(nbCustomers);
        for (int i = 0; i < nbCustomers; ++i) {
            distSupplierData[i] = computeDist(xCoordSupplier, xCoord[i], yCoordSupplier, yCoord[i]);
        }
    }

    int computeDist(double xi, double xj, double yi, double yj) {
        return round(sqrt(pow(xi - xj, 2) + pow(yi - yj, 2)));
    }
};

int main(int argc, char** argv) {
    if (argc < 2) {
        cerr << "Usage: irp inputFile [outputFile] [timeLimit]" << endl;
        return 1;
    }

    const char* instanceFile = argv[1];
    const char* solFile = argc > 2 ? argv[2] : NULL;
    const char* strTimeLimit = argc > 3 ? argv[3] : "20";

    try {
        Irp model;
        model.readInstance(instanceFile);
        model.solve(atoi(strTimeLimit));
        if (solFile != NULL)
            model.writeSolution(solFile);
        return 0;
    } catch (const exception& e) {
        cerr << "An error occurred: " << e.what() << endl;
        return 1;
    }
}
