#include "optimizer/hexalyoptimizer.h"
#include <fstream>
#include <iostream>
#include <json.hpp>
#include <sstream>
#include <vector>

using namespace hexaly;
using namespace std;
using json = nlohmann::json;

class Darp {
public:
    HexalyOptimizer optimizer;

    int nbClients;
    int nbNodes;
    int nbVehicles;
    double factor;
    double depotTwEnd;
    int capacity;
    double scale;

    vector<int> quantitiesData;
    vector<vector<double>> distances;
    vector<double> startsData;
    vector<double> endsData;
    vector<double> loadingTimesData;
    vector<double> maxTravelTimes;
    vector<double> distanceWarehouseData;
    vector<double> timeWarehouseData;
    vector<vector<double>> distanceMatrixData;
    vector<vector<double>> timeMatrixData;

    vector<HxExpression> routes;
    vector<HxExpression> depotStarts;
    vector<HxExpression> waiting;

    HxExpression totalLateness;
    HxExpression totalClientLateness;
    HxExpression totalDistance;

    void readInstance(const string& fileName) {
        ifstream infile;
        infile.exceptions(ifstream::failbit | ifstream::badbit);
        infile.open(fileName.c_str());

        json instance;
        instance << infile;
        infile.close();

        nbClients = instance["nbClients"];
        nbNodes = instance["nbNodes"];
        nbVehicles = instance["nbVehicles"];
        depotTwEnd = instance["depot"]["twEnd"];
        capacity = instance["capacity"];
        scale = instance["scale"];

        quantitiesData.resize(2 * nbClients);
        startsData.resize(2 * nbClients);
        endsData.resize(2 * nbClients);
        loadingTimesData.resize(2 * nbClients);
        maxTravelTimes.resize(2 * nbClients);

        for (int k = 0; k < nbClients; ++k) {
            quantitiesData[k] = instance["clients"][k]["nbClients"].get<int>();
            quantitiesData[k + nbClients] = -instance["clients"][k]["nbClients"].get<int>();

            startsData[k] = instance["clients"][k]["pickup"]["start"].get<double>();
            startsData[k + nbClients] = instance["clients"][k]["delivery"]["start"].get<double>();

            endsData[k] = instance["clients"][k]["pickup"]["end"].get<double>();
            endsData[k + nbClients] = instance["clients"][k]["delivery"]["end"].get<double>();

            loadingTimesData[k] = instance["clients"][k]["pickup"]["loadingTime"].get<double>();
            loadingTimesData[k + nbClients] = instance["clients"][k]["delivery"]["loadingTime"].get<double>();

            maxTravelTimes[k] = instance["clients"][k]["pickup"]["maxTravelTime"].get<double>();
            maxTravelTimes[k + nbClients] = instance["clients"][k]["delivery"]["maxTravelTime"].get<double>();
        }

        distances.resize(nbNodes + 1, vector<double>(nbNodes + 1));
        for (int i = 0; i < nbNodes + 1; ++i) {
            for (int j = 0; j < nbNodes + 1; ++j) {
                distances[i][j] = instance["distanceMatrix"][i][j].get<double>();
            }
        }

        factor = 1.0 / (instance["scale"].get<double>() * instance["speed"].get<double>());

        distanceWarehouseData.resize(nbNodes);
        timeWarehouseData.resize(nbNodes);
        for (int k = 0; k < nbNodes; ++k) {
            distanceWarehouseData[k] = distances[0][k+1];
            timeWarehouseData[k] = distanceWarehouseData[k] * factor;
        }

        distanceMatrixData.resize(nbNodes, vector<double>(nbNodes));
        timeMatrixData.resize(nbNodes, vector<double>(nbNodes));
        for (int i = 0; i < nbNodes; ++i) {
            for (int j = 0; j < nbNodes; ++j){
                distanceMatrixData[i][j] = distances[i+1][j+1];
                timeMatrixData[i][j] = distanceMatrixData[i][j] * factor;
            }
        }
    }
    
    void solve(int limit) {
        HxModel model = optimizer.getModel();

        // routes[k] represents the nodes visited by vehicle k
        routes.resize(nbVehicles);
        depotStarts.resize(nbVehicles);
        for (int k = 0; k < nbVehicles; ++k) {
            routes[k] = model.listVar(nbNodes);
            depotStarts[k] = model.floatVar(0.0, depotTwEnd);
        }
        // waiting[k] is the waiting time at node k
        waiting.resize(nbNodes);
        for (int k = 0; k < nbNodes; ++k) {
            waiting[k] = model.floatVar(0.0, depotTwEnd);
        }
        // Each node is taken by one vehicle
        model.constraint(model.partition(routes.begin(), routes.end()));

        HxExpression quantities = model.array(quantitiesData.begin(), quantitiesData.end());
        HxExpression timeWarehouse = model.array(timeWarehouseData.begin(), timeWarehouseData.end());
        HxExpression timeMatrix = model.array();
        for (int i = 0; i < nbNodes; ++i) {
            timeMatrix.addOperand(model.array(timeMatrixData[i].begin(), timeMatrixData[i].end()));
        }
        HxExpression loadingTimes = model.array(loadingTimesData.begin(), loadingTimesData.end());
        HxExpression starts = model.array(startsData.begin(), startsData.end());
        HxExpression ends = model.array(endsData.begin(), endsData.end());
        HxExpression waitingArray = model.array(waiting.begin(), waiting.end());
        HxExpression distanceMatrix = model.array();
        for (int i = 0; i < nbNodes; ++i) {
            distanceMatrix.addOperand(model.array(distanceMatrixData[i].begin(), distanceMatrixData[i].end()));
        }
        HxExpression distanceWarehouse = model.array(distanceWarehouseData.begin(), distanceWarehouseData.end());
        vector<HxExpression> times(nbVehicles);
        vector<HxExpression> lateness(nbVehicles);
        vector<HxExpression> homeLateness(nbVehicles);
        vector<HxExpression> routeDistances(nbVehicles);

        for (int k = 0; k < nbVehicles; ++k) {
            HxExpression route = routes[k];
            HxExpression c = model.count(route);

            HxExpression demandLambda = model.createLambdaFunction(
                [&](HxExpression i, HxExpression prev) { return prev + quantities[route[i]]; });
            // routeQuantities[k][i] indicates the number of clients in vehicle k
            // at its i-th taken node
            HxExpression routeQuantities = model.array(model.range(0, c), demandLambda);
            HxExpression quantityLambda = model.createLambdaFunction(
                [&](HxExpression i) { return (routeQuantities[i] <= capacity); });
            // Vehicles have a maximum capacity
            model.constraint(model.and_(model.range(0, c), quantityLambda));

            HxExpression timesLambda = model.createLambdaFunction(
                [&](HxExpression i, HxExpression prev) {
                    return model.max(starts[route[i]], model.iif(
                        i == 0,
                        depotStarts[k] + timeWarehouse[route[0]],
                        prev + timeMatrix[route[i-1]][route[i]]
                    )) + waitingArray[route[i]] + loadingTimes[route[i]];
                }
            );
            // times[k][i] is the time at which vehicle k leaves the i-th node
            // (after waiting and loading time at node i)
            times[k] = model.array(model.range(0, c), timesLambda);

            HxExpression latenessLambda = model.createLambdaFunction(
                [&](HxExpression i) {
                    return model.max(
                        0,
                        times[k][i] - loadingTimes[route[i]] - ends[route[i]]
                    );
                }
            );
            // Total lateness of the k-th route
            lateness[k] = model.sum(model.range(0, c), latenessLambda);

            homeLateness[k] = model.iif(
                c > 0,
                model.max(0, times[k][c-1] + timeWarehouse[route[c-1]] - depotTwEnd),
                0
            );

            HxExpression routeDistLambda = model.createLambdaFunction(
                [&](HxExpression i) { return distanceMatrix[route[i-1]][route[i]]; });
            routeDistances[k] = model.sum(model.range(1, c), routeDistLambda)
                + model.iif(
                    c > 0,
                    distanceWarehouse[route[0]] + distanceWarehouse[route[c-1]],
                    0
                );
        }

        HxExpression routesArray = model.array(routes.begin(), routes.end());
        HxExpression timesArray = model.array(times.begin(), times.end());
        vector<HxExpression> clientLateness(nbClients);

        for (int k = 0; k < nbClients; ++k) {
            // For each pickup node k, its associated delivery node is k + instance.nbClients
            HxExpression pickupListIndex = model.find(routesArray, k);
            HxExpression deliveryListIndex = model.find(routesArray, k + nbClients);
            // A client picked up in route i is delivered in route i
            model.constraint(pickupListIndex == deliveryListIndex);

            HxExpression clientList = routesArray[pickupListIndex];
            HxExpression pickupIndex = model.indexOf(clientList, k);
            HxExpression deliveryList = routesArray[deliveryListIndex];
            HxExpression deliveryIndex = model.indexOf(deliveryList, k + nbClients);
            // Pickup before delivery
            model.constraint(pickupIndex < deliveryIndex);

            HxExpression pickupTime = timesArray[pickupListIndex][pickupIndex];
            HxExpression deliveryTime =
                timesArray[deliveryListIndex][deliveryIndex] - loadingTimes[k + nbClients];
            HxExpression travelTime = deliveryTime - pickupTime;
            clientLateness[k] = model.max(travelTime - maxTravelTimes[k], 0);
        }

        vector<HxExpression> latenessPlusHomeLateness(nbVehicles);
        for (int k = 0; k < nbVehicles; ++k) {
            latenessPlusHomeLateness[k] = lateness[k] + homeLateness[k];
        }

        totalLateness = model.sum(latenessPlusHomeLateness.begin(), latenessPlusHomeLateness.end());
        totalClientLateness = model.sum(clientLateness.begin(), clientLateness.end());
        totalDistance = model.sum(routeDistances.begin(), routeDistances.end());

        model.minimize(totalLateness);
        model.minimize(totalClientLateness);
        model.minimize(totalDistance / scale);

        model.close();

        optimizer.getParam().setTimeLimit(limit);

        optimizer.getParam().setSeed(5);
        optimizer.solve();
    }

    /* Write the solution in a file with the following format:
     *  - total lateness on the routes, total client lateness, total distance
     *  - for each vehicle, the depot start time, the nodes visited (omitting the start/end at the
     * depot), and the waiting time at each node */
    void writeSolution(const string& fileName) {
        ofstream outfile;
        outfile.exceptions(ofstream::failbit | ofstream::badbit);
        outfile.open(fileName.c_str());

        outfile << totalLateness.getDoubleValue() << " " << totalClientLateness.getDoubleValue()
            << " " << totalDistance.getDoubleValue() << endl;
        for (int k = 0; k < nbVehicles; ++k) {
            HxCollection route = routes[k].getCollectionValue();
            outfile << "Vehicle " << k << " (" << depotStarts[k].getDoubleValue() << "): ";
            for (int i = 0; i < route.count(); ++i) {
                outfile << route[i] << " (" << waiting[route[i]].getDoubleValue() << "), ";
            }
            outfile << endl;
        }
    }
};

int main(int argc, char** argv) {
    if (argc < 2) {
        cerr << "Usage: darp inputFile [outputFile] [timeLimit] " << endl;
        return 1;
    }

    const char* instanceFile = argv[1];
    const char* solFile = argc > 2 ? argv[2] : NULL;
    const char* strTimeLimit = argc > 3 ? argv[3] : "20";
    try {
        Darp model;
        model.readInstance(instanceFile);
        model.solve(atoi(strTimeLimit));
        if (solFile != NULL){
            model.writeSolution(solFile);
        }
        return 0;
    } catch (const exception& e) {
        cerr << "An error occurred: " << e.what() << endl;
        return 1;
    }

}