#include "optimizer/hexalyoptimizer.h"
#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <numeric>
#include <vector>

using namespace hexaly;
using namespace std;

class LocationRoutingProblem {
public:
    // Hexaly Optimizer
    HexalyOptimizer optimizer;

    // Number of customers
    int nbCustomers;

    // Customers coordinates
    vector<int> xCustomers;
    vector<int> yCustomers;

    // Customers demands
    vector<double> demandsData;

    // Number of depots
    int nbDepots;

    // Depots coordinates
    vector<double> xDepots;
    vector<double> yDepots;

    // Capacity of depots
    vector<double> depotsCapacity;

    // Cost of opening a depot
    vector<double> openingDepotsCost;

    // Number of trucks
    int nbTrucks;

    // Capacity of trucks
    int truckCapacity;

    // Cost of opening a route
    int openingRouteCost;

    // Is the route used ?
    vector<HxExpression> sequenceUsed;

    // What is the depot of the route ?
    vector<HxExpression> associatedDepot;

    // Distance matrixes
    vector<vector<double>> distMatrixData;
    vector<vector<double>> distDepotsData;

    int areCostDouble;

    // Decision variables
    vector<HxExpression> customersSequences;
    vector<HxExpression> depots;

    // Sum of all the costs
    HxExpression totalCost;

    void readInstance(const char* fileName) { readInputLrp(fileName); }

    void solve(const char* limit) {
        // Declare the optimization model
        HxModel m = optimizer.getModel();
        int minNbTrucks = ceil(accumulate(demandsData.begin(), demandsData.end(), 0) / truckCapacity);
        nbTrucks = ceil(1.5 * minNbTrucks);

        // A sequence is represented as a list containing the customers in the order they are visited
        customersSequences.resize(nbTrucks);
        for (int i = 0; i < nbTrucks; ++i) {
            customersSequences[i] = m.listVar(nbCustomers);
        }
        // All customers should be assigned to a sequence
        m.constraint(m.partition(customersSequences.begin(), customersSequences.end()));

        // A depot is represented as a set containing the associated customersSequences
        depots.resize(nbDepots);
        for (int d = 0; d < nbDepots; ++d) {
            depots[d] = m.setVar(nbTrucks);
        }
        // All the customersSequences should be assigned to a depot
        m.constraint(m.partition(depots.begin(), depots.end()));

        vector<HxExpression> distRoutes;
        vector<HxExpression> routeCosts;
        distRoutes.resize(nbTrucks);
        sequenceUsed.resize(nbTrucks);
        routeCosts.resize(nbTrucks);
        associatedDepot.resize(nbTrucks);

        // Create Hexaly arrays to be able to access them with "at" operators
        HxExpression quantityServed = m.array();
        HxExpression demands = m.array(demandsData.begin(), demandsData.end());
        HxExpression distMatrix = m.array();
        HxExpression distDepots = m.array();
        for (int i = 0; i < nbCustomers; ++i) {
            distMatrix.addOperand(m.array(distMatrixData[i].begin(), distMatrixData[i].end()));
            distDepots.addOperand(m.array(distDepotsData[i].begin(), distDepotsData[i].end()));
        }

        for (int r = 0; r < nbTrucks; ++r) {
            HxExpression sequence = customersSequences[r];
            HxExpression c = m.count(sequence);

            // A sequence is used if it serves at least one customer
            sequenceUsed[r] = c > 0;
            // The "find" function gets the depot assigned to the sequence
            associatedDepot[r] = m.find(m.array(depots.begin(), depots.end()), r);

            HxExpression demandLambda = m.lambdaFunction([&](HxExpression j) { return demands[j]; });
            quantityServed.addOperand(m.sum(sequence, demandLambda));
            // The quantity needed in each sequence must not exceed the vehicle capacity
            m.constraint(quantityServed[r] <= truckCapacity);

            HxExpression distLambda =
                m.lambdaFunction([&](HxExpression i) { return m.at(distMatrix, sequence[i], sequence[i + 1]); });

            distRoutes[r] = m.iif(sequenceUsed[r],
                                  m.at(distDepots, sequence[0], associatedDepot[r]) +
                                      m.at(distDepots, sequence[c - 1], associatedDepot[r]),
                                  0) +
                            m.sum(m.range(0, c - 1), distLambda);

            // The sequence cost is the sum of the opening cost and the sequence length
            routeCosts[r] = sequenceUsed[r] * openingRouteCost + distRoutes[r];
        }

        vector<HxExpression> depotCost;
        depotCost.resize(nbDepots);
        for (int d = 0; d < nbDepots; ++d) {
            // A depot is open if at least a sequence starts from there
            depotCost[d] = openingDepotsCost[d] * (m.count(depots[d]) > 0);

            HxExpression depotLambda = m.lambdaFunction([&](HxExpression r) { return quantityServed[r]; });
            HxExpression depotQuantity = m.sum(depots[d], depotLambda);
            // The total demand served by a depot must not exceed its capacity
            m.constraint(depotQuantity <= depotsCapacity[d]);
        }
        HxExpression depotsCost = m.sum(depotCost.begin(), depotCost.end());
        HxExpression routingCost = m.sum(routeCosts.begin(), routeCosts.end());
        totalCost = routingCost + depotsCost;

        m.minimize(totalCost);
        m.close();

        optimizer.getParam().setTimeLimit(atoi(limit));
        optimizer.solve();
    }

    /* Write the solution in a file */
    void writeSolution(const char* inFile, const string& solFile) {
        ofstream file;
        file.exceptions(ofstream::failbit | ofstream::badbit);
        file.open(solFile.c_str());
        file << "File name: " << inFile << "; total cost = " << totalCost.getDoubleValue() << endl;
        for (int r = 0; r < nbTrucks; ++r) {
            if (sequenceUsed[r].getValue()) {
                file << "Sequence " << r << ", assigned to depot " << associatedDepot[r].getValue() << " : ";
                HxCollection customersCollection = customersSequences[r].getCollectionValue();
                for (hxint i = 0; i < customersCollection.count(); ++i) {
                    file << customersCollection[i] << " ";
                }
                file << endl;
            }
        }
    }

private:
    void readInputLrp(const char* fileName) {
        string file = fileName;
        ifstream infile(file.c_str());
        if (!infile.is_open()) {
            throw std::runtime_error("File cannot be opened.");
        }
        infile >> nbCustomers;
        xCustomers.resize(nbCustomers);
        yCustomers.resize(nbCustomers);
        demandsData.resize(nbCustomers);
        distMatrixData.resize(nbCustomers);
        distDepotsData.resize(nbCustomers);
        infile >> nbDepots;
        xDepots.resize(nbDepots);
        yDepots.resize(nbDepots);
        depotsCapacity.resize(nbDepots);
        openingDepotsCost.resize(nbDepots);
        for (int i = 0; i < nbDepots; ++i) {
            infile >> xDepots[i];
            infile >> yDepots[i];
        }
        for (int i = 0; i < nbCustomers; ++i) {
            infile >> xCustomers[i];
            infile >> yCustomers[i];
        }
        infile >> truckCapacity;
        for (int i = 0; i < nbDepots; ++i) {
            infile >> depotsCapacity[i];
        }
        for (int i = 0; i < nbCustomers; ++i) {
            infile >> demandsData[i];
        }
        vector<double> tempOpeningCostDepots;
        tempOpeningCostDepots.resize(nbDepots);
        for (int i = 0; i < nbDepots; ++i) {
            infile >> tempOpeningCostDepots[i];
        }
        int tempOpeningCostRoute;
        infile >> tempOpeningCostRoute;
        infile >> areCostDouble;
        infile.close();
        if (areCostDouble == 1) {
            openingRouteCost = tempOpeningCostRoute;
            for (int i = 0; i < nbDepots; ++i) {
                openingDepotsCost[i] = tempOpeningCostDepots[i];
            }
        } else {
            openingRouteCost = round(tempOpeningCostRoute);
            for (int i = 0; i < nbDepots; ++i) {
                openingDepotsCost[i] = round(tempOpeningCostDepots[i]);
            }
        }
        computeDistanceMatrix();
    }

    void computeDistanceMatrix() {
        for (int i = 0; i < nbCustomers; ++i) {
            distMatrixData[i].resize(nbCustomers);
            distDepotsData[i].resize(nbDepots);
            for (int j = 0; j < nbCustomers; ++j) {
                distMatrixData[i][j] =
                    computeDist(xCustomers[i], yCustomers[i], xCustomers[j], yCustomers[j], areCostDouble);
            }
            for (int d = 0; d < nbDepots; ++d) {
                distDepotsData[i][d] = computeDist(xCustomers[i], yCustomers[i], xDepots[d], yDepots[d], areCostDouble);
            }
        }
    }

    double computeDist(int xi, int yi, int xj, int yj, int areCostDouble) {
        double dist = sqrt(pow(xi - xj, 2) + pow(yi - yj, 2));
        if (areCostDouble == 0) {
            dist = ceil(dist * 100);
        }
        return dist;
    }
};

int main(int argc, char** argv) {
    if (argc < 2) {
        cerr << "Usage: ./lrp inputFile [outputFile] [timeLimit]" << endl;
        return 1;
    }
    const char* instanceFile = argv[1];
    const char* solFile = argc > 2 ? argv[2] : NULL;
    const char* strTimeLimit = argc > 3 ? argv[3] : "20";

    try {
        LocationRoutingProblem model;
        model.readInstance(instanceFile);
        model.solve(strTimeLimit);
        if (solFile != NULL)
            model.writeSolution(instanceFile, solFile);
    } catch (const std::exception& e) {
        std::cerr << "An error occured: " << e.what() << endl;
    }

    return 0;
}