#include "optimizer/hexalyoptimizer.h"

#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <vector>

using namespace hexaly;
using namespace std;

class Mdvrp {
public:
    HexalyOptimizer optimizer;

    // Number of customers
    int nbCustomers;

    // Number of depots/warehouses
    int nbDepots;

    // Number of trucks per depot
    int nbTrucksPerDepot;

    // Capacity of the trucks per depot
    vector<int> truckCapacity;

    // Duration capacity of the trucks per depot
    vector<int> routeDurationCapacity;

    // Service time per customer
    vector<int> serviceTimeData;

    // Demand per customer
    vector<int> demandsData;

    // Distance matrix between customers
    vector<vector<double>> distanceMatrixCustomersData;

    // Distances between customers and depots
    vector<vector<double>> distanceWarehouseData;

    // Decision variables
    vector<vector<HxExpression>> customersSequences;

    // Distance traveled by all the trucks
    HxExpression totalDistance;

    /* Read instance data */
    void readInstance(const string& fileName) { readInputMdvrp(fileName); }

    void solve(int limit) {
        // Declare the optimization model
        HxModel model = optimizer.getModel();

        // Sequence of customers visited by each truck
        customersSequences.resize(nbDepots);

        // Vectorization for partition constraint
        vector<HxExpression> customersSequencesConstraint(nbDepots * nbTrucksPerDepot);

        for (int d = 0; d < nbDepots; ++d) {
            customersSequences[d].resize(nbTrucksPerDepot);
            for (int k = 0; k < nbTrucksPerDepot; ++k) {
                customersSequences[d][k] = model.listVar(nbCustomers);
                customersSequencesConstraint[d * nbTrucksPerDepot + k] = customersSequences[d][k];
            }
        }

        // All customers must be visited by exactly one truck
        model.constraint(model.partition(customersSequencesConstraint.begin(), customersSequencesConstraint.end()));

        // Create Hexaly arrays to be able to access them with an "at" operator
        HxExpression demands = model.array(demandsData.begin(), demandsData.end());
        HxExpression serviceTime = model.array(serviceTimeData.begin(), serviceTimeData.end());

        HxExpression distMatrix = model.array();
        for (int n = 0; n < nbCustomers; ++n) {
            distMatrix.addOperand(
                model.array(distanceMatrixCustomersData[n].begin(), distanceMatrixCustomersData[n].end()));
        }

        // Distances traveled by each truck from each depot
        vector<vector<HxExpression>> routeDistances;
        routeDistances.resize(nbDepots);

        // Total distance traveled
        totalDistance = model.sum();

        for (int d = 0; d < nbDepots; ++d) {
            routeDistances[d].resize(nbTrucksPerDepot);
            HxExpression distDepot = model.array(distanceWarehouseData[d].begin(), distanceWarehouseData[d].end());

            for (int k = 0; k < nbTrucksPerDepot; ++k) {
                HxExpression sequence = customersSequences[d][k];
                HxExpression c = model.count(sequence);

                // The quantity needed in each route must not exceed the truck capacity
                HxExpression demandLambda = model.createLambdaFunction([&](HxExpression j) { return demands[j]; });

                HxExpression routeQuantity = model.sum(sequence, demandLambda);

                model.constraint(routeQuantity <= truckCapacity[d]);

                // Distance traveled by truck k of depot d
                HxExpression distLambda = model.createLambdaFunction(
                    [&](HxExpression i) { return model.at(distMatrix, sequence[i - 1], sequence[i]); });

                routeDistances[d][k] = model.sum(model.range(1, c), distLambda) +
                                       model.iif(c > 0, distDepot[sequence[0]] + distDepot[sequence[c - 1]], 0);

                totalDistance.addOperand(routeDistances[d][k]);

                // We add service time
                HxExpression serviceLambda = model.createLambdaFunction([&](HxExpression j) { return serviceTime[j]; });
                HxExpression routeServiceTime = model.sum(sequence, serviceLambda);

                // The total distance should not exceed the duration capacity of the truck
                // (only if we define such a capacity)
                if (routeDurationCapacity[d] > 0) {
                    model.constraint(routeDistances[d][k] + routeServiceTime <= routeDurationCapacity[d]);
                }
            }
        }

        // Objective: minimize the total distance traveled
        model.minimize(totalDistance);

        model.close();

        // Parametrize the optimizer
        optimizer.getParam().setTimeLimit(limit);

        optimizer.solve();
    }

    // Write the solution in a file with the following format:
    //  - instance, time_limit, total distance
    //  - for each depot and for each truck in this depot, the customers visited
    void writeSolution(const string& fileName, const string& instanceFile, const string& timeLimit) {
        ofstream outfile;
        outfile.open(fileName.c_str());

        outfile << "Instance: " << instanceFile << " ; time_limit: " << timeLimit
                << " ; Objective value: " << totalDistance.getDoubleValue() << endl;
        for (int d = 0; d < nbDepots; ++d) {
            vector<int> trucksUsed;
            for (int k = 0; k < nbTrucksPerDepot; ++k) {
                if (customersSequences[d][k].getCollectionValue().count() > 0) {
                    trucksUsed.push_back(k);
                }
            }
            if (trucksUsed.size() > 0) {
                outfile << "Depot " << d + 1 << endl;
                for (int k = 0; k < trucksUsed.size(); ++k) {
                    outfile << "Truck " << (k + 1) << " : ";
                    HxCollection customersCollection = customersSequences[d][trucksUsed[k]].getCollectionValue();
                    for (int p = 0; p < customersCollection.count(); ++p) {
                        outfile << customersCollection[p] + 1 << " ";
                    }
                    outfile << endl;
                }
                outfile << endl;
            }
        }
    }

private:
    // Input files following "Cordeau"'s format
    void readInputMdvrp(const string& fileName) {
        ifstream infile;
        infile.open(fileName.c_str());

        infile.ignore(); // We ignore the first int of the instance

        // Numbers of trucks per depot, customers and depots
        infile >> nbTrucksPerDepot;
        infile >> nbCustomers;
        infile >> nbDepots;

        routeDurationCapacity.resize(nbDepots);
        truckCapacity.resize(nbDepots);

        for (int d = 0; d < nbDepots; ++d) {
            infile >> routeDurationCapacity[d];
            infile >> truckCapacity[d];
        }

        // Coordinates X and Y, service time and demand for customers
        vector<double> nodesX(nbCustomers);
        vector<double> nodesY(nbCustomers);
        serviceTimeData.resize(nbCustomers);
        demandsData.resize(nbCustomers);

        for (int n = 0; n < nbCustomers; ++n) {
            int id;
            int bin;
            infile >> id;
            infile >> nodesX[id - 1];
            infile >> nodesY[id - 1];
            infile >> serviceTimeData[id - 1];
            infile >> demandsData[id - 1];

            // Ignore the end of the line
            infile.ignore(numeric_limits<streamsize>::max(), '\n');
        }
        // Coordinates X and Y for depots
        vector<double> depotX(nbDepots);
        vector<double> depotY(nbDepots);

        for (int d = 0; d < nbDepots; ++d) {
            int id;
            int bin;
            infile >> id;
            infile >> depotX[id - nbCustomers - 1];
            infile >> depotY[id - nbCustomers - 1];

            // Ignore the end of the line
            infile.ignore(numeric_limits<streamsize>::max(), '\n');
        }

        // Compute the distance matrices
        computeDistanceMatrixCustomers(nodesX, nodesY);
        computeDistanceWarehouse(depotX, depotY, nodesX, nodesY);

        infile.close();
    }

    // Compute the distance matrix for customers
    void computeDistanceMatrixCustomers(const vector<double>& nodesX, const vector<double>& nodesY) {
        distanceMatrixCustomersData.resize(nbCustomers);
        for (int i = 0; i < nbCustomers; ++i) {
            distanceMatrixCustomersData[i].resize(nbCustomers);
        }

        for (int i = 0; i < nbCustomers; ++i) {
            distanceMatrixCustomersData[i][i] = 0;
            for (int j = i + 1; j < nbCustomers; ++j) {
                double distance = computeDist(nodesX[i], nodesX[j], nodesY[i], nodesY[j]);
                distanceMatrixCustomersData[i][j] = distance;
                distanceMatrixCustomersData[j][i] = distance;
            }
        }
    }

    // Compute the distance matrix for warehouses
    void computeDistanceWarehouse(const vector<double>& depotX, const vector<double>& depotY,
                                  const vector<double>& nodesX, const vector<double>& nodesY) {
        distanceWarehouseData.resize(nbDepots);

        for (int d = 0; d < nbDepots; ++d) {
            distanceWarehouseData[d].resize(nbCustomers);
            for (int i = 0; i < nbCustomers; ++i) {
                distanceWarehouseData[d][i] = computeDist(nodesX[i], depotX[d], nodesY[i], depotY[d]);
            }
        }
    }

    // Compute the distance between two points
    double computeDist(double xi, double xj, double yi, double yj) { return sqrt(pow(xi - xj, 2) + pow(yi - yj, 2)); }
};

int main(int argc, char** argv) {
    if (argc < 2) {
        cerr << "Usage: mdvrp inputFile [outputFile] [timeLimit]" << endl;
        return 1;
    }
    const char* instanceFile = argv[1];
    const char* outputFile = argc > 2 ? argv[2] : NULL;
    const char* strTimeLimit = argc > 3 ? argv[3] : "20";

    try {
        Mdvrp model;
        model.readInstance(instanceFile);
        model.solve(atoi(strTimeLimit));

        // If we want to write the solution
        if (outputFile != NULL)
            model.writeSolution(outputFile, instanceFile, strTimeLimit);
    } catch (const exception& e) {
        cerr << "An error occured: " << e.what() << endl;
    }

    return 0;
}