import sys
import hexaly.optimizer


def main(instance_file, output_file, str_time_limit):
    #
    # Read instance data
    #
    instance = CarpInstance(instance_file)

    with hexaly.optimizer.HexalyOptimizer() as optimizer:
        #
        # Declare the optimization model
        #
        model = optimizer.model

        # Sequence of edges visited and "serviced" by each truck
        edges_sequences_vars = [model.list(2 * instance.nb_required_edges)
                                for _ in range(instance.nb_trucks)]
        edges_sequences = model.array(edges_sequences_vars)

        # Create distance and cost arrays to be able to access it with an "at" operator
        demands = model.array(instance.demands_data)
        costs = model.array(instance.costs_data)
        dist_from_depot = model.array(instance.dist_from_depot_data)
        dist_to_depot = model.array(instance.dist_to_depot_data)
        edges_dist = model.array()
        for n in range(2 * instance.nb_required_edges):
            edges_dist.add_operand(model.array(instance.edges_dist_data[n]))

        # An edge must be serviced by at most one truck
        model.constraint(model.disjoint(edges_sequences))

        # An edge can be travelled in both directions but its demand must be 
        # satisfied only once
        for i in range(instance.nb_required_edges):
            model.constraint(
                model.contains(edges_sequences, 2 * i)
                + model.contains(edges_sequences, 2 * i + 1) 
                == 1)

        route_distances = [None] * instance.nb_trucks
        for k in range(instance.nb_trucks):
            sequence = edges_sequences_vars[k]
            c = model.count(sequence)

            # Quantity in each truck
            demand_lambda = model.lambda_function(lambda j: demands[j])
            route_quantity = model.sum(sequence, demand_lambda)
            # Capacity constraint : a truck must not exceed its capacity
            model.constraint(route_quantity <= instance.truck_capacity)

            # Distance travelled by each truck
            dist_lambda = model.lambda_function(
                lambda i:
                    costs[sequence[i]]
                    + model.at(
                        edges_dist,
                        sequence[i - 1],
                        sequence[i]))
            route_distances[k] = model.sum(model.range(1, c), dist_lambda) \
                + model.iif(
                    c > 0,
                    costs[sequence[0]] + dist_from_depot[sequence[0]] \
                        + dist_to_depot[sequence[c - 1]],
                    0)

        # Total distance travelled
        total_distance = model.sum(route_distances)

        # Objective: minimize the distance travelled
        model.minimize(total_distance)

        model.close()

        # Parameterize the optimizer
        optimizer.param.time_limit = int(str_time_limit)

        optimizer.solve()

        #
        # Write the solution in a file with the following format:
        #  - total distance
        #  - number of routes
        #  - for each truck, the edges visited
        #
        if output_file:
            with open(output_file, 'w') as f:
                f.write("Objective function value : {}\nNumber of routes : {}\n".format(
                    total_distance.value, instance.nb_trucks))
                for k in range(instance.nb_trucks):
                    f.write("Sequence of truck {}: ".format(k + 1))
                    sequence = edges_sequences_vars[k].value
                    c = len(sequence)
                    for i in range(c):
                        f.write("({}, {})  ".format(instance.origins_data[sequence[i]],
                                instance.destinations_data[sequence[i]]))
                    f.write("\n")


class CarpInstance:

    def read_elem(self, filename):
        with open(filename) as f:
            return [str(elem) for elem in f.read().strip().split("\n")]

    # The input files follow the format of the DIMACS challenge
    def __init__(self, filename):
        file_it = iter(self.read_elem(filename))
        for _ in range(2):
            next(file_it)
        nb_nodes = int(next(file_it).strip().split(":")[1])
        self.nb_required_edges = int(next(file_it).strip().split(":")[1])
        nb_not_required_edges = int(next(file_it).strip().split(":")[1])
        self.nb_trucks = int(next(file_it).strip().split(":")[1])
        self.truck_capacity = int(next(file_it).strip().split(":")[1])
        for _ in range(3):
            next(file_it)
        self.demands_data = list()
        self.costs_data = list()
        self.origins_data = list()
        self.destinations_data = list()
        required_nodes = list()
        node_neighbors = [([0] * nb_nodes) for _ in range(nb_nodes)]
        for _ in range(self.nb_required_edges):
            elements = next(file_it)
            edge = tuple(map(int, elements.strip().split("   ")[0][2:-1].strip().split(",")))
            cost = int(elements.strip().split("   ")[1].strip().split()[1])
            demand = int(elements.strip().split("   ")[2].strip().split()[1])
            for _ in range(2):
                self.costs_data.append(cost)
                self.demands_data.append(demand)
            # even indices store direct edges, and odd indices store reverse edges
            self.origins_data.append(edge[0])
            self.destinations_data.append(edge[1])
            self.origins_data.append(edge[1])
            self.destinations_data.append(edge[0])
            if edge[0] not in required_nodes:
                required_nodes.append(edge[0])
            if edge[1] not in required_nodes:
                required_nodes.append(edge[1])
            node_neighbors[edge[0] - 1][edge[1] - 1] = cost
            node_neighbors[edge[1] - 1][edge[0] - 1] = cost
        if nb_not_required_edges > 0:
            next(file_it)
            for _ in range(nb_not_required_edges):
                elements = next(file_it)
                edge = tuple(map(int, elements.strip().split("   ")[0][2:-1].strip().split(",")))
                cost = int(elements.strip().split("   ")[1].strip().split()[1])
                node_neighbors[edge[0] - 1][edge[1] - 1] = cost
                node_neighbors[edge[1] - 1][edge[0] - 1] = cost
        depot_node = int(next(file_it).strip().split(":")[1])
        # Finds the shortest path from one "required node" to another
        nb_required_nodes = len(required_nodes)
        required_distances = list()
        for node in required_nodes:
            paths = self.shortest_path_finder(node, nb_nodes, node_neighbors)
            required_distances.append(paths)
        # Since we can explore the edges in both directions, we will represent all possible
        # edges with an index
        self.edges_dist_data = None
        self.find_distance_between_edges(nb_required_nodes, required_nodes, required_distances)
        self.dist_to_depot_data = None
        self.find_distance_to_depot(nb_required_nodes, depot_node,
                                    required_nodes, required_distances)
        self.dist_from_depot_data = None
        self.find_distance_from_depot(nb_required_nodes, nb_nodes, depot_node,
                                      required_nodes, required_distances, node_neighbors)

    # Finds the shortest path from one node "origin" to all the other nodes of the graph
    # thanks to the Dijkstra's algorithm
    def min_distance(self, nb_nodes, shortest_path, sptSet):
        min = sys.maxsize
        for i in range(nb_nodes):
            if shortest_path[i] < min and sptSet[i] == False:
                min = shortest_path[i]
                min_index = i
        return min_index

    def shortest_path_finder(self, origin, nb_nodes, node_neighbors):
        shortest_path = [sys.maxsize] * nb_nodes
        shortest_path[origin - 1] = 0
        sptSet = [False] * nb_nodes
        for _ in range(nb_nodes):
            current_node = self.min_distance(nb_nodes, shortest_path, sptSet)
            sptSet[current_node] = True
            current_neighbors = node_neighbors[current_node]
            for neighbor in range(nb_nodes):
                if current_neighbors[neighbor] != 0:
                    distance = current_neighbors[neighbor]
                    if ((sptSet[neighbor] == False) and
                            (shortest_path[current_node] + distance < shortest_path[neighbor])):
                        shortest_path[neighbor] = distance + shortest_path[current_node]
        return shortest_path

    def find_distance_between_edges(self, nb_required_nodes, required_nodes, required_distances):
        self.edges_dist_data = [[None] * (2 * self.nb_required_edges)
                                for _ in range(2 * self.nb_required_edges)]
        for i in range(2 * self.nb_required_edges):
            for j in range(2 * self.nb_required_edges):
                if self.destinations_data[i] == self.origins_data[j]:
                    self.edges_dist_data[i][j] = 0
                else:
                    for k in range(nb_required_nodes):
                        if required_nodes[k] == self.destinations_data[i]:
                            self.edges_dist_data[i][j] = required_distances[k][
                                self.origins_data[j] - 1]

    def find_distance_to_depot(
            self, nb_required_nodes, depot_node, required_nodes, required_distances):
        self.dist_to_depot_data = [None] * (2 * self.nb_required_edges)
        for i in range(2 * self.nb_required_edges):
            if self.destinations_data[i] == depot_node:
                self.dist_to_depot_data[i] = 0
            else:
                for k in range(nb_required_nodes):
                    if required_nodes[k] == self.destinations_data[i]:
                        self.dist_to_depot_data[i] = required_distances[k][depot_node-1]

    def find_distance_from_depot(self, nb_required_nodes, nb_nodes, depot_node,
                                 required_nodes, required_distances, node_neighbors):
        self.dist_from_depot_data = [None] * (2 * self.nb_required_edges)
        for i in range(2 * self.nb_required_edges):
            if depot_node == self.origins_data[i]:
                self.dist_from_depot_data[i] = 0
            else:
                depot_is_required_node = False
                for k in range(nb_required_nodes):
                    if required_nodes[k] == depot_node:
                        depot_is_required_node = True
                        self.dist_from_depot_data[i] = required_distances[k][
                            self.origins_data[i] - 1]
                if not depot_is_required_node:
                    shortest_paths_from_depot = self.shortest_path_finder(
                        depot_node, nb_nodes, node_neighbors)
                    self.dist_from_depot_data[i] = shortest_paths_from_depot[self.origins_data[i] - 1]


if __name__ == '__main__':
    if len(sys.argv) < 2:
        print("Usage: python capacitated_arc_routing.py input_file [output_file] [time_limit]")
        sys.exit(1)

    instance_file = sys.argv[1]
    output_file = sys.argv[2] if len(sys.argv) > 2 else None
    str_time_limit = sys.argv[3] if len(sys.argv) > 3 else "20"

    main(instance_file, output_file, str_time_limit)
