#include "optimizer/hexalyoptimizer.h"
#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <vector>

using namespace hexaly;
using namespace std;

class ClusteredCvrp {
public:
    // Hexaly Optimizer
    HexalyOptimizer optimizer;

    // Number of customers
    int nbCustomers;

    // Capacity of the trucks
    int truckCapacity;

    // Demand on each cluster
    vector<int> demandsData;

    // Customers in each cluster;
    vector<vector<int>> clustersData;

    // Distance matrix between customers
    vector<vector<int>> distMatrixData;

    // Distances between customers and depot
    vector<int> distDepotData;

    // Number of trucks
    int nbTrucks;

    // Number of clusters
    int nbClusters;

    // Decision variables
    vector<HxExpression> truckSequences;
    vector<HxExpression> clustersSequences;

    // Are the trucks actually used
    vector<HxExpression> trucksUsed;

    // Number of trucks used in the solution
    HxExpression nbTrucksUsed;

    // Distance traveled by all the trucks
    HxExpression totalDistance;

    /* Read instance data */
    void readInstance(const string& fileName) {
        readInputCvrp(fileName);
    }

    void solve(int limit) {
        // Declare the optimization model
        HxModel model = optimizer.getModel();

        // Create HexalyOptimizer arrays to be able to access them with an "at" operator
        HxExpression demands = model.array(demandsData.begin(), demandsData.end());
        HxExpression distMatrix = model.array();
        for (int n = 0; n < nbCustomers; ++n) {
            distMatrix.addOperand(model.array(distMatrixData[n].begin(), 
                    distMatrixData[n].end()));
        }
        HxExpression clusters = model.array();
        for (int n = 0; n < clustersData.size(); ++n) {
            clusters.addOperand(model.array(clustersData[n].begin(), clustersData[n].end()));
        }
        HxExpression distDepot = model.array(distDepotData.begin(), distDepotData.end());
        
        // A list is created for each cluster, to determine the order within the cluster
        clustersSequences.resize(nbClusters);
        for (int k = 0; k < nbClusters; ++k) {
            int c = (int) clustersData[k].size();
            clustersSequences[k] = model.listVar(c);
            // All customers in the cluster must be visited 
            model.constraint(model.count(clustersSequences[k]) == c);
        }

        HxExpression clustersDistances =  model.array();
        HxExpression initialNodes = model.array();
        HxExpression endNodes = model.array();
        for (int k = 0; k < nbClusters; ++k) {
            HxExpression sequence = clustersSequences[k];
            HxExpression c = model.count(sequence);

            // Distance traveled within clsuter k
            HxExpression clustersDistances_lambda = model.createLambdaFunction(
                    [&](HxExpression i) { return model.at(distMatrix,
                    clusters[k][sequence[i - 1]], clusters[k][sequence[i]]); });
            clustersDistances.addOperand(model.sum(model.range(1, c), clustersDistances_lambda));

            // First and last point when visiting cluster k
            initialNodes.addOperand(clusters[k][sequence[0]]);
            endNodes.addOperand(clusters[k][sequence[c - 1]]);
        }

        // Sequence of clusters visited by each truck
        truckSequences.resize(nbTrucks);
        for (int k = 0; k < nbTrucks; ++k) {
            truckSequences[k] = model.listVar(nbClusters);
        }
        // All clusters must be visited by the trucks
        model.constraint(model.partition(truckSequences.begin(), truckSequences.end()));
        
        vector<HxExpression> routeDistances(nbTrucks);
        for (int k = 0; k < nbTrucks; ++k) {
            HxExpression sequence = truckSequences[k];
            HxExpression c = model.count(sequence);

            // The quantity needed in each route must not exceed the truck capacity
            HxExpression demandLambda =
                    model.createLambdaFunction([&](HxExpression j) { return demands[j]; });
            HxExpression routeQuantity = model.sum(sequence, demandLambda);
            model.constraint(routeQuantity <= truckCapacity);

            // Distance traveled by truck k
            // = distance in each cluster + distance between clusters + distance with depot 
            // at the beginning end at the end of a route
            HxExpression routeDistances_lambda = model.createLambdaFunction(
                    [&](HxExpression i) { return model.at(clustersDistances, sequence[i]) 
                    + model.at(distMatrix, endNodes[sequence[i - 1]], initialNodes[sequence[i]]);}
            );
            
            routeDistances[k] = model.sum(model.range(1, c), routeDistances_lambda)
                    + model.iif(c > 0, model.at(clustersDistances, sequence[0]) 
                    + distDepot[initialNodes[sequence[0]]] 
                    + distDepot[endNodes[sequence[c - 1]]], 0);
        }

        // Total distance traveled
        totalDistance = model.sum(routeDistances.begin(), routeDistances.end());

        // Objective: minimize the distance traveled
        model.minimize(totalDistance);

        model.close();

        // Parametrize the optimizer
        optimizer.getParam().setTimeLimit(limit);
        
        optimizer.solve();
    }

    /* Write the solution in a file with the following format:
     * - total distance
     * - for each truck the customers visited (omitting the start/end at the depot) */
    void writeSolution(const string& fileName) {
        ofstream outfile;
        outfile.exceptions(ofstream::failbit | ofstream::badbit);
        outfile.open(fileName.c_str());

        outfile << totalDistance.getValue() << endl;
        for (int k = 0; k < nbTrucks; ++k) {
            // Values in sequence are in [0..nbCustomers-1]. +2 is to put it back in 
            // [2..nbCustomers+1] as in the data files (1 being the depot)
            HxCollection customersCollection = truckSequences[k].getCollectionValue();
            for (int i = 0; i < customersCollection.count(); ++i) {
                int cluster = customersCollection[i];
                HxCollection clustersCollection = 
                        clustersSequences[cluster].getCollectionValue();
                for (int j = 0; j < clustersCollection.count(); ++j)
                    outfile << clustersData[cluster][clustersCollection[j]] + 2 << " ";
            }
            outfile << endl;
        }
    }

private:
    // The input files follow the "Augerat" format
    void readInputCvrp(const string& fileName) {
        ifstream infile(fileName.c_str());
        if (!infile.is_open()) {
            throw std::runtime_error("File cannot be opened.");
        }
        string str;
        char* pch;
        char* line;
        int nbNodes;
        while (true) {
            getline(infile, str);
            line = strdup(str.c_str());
            pch = strtok(line, " ");
            if (strcmp(pch, "DIMENSION:") == 0) {
                pch = strtok(NULL, " ");
                nbNodes = atoi(pch);
                nbCustomers = nbNodes - 1;
            } else if (strcmp(pch, "VEHICLES:") == 0) {
                pch = strtok(NULL, " ");
                nbTrucks = atoi(pch);
            } else if (strcmp(pch, "GVRP_SETS:") == 0) {
                pch = strtok(NULL, " ");
                nbClusters = atoi(pch);
            } else if (strcmp(pch, "CAPACITY:") == 0) {
                pch = strtok(NULL, "");
                truckCapacity = atoi(pch);
            } else if (strcmp(pch, "NODE_COORD_SECTION") == 0) {
                break;
            }
        }

        vector<int> customersX(nbCustomers);
        vector<int> customersY(nbCustomers);
        int depotX, depotY;
        for (int n = 1; n <= nbNodes; ++n) {
            int id;
            infile >> id;
            if (id != n) {
                throw std::runtime_error("Unexpected index");
            }
            if (n == 1) {
                infile >> depotX;
                infile >> pch;
                infile >> depotY;
                infile >> pch;
            } else {
                // -2 because original customer indices are in 2..nbNodes
                infile >> customersX[n - 2];
                infile >> pch;
                infile >> customersY[n - 2];
                infile >> pch;
            }
        }

        computeDistanceMatrix(depotX, depotY, customersX, customersY);

        getline(infile, str); // End the last line
        getline(infile, str);
        line = strdup(str.c_str());
        pch = strtok(line, " ");
        if (strcmp(pch, "GVRP_SET_SECTION") != 0) {
            throw std::runtime_error("Expected keyword GVRP_SET_SECTION");
        }
        for (int n = 1; n <= nbClusters; ++n) {
            vector<int> cluster;
            int id;
            infile >> id;
            if (id != n) {
                throw std::runtime_error("Unexpected index");
            }
            int data;
            infile >> data;
            while (data != -1) {
                // -2 because original customer indices are in 2..nbNodes
                cluster.push_back(data - 2);
                infile >> data;
            }
            clustersData.push_back(cluster);

        };

        getline(infile, str); // End the last line
        getline(infile, str);
        line = strdup(str.c_str());
        pch = strtok(line, " ");
        if (strcmp(pch, "DEMAND_SECTION") != 0) {
            throw std::runtime_error("Expected keyword DEMAND_SECTION");
        }
        demandsData.resize(nbClusters);
        for (int n = 1; n <= nbClusters; ++n) {
            int id;
            infile >> id;
            if (id != n) {
                throw std::runtime_error("Unexpected index");
            }
            int demand;
            infile >> demand;
            demandsData[n - 1] = demand;
        }
        infile.close();

    }

    // Compute the distance matrix
    void computeDistanceMatrix(int depotX, int depotY, const vector<int>& customersX, const vector<int>& customersY) {
        distMatrixData.resize(nbCustomers);
        for (int i = 0; i < nbCustomers; ++i) {
            distMatrixData[i].resize(nbCustomers);
        }
        for (int i = 0; i < nbCustomers; ++i) {
            distMatrixData[i][i] = 0;
            for (int j = i + 1; j < nbCustomers; ++j) {
                int distance = computeDist(customersX[i], customersX[j], customersY[i], customersY[j]);
                distMatrixData[i][j] = distance;
                distMatrixData[j][i] = distance;
            }
        }

        distDepotData.resize(nbCustomers);
        for (int i = 0; i < nbCustomers; ++i) {
            distDepotData[i] = computeDist(depotX, customersX[i], depotY, customersY[i]);
        }
    }

    int computeDist(int xi, int xj, int yi, int yj) {
        double exactDist = sqrt(pow((double)xi - xj, 2) + pow((double)yi - yj, 2));
        return floor(exactDist + 0.5);
    }

};

int main(int argc, char** argv) {
    if (argc < 1) {
        cerr << "Usage: clustered-vehicle-routing inputFile [outputFile] [timeLimit] " << endl;
        return 1;
    }
    const char* instanceFile = argc > 1 ? argv[1] : "instances/A-n32-k5-C11-V2.gvrp";
    const char* solFile = argc > 2 ? argv[2] : NULL;
    const char* strTimeLimit = argc > 3 ? argv[3] : "5";
    try {
        ClusteredCvrp model;
        model.readInstance(instanceFile);
        model.solve(atoi(strTimeLimit));
        if (solFile != NULL)
            model.writeSolution(solFile);
        return 0;
    } catch (const exception& e) {
        cerr << "An error occurred: " << e.what() << endl;
        return 1;
    }
}