import hexaly.optimizer
import sys
import math


def read_elem(filename):
    with open(filename) as f:
        return [str(elem) for elem in f.read().split()]


def main(instance_file, str_time_limit, sol_file):
    #
    # Read instance data
    #
    nb_customers, nb_facilities, capacity, customers_demands, \
        depot_distances_data, distance_matrix_data, assignement_costs_data = read_input(instance_file)

    # A point is either a customer or a facility
    # Facilities are duplicated for each customer
    nb_points = nb_customers + nb_customers * nb_facilities

    demands_data = [None] * nb_points
    for c in range(nb_customers):
        demands_data[c] = customers_demands[c]
        for f in range(nb_facilities):
            demands_data[nb_customers + c * nb_facilities + f] = customers_demands[c]

    min_nb_trucks = int(math.ceil(sum(customers_demands) / capacity))
    nb_trucks = int(math.ceil(1.5 * min_nb_trucks))

    with hexaly.optimizer.HexalyOptimizer() as optimizer:
        #
        # Declare the optimization model
        #
        m = optimizer.model

        # Each route is represented as a list containing the points in the order they are visited
        routes_sequences = [m.list(nb_points) for _ in range(nb_trucks)]
        routes = m.array(routes_sequences)

        # Each point must be visited at most once
        m.constraint(m.disjoint(routes_sequences))

        dist_routes = [None] * nb_trucks
        assignement_cost_routes = [None] * nb_trucks

        # Create Hexaly arrays to be able to access them with "at" operators
        demands = m.array(demands_data)
        dist_matrix = m.array()
        dist_depots = m.array(depot_distances_data)
        assignement_costs = m.array(assignement_costs_data)
        for i in range(nb_points):
            dist_matrix.add_operand(m.array(distance_matrix_data[i]))

        for c in range(nb_customers):
            start_facilities = nb_customers + c * nb_facilities
            end_facilities = start_facilities + nb_facilities

            # Each customer is either contained in a route or assigned to a facility
            facility_used = [m.contains(routes, f) for f in range(start_facilities, end_facilities)]
            delivery_count = m.contains(routes, c) + m.sum(facility_used)
            m.constraint(delivery_count == 1)

        for r in range(nb_trucks):
            route = routes_sequences[r]
            c = m.count(route)

            # Each truck cannot carry more than its capacity
            demand_lambda = m.lambda_function(lambda j: demands[j])
            quantity_served = m.sum(route, demand_lambda)
            m.constraint(quantity_served <= capacity)

            # Distance traveled by each truck
            dist_lambda = m.lambda_function(
                lambda i: m.at(dist_matrix, route[i], route[i + 1]))
            dist_routes[r] = m.sum(m.range(0, c - 1), dist_lambda) + m.iif(
                c > 0,
                m.at(dist_depots, route[0])
                + m.at(dist_depots, route[c - 1]),
                0)
            
            # Cost to assign customers to their facility
            assignment_cost_lambda = m.lambda_function(
                lambda i: assignement_costs[i]
            )
            assignement_cost_routes[r] = m.sum(route, assignment_cost_lambda)

        # The total distance travelled
        total_distance_cost = m.sum(dist_routes)
        # The total assignement cost
        total_assignement_cost = m.sum(assignement_cost_routes)

        # Objective: minimize the sum of the total distance travelled and the total assignement cost
        total_cost = total_distance_cost + total_assignement_cost

        m.minimize(total_cost)

        m.close()

        # Parameterize the optimizer
        optimizer.param.time_limit = int(str_time_limit)

        optimizer.solve()

        if sol_file != None:
            with open(sol_file, 'w') as file:
                file.write("File name: {}; totalCost = {}; totalDistance = {}; totalAssignementCost = {}\n"
                           .format(instance_file, total_cost.value, total_distance_cost.value, total_assignement_cost.value))
                for r in range(nb_trucks):
                    route = routes_sequences[r].value
                    if len(route) == 0:
                        continue
                    file.write("Route {} [".format(r))
                    for i, point in enumerate(route):
                        if point < nb_customers:
                            file.write("Customer {}".format(point))
                        else:
                            file.write("Facility {} assigned to Customer {}"
                                       .format(point % nb_customers, (point - nb_customers) // nb_facilities))
                        if i < len(route) - 1:
                            file.write(", ")
                    file.write("]\n")


def read_input_dat(filename):
    file_it = iter(read_elem(filename))

    nb_customers = int(next(file_it))
    nb_facilities = int(next(file_it))

    facilities_x = [None] * nb_facilities
    facilities_y = [None] * nb_facilities
    for i in range(nb_facilities):
        facilities_x[i] = int(next(file_it))
        facilities_y[i] = int(next(file_it))

    customers_x = [None] * nb_customers
    customers_y = [None] * nb_customers
    for i in range(nb_customers):
        customers_x[i] = int(next(file_it))
        customers_y[i] = int(next(file_it))

    truck_capacity = int(next(file_it))

    # Facility capacities : skip
    for f in range(nb_facilities):
        next(file_it)

    customer_demands = [None] * nb_customers
    for i in range(nb_customers):
        customer_demands[i] = int(next(file_it))

    depot_x, depot_y = compute_depot_coordinates(customers_x, customers_y,
                                                 facilities_x, facilities_y)
    depot_distances, distance_matrix = compute_distances(customers_x, customers_y,
                                                         facilities_x, facilities_y,
                                                         depot_x, depot_y)
    assignement_costs = compute_assignment_costs(nb_customers, nb_facilities, distance_matrix)

    return nb_customers, nb_facilities, truck_capacity, customer_demands, \
        depot_distances, distance_matrix, assignement_costs


def compute_depot_coordinates(customers_x, customers_y, facilities_x, facilities_y):
    # Compute the coordinates of the bounding box containing all of the points
    x_min = min(min(customers_x), min(facilities_x))
    x_max = max(max(customers_x), max(facilities_x))
    y_min = min(min(customers_y), min(facilities_y))
    y_max = max(max(customers_y), max(facilities_y))

    # We assume that the depot is at the center of the bounding box
    return x_min + (x_max - x_min) // 2, y_min + (y_max - y_min) // 2


def compute_distances(customers_x, customers_y, facilities_x, facilities_y, depot_x, depot_y):
    nb_customers = len(customers_x)
    nb_facilities = len(facilities_x)
    nb_points = nb_customers + nb_customers * nb_facilities


    # Distance to depot
    depot_distances = [None] * nb_points

    # Customer to depot
    for c in range(nb_customers):
        depot_distances[c] = compute_dist(customers_x[c], depot_x, customers_y[c], depot_y)

    # Facility to depot
    for c in range(nb_customers):
        for f in range(nb_facilities):
            depot_distances[nb_customers + c * nb_facilities + f] = \
                compute_dist(facilities_x[f], depot_x, facilities_y[f], depot_y)

    # Distance between points
    distance_matrix = [[None for _ in range(nb_points)] for _ in range(nb_points)]

    # Distances between customers
    for c_1 in range(nb_customers):
        for c_2 in range(nb_customers):
            distance_matrix[c_1][c_2] = \
                compute_dist(customers_x[c_1], customers_x[c_2],
                             customers_y[c_1], customers_y[c_2])

    # Distances between customers and facilities
    for c_1 in range(nb_customers):
        for f in range(nb_facilities):
            distance = compute_dist(facilities_x[f], customers_x[c_1],
                                    facilities_y[f], customers_y[c_1])
            for c_2 in range(nb_customers):
                # Index representing serving c_2 through facility f
                facility_index = nb_customers + c_2 * nb_facilities + f
                distance_matrix[facility_index][c_1] = distance
                distance_matrix[c_1][facility_index] = distance

    # Distances between facilities
    for f_1 in range(nb_facilities):
        for f_2 in range(nb_facilities):
            dist = compute_dist(facilities_x[f_1], facilities_x[f_2], facilities_y[f_1], facilities_y[f_2])
            for c_1 in range(nb_customers):
                for c_2 in range(nb_customers):
                    distance_matrix[nb_customers + c_1 * nb_facilities + f_1]\
                        [nb_customers + c_2 * nb_facilities + f_2] = dist

    return depot_distances, distance_matrix


def compute_assignment_costs(nb_customers, nb_facilities, distance_matrix):
    # Compute assignment cost for each point
    nb_points = nb_customers + nb_customers * nb_facilities
    assignment_costs = [0] * nb_points
    for c in range(nb_customers):
        for f in range(nb_facilities):
            #  Cost of serving customer c through facility f
            assignment_costs[nb_customers + c * nb_facilities + f] = \
                distance_matrix[c][nb_customers + c * nb_facilities + f]
    return assignment_costs


def compute_dist(xi, xj, yi, yj):
    exact_dist = math.sqrt(math.pow(xi - xj, 2) + math.pow(yi - yj, 2))
    return round(exact_dist)


def read_input(filename):
    if filename.endswith(".dat"):
        return read_input_dat(filename)
    else:
        raise Exception("Unknown file format")


if __name__ == '__main__':
    if len(sys.argv) < 2:
        print("Usage: python vrptf.py input_file \
            [output_file] [time_limit]")
        sys.exit(1)

    instance_file = sys.argv[1]
    sol_file = sys.argv[2] if len(sys.argv) > 2 else None
    str_time_limit = sys.argv[3] if len(sys.argv) > 3 else "20"

    main(instance_file, str_time_limit, sol_file)