#include "optimizer/hexalyoptimizer.h"
#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <vector>

using namespace hexaly;
using namespace std;

class Pcvrp {
public:
    // Hexaly Optimizer
    HexalyOptimizer optimizer;

    // Number of customers
    int nbCustomers;

    // Capacity of the trucks
    int truckCapacity;

    // Minimum number of demands to satisfy
    int demandsToSatisfy;

    // Demand on each customer
    vector<int> demandsData;
    
    // Prize on each customer
    vector<int> prizesData;

    // Distance matrix between customers
    vector<vector<int>> distMatrixData;

    // Distances between customers and depot
    vector<int> distDepotData;

    // Number of trucks
    int nbTrucks;

    // Decision variables
    vector<HxExpression> customersSequences;

    // Are the trucks actually used
    vector<HxExpression> trucksUsed;

    // Number of trucks used in the solution
    HxExpression nbTrucksUsed;

    // Distance traveled by all the trucks
    HxExpression totalDistance;
        
    // Total prize of the solution
    HxExpression totalPrize;

    // Total nb demands satisfied in the solution
    HxExpression totalQuantity;

    /* Read instance data */
    void readInstance(const string& fileName) {
        readInputPcvrp(fileName);
    }

    void solve(int limit) {
        // Declare the optimization model
        HxModel model = optimizer.getModel();
       
        trucksUsed.resize(nbTrucks);
        customersSequences.resize(nbTrucks);
        vector<HxExpression> distRoutes(nbTrucks);
        vector<HxExpression> routeQuantities(nbTrucks);
        vector<HxExpression> routePrizes(nbTrucks);

        // Sequence of customers visited by each truck
        for (int k = 0; k < nbTrucks; ++k) {
            customersSequences[k] = model.listVar(nbCustomers);
        }

        // A customer might be visited by only one truck
        model.constraint(model.disjoint(customersSequences.begin(), customersSequences.end()));

        // Create Hexaly arrays to be able to access them with an "at" operator
        HxExpression demands = model.array(demandsData.begin(), demandsData.end());
        HxExpression prizes = model.array(prizesData.begin(), prizesData.end());
        HxExpression distMatrix = model.array();
        for (int n = 0; n < nbCustomers; ++n) {
            distMatrix.addOperand(model.array(distMatrixData[n].begin(), distMatrixData[n].end()));
        }
        HxExpression distDepot = model.array(distDepotData.begin(), distDepotData.end());

        for (int k = 0; k < nbTrucks; ++k) {
            HxExpression sequence = customersSequences[k];
            HxExpression c = model.count(sequence);

            // A truck is used if it visits at least one customer
            trucksUsed[k] = c > 0;

            // The quantity needed in each route must not exceed the truck capacity
            HxExpression demandLambda =
                model.createLambdaFunction([&](HxExpression j) { return demands[j]; });
            routeQuantities[k] = model.sum(sequence, demandLambda);
            model.constraint(routeQuantities[k] <= truckCapacity);

            // Distance traveled by truck k
            HxExpression distLambda = model.createLambdaFunction(
                [&](HxExpression i) { return model.at(distMatrix, sequence[i - 1], sequence[i]); });
            distRoutes[k] = model.sum(model.range(1, c), distLambda) +
                            model.iif(c > 0, distDepot[sequence[0]] + distDepot[sequence[c - 1]], 0);

            // Route prize of truck k
            HxExpression prizeLambda = 
                model.createLambdaFunction([&](HxExpression j) { return prizes[j]; });
            routePrizes[k] = model.sum(sequence, prizeLambda);
        }

        // Total nb demands satisfied
        totalQuantity = model.sum(routeQuantities.begin(), routeQuantities.end());

        model.constraint(totalQuantity >= demandsToSatisfy);

        // Total nb trucks used
        nbTrucksUsed = model.sum(trucksUsed.begin(), trucksUsed.end());

        // Total distance traveled
        totalDistance = model.sum(distRoutes.begin(), distRoutes.end());

        // Total prize
        totalPrize = model.sum(routePrizes.begin(), routePrizes.end());

        // Objective: minimize the number of trucks used, then maximize the total prize and minimize the distance traveled
        model.minimize(nbTrucksUsed);
        model.maximize(totalPrize);
        model.minimize(totalDistance);

        model.close();

        // Parametrize the optimizer
        optimizer.getParam().setTimeLimit(limit);

        optimizer.solve();
    }

   /* Write the solution in a file with the following format:
    * - total prize, number of trucks used and total distance
    * - for each truck the customers visited (omitting the start/end at the depot)
    * - number of unvisited customers, demands satisfied */
    void writeSolution(const string& fileName) {
        ofstream outfile;
        outfile.exceptions(ofstream::failbit | ofstream::badbit);
        outfile.open(fileName.c_str());

        outfile << totalPrize.getValue() << " " << nbTrucksUsed.getValue() << " " << totalDistance.getValue() << endl;
        int nbUnvisitedCustomers = nbCustomers;
        for (int k = 0; k < nbTrucks; ++k) {
            if (trucksUsed[k].getValue() != 1)
                continue;
            // Values in sequence are in 0...nbCustomers. +1 is to put it back in 1...nbCustomers+1
            // as in the data files (0 being the depot)
            HxCollection customersCollection = customersSequences[k].getCollectionValue();
            for (int i = 0; i < customersCollection.count(); ++i) {
                outfile << customersCollection[i] + 1 << " ";
                --nbUnvisitedCustomers;
            }
            outfile << endl;
        }
        outfile << nbUnvisitedCustomers << " " << totalQuantity.getValue();
    }

private:
    // The input files follow the "longjianyu" format
    void readInputPcvrp(const string& fileName) {
        ifstream infile(fileName.c_str());
        if (!infile.is_open()) {
            throw std::runtime_error("File cannot be opened.");
        }

        infile >> nbTrucks;
        infile >> truckCapacity;
        infile >> demandsToSatisfy;
 
        int n = 0;
        vector<int> customersX, customersY;
        int depotX, depotY;
        int x, y, demand, prize;
        int id;
        while (infile >> id) {
            if (id != n) {
                throw std::runtime_error("Unexpected index");
            }
            if (n == 0) {
                infile >> depotX;
                infile >> depotY;
            } else {
                infile >> x;
                infile >> y;
                customersX.push_back(x);
                customersY.push_back(y);
                infile >> demand;
                infile >> prize;
                demandsData.push_back(demand);
                prizesData.push_back(prize);
            }
            ++n;
        }
        
        nbCustomers = n - 1;

        computeDistanceMatrix(depotX, depotY, customersX, customersY);

        infile.close();
    }

    // Compute the distance matrix
    void computeDistanceMatrix(int depotX, int depotY, const vector<int>& customersX, const vector<int>& customersY) {
        distMatrixData.resize(nbCustomers);
        for (int i = 0; i < nbCustomers; ++i) {
            distMatrixData[i].resize(nbCustomers);
        }
        for (int i = 0; i < nbCustomers; ++i) {
            distMatrixData[i][i] = 0;
            for (int j = i + 1; j < nbCustomers; ++j) {
                int distance = computeDist(customersX[i], customersX[j], customersY[i], customersY[j]);
                distMatrixData[i][j] = distance;
                distMatrixData[j][i] = distance;
            }
        }

        distDepotData.resize(nbCustomers);
        for (int i = 0; i < nbCustomers; ++i) {
            distDepotData[i] = computeDist(depotX, customersX[i], depotY, customersY[i]);
        }
    }

    int computeDist(int xi, int xj, int yi, int yj) {
        double exactDist = sqrt(pow((double)xi - xj, 2) + pow((double)yi - yj, 2));
        return floor(exactDist + 0.5);
    }
};

int main(int argc, char** argv) {
    if (argc < 2) {
        cerr << "Usage: pcvrp inputFile [outputFile] [timeLimit]" << endl;
        return 1;
    }

    const char* instanceFile = argv[1];
    const char* solFile = argc > 2 ? argv[2] : NULL;
    const char* strTimeLimit = argc > 3 ? argv[3] : "20";

    try {
        Pcvrp model;
        model.readInstance(instanceFile);
        model.solve(atoi(strTimeLimit));
        if (solFile != NULL)
            model.writeSolution(solFile);
        return 0;
    } catch (const exception& e) {
        cerr << "An error occurred: " << e.what() << endl;
        return 1;
    }
}
