#include "optimizer/hexalyoptimizer.h"
#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <vector>

using namespace hexaly;
using namespace std;

class VehicleRoutingWithBackhauls
{
public:
    // Hexaly Optimizer
    HexalyOptimizer optimizer;

    // Number of customers
    int nbCustomers;

    // Capacity of the trucks
    int truckCapacity;

    // Demand on each customer
    vector<int> pickupDemandsData;
    vector<int> deliveryDemandsData;

    // Type of each customer
    vector<int> isBackhaulData;

    // Distance matrix between customers
    vector<vector<int>> distMatrixData;

    // Distances between customers and depot
    vector<int> distDepotData;

    // Number of trucks
    int nbTrucks;

    // Decision variables
    vector<HxExpression> customersSequences;

    // Are the trucks actually used
    vector<HxExpression> trucksUsed;

    // Number of trucks used in the solution
    HxExpression nbTrucksUsed;

    // Distance traveled by all the trucks
    HxExpression totalDistance;

    /* Read instance data */
    void readInstance(const string &fileName)
    {
        readInputVrpb(fileName);
    }

    void solve(int limit)
    {
        // Declare the optimization model
        HxModel model = optimizer.getModel();

        // Sequence of customers visited by each truck
        customersSequences.resize(nbTrucks);
        for (int k = 0; k < nbTrucks; ++k)
        {
            customersSequences[k] = model.listVar(nbCustomers);
        }

        // All customers must be visited by exactly one truck
        model.constraint(model.partition(customersSequences.begin(), customersSequences.end()));

        // Create Hexaly arrays to be able to access them with an "at" operator
        HxExpression deliveryDemands = model.array(deliveryDemandsData.begin(), deliveryDemandsData.end());
        HxExpression pickupDemands = model.array(pickupDemandsData.begin(), pickupDemandsData.end());
        HxExpression isBackhaul = model.array(isBackhaulData.begin(), isBackhaulData.end());
        HxExpression distMatrix = model.array();
        for (int n = 0; n < nbCustomers; ++n)
        {
            distMatrix.addOperand(model.array(distMatrixData[n].begin(), distMatrixData[n].end()));
        }
        HxExpression distDepot = model.array(distDepotData.begin(), distDepotData.end());

        trucksUsed.resize(nbTrucks);
        vector<HxExpression> distRoutes(nbTrucks);

        for (int k = 0; k < nbTrucks; ++k)
        {
            HxExpression sequence = customersSequences[k];
            HxExpression c = model.count(sequence);

            // A truck is used if it visits at least one customer
            trucksUsed[k] = c > 0;

            // A pickup cannot be followed by a delivery
            HxExpression precedencyLambda =
                model.createLambdaFunction([&](HxExpression i)
                                           { return model.leq(model.at(isBackhaul, sequence[i - 1]), model.at(isBackhaul, sequence[i])); });
            model.constraint(model.and_(model.range(1, c), precedencyLambda));

            // The quantity needed in each route must not exceed the truck capacity
            HxExpression deliveryDemandLambda =
                model.createLambdaFunction([&](HxExpression j)
                                           { return deliveryDemands[j]; });
            HxExpression routeDeliveryQuantity = model.sum(sequence, deliveryDemandLambda);
            model.constraint(routeDeliveryQuantity <= truckCapacity);
            HxExpression pickupDemandLambda =
                model.createLambdaFunction([&](HxExpression j)
                                           { return pickupDemands[j]; });
            HxExpression routePickupQuantity = model.sum(sequence, pickupDemandLambda);
            model.constraint(routePickupQuantity <= truckCapacity);

            // Distance traveled by truck k
            HxExpression distLambda = model.createLambdaFunction(
                [&](HxExpression i)
                { return model.at(distMatrix, sequence[i - 1], sequence[i]); });
            distRoutes[k] = model.sum(model.range(1, c), distLambda) +
                            model.iif(c > 0, distDepot[sequence[0]] + distDepot[sequence[c - 1]], 0);
        }

        // Total number of trucks used
        nbTrucksUsed = model.sum(trucksUsed.begin(), trucksUsed.end());

        // Total distance traveled
        totalDistance = model.sum(distRoutes.begin(), distRoutes.end());

        // Objective: minimize the number of trucks used, then minimize the distance traveled
        model.minimize(nbTrucksUsed);
        model.minimize(totalDistance);

        model.close();

        // Parametrize the optimizer
        optimizer.getParam().setTimeLimit(limit);

        optimizer.solve();
    }

    /* Write the solution in a file with the following format:
     *  - number of trucks used and total distance
     *  - for each truck the customers visited (omitting the start/end at the depot) */
    void writeSolution(const string &fileName)
    {
        ofstream outfile;
        outfile.exceptions(ofstream::failbit | ofstream::badbit);
        outfile.open(fileName.c_str());

        outfile << nbTrucksUsed.getValue() << " " << totalDistance.getValue() << endl;
        for (int k = 0; k < nbTrucks; ++k)
        {
            if (trucksUsed[k].getValue() != 1)
                continue;
            // Values in sequence are in 0...nbCustomers. +2 is to put it back in 2...nbCustomers+2
            // as in the data files (1 being the depot)
            HxCollection customersCollection = customersSequences[k].getCollectionValue();
            for (int i = 0; i < customersCollection.count(); ++i)
            {
                outfile << customersCollection[i] + 2 << " ";
            }
            outfile << endl;
        }
    }

private:
    // The input files follow the "CVRPLib" format
    void readInputVrpb(const string &fileName)
    {
        ifstream infile(fileName.c_str());
        if (!infile.is_open())
        {
            throw std::runtime_error("File cannot be opened.");
        }

        string str;
        char *pch;
        char *line;
        int nbNodes;
        while (true)
        {
            getline(infile, str);
            line = strdup(str.c_str());
            pch = strtok(line, " :");
            if (strcmp(pch, "DIMENSION") == 0)
            {
                pch = strtok(NULL, " :");
                nbNodes = atoi(pch);
                nbCustomers = nbNodes - 1;
            }
            else if (strcmp(pch, "VEHICLES") == 0)
            {
                pch = strtok(NULL, " :");
                nbTrucks = atoi(pch);
            }
            else if (strcmp(pch, "CAPACITY") == 0)
            {
                pch = strtok(NULL, " :");
                truckCapacity = atoi(pch);
            }
            else if (strcmp(pch, "EDGE_WEIGHT_TYPE") == 0)
            {
                pch = strtok(NULL, " :");
                if (strcmp(pch, "EXACT_2D") != 0)
                {
                    throw std::runtime_error("Only Edge Weight Type EXACT_2D is supported");
                }
            }
            else if (strcmp(pch, "NODE_COORD_SECTION") == 0)
            {
                break;
            }
        }

        vector<int> customersX(nbCustomers);
        vector<int> customersY(nbCustomers);
        int depotX, depotY;
        for (int n = 1; n <= nbNodes; ++n)
        {
            int id;
            infile >> id;
            if (id != n)
            {
                throw std::runtime_error("Unexpected index");
            }
            if (n == 1)
            {
                infile >> depotX;
                infile >> depotY;
            }
            else
            {
                // -2 because original customer indices are in 2..nbNodes
                infile >> customersX[n - 2];
                infile >> customersY[n - 2];
            }
        }

        computeDistanceMatrix(depotX, depotY, customersX, customersY);

        getline(infile, str); // End the last line
        getline(infile, str);
        line = strdup(str.c_str());
        pch = strtok(line, " :");
        if (strcmp(pch, "DEMAND_SECTION") != 0)
        {
            throw std::runtime_error("Expected keyword DEMAND_SECTION");
        }

        vector<int> demandsData(nbCustomers);
        for (int n = 1; n <= nbNodes; ++n)
        {
            int id;
            infile >> id;
            if (id != n)
            {
                throw std::runtime_error("Unexpected index");
            }
            int demand;
            infile >> demand;
            if (n == 1)
            {
                if (demand != 0)
                {
                    throw std::runtime_error("Demand for depot should be O");
                }
            }
            else
            {
                // -2 because original customer indices are in 2..nbNodes
                demandsData[n - 2] = demand;
            }
        }

        isBackhaulData.resize(nbCustomers);
        fill(isBackhaulData.begin(), isBackhaulData.end(), 0);
        getline(infile, str); // End the last line
        getline(infile, str);
        line = strdup(str.c_str());
        pch = strtok(line, " :");
        if (strcmp(pch, "BACKHAUL_SECTION") != 0)
        {
            throw std::runtime_error("Expected keyword BACKHAUL_SECTION");
        }
        while (true)
        {
            int id;
            infile >> id;
            if (id == -1)
                break;
            // -2 because original customer indices are in 2..nbNodes
            isBackhaulData[id - 2] = 1;
        }

        deliveryDemandsData.resize(nbCustomers);
        pickupDemandsData.resize(nbCustomers);
        for (int i = 0; i <= nbCustomers; ++i)
        {
            if (isBackhaulData[i])
            {
                deliveryDemandsData[i] = 0;
                pickupDemandsData[i] = demandsData[i];
            }
            else
            {
                deliveryDemandsData[i] = demandsData[i];
                pickupDemandsData[i] = 0;
            }
        }

        getline(infile, str); // End the last line
        getline(infile, str);
        line = strdup(str.c_str());
        pch = strtok(line, " :");
        if (strcmp(pch, "DEPOT_SECTION") != 0)
        {
            throw std::runtime_error("Expected keyword DEPOT_SECTION");
        }

        int depotId;
        infile >> depotId;
        if (depotId != 1)
        {
            throw std::runtime_error("Depot id is supposed to be 1");
        }

        int endOfDepotSection;
        infile >> endOfDepotSection;
        if (endOfDepotSection != -1)
        {
            throw std::runtime_error("Expecting only one depot, more than one found");
        }

        infile.close();
    }

    // Compute the distance matrix
    void computeDistanceMatrix(int depotX, int depotY, const vector<int> &customersX, const vector<int> &customersY)
    {
        distMatrixData.resize(nbCustomers);
        for (int i = 0; i < nbCustomers; ++i)
        {
            distMatrixData[i].resize(nbCustomers);
        }
        for (int i = 0; i < nbCustomers; ++i)
        {
            distMatrixData[i][i] = 0;
            for (int j = i + 1; j < nbCustomers; ++j)
            {
                int distance = computeDist(customersX[i], customersX[j], customersY[i], customersY[j]);
                distMatrixData[i][j] = distance;
                distMatrixData[j][i] = distance;
            }
        }

        distDepotData.resize(nbCustomers);
        for (int i = 0; i < nbCustomers; ++i)
        {
            distDepotData[i] = computeDist(depotX, customersX[i], depotY, customersY[i]);
        }
    }

    int computeDist(int xi, int xj, int yi, int yj)
    {
        double exactDist = sqrt(pow((double)xi - xj, 2) + pow((double)yi - yj, 2));
        return floor(exactDist + 0.5);
    }
};

int main(int argc, char **argv)
{
    if (argc < 2)
    {
        cerr << "Usage: vehicle_routing_backhauls inputFile [outputFile] [timeLimit]" << endl;
        return 1;
    }

    const char *instanceFile = argv[1];
    const char *solFile = argc > 2 ? argv[2] : NULL;
    const char *strTimeLimit = argc > 3 ? argv[3] : "20";

    try
    {
        VehicleRoutingWithBackhauls model;
        model.readInstance(instanceFile);
        model.solve(atoi(strTimeLimit));
        if (solFile != NULL)
            model.writeSolution(solFile);
        return 0;
    }
    catch (const exception &e)
    {
        cerr << "An error occurred: " << e.what() << endl;
        return 1;
    }
}