import hexaly.optimizer
import sys
import math

def read_elem(filename):
    with open(filename) as f:
        return [str(elem) for elem in f.read().split()]
    

def main(instance_file, str_time_limit, output_file):
    # Read instance data
    nb_customers, nb_trucks, nb_clusters, truck_capacity, dist_matrix_data, dist_depot_data, \
        demands_data, clusters_data = read_input_cvrp(instance_file)


    with hexaly.optimizer.HexalyOptimizer() as optimizer:
        # Declare the optimization model
        model = optimizer.model

        # Create Hexaly arrays to be able to access them with an "at" operator
        demands = model.array(demands_data)
        dist_matrix = model.array(dist_matrix_data)
        clusters = model.array(clusters_data)
        dist_depot = model.array(dist_depot_data)
        
        # A list is created for each cluster, to determine the order within the cluster
        clusters_sequences = []
        for k in range(nb_clusters):
            clusters_sequences.append(model.list(len(clusters_data[k])))
            # All customers in the cluster must be visited 
            model.constraint(model.count(clusters_sequences[k]) == len(clusters_data[k]))

        clustersDistances = model.array()
        initialNodes = model.array()
        endNodes = model.array()
        for k in range(nb_clusters):
            sequence = clusters_sequences[k]
            c = model.count(sequence)
            # Distance traveled within cluster k
            clustersDistances_lambda = model.lambda_function(lambda i:
                    model.at(dist_matrix, clusters[k][sequence[i - 1]],
                    clusters[k][sequence[i]]))
            clustersDistances.add_operand(model.sum(model.range(1,c), 
                    clustersDistances_lambda))
            
            # First and last point when visiting cluster k
            initialNodes.add_operand(clusters[k][sequence[0]]) 
            endNodes.add_operand(clusters[k][sequence[c - 1]])
        
        # Sequences of clusters visited by each truck
        truckSequences = [model.list(nb_clusters) for _ in range(nb_trucks)]

        # Each cluster must be visited by exactly one truck
        model.constraint(model.partition(truckSequences))

        routeDistances = [None] * nb_trucks
        for k in range(nb_trucks):
            sequence = truckSequences[k]
            c = model.count(sequence)

            # The quantity needed in each route must not exceed the truck capacity
            demand_lambda = model.lambda_function(lambda j: demands[j])
            route_quantity = model.sum(sequence, demand_lambda)
            model.constraint(route_quantity <= truck_capacity)

            # Distance traveled by each truck
            # = distance in each cluster + distance between clusters + distance with depot at 
            # the beginning end at the end of a route
            routeDistances_lambda = model.lambda_function(lambda i:
                    model.at(clustersDistances, sequence[i]) + model.at(dist_matrix,
                    endNodes[sequence[i - 1]], initialNodes[sequence[i]]))
            routeDistances[k] = model.sum(model.range(1, c), routeDistances_lambda) \
                    + model.iif(c > 0, model.at(clustersDistances, sequence[0]) 
                    + dist_depot[initialNodes[sequence[0]]]
                    + dist_depot[endNodes[sequence[c - 1]]], 0)

        # Total distance traveled
        total_distance = model.sum(routeDistances)

        # Objective:  minimize the distance traveled
        model.minimize(total_distance)

        model.close()

        # Parameterize the optimizer
        optimizer.param.time_limit = int(str_time_limit)

        optimizer.solve() 
        # Write the solution in a file with the following format:
        # - total distance
        # - for each truck the customers visited (omitting the start/end at the depot)
        if output_file is not None:
            with open(output_file, 'w') as f:
                f.write("%d\n" % (total_distance.value))
                for k in range(nb_trucks):
                    # Values in sequence are in [0..nbCustomers - 1]. +2 is to put it back
                    # in [2..nbCustomers+1] as in the data files (1 being the depot)
                    for cluster in truckSequences[k].value:
                        for customer in clusters_sequences[cluster].value:
                            f.write("%d " % (clusters_data[cluster][customer] + 2))
                    f.write("\n")

# The input files follow the "Augerat" format
def read_input_cvrp(filename):
    file_it = iter(read_elem(filename))

    nb_nodes = 0
    while True:
        token = next(file_it)
        if token == "DIMENSION:":
            nb_nodes = int(next(file_it))
            nb_customers = nb_nodes - 1
        if token == "VEHICLES:":
            nb_trucks = int(next(file_it))
        elif token == "GVRP_SETS:":
            nb_clusters = int(next(file_it))
        elif token == "CAPACITY:":
            truck_capacity = int(next(file_it))
        elif token == "NODE_COORD_SECTION":
            break

    customers_x = [None] * nb_customers
    customers_y = [None] * nb_customers
    depot_x = 0
    depot_y = 0
    for n in range(nb_nodes):
        node_id = int(next(file_it))
        if node_id != n + 1:
            print("Unexpected index")
            sys.exit(1)
        if node_id == 1:
            depot_x = int(float(next(file_it)))
            depot_y = int(float(next(file_it)))
        else:
            # -2 because original customer indices are in 2..nbNodes
            customers_x[node_id - 2] = int(float(next(file_it)))
            customers_y[node_id - 2] = int(float(next(file_it)))

    distance_matrix = compute_distance_matrix(customers_x, customers_y)
    distance_depots = compute_distance_depots(depot_x, depot_y, customers_x, customers_y)
    
    token = next(file_it)
    if token != "GVRP_SET_SECTION":
        print("Expected token GVRP_SET_SECTION")
        sys.exit(1)
    clusters_data = [None]*nb_clusters
    for n in range(nb_clusters):
        node_id = int(next(file_it))
        if node_id != n + 1:
            print("Unexpected index")
            sys.exit(1)
        cluster = []
        value = int(next(file_it))
        while value != -1:
            # -2 because original customer indices are in 2..nbNodes
            cluster.append(value-2)
            value = int(next(file_it))
        clusters_data[n] = cluster
    token = next(file_it)
    if token != "DEMAND_SECTION":
        print("Expected token DEMAND_SECTION")
        sys.exit(1)

    demands = [None] * nb_clusters
    for n in range(nb_clusters):
        node_id = int(next(file_it))
        if node_id != n + 1:
            print("Unexpected index")
            sys.exit(1)
        demands[n] = int(next(file_it))
    return nb_customers, nb_trucks, nb_clusters, truck_capacity, distance_matrix, \
        distance_depots, demands, clusters_data

# Compute the distance matrix
def compute_distance_matrix(customers_x, customers_y):
    nb_customers = len(customers_x)
    distance_matrix = [[None for i in range(nb_customers)] for j in range(nb_customers)]
    for i in range(nb_customers):
        distance_matrix[i][i] = 0
        for j in range(nb_customers):
            dist = compute_dist(customers_x[i], customers_x[j], customers_y[i], customers_y[j])
            distance_matrix[i][j] = dist
            distance_matrix[j][i] = dist
    return distance_matrix

# Compute the distances to depot
def compute_distance_depots(depot_x, depot_y, customers_x, customers_y):
    nb_customers = len(customers_x)
    distance_depots = [None] * nb_customers
    for i in range(nb_customers):
        dist = compute_dist(depot_x, customers_x[i], depot_y, customers_y[i])
        distance_depots[i] = dist
    return distance_depots


def compute_dist(xi, xj, yi, yj):
    exact_dist = math.sqrt(math.pow(xi - xj, 2) + math.pow(yi - yj, 2))
    return int(math.floor(exact_dist + 0.5))


if __name__ == '__main__':
    if len(sys.argv) < 2:
        print("Usage: python clustered-vehicle-routing.py input_file [output_file] [time_limit]")
        sys.exit(1)

    instance_file = sys.argv[1]
    output_file = sys.argv[2] if len(sys.argv) > 2 else None
    str_time_limit = sys.argv[3] if len(sys.argv) > 3 else "20"

    main(instance_file, str_time_limit, output_file)



