import com.hexaly.optimizer.*;

import java.io.*;
import java.util.*;

import com.google.gson.JsonObject;
import com.google.gson.JsonArray;
import com.google.gson.JsonParser;
import com.google.gson.stream.JsonReader;

public class Darp {
    private final HexalyOptimizer optimizer;

    int nbClients;
    int nbNodes;
    int nbVehicles;
    double depotTwEnd;
    int capacity;
    double scale;
    double factor;

    int[] quantitiesData;
    double[] startsData;
    double[] endsData;
    double[] loadingTimesData;
    double[] maxTravelTimes;

    double[][] distances;
    double[] distanceWarehouseData;
    double[] timeWarehouseData;
    double[][] distanceMatrixData;
    double[][] timeMatrixData;

    HxExpression[] routes;
    HxExpression[] depotStarts;
    HxExpression[] waiting;

    HxExpression totalLateness;
    HxExpression totalClientLateness;
    HxExpression totalDistance;

    private Darp() {
        optimizer = new HexalyOptimizer();
    }

    // Read instance data
    private void readInstance(String fileName) throws IOException {
        JsonReader reader = new JsonReader(new InputStreamReader(new FileInputStream(fileName)));
        JsonObject instance = JsonParser.parseReader(reader).getAsJsonObject();

        nbClients = instance.get("nbClients").getAsInt();
        nbNodes = instance.get("nbNodes").getAsInt();
        nbVehicles = instance.get("nbVehicles").getAsInt();
        JsonObject depot = (JsonObject) instance.get("depot");
        depotTwEnd = depot.get("twEnd").getAsDouble();
        capacity = instance.get("capacity").getAsInt();
        scale = instance.get("scale").getAsDouble();
        JsonArray clients = (JsonArray) instance.get("clients");

        quantitiesData = new int[2 * nbClients];
        startsData = new double[2 * nbClients];
        endsData = new double[2 * nbClients];
        loadingTimesData = new double[2 * nbClients];
        maxTravelTimes = new double[2 * nbClients];

        for (int k = 0; k < nbClients; ++k){
            JsonObject clientk = (JsonObject) clients.get(k);

            quantitiesData[k] = clientk.get("nbClients").getAsInt();
            quantitiesData[k + nbClients] = -clientk.get("nbClients").getAsInt();

            JsonObject clientkPickup = (JsonObject) clientk.get("pickup");
            JsonObject clientkDelivery = (JsonObject) clientk.get("delivery");

            startsData[k] = clientkPickup.get("start").getAsDouble();
            startsData[k + nbClients] = clientkDelivery.get("start").getAsDouble();

            endsData[k] = clientkPickup.get("end").getAsDouble();
            endsData[k + nbClients] = clientkDelivery.get("end").getAsDouble();

            loadingTimesData[k] = clientkPickup.get("loadingTime").getAsDouble();
            loadingTimesData[k + nbClients] = clientkDelivery.get("loadingTime").getAsDouble();

            maxTravelTimes[k] = clientkPickup.get("maxTravelTime").getAsDouble();
            maxTravelTimes[k + nbClients] = clientkDelivery.get("maxTravelTime").getAsDouble();
        }

        distances = new double[nbNodes + 1][nbNodes + 1];
        JsonArray distanceMatrixJson = (JsonArray) instance.get("distanceMatrix");
        for (int i = 0; i < nbNodes + 1; ++i){
            JsonArray distanceMatrixJsoni = (JsonArray) distanceMatrixJson.get(i);
            for (int j = 0; j < nbNodes + 1; ++j){
                distances[i][j] = distanceMatrixJsoni.get(j).getAsDouble();
            }
        }

        factor = 1.0 / (instance.get("scale").getAsDouble() * instance.get("speed").getAsDouble());

        distanceWarehouseData = new double[nbNodes];
        timeWarehouseData = new double[nbNodes];
        for (int k = 0; k < nbNodes; ++k){
            distanceWarehouseData[k] = distances[0][k+1];
            timeWarehouseData[k] = distanceWarehouseData[k] * factor;
        }

        distanceMatrixData = new double[nbNodes][nbNodes];
        timeMatrixData = new double[nbNodes][nbNodes];
        for (int i = 0; i < nbNodes; ++i){
            for (int j = 0; j < nbNodes; ++j){
                distanceMatrixData[i][j] = distances[i+1][j+1];
                timeMatrixData[i][j] = distanceMatrixData[i][j] * factor;
            }
        }
    }

    private void solve(int limit) {
        HxModel model = optimizer.getModel();

        // routes[k] represents the nodes visited by vehicle k
        routes = new HxExpression[nbVehicles];
        depotStarts = new HxExpression[nbVehicles];
        for (int k = 0; k < nbVehicles; ++k){
            routes[k] = model.listVar(nbNodes);
            depotStarts[k] = model.floatVar(0.0, depotTwEnd);
        }
        // waiting[k] is the waiting time at node k
        waiting = new HxExpression[nbNodes];
        for (int k = 0; k < nbNodes; ++k){
            waiting[k] = model.floatVar(0.0, depotTwEnd);
        }
        // Each node is taken by one vehicle
        model.constraint(model.partition(routes));

        HxExpression quantities = model.array(quantitiesData);
        HxExpression timeWarehouse = model.array(timeWarehouseData);
        HxExpression timeMatrix = model.array(timeMatrixData);
        for (int i = 0; i < nbNodes; ++i){
            timeMatrix.addOperand(model.array(timeMatrixData[i]));
        }
        HxExpression loadingTimes = model.array(loadingTimesData);
        HxExpression starts = model.array(startsData);
        HxExpression ends = model.array(endsData);
        HxExpression waitingArray = model.array(waiting);
        HxExpression distanceMatrix = model.array();
        for (int i = 0; i < nbNodes; ++i){
            distanceMatrix.addOperand(model.array(distanceMatrixData[i]));
        }
        HxExpression distanceWarehouse = model.array(distanceWarehouseData);

        HxExpression[] times = new HxExpression[nbVehicles];
        HxExpression[] lateness = new HxExpression[nbVehicles];
        HxExpression[] homeLateness = new HxExpression[nbVehicles];
        HxExpression[] routeDistances = new HxExpression[nbVehicles];

        for (int k = 0; k < nbVehicles; ++k){
            HxExpression route = routes[k];
            HxExpression c = model.count(route);

            HxExpression demandLambda = model.lambdaFunction(
                (i, prev) -> model.sum(prev, model.at(quantities, model.at(route, i))));
            // routeQuantities[k][i] indicates the number of clients in vehicle k
            // at its i-th taken node
            HxExpression routeQuantities = model.array(model.range(0, c), demandLambda);
            HxExpression quantityLambda = model.createLambdaFunction(
                i -> model.leq(model.at(routeQuantities, i), capacity));
            // Vehicles have a maximum capacity
            model.constraint(model.and(model.range(0, c), quantityLambda));

            HxExpression depotStartsk = depotStarts[k];
            HxExpression timesLambda = model.lambdaFunction(
                (i, prev) -> model.sum(
                    model.max(
                        model.at(starts, model.at(route, i)),
                        model.iif(
                            model.eq(i, 0),
                            model.sum(
                                depotStartsk,
                                model.at(timeWarehouse, model.at(route, 0))
                            ),
                            model.sum(
                                prev,
                                model.at(
                                    timeMatrix,
                                    model.at(route, model.sub(i, 1)),
                                    model.at(route, i)
                                )
                            )
                        )
                    ),
                    model.sum(
                        model.at(waitingArray, model.at(route, i)),
                        model.at(loadingTimes, model.at(route, i))
                    )
                )
            );
            HxExpression timesk = model.array(model.range(0, c), timesLambda);
            // times[k][i] is the time at which vehicle k leaves the i-th node
            // (after waiting and loading time at node i)
            times[k] = timesk;
            HxExpression latenessLambda = model.lambdaFunction(
                i -> model.max(
                    0,
                    model.sub(
                        model.sub(
                            model.at(timesk, i),
                            model.at(loadingTimes, model.at(route, i))
                        ),
                        model.at(ends, model.at(route, i))
                    )
                )
            );
            // Total lateness of the k-th route
            lateness[k] = model.sum(model.range(0, c), latenessLambda);

            homeLateness[k] = model.iif(
                model.gt(c, 0),
                model.max(
                    0,
                    model.sum(
                        model.at(timesk, model.sub(c, 1)),
                        model.sub(
                            model.at(timeWarehouse, model.at(route, model.sub(c, 1))),
                            depotTwEnd
                        )
                    )
                ),
                0
            );

            HxExpression routeDistLambda = model.lambdaFunction(
                i -> model.at(
                    distanceMatrix,
                    model.at(route, model.sub(i, 1)),
                    model.at(route, i)
                )
            );
            routeDistances[k] = model.sum(
                model.sum(model.range(1, c), routeDistLambda),
                model.iif(
                    model.gt(c, 0),
                    model.sum(
                        model.at(distanceWarehouse, model.at(route, 0)),
                        model.at(distanceWarehouse, model.at(route, model.sub(c, 1)))
                    ),
                    0
                )
            );
        }

        HxExpression routesArray = model.array(routes);
        HxExpression timesArray = model.array(times);
        HxExpression[] clientLateness = new HxExpression[nbClients];

        for (int k = 0; k < nbClients; ++k) {
            // For each pickup node k, its associated delivery node is k + nbClients
            HxExpression pickupListIndex = model.find(routesArray, k);
            HxExpression deliveryListIndex = model.find(routesArray, k + nbClients);
            // A client picked up in route i is delivered in route i
            model.constraint(model.eq(pickupListIndex, deliveryListIndex));

            HxExpression clientList = model.at(routesArray, pickupListIndex);
            HxExpression pickupIndex = model.indexOf(clientList, k);
            HxExpression deliveryList = model.at(routesArray, deliveryListIndex);
            HxExpression deliveryIndex = model.indexOf(deliveryList, k + nbClients);
            // Pickup before delivery
            model.constraint(model.lt(pickupIndex, deliveryIndex));

            HxExpression pickupTime = model.at(timesArray, pickupListIndex, pickupIndex);
            HxExpression deliveryTime = model.sub(
                model.at(timesArray, deliveryListIndex, deliveryIndex),
                model.at(loadingTimes, k + nbClients)
            );
            HxExpression travelTime = model.sub(deliveryTime, pickupTime);
            clientLateness[k] = model.max(model.sub(travelTime, maxTravelTimes[k]), 0);
        }

        HxExpression[] latenessPlusHomeLateness = new HxExpression[nbVehicles];
        for (int k = 0; k < nbVehicles; ++k) {
            latenessPlusHomeLateness[k] = model.sum(lateness[k], homeLateness[k]);
        }

        totalLateness = model.sum(latenessPlusHomeLateness);
        totalClientLateness = model.sum(clientLateness);
        totalDistance = model.sum(routeDistances);

        model.minimize(totalLateness);
        model.minimize(totalClientLateness);
        model.minimize(model.div(totalDistance, scale));

        model.close();

        optimizer.getParam().setTimeLimit(limit);

        optimizer.solve();
    }

    /* Write the solution in a file with the following format:
     *  - total lateness on the routes, total client lateness, total distance
     *  - for each vehicle, the depot start time, the nodes visited (omitting the start/end at the
     * depot), and the waiting time at each node */
    private void writeSolution(String fileName) throws IOException {
        try (PrintWriter output = new PrintWriter(fileName)) {
            output.println(
                totalLateness.getDoubleValue()
                + " "
                + totalClientLateness.getDoubleValue()
                + " "
                + totalDistance.getDoubleValue()
            );
            for (int k = 0; k < nbVehicles; ++k) {
                HxCollection route = routes[k].getCollectionValue();
                output.print("Vehicle " + (k + 1) + " (" + depotStarts[k].getDoubleValue() + "): ");
                for (int i = 0; i < route.count(); ++i) {
                    int routei = (int) route.get(i); // route.get(i) casted to long
                    output.print(routei + " (" + (waiting[routei].getDoubleValue() + "), "));
                }
                output.println();
            }
        }
    }

    public static void main(String[] args) {

        if (args.length < 1) {
            System.err.println("Usage: java Darp inputFile [outputFile] [timeLimit]");
            System.exit(1);
        }

        try {
            String instanceFile = args[0];
            String outputFile = args.length > 1 ? args[1] : null;
            String strTimeLimit = args.length > 2 ? args[2] : "20";

            Darp model = new Darp();
            model.readInstance(instanceFile);
            model.solve(Integer.parseInt(strTimeLimit));
            if (outputFile != null) {
                model.writeSolution(outputFile);
            }
        } catch (Exception ex) {
            System.err.println(ex);
            ex.printStackTrace();
            System.exit(1);
        }
    }
}