import hexaly.optimizer
import sys
import math


def read_elem(filename):
    with open(filename) as f:
        return [str(elem) for elem in f.read().split()]


def main(instance_file, str_time_limit, sol_file):
    #
    # Read instance data
    #
    nb_customers, nb_depots, vehicle_capacity, opening_route_cost, demands_data, \
        capacity_depots, opening_depots_cost, dist_matrix_data, dist_depots_data = \
        read_input_lrp(instance_file)

    min_nb_trucks = int(math.ceil(sum(demands_data) / vehicle_capacity))
    nb_trucks = int(math.ceil(1.5 * min_nb_trucks))

    with hexaly.optimizer.HexalyOptimizer() as optimizer:
        #
        # Declare the optimization model
        #
        m = optimizer.model

        # A route is represented as a list containing the customers in the order they are
        # visited
        customers_sequences = [m.list(nb_customers) for _ in range(nb_trucks)]
        # All customers should be assigned to a route
        m.constraint(m.partition(customers_sequences))

        # A depot is represented as a set containing the associated sequences
        depots = [m.set(nb_trucks) for _ in range(nb_depots)]
        # All the sequences should be assigned to a depot
        m.constraint(m.partition(depots))

        route_costs = [None] * nb_trucks
        sequence_used = [None] * nb_trucks
        dist_routes = [None] * nb_trucks
        associated_depot = [None] * nb_trucks

        # Create Hexaly arrays to be able to access them with "at" operators
        demands = m.array(demands_data)
        dist_matrix = m.array()
        dist_depots = m.array()
        quantity_served = m.array()
        for i in range(nb_customers):
            dist_matrix.add_operand(m.array(dist_matrix_data[i]))
            dist_depots.add_operand(m.array(dist_depots_data[i]))

        for r in range(nb_trucks):
            sequence = customers_sequences[r]
            c = m.count(sequence)

            # A sequence is used if it serves at least one customer
            sequence_used[r] = c > 0
            # The "find" function gets the depot that is assigned to the sequence
            associated_depot[r] = m.find(m.array(depots), r)

            # The quantity needed in each sequence must not exceed the vehicle capacity
            demand_lambda = m.lambda_function(lambda j: demands[j])
            quantity_served.add_operand(m.sum(sequence, demand_lambda))
            m.constraint(quantity_served[r] <= vehicle_capacity)

            # Distance traveled by each truck
            dist_lambda = m.lambda_function(
                lambda i: m.at(dist_matrix, sequence[i], sequence[i + 1]))
            depot = associated_depot[r]
            dist_routes[r] = m.sum(m.range(0, c - 1), dist_lambda) + m.iif(
                sequence_used[r],
                m.at(dist_depots, sequence[0], depot)
                + m.at(dist_depots, sequence[c - 1], depot),
                0)

            # The sequence cost is the sum of the opening cost and the sequence length
            route_costs[r] = sequence_used[r] * opening_route_cost + dist_routes[r]

        depot_cost = [None] * nb_depots
        for d in range(nb_depots):
            # A depot is open if at least one sequence starts from there
            depot_cost[d] = (m.count(depots[d]) > 0) * opening_depots_cost[d]

            # The total demand served by a depot must not exceed its capacity
            depot_lambda = m.lambda_function(lambda r: quantity_served[r])
            depot_quantity = m.sum(depots[d], depot_lambda)
            m.constraint(depot_quantity <= capacity_depots[d])

        depots_cost = m.sum(depot_cost)
        routing_cost = m.sum(route_costs)
        totalCost = routing_cost + depots_cost

        m.minimize(totalCost)

        m.close()

        # Parameterize the optimizer
        optimizer.param.time_limit = int(str_time_limit)

        optimizer.solve()

        if sol_file != None:
            with open(sol_file, 'w') as file:
                file.write("File name: %s; totalCost = %d \n" % (instance_file, totalCost.value))
                for r in range(nb_trucks):
                    if sequence_used[r].value:
                        file.write("Route %d, assigned to depot %d: " % (r, associated_depot[r].value))
                        for customer in customers_sequences[r].value:
                            file.write("%d " % customer)
                        file.write("\n")


def read_input_lrp_dat(filename):
    file_it = iter(read_elem(filename))

    nb_customers = int(next(file_it))
    nb_depots = int(next(file_it))

    x_depot = [None] * nb_depots
    y_depot = [None] * nb_depots
    for i in range(nb_depots):
        x_depot[i] = int(next(file_it))
        y_depot[i] = int(next(file_it))

    x_customer = [None] * nb_customers
    y_customer = [None] * nb_customers
    for i in range(nb_customers):
        x_customer[i] = int(next(file_it))
        y_customer[i] = int(next(file_it))

    vehicle_capacity = int(next(file_it))
    capacity_depots = [None] * nb_depots
    for i in range(nb_depots):
        capacity_depots[i] = int(next(file_it))

    demands = [None] * nb_customers
    for i in range(nb_customers):
        demands[i] = int(next(file_it))

    temp_opening_cost_depot = [None] * nb_depots
    for i in range(nb_depots):
        temp_opening_cost_depot[i] = float(next(file_it))
    temp_opening_route_cost = int(next(file_it))
    are_cost_double = int(next(file_it))

    opening_depots_cost = [None] * nb_depots
    if are_cost_double == 1:
        opening_depots_cost = temp_opening_cost_depot
        opening_route_cost = temp_opening_route_cost
    else:
        opening_route_cost = round(temp_opening_route_cost)
        for i in range(nb_depots):
            opening_depots_cost[i] = round(temp_opening_cost_depot[i])

    distance_customers = compute_distance_matrix(x_customer, y_customer, are_cost_double)
    distance_customers_depots = compute_distance_depot(x_customer, y_customer,
                                                       x_depot, y_depot, are_cost_double)

    return nb_customers, nb_depots, vehicle_capacity, opening_route_cost, demands, \
        capacity_depots, opening_depots_cost, distance_customers, distance_customers_depots

# Compute the distance matrix
def compute_distance_matrix(customers_x, customers_y, are_cost_double):
    nb_customers = len(customers_x)
    dist_customers = [[None for _ in range(nb_customers)] for _ in range(nb_customers)]
    for i in range(nb_customers):
        dist_customers[i][i] = 0
        for j in range(nb_customers):
            dist = compute_dist(customers_x[i], customers_x[j],
                                customers_y[i], customers_y[j], are_cost_double)
            dist_customers[i][j] = dist
            dist_customers[j][i] = dist
    return dist_customers

# Compute the distance depot matrix
def compute_distance_depot(customers_x, customers_y, depot_x, depot_y, are_cost_double):
    nb_customers = len(customers_x)
    nb_depots = len(depot_x)
    distance_customers_depots = [[None for _ in range(nb_depots)] for _ in range(nb_customers)]
    for i in range(nb_customers):
        for d in range(nb_depots):
            dist = compute_dist(customers_x[i], depot_x[d],
                                customers_y[i], depot_y[d], are_cost_double)
            distance_customers_depots[i][d] = dist
    return distance_customers_depots


def compute_dist(xi, xj, yi, yj, are_cost_double):
    dist = math.sqrt(math.pow(xi - xj, 2) + math.pow(yi - yj, 2))
    if are_cost_double == 0:
        dist = math.ceil(100 * dist)
    return dist


def read_input_lrp(filename):
    if filename.endswith(".dat"):
        return read_input_lrp_dat(filename)
    else:
        raise Exception("Unknown file format")


if __name__ == '__main__':
    if len(sys.argv) < 2:
        print("Usage: python location_routing_problem.py input_file \
            [output_file] [time_limit]")
        sys.exit(1)

    instance_file = sys.argv[1]
    sol_file = sys.argv[2] if len(sys.argv) > 2 else None
    str_time_limit = sys.argv[3] if len(sys.argv) > 3 else "20"

    main(instance_file, str_time_limit, sol_file)
