import hexaly.optimizer
import sys
import math


def read_elem(filename):
    with open(filename) as f:
        return [str(elem) for elem in f.read().split()]


def main(instance_file, str_time_limit, output_file):

    #
    # Read instance data
    #
    nb_customers, nb_trucks, truck_capacity, dist_matrix_data, dist_depot_data, \
        delivery_demands_data, pickup_demands_data, backhaul_data = read_input_vrpb(
            instance_file)

    with hexaly.optimizer.HexalyOptimizer() as optimizer:
        #
        # Declare the optimization model
        #
        model = optimizer.model

        # Sequence of customers visited by each truck
        customers_sequences = [model.list(nb_customers)
                               for _ in range(nb_trucks)]

        # All customers must be visited by exactly one truck
        model.constraint(model.partition(customers_sequences))

        # Create Hexaly arrays to be able to access them with an "at" operator
        delivery_demands = model.array(delivery_demands_data)
        pickup_demands = model.array(pickup_demands_data)
        dist_matrix = model.array(dist_matrix_data)
        dist_depot = model.array(dist_depot_data)

        # A truck is used if it visits at least one customer
        trucks_used = [(model.count(customers_sequences[k]) > 0)
                       for k in range(nb_trucks)]

        dist_routes = [None] * nb_trucks
        is_backhaul = model.array(backhaul_data.values())
        for k in range(nb_trucks):
            sequence = customers_sequences[k]
            c = model.count(sequence)

            # A pickup cannot be followed by a delivery
            precedency_lambda = model.lambda_function(lambda i: model.or_(model.not_(
                model.at(is_backhaul, sequence[i-1])), model.at(is_backhaul, sequence[i])))
            model.constraint(model.and_(model.range(1, c), precedency_lambda))

            # The quantity needed in each route must not exceed the truck capacity
            delivery_demand_lambda = model.lambda_function(
                lambda j: delivery_demands[j])
            route_pickup_quantity = model.sum(sequence, delivery_demand_lambda)
            model.constraint(route_pickup_quantity <= truck_capacity)

            pickup_demand_lambda = model.lambda_function(
                lambda j: pickup_demands[j])
            route_pickup_quantity = model.sum(sequence, pickup_demand_lambda)
            model.constraint(route_pickup_quantity <= truck_capacity)

            # Distance traveled by each truck
            dist_lambda = model.lambda_function(lambda i:
                                                model.at(dist_matrix,
                                                         sequence[i - 1],
                                                         sequence[i]))
            dist_routes[k] = model.sum(model.range(1, c), dist_lambda) \
                + model.iif(c > 0,
                            dist_depot[sequence[0]] +
                            dist_depot[sequence[c - 1]],
                            0)

        # Total number of trucks used
        nb_trucks_used = model.sum(trucks_used)

        # Total distance traveled
        total_distance = model.sum(dist_routes)

        # Objective: minimize the number of trucks used, then minimize the distance traveled
        model.minimize(nb_trucks_used)
        model.minimize(total_distance)

        model.close()

        # Parameterize the optimizer
        optimizer.param.time_limit = int(str_time_limit)

        optimizer.solve()

        #
        # Write the solution in a file with the following format:
        #  - number of trucks used and total distance
        #  - for each truck the customers visited (omitting the start/end at the depot)
        #
        if output_file is not None:
            with open(output_file, 'w') as f:
                f.write("%d %d\n" %
                        (nb_trucks_used.value, total_distance.value))
                for k in range(nb_trucks):
                    if trucks_used[k].value != 1:
                        continue
                    # Values in sequence are in 0...nbCustomers. +2 is to put it back
                    # in 2...nbCustomers+2 as in the data files (1 being the depot)
                    for customer in customers_sequences[k].value:
                        f.write("%d " % (customer + 2))
                    f.write("\n")


# The input files follow the "CVRPLib" format
def read_input_vrpb(filename):
    file_it = iter(read_elem(filename))

    nb_nodes = 0
    while True:
        token = next(file_it)
        if token == "DIMENSION":
            next(file_it)  # Removes the ":"
            nb_nodes = int(next(file_it))
            nb_customers = nb_nodes - 1
        elif token == "VEHICLES":
            next(file_it)  # Removes the ":"
            nb_trucks = int(next(file_it))
        elif token == "CAPACITY":
            next(file_it)  # Removes the ":"
            truck_capacity = int(next(file_it))
        elif token == "EDGE_WEIGHT_TYPE":
            next(file_it)  # Removes the ":"
            token = next(file_it)
            if token != "EXACT_2D":
                print("Edge Weight Type " + token +
                      " is not supported (only EXACT_2D)")
                sys.exit(1)
        elif token == "NODE_COORD_SECTION":
            break

    customers_x = [None] * nb_customers
    customers_y = [None] * nb_customers
    depot_x = 0
    depot_y = 0
    for n in range(nb_nodes):
        node_id = int(next(file_it))
        if node_id != n + 1:
            print("Unexpected index")
            sys.exit(1)
        if node_id == 1:
            depot_x = int(next(file_it))
            depot_y = int(next(file_it))
        else:
            # -2 because original customer indices are in 2..nbNodes
            customers_x[node_id - 2] = int(next(file_it))
            customers_y[node_id - 2] = int(next(file_it))

    distance_matrix = compute_distance_matrix(customers_x, customers_y)
    distance_depots = compute_distance_depots(
        depot_x, depot_y, customers_x, customers_y)

    token = next(file_it)
    if token != "DEMAND_SECTION":
        print("Expected token DEMAND_SECTION")
        sys.exit(1)

    demands = [None] * nb_customers
    for n in range(nb_nodes):
        node_id = int(next(file_it))
        if node_id != n + 1:
            print("Unexpected index")
            sys.exit(1)
        if node_id == 1:
            if int(next(file_it)) != 0:
                print("Demand for depot should be 0")
                sys.exit(1)
        else:
            # -2 because original customer indices are in 2..nbNodes
            demands[node_id - 2] = int(next(file_it))

    token = next(file_it)
    if token != "BACKHAUL_SECTION":
        print("Expected token BACKHAUL_SECTION")
        sys.exit(1)

    is_backhaul = {i: False for i in range(nb_customers)}
    while True:
        node_id = int(next(file_it))
        if node_id == -1:
            break
        # -2 because original customer indices are in 2..nbNodes
        is_backhaul[node_id - 2] = True
    delivery_demands = [0 if is_backhaul[i] else demands[i]
                        for i in range(nb_customers)]
    pickup_demands = [demands[i] if is_backhaul[i]
                      else 0 for i in range(nb_customers)]

    token = next(file_it)
    if token != "DEPOT_SECTION":
        print("Expected token DEPOT_SECTION")
        sys.exit(1)

    depot_id = int(next(file_it))
    if depot_id != 1:
        print("Depot id is supposed to be 1")
        sys.exit(1)

    end_of_depot_section = int(next(file_it))
    if end_of_depot_section != -1:
        print("Expecting only one depot, more than one found")
        sys.exit(1)

    return nb_customers, nb_trucks, truck_capacity, distance_matrix, distance_depots, \
        delivery_demands, pickup_demands, is_backhaul


# Compute the distance matrix
def compute_distance_matrix(customers_x, customers_y):
    nb_customers = len(customers_x)
    distance_matrix = [[None for i in range(
        nb_customers)] for j in range(nb_customers)]
    for i in range(nb_customers):
        distance_matrix[i][i] = 0
        for j in range(nb_customers):
            dist = compute_dist(
                customers_x[i], customers_x[j], customers_y[i], customers_y[j])
            distance_matrix[i][j] = dist
            distance_matrix[j][i] = dist
    return distance_matrix


# Compute the distances to depot
def compute_distance_depots(depot_x, depot_y, customers_x, customers_y):
    nb_customers = len(customers_x)
    distance_depots = [None] * nb_customers
    for i in range(nb_customers):
        dist = compute_dist(depot_x, customers_x[i], depot_y, customers_y[i])
        distance_depots[i] = dist
    return distance_depots


def compute_dist(xi, xj, yi, yj):
    exact_dist = math.sqrt(math.pow(xi - xj, 2) + math.pow(yi - yj, 2))
    return int(math.floor(exact_dist + 0.5))


if __name__ == '__main__':
    if len(sys.argv) < 2:
        print(
            "Usage: python vehicle_routing_backhauls.py input_file [output_file] [time_limit]")
        sys.exit(1)

    instance_file = sys.argv[1]
    output_file = sys.argv[2] if len(sys.argv) > 2 else None
    str_time_limit = sys.argv[3] if len(sys.argv) > 3 else "20"

    main(instance_file, str_time_limit, output_file)
