import hexaly.optimizer
import sys
import json

def read_data(filename):
    with open(filename) as f:
        return json.load(f)
    
def read_input_darp(instance_file):
    instance = read_data(instance_file)

    nb_clients = instance["nbClients"]
    nb_nodes = instance["nbNodes"]
    nb_vehicles = instance["nbVehicles"]
    depot_tw_end = instance["depot"]["twEnd"]
    capacity = instance["capacity"]
    scale = instance["scale"]

    quantities = [-1 for i in range(2 * nb_clients)]
    distances = instance["distanceMatrix"]
    starts = [-1.0 for i in range(2 * nb_clients)]
    ends = [-1.0 for i in range(2 * nb_clients)]
    loading_times = [-1.0 for i in range(2 * nb_clients)]
    max_travel_times = [-1.0 for i in range(2 * nb_clients)]
    for k in range(nb_clients):
        quantities[k] = instance["clients"][k]["nbClients"]
        quantities[k+nb_clients] = -instance["clients"][k]["nbClients"]

        starts[k] = instance["clients"][k]["pickup"]["start"]
        ends[k] = instance["clients"][k]["pickup"]["end"]
        
        starts[k+nb_clients] = instance["clients"][k]["delivery"]["start"]
        ends[k+nb_clients] = instance["clients"][k]["delivery"]["end"]

        loading_times[k] = instance["clients"][k]["pickup"]["loadingTime"]
        loading_times[k+nb_clients] = instance["clients"][k]["delivery"]["loadingTime"]

        max_travel_times[k] = instance["clients"][k]["pickup"]["maxTravelTime"]
        max_travel_times[k+nb_clients] = instance["clients"][k]["delivery"]["maxTravelTime"]
        
    factor = 1.0 / (scale * instance["speed"])

    distance_warehouse = [-1.0 for i in range(nb_nodes)]
    time_warehouse = [-1.0 for i in range(nb_nodes)]
    distance_matrix = [[-1.0 for i in range(nb_nodes)] for j in range(nb_nodes)]
    time_matrix = [[-1.0 for i in range(nb_nodes)] for j in range(nb_nodes)]
    for i in range(nb_nodes):
        distance_warehouse[i] = distances[0][i+1]
        time_warehouse[i] = distance_warehouse[i] * factor
        for j in range(nb_nodes):
            distance_matrix[i][j] = distances[i+1][j+1]
            time_matrix[i][j] = distance_matrix[i][j] * factor

    return nb_clients, nb_nodes, nb_vehicles, depot_tw_end, capacity, scale, quantities, \
        starts, ends, loading_times, max_travel_times, distance_warehouse, time_warehouse, \
        distance_matrix, time_matrix

def main(instance_file, str_time_limit, sol_file):

    nb_clients, nb_nodes, nb_vehicles, depot_tw_end, capacity, scale, quantities_data, \
        starts_data, ends_data, loading_times_data, max_travel_times, distance_warehouse_data, \
        time_warehouse_data, distance_matrix_data, time_matrix_data = read_input_darp(instance_file)

    with hexaly.optimizer.HexalyOptimizer() as optimizer:
        model = optimizer.model

        # routes[k] represents the nodes visited by vehicle k
        routes = [model.list(nb_nodes) for k in range(nb_vehicles)]
        depot_starts = [model.float(0, depot_tw_end) for k in range(nb_vehicles)]
        # Each node is taken by one vehicle
        model.constraint(model.partition(routes))

        quantities = model.array(quantities_data)
        time_warehouse = model.array(time_warehouse_data)
        time_matrix = model.array(time_matrix_data)
        loading_times = model.array(loading_times_data)
        starts = model.array(starts_data)
        ends = model.array(ends_data)
        # waiting[k] is the waiting time at node k
        waiting = [model.float(0, depot_tw_end) for k in range(nb_nodes)]
        waiting_array = model.array(waiting)
        distance_matrix = model.array(distance_matrix_data)
        distance_warehouse = model.array(distance_warehouse_data)

        times = [None] * nb_vehicles
        lateness = [None] * nb_vehicles
        home_lateness = [None] * nb_vehicles
        route_distances = [None] * nb_vehicles

        for k in range(nb_vehicles):
            route = routes[k]
            c = model.count(route)

            demand_lambda = model.lambda_function(lambda i, prev: prev + quantities[route[i]])
            # route_quantities[k][i] indicates the number of clients in vehicle k
            # at its i-th taken node
            route_quantities = model.array(model.range(0, c), demand_lambda)
            quantity_lambda = model.lambda_function(lambda i: route_quantities[i] <= capacity)
            # Vehicles have a maximum capacity
            model.constraint(model.and_(model.range(0, c), quantity_lambda))

            times_lambda = model.lambda_function(
                lambda i, prev: model.max(
                    starts[route[i]],
                    model.iif(
                        i == 0,
                        depot_starts[k] + time_warehouse[route[0]],
                        prev + time_matrix[route[i-1]][route[i]]
                    )
                ) + waiting_array[route[i]] + loading_times[route[i]]
            )
            # times[k][i] is the time at which vehicle k leaves the i-th node
            # (after waiting and loading time at node i)
            times[k] = model.array(model.range(0, c), times_lambda)

            lateness_lambda = model.lambda_function(
                lambda i: model.max(
                    0,
                    times[k][i] - loading_times[route[i]] - ends[route[i]]
                )
            )
            # Total lateness of the k-th route
            lateness[k] = model.sum(model.range(0, c), lateness_lambda)

            home_lateness[k] = model.iif(
                c > 0,
                model.max(0, times[k][c-1] + time_warehouse[route[c-1]] - depot_tw_end),
                0
            )

            route_dist_lambda = model.lambda_function(
                lambda i: distance_matrix[route[i-1]][route[i]]
            )
            route_distances[k] = model.sum(
                model.range(1, c),
                route_dist_lambda
            ) + model.iif(
                c > 0,
                distance_warehouse[route[0]] + distance_warehouse[route[c-1]],
                0
            )

        routes_array = model.array(routes)
        times_array = model.array(times)
        client_lateness = [None] * nb_clients
        for k in range(nb_clients):
            # For each pickup node k, its associated delivery node is k + nb_clients
            pickup_list_index = model.find(routes_array, k)
            delivery_list_index = model.find(routes_array, k + nb_clients)
            # A client picked up in route i is delivered in route i
            model.constraint(pickup_list_index == delivery_list_index)

            client_list = routes_array[pickup_list_index]
            pickup_index = model.index(client_list, k)
            delivery_list = routes_array[delivery_list_index]
            delivery_index = model.index(delivery_list, k + nb_clients)
            # Pickup before delivery
            model.constraint(pickup_index < delivery_index)

            pickup_time = times_array[pickup_list_index][pickup_index]
            delivery_time = times_array[delivery_list_index][delivery_index] \
                - loading_times[k + nb_clients]
            travel_time = delivery_time - pickup_time
            client_lateness[k] = model.max(travel_time - max_travel_times[k], 0)

        total_lateness = model.sum(lateness + home_lateness)
        total_client_lateness = model.sum(client_lateness)
        total_distance = model.sum(route_distances)

        model.minimize(total_lateness)
        model.minimize(total_client_lateness)
        model.minimize(total_distance / scale)

        model.close()
        optimizer.param.time_limit = int(str_time_limit)
        optimizer.solve()

        #
        # Write the solution in a file with the following format:
        #  - total lateness on the routes, total client lateness, and total distance
        #  - for each vehicle, the depot start time, the nodes visited (omitting the start/end at the
        # depot), and the waiting time at each node
        #
        if sol_file is not None:
            with open(sol_file, 'w') as f:
                f.write("%d %d %.2f\n" % (
                    total_lateness.value,
                    total_client_lateness.value,
                    total_distance.value
                ))
                for k in range(nb_vehicles):
                    f.write("Vehicle %d (%.2f): " %(k + 1, depot_starts[k].value))
                    for node in routes[k].value:
                        f.write("%d (%.2f), " % (node, waiting[node].value))
                    f.write("\n")
    return 0

if __name__ == '__main__':
    if len(sys.argv) < 2:
        print("Usage: python darp.py input_file [output_file] [time_limit]")
        sys.exit(1)

    instance_file = sys.argv[1]
    sol_file = sys.argv[2] if len(sys.argv) > 2 else None
    str_time_limit = sys.argv[3] if len(sys.argv) > 3 else "20"

    main(instance_file, str_time_limit, sol_file)