#include "optimizer/hexalyoptimizer.h"
#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <vector>

using namespace hexaly;
using namespace std;

class MultiTripVRP {
public:
    // Hexaly Optimizer
    HexalyOptimizer optimizer;

    // Number of customers
    int nbCustomers;

    // Number of depots
    int nbDepots;

    // Number of depot copies
    int nbDepotCopies;

    // Total number of locations (customers, depots, depots copies)
    int nbTotalLocations;

    // Capacity of the trucks
    int truckCapacity;

    // Demand on each customer
    vector<int> demandsData;

    // Distance matrix between customers
    vector<vector<int>> distMatrixData;

    // Number of trucks
    int nbTrucks;

    // Maximum distance traveled by a truck
    int maxDist;

    // Decision variables
    vector<HxExpression> visitOrders;

    // Are the trucks actually used
    vector<HxExpression> trucksUsed;

    // Distance traveled by all the trucks
    HxExpression totalDistance;

    void readInstance(const char* fileName) { readInputMultiTripVRP(fileName); }

    void solve(int limit) {
        // Declare the optimization model
        HxModel model = optimizer.getModel();

        // Locations visited by each truck (Customers and Depots)
        // Add copies of the depots (so that they can be visited multiple times)
        // Add an extra fictive truck (who will visit every depot that will not be visited by real trucks)
        visitOrders.resize(nbTrucks + 1);
        for (int k = 0; k < nbTrucks + 1; ++k) {
            visitOrders[k] = model.listVar(nbTotalLocations);
        }

        // The fictive truck cannot visit customers
        for(int i=0; i < nbCustomers; i++){
            model.constraint(!(model.contains(visitOrders[nbTrucks],i)));
        }

        // All customers must be visited by exactly one truck
        model.constraint(model.partition(visitOrders.begin(), visitOrders.end()));

        // Create Hexaly arrays to be able to access them with an "at" operator
        HxExpression demands = model.array(demandsData.begin(), demandsData.end());
        HxExpression distMatrix = model.array();
        for (int n = 0; n < nbCustomers + nbDepots * nbDepotCopies; ++n) {
            distMatrix.addOperand(model.array(distMatrixData[n].begin(), distMatrixData[n].end()));
        }

        trucksUsed.resize(nbTrucks);
        vector<HxExpression> distRoutes(nbTrucks);

        for (int k = 0; k < nbTrucks; ++k) {
            HxExpression sequence = visitOrders[k];
            HxExpression c = model.count(sequence);

            // A truck is used if it visits at least one customer
            trucksUsed[k] = c > 0;

            // Compute the quantity in the truck at each step
            HxExpression routeQuantityLambda =
                model.createLambdaFunction([&](HxExpression i, HxExpression prev) { return model.iif(sequence[i] < nbCustomers, prev+demands[sequence[i]],0); });
            HxExpression routeQuantity = model.array(model.range(0, c), routeQuantityLambda, 0);
            // Trucks cannot carry more than their capacity
            HxExpression quantityLambda = model.createLambdaFunction([&](HxExpression i){ return routeQuantity[i]<=truckCapacity; });
            model.constraint(model.and_(model.range(0,c), quantityLambda));

            // Distance traveled by truck k
            HxExpression distLambda = model.createLambdaFunction(
                [&](HxExpression i) { return model.at(distMatrix, sequence[i - 1], sequence[i]); });
            distRoutes[k] = model.sum(model.range(1, c), distLambda) +
                            model.iif(c > 0,
                                model.at(distMatrix,nbCustomers,sequence[0]) +
                                model.at(distMatrix,sequence[c-1],nbCustomers),
                                0);
            
            // Limit distance traveled
            model.constraint(distRoutes[k] <= maxDist);
        }

        // Total distance traveled
        totalDistance = model.sum(distRoutes.begin(), distRoutes.end());

        // Objective: minimize the distance traveled
        model.minimize(totalDistance);

        model.close();

        // Parametrize the optimizer
        optimizer.getParam().setTimeLimit(limit);

        optimizer.solve();
    }

    /* Write the solution in a file */
    void writeSolution(const char* inFile, const string& solFile) {
        ofstream file;
        file.exceptions(ofstream::failbit | ofstream::badbit);
        file.open(solFile.c_str());
        file << "File name: " << inFile << "; total distance = " << totalDistance.getValue() << endl;
        for (int r = 0; r < nbTrucks; ++r) {
            if (trucksUsed[r].getValue()) {
                HxCollection visitCollection = visitOrders[r].getCollectionValue();
                    file << "Truck " << r << " : " ;
                for (hxint i = 0; i < visitCollection.count(); ++i) {
                    file << (visitCollection[i]<nbCustomers ? visitCollection[i] : -(floor((visitCollection[i]-nbCustomers)/nbDepotCopies) + 1)) << " ";
                }
                file << endl;
            }
        }
    }

private:

    void readInputMultiTripVRP(const char* fileName) {
        vector<int> customersX, customersY, depotsX, depotsY;
        string file = fileName;
        ifstream infile(file.c_str());
        if (!infile.is_open()) {
            throw std::runtime_error("File cannot be opened.");
        }
        infile >> nbCustomers;
        customersX.resize(nbCustomers);
        customersY.resize(nbCustomers);
        demandsData.resize(nbCustomers);
        distMatrixData.resize(nbTotalLocations);
        infile >> nbDepots;
        depotsX.resize(nbDepots);
        depotsY.resize(nbDepots);
        for (int i = 0; i < nbDepots; ++i) {
            infile >> depotsX[i];
            infile >> depotsY[i];
        }
        for (int i = 0; i < nbCustomers; ++i) {
            infile >> customersX[i];
            infile >> customersY[i];
        }
        infile >> truckCapacity;
        truckCapacity /= 2;
        // Ignore depot infos
        int temp;
        for (int i = 0; i < nbDepots; ++i) {
            infile>>temp;
        }
        for (int i = 0; i < nbCustomers; ++i) {
            infile >> demandsData[i];
        }

        nbDepotCopies = 20;

        nbTotalLocations = nbCustomers + nbDepots*nbDepotCopies;

        maxDist = 400;
            
        nbTrucks = 3;

        computeDistanceMatrix(depotsX, depotsY, customersX, customersY);
    }

    // Compute the distance matrix
    void computeDistanceMatrix(const vector<int>& depotsX, const vector<int>& depotsY, const vector<int>& customersX, const vector<int>& customersY) {
        distMatrixData.resize(nbTotalLocations);
        for (int i = 0; i < nbTotalLocations; ++i) {
            distMatrixData[i].resize(nbTotalLocations);
        }
        for (int i = 0; i < nbCustomers; ++i) {
            distMatrixData[i][i] = 0;
            for (int j = i + 1; j < nbCustomers; ++j) {
                int distance = computeDist(customersX[i], customersX[j], customersY[i], customersY[j]);
                distMatrixData[i][j] = distance;
                distMatrixData[j][i] = distance;
            }

            for (int d=0; d<nbDepots; ++d){
                int distance = computeDist(customersX[i], depotsX[d], customersY[i], depotsY[d]);
                for (int c=0; c<nbDepotCopies; ++c){
                    int j = nbCustomers + d*nbDepotCopies + c;
                    distMatrixData[i][j] = distance;
                    distMatrixData[j][i] = distance;
                }
            }
        }

        for(int i=nbCustomers; i<nbTotalLocations; i++){
            for(int j=nbCustomers; j<nbTotalLocations;j++){
                // Going from one depot to an other is never worth it
                distMatrixData[i][j] = 100000;
            }
        }
    }

    int computeDist(int xi, int xj, int yi, int yj) {
        double exactDist = sqrt(pow((double)xi - xj, 2) + pow((double)yi - yj, 2));
        return floor(exactDist + 0.5);
    }
};



int main(int argc, char** argv) {
    if (argc < 2) {
        cerr << "Usage: ./lrp inputFile [outputFile] [timeLimit]" << endl;
        return 1;
    }
    const char* instanceFile = argv[1];
    const char* solFile = argc > 2 ? argv[2] : NULL;
    const char* strTimeLimit = argc > 3 ? argv[3] : "20";

    try {
        MultiTripVRP model;
        model.readInstance(instanceFile);
        model.solve(atoi(strTimeLimit));
        if (solFile != NULL)
            model.writeSolution(instanceFile, solFile);
    } catch (const std::exception& e) {
        std::cerr << "An error occured: " << e.what() << endl;
    }

    return 0;
}