#include "optimizer/hexalyoptimizer.h"
#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <vector>

using namespace hexaly;
using namespace std;

class Pdptw {
public:
    // Hexaly Optimizer
    HexalyOptimizer optimizer;

    // Number of customers
    int nbCustomers;

    // Capacity of the trucks
    int truckCapacity;

    // Latest allowed arrival to depot
    int maxHorizon;

    // Demand for each customer
    vector<int> demandsData;

    // Earliest arrival for each customer
    vector<int> earliestStartData;

    // Latest departure from each customer
    vector<int> latestEndData;

    // Service time for each customer
    vector<int> serviceTimeData;

    // Index for pickup for each node
    vector<int> pickUpIndex;

    // Index for delivery for each node
    vector<int> deliveryIndex;

    // Distance matrix between customers
    vector<vector<double>> distMatrixData;

    // Distances between customers and depot
    vector<double> distDepotData;

    // Number of trucks
    int nbTrucks;

    // Decision variables
    vector<HxExpression> customersSequences;

    // Are the trucks actually used
    vector<HxExpression> trucksUsed;

    // Cumulated lateness in the solution (must be 0 for the solution to be valid)
    HxExpression totalLateness;

    // Number of trucks used in the solution
    HxExpression nbTrucksUsed;

    // Distance traveled by all the trucks
    HxExpression totalDistance;

    Pdptw() {}

    // Read instance data
    void readInstance(const string& fileName) { readInputPdptw(fileName); }

    void solve(int limit) {
        // Declare the optimization model
        HxModel model = optimizer.getModel();

        // Sequence of customers visited by each truck
        customersSequences.resize(nbTrucks);
        for (int k = 0; k < nbTrucks; ++k) {
            customersSequences[k] = model.listVar(nbCustomers);
        }

        // All customers must be visited by exactly one truck
        model.constraint(model.partition(customersSequences.begin(), customersSequences.end()));

        // Create Hexaly arrays to be able to access them with "at" operators
        HxExpression demands = model.array(demandsData.begin(), demandsData.end());
        HxExpression earliest = model.array(earliestStartData.begin(), earliestStartData.end());
        HxExpression latest = model.array(latestEndData.begin(), latestEndData.end());
        HxExpression serviceTime = model.array(serviceTimeData.begin(), serviceTimeData.end());
        HxExpression distMatrix = model.array();
        for (int n = 0; n < nbCustomers; ++n) {
            distMatrix.addOperand(model.array(distMatrixData[n].begin(), distMatrixData[n].end()));
        }
        HxExpression distDepot = model.array(distDepotData.begin(), distDepotData.end());

        trucksUsed.resize(nbTrucks);
        vector<HxExpression> distRoutes(nbTrucks), endTime(nbTrucks), homeLateness(nbTrucks), lateness(nbTrucks);

        // Pickups and deliveries
        HxExpression customersSequencesArray = model.array(customersSequences.begin(), customersSequences.end());
        for (int i = 0; i < nbCustomers; ++i) {
            if (pickUpIndex[i] == -1) {
                HxExpression pickUpListIndex = model.find(customersSequencesArray, i);
                HxExpression deliveryListIndex = model.find(customersSequencesArray, deliveryIndex[i]);
                model.constraint(pickUpListIndex == deliveryListIndex);
                HxExpression pickupList = model.at(customersSequencesArray, pickUpListIndex);
                HxExpression deliveryList = model.at(customersSequencesArray, deliveryListIndex);
                model.constraint(model.indexOf(pickupList, i) < model.indexOf(deliveryList, deliveryIndex[i]));
            }
        }

        for (int k = 0; k < nbTrucks; ++k) {
            HxExpression sequence = customersSequences[k];
            HxExpression c = model.count(sequence);

            // A truck is used if it visits at least one customer
            trucksUsed[k] = c > 0;

            // The quantity needed in each route must not exceed the truck capacity at any point in the sequence
            HxExpression demandLambda = model.createLambdaFunction(
                [&](HxExpression i, HxExpression prev) { return prev + demands[sequence[i]]; });
            HxExpression routeQuantity = model.array(model.range(0, c), demandLambda, 0);

            HxExpression quantityLambda =
                model.createLambdaFunction([&](HxExpression i) { return routeQuantity[i] <= truckCapacity; });
            model.constraint(model.and_(model.range(0, c), quantityLambda));

            // Distance traveled by truck k
            HxExpression distLambda = model.createLambdaFunction(
                [&](HxExpression i) { return model.at(distMatrix, sequence[i - 1], sequence[i]); });
            distRoutes[k] = model.sum(model.range(1, c), distLambda) +
                            model.iif(c > 0, distDepot[sequence[0]] + distDepot[sequence[c - 1]], 0);

            // End of each visit
            HxExpression endLambda = model.createLambdaFunction([&](HxExpression i, HxExpression prev) {
                return model.max(earliest[sequence[i]],
                                 model.iif(i == 0, distDepot[sequence[0]],
                                           prev + model.at(distMatrix, sequence[i - 1], sequence[i]))) +
                       serviceTime[sequence[i]];
            });

            endTime[k] = model.array(model.range(0, c), endLambda, 0);

            // Arriving home after max_horizon
            homeLateness[k] =
                model.iif(trucksUsed[k], model.max(0, endTime[k][c - 1] + distDepot[sequence[c - 1]] - maxHorizon), 0);

            // Completing visit after latest_end
            HxExpression lateLambda = model.createLambdaFunction(
                [&](HxExpression i) { return model.max(0, endTime[k][i] - latest[sequence[i]]); });
            lateness[k] = homeLateness[k] + model.sum(model.range(0, c), lateLambda);
        }

        // Total lateness
        totalLateness = model.sum(lateness.begin(), lateness.end());

        // Total number of trucks used
        nbTrucksUsed = model.sum(trucksUsed.begin(), trucksUsed.end());

        // Total distance traveled
        totalDistance = model.round(100 * model.sum(distRoutes.begin(), distRoutes.end())) / 100;

        // Objective: minimize the number of trucks used, then minimize the distance traveled
        model.minimize(totalLateness);
        model.minimize(nbTrucksUsed);
        model.minimize(totalDistance);

        model.close();

        // Parameterize the optimizer
        optimizer.getParam().setTimeLimit(limit);

        optimizer.solve();
    }

    /* Write the solution in a file with the following format:
     *  - number of trucks used and total distance
     *  - for each truck the customers visited (omitting the start/end at the depot) */
    void writeSolution(const string& fileName) {
        ofstream outfile;
        outfile.exceptions(ofstream::failbit | ofstream::badbit);
        outfile.open(fileName.c_str());

        outfile << nbTrucksUsed.getValue() << " " << totalDistance.getDoubleValue() << endl;
        for (int k = 0; k < nbTrucks; ++k) {
            if (trucksUsed[k].getValue() != 1)
                continue;
            // Values in sequence are in 0...nbCustomers. +2 is to put it back in 2...nbCustomers+2
            // as in the data files (1 being the depot)
            HxCollection customersCollection = customersSequences[k].getCollectionValue();
            for (int i = 0; i < customersCollection.count(); ++i) {
                outfile << customersCollection[i] + 1 << " ";
            }
            outfile << endl;
        }
    }

private:
    // The input files follow the "Li & Lim" format
    void readInputPdptw(const string& fileName) {
        ifstream infile(fileName.c_str());
        if (!infile.is_open()) {
            throw std::runtime_error("File cannot be opened.");
        }

        string str;
        long dump;
        int depotX, depotY;
        vector<int> customersX;
        vector<int> customersY;

        infile >> nbTrucks;
        infile >> truckCapacity;
        infile >> dump;
        infile >> dump;
        infile >> depotX;
        infile >> depotY;
        infile >> dump;
        infile >> dump;
        infile >> maxHorizon;
        infile >> dump;
        infile >> dump;
        infile >> dump;

        while (infile >> dump) {
            int cx, cy, demand, ready, due, service, pick, delivery;
            infile >> cx;
            infile >> cy;
            infile >> demand;
            infile >> ready;
            infile >> due;
            infile >> service;
            infile >> pick;
            infile >> delivery;

            customersX.push_back(cx);
            customersY.push_back(cy);
            demandsData.push_back(demand);
            earliestStartData.push_back(ready);
            latestEndData.push_back(due + service); // in input files due date is meant as latest start time
            serviceTimeData.push_back(service);
            pickUpIndex.push_back(pick - 1);
            deliveryIndex.push_back(delivery - 1);
        }

        nbCustomers = customersX.size(); 

        computeDistanceMatrix(depotX, depotY, customersX, customersY);

        infile.close();
    }

    // Compute the distance matrix
    void computeDistanceMatrix(int depotX, int depotY, const vector<int>& customersX, const vector<int>& customersY) {
        distMatrixData.resize(nbCustomers);
        for (int i = 0; i < nbCustomers; ++i) {
            distMatrixData[i].resize(nbCustomers);
        }
        for (int i = 0; i < nbCustomers; ++i) {
            distMatrixData[i][i] = 0;
            for (int j = i + 1; j < nbCustomers; ++j) {
                double distance = computeDist(customersX[i], customersX[j], customersY[i], customersY[j]);
                distMatrixData[i][j] = distance;
                distMatrixData[j][i] = distance;
            }
        }

        distDepotData.resize(nbCustomers);
        for (int i = 0; i < nbCustomers; ++i) {
            distDepotData[i] = computeDist(depotX, customersX[i], depotY, customersY[i]);
        }
    }

    double computeDist(int xi, int xj, int yi, int yj) {
        return sqrt(pow((double)xi - xj, 2) + pow((double)yi - yj, 2));
    }
};

int main(int argc, char** argv) {
    if (argc < 2) {
        cerr << "Usage: pdptw inputFile [outputFile] [timeLimit]" << endl;
        return 1;
    }

    const char* instanceFile = argv[1];
    const char* solFile = argc > 2 ? argv[2] : NULL;
    const char* strTimeLimit = argc > 3 ? argv[3] : "20";

    try {
        Pdptw model;
        model.readInstance(instanceFile);
        model.solve(atoi(strTimeLimit));
        if (solFile != NULL)
            model.writeSolution(solFile);
        return 0;
    } catch (const exception& e) {
        cerr << "An error occurred: " << e.what() << endl;
        return 1;
    }
}
