#include <climits>
#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <numeric>
#include <vector>
#include <algorithm>

#include "optimizer/hexalyoptimizer.h"

using namespace hexaly;
using namespace std;

class Vrptf {
public:
    // Hexaly Optimizer
    HexalyOptimizer optimizer;

    // Number of customers
    int nbCustomers;

    // Customers coordinates
    vector<int> xCustomers;
    vector<int> yCustomers;

    // Customers demands
    vector<int> demandsData;

    // Number of depots
    int nbFacilities;

    // Depots coordinates
    vector<int> xFacilities;
    vector<int> yFacilities;

    // Number of points
    int nbPoints;

    // Depot coordinates
    int xDepot;
    int yDepot;

    // Number of trucks
    int nbTrucks;

    // Capacity of trucks
    int truckCapacity;

    // Distance matrix
    vector<vector<int>> distMatrixData;

    // Distance to depot array
    vector<int> distDepotsData;

    // Assignement costs
    vector<int> assignmentCostsData;

    // Decision variables
    vector<HxExpression> routesSequences;

    // Objective value
    HxExpression totalDistanceCost;
    HxExpression totalAssignementCost;
    HxExpression totalCost;

    void readInstance(const char* fileName) {
        readInput(fileName);
    }

    void solve(const char* limit) {
        // Declare the optimization model
        HxModel m = optimizer.getModel();

        double totalDemand = accumulate(demandsData.begin(), demandsData.begin() + nbCustomers, 0);
        int minNbTrucks = ceil(totalDemand / truckCapacity);
        nbTrucks = ceil(1.5 * minNbTrucks);

        // A route is represented as a list containing the points in the order they are visited
        routesSequences.resize(nbTrucks);
        for (int r = 0; r < nbTrucks; ++r) {
            routesSequences[r] = m.listVar(nbPoints);
        }

        HxExpression routes = m.array(routesSequences.begin(), routesSequences.end());

        // Each customer must be visited at most once
        m.constraint(m.disjoint(routesSequences.begin(), routesSequences.end()));

        // Create Hexaly arrays to be able to access them with "at" operators
        HxExpression demands = m.array(demandsData.begin(), demandsData.end());
        HxExpression distMatrix = m.array();
        for (int i = 0; i < nbPoints; ++i) {
            distMatrix.addOperand(m.array(distMatrixData[i].begin(), distMatrixData[i].end()));
        }
        HxExpression distDepots = m.array(distDepotsData.begin(), distDepotsData.end());
        HxExpression assignementCosts = m.array(assignmentCostsData.begin(), assignmentCostsData.end());

        vector<HxExpression> distRoutes(nbTrucks);
        vector<HxExpression> assignmentCostRoutes(nbTrucks);

        for (int c = 0; c < nbCustomers; ++c) {
            int startFacilities = nbCustomers + c * nbFacilities;
            int endFacilities = startFacilities + nbFacilities;

            // Each customer is either contained in a route or assigned to a facility
            HxExpression facilityUsedSum = m.sum();
            for (int f = startFacilities; f < endFacilities; ++f) {
                facilityUsedSum.addOperand(m.contains(routes, f));
            }
            m.constraint(m.contains(routes, c) + facilityUsedSum == 1);
        }

        for (int r = 0; r < nbTrucks; ++r) {
            HxExpression route = routesSequences[r];
            HxExpression c = m.count(route);

            // Each truck cannot carry more than its capacity
            HxExpression demandLambda = m.lambdaFunction([&](HxExpression j) { return demands[j]; });
            HxExpression quantityServed = m.sum(route, demandLambda);
            m.constraint(quantityServed <= truckCapacity);

            HxExpression distLambda =
                m.lambdaFunction([&](HxExpression i) { return m.at(distMatrix, route[i], route[i + 1]); });

            // Truck is used if it visits at least one point
            HxExpression truckUsed = c > 0;

            // Distance traveled by each truck
            distRoutes[r] = m.sum(m.range(0, c - 1), distLambda) + 
                m.iif(truckUsed,
                    m.at(distDepots, route[0]) +
                    m.at(distDepots, route[c - 1]),
                    0);

            // The cost to assign customers to their facility
            HxExpression assignementCostLambda =
                m.lambdaFunction([&](HxExpression i) { return assignementCosts[i]; });
            assignmentCostRoutes[r] = m.sum(route, assignementCostLambda);
        }

        // The total distance travelled
        totalDistanceCost = m.sum(distRoutes.begin(), distRoutes.end());
        // The total assignement cost
        totalAssignementCost = m.sum(assignmentCostRoutes.begin(), assignmentCostRoutes.end());

        // Objective: minimize the sum of the total distance travelled and the total assignement cost
        totalCost = totalDistanceCost + totalAssignementCost;

        m.minimize(totalCost);
        m.close();

        optimizer.getParam().setTimeLimit(atoi(limit));
        optimizer.solve();
    }

    /* Write the solution in a file */
    void writeSolution(const char* inFile, const string& solFile) {
        ofstream file;
        file.exceptions(ofstream::failbit | ofstream::badbit);
        file.open(solFile.c_str());
        file << "File name: " << inFile << "; totalCost = " << totalCost.getIntValue()
            << "; totalDistance = " << totalDistanceCost.getIntValue()
            << "; totalAssignementCost = " << totalAssignementCost.getIntValue() << endl;
        for (int r = 0; r < nbTrucks; ++r) {
            HxCollection route = routesSequences[r].getCollectionValue();
            if (route.count() == 0) continue;
            file << "Route " << r << " [";
            for (int i = 0; i < route.count(); ++i) {
                long point = route.get(i);
                if (point < nbCustomers) {
                    file << "Customer " << point;
                }
                else {
                    file << "Facility " << point % nbCustomers << " assigned to Customer " << (point - nbCustomers) / nbFacilities;
                }
                if (i < route.count() - 1) {
                    file << ", ";
                }
            }
            file << "]" << endl;
        }
    }

private:
    void readInput(const char* fileName) {
        string file = fileName;
        ifstream infile(file.c_str());
        if (!infile.is_open()) {
            throw std::runtime_error("File cannot be opened.");
        }
        infile >> nbCustomers;
        xCustomers.resize(nbCustomers);
        yCustomers.resize(nbCustomers);
        infile >> nbFacilities;
        xFacilities.resize(nbFacilities);
        yFacilities.resize(nbFacilities);

        // A point is either a customer or a facility
        // Facilities are duplicated for each customer
        nbPoints = nbCustomers + nbCustomers * nbFacilities;
        demandsData.resize(nbPoints);
        distMatrixData.resize(nbPoints);
        for (int i = 0; i < nbPoints; ++i) {
            distMatrixData[i].resize(nbPoints);
        }
        distDepotsData.resize(nbPoints);

        for (int f = 0; f < nbFacilities; ++f) {
            infile >> xFacilities[f];
            infile >> yFacilities[f];
        }
        for (int c = 0; c < nbCustomers; ++c) {
            infile >> xCustomers[c];
            infile >> yCustomers[c];
        }

        infile >> truckCapacity;

        // Facility capacities : skip
        infile.ignore(1000, '\n');
        infile.ignore(1000, '\n');
        for (int f = 0; f < nbFacilities; ++f) {
            infile.ignore(1000, '\n');
        }

        for (int c = 0; c < nbCustomers; ++c) {
            infile >> demandsData[c];
            for (int f = 0; f < nbFacilities; ++f) {
                demandsData[nbCustomers + c * nbFacilities + f] = demandsData[c];
            }
        }

        computeDepotCoordinates();
        computeDistances();
        computeAssignmentCosts();
    }

    void computeDepotCoordinates() {
        // Compute the coordinates of the bounding box containing all of the points
        int xMin = INT_MAX;
        int xMax = INT_MIN;
        int yMin = INT_MAX;
        int yMax = INT_MIN;

        for (int c = 0; c < nbCustomers; ++c) {
            xMin = min(xMin, xCustomers[c]);
        }
        for (int f = 0; f < nbFacilities; ++f) {
            xMin = min(xMin, xFacilities[f]);
        }
        for (int c = 0; c < nbCustomers; ++c) {
            xMax = max(xMax, xCustomers[c]);
        }
        for (int f = 0; f < nbFacilities; ++f) {
            xMax = max(xMax, xFacilities[f]);
        }
        for (int c = 0; c < nbCustomers; ++c) {
            yMin = min(yMin, yCustomers[c]);
        }
        for (int f = 0; f < nbFacilities; ++f) {
            yMin = min(yMin, yFacilities[f]);
        }
        for (int c = 0; c < nbCustomers; ++c) {
            yMax = max(yMax, yCustomers[c]);
        }
        for (int f = 0; f < nbFacilities; ++f) {
            yMax = max(yMax, yFacilities[f]);
        }

        // We assume that the depot is at the center of the bounding box
        xDepot = xMin + (xMax - xMin) / 2;
        yDepot = yMin + (yMax - yMin) / 2;
    }

    void computeDistances() {
        // Customer to depot
        for (int c = 0; c < nbCustomers; ++c) {
            distDepotsData[c] = computeDist(xCustomers[c], xDepot, yCustomers[c], yDepot);
        }

        // Facility to depot
        for (int c = 0; c < nbCustomers; ++c) {
            for (int f = 0; f < nbFacilities; ++f) {
                distDepotsData[nbCustomers + c * nbFacilities + f] = computeDist(xFacilities[f], xDepot, yFacilities[f],
                    yDepot);
            }
        }

        // Distances between customers
        for (int c1 = 0; c1 < nbCustomers; ++c1) {
            for (int c2 = 0; c2 < nbCustomers; ++c2) {
                long dist = computeDist(xCustomers[c1], xCustomers[c2], yCustomers[c1], yCustomers[c2]);
                distMatrixData[c1][c2] = dist;
                distMatrixData[c2][c1] = dist;
            }
        }

        // Distances between customers and facilities
        for (int c1 = 0; c1 < nbCustomers; ++c1) {
            for (int f = 0; f < nbFacilities; ++f) {
                long dist = computeDist(xFacilities[f], xCustomers[c1], yFacilities[f], yCustomers[c1]);
                for (int c2 = 0; c2 < nbCustomers; ++c2) {
                     // Index representing serving c2 through facility f
                    int facilityIndex = nbCustomers + c2 * nbFacilities + f;
                    distMatrixData[facilityIndex][c1] = dist;
                    distMatrixData[c1][facilityIndex] = dist;
                }
            }
        }

        // Distances between facilities
        for (int f1 = 0; f1 < nbFacilities; ++f1) {
            for (int f2 = 0; f2 < nbFacilities; ++f2) {
                long dist = computeDist(xFacilities[f1], xFacilities[f2], yFacilities[f1], yFacilities[f2]);
                for (int c1 = 0; c1 < nbCustomers; ++c1) {
                    // Index representing serving c1 through facility f1
                    int index1 = nbCustomers + c1 * nbFacilities + f1;
                    for (int c2 = 0; c2 < nbCustomers; ++c2) {
                        // Index representing serving c2 through facility f2
                        int index2 = nbCustomers + c2 * nbFacilities + f2;
                        distMatrixData[index1][index2] = dist;
                    }
                }
            }
        }
    }

    void computeAssignmentCosts() {
        // Compute assignment cost for each point
        assignmentCostsData.resize(nbPoints);
        for (int c = 0; c < nbCustomers; ++c) {
            assignmentCostsData[c] = 0;
            for (int f = 0; f < nbFacilities; ++f) {
                // Cost of serving customer c through facility f
                assignmentCostsData[nbCustomers + c * nbFacilities + f] =
                    distMatrixData[c][nbCustomers + c * nbFacilities + f];
            }
        }
    }

    int computeDist(int xi, int xj, int yi, int yj) {
        double exactDist = sqrt(pow(xi - xj, 2) + pow(yi - yj, 2));
        return round(exactDist);
    }
};

int main(int argc, char** argv) {
    if (argc < 2) {
        cerr << "Usage: ./vrptf inputFile [outputFile] [timeLimit]" << endl;
        return 1;
    }
    const char* instanceFile = argv[1];
    const char* solFile = argc > 2 ? argv[2] : NULL;
    const char* strTimeLimit = argc > 3 ? argv[3] : "20";

    try {
        Vrptf model;
        model.readInstance(instanceFile);
        model.solve(strTimeLimit);
        if (solFile != NULL)
            model.writeSolution(instanceFile, solFile);
    } catch (const std::exception& e) {
        std::cerr << "An error occured: " << e.what() << endl;
    }

    return 0;
}