import hexaly.optimizer
import sys

def read_instance(filename):
    # The import files follow the "Taillard" format
    with open(filename, 'r') as f:
        lines = f.readlines()

    first_line = lines[0].split()
    # Number of tasks
    nb_tasks = int(first_line[0])
    # Number of resources
    nb_resources = int(first_line[1])

    second_line = lines[1].split()
    # Capacity of each resource
    capacity = [int(l) for l in second_line]

    # Remaining lines contain task information at each row
    nb_tasks_per_resource = [0 for _ in range(nb_resources)]
    types_in_resource = [[] for _ in range(nb_resources)]
    tasks_in_resource = [[] for _ in range(nb_resources)]
    task_index_in_resource = []
    types, resources, duration, nb_successors = [], [], [], []
    successors = [[] for _ in range(nb_tasks)]

    for i in range(nb_tasks):

        # Extract dataset line related to this task
        task_line = i + 2
        task_information = lines[task_line].split()
        
        # Type of task i
        types.append(int(task_information[0]))
        # Resource required for task i
        resources.append(int(task_information[1]))
        
        # Index of task i on resource[i]
        task_index_in_resource.append(nb_tasks_per_resource[resources[i]])
        # Map from name of task i on resource[i] to task i type
        types_in_resource[resources[i]].append(types[i])
        # Map from name of task i on resource[i] to task i
        tasks_in_resource[resources[i]].append(i)
        # Increment number of tasks required by this resource
        nb_tasks_per_resource[resources[i]] += 1
        
        # Task duration
        duration.append(int(task_information[2]))

        # Number of successors of this task
        nb_successors.append(int(task_information[3]))
        # Tasks that must succeed current task
        for succeeding_task in task_information[4:]:
            successors[i].append(int(succeeding_task))

    # Trivial time horizon
    time_horizon = sum(duration[t] for t in range(nb_tasks))

    return nb_tasks, nb_resources, capacity, types, resources, duration, \
            nb_successors, successors, nb_tasks_per_resource, \
              task_index_in_resource, types_in_resource, tasks_in_resource, \
                nb_tasks_per_resource, time_horizon


def main(instance_file, output_file, time_limit):

    nb_tasks, nb_resources, capacity, types, resources, duration, \
            nb_successors, successors, nb_tasks_per_resource, \
              task_index_in_resource, types_in_resource, tasks_in_resource, \
                nb_tasks_per_resource, time_horizon \
                 = read_instance(instance_file)
    
    with hexaly.optimizer.HexalyOptimizer() as optimizer:
        #
        # Declare the optimization model
        #
        model = optimizer.model

        # For each resource, the contents of each batch of tasks performed
        batch_content = [[model.set(nb_tasks_per_resource[r]) 
                          for b in range(nb_tasks_per_resource[r])]
                            for r in range(nb_resources)]
        
        # Create HexalyOptimizer arrays in order to be able to access them with "at" operators
        batch_content_arrays = [model.array(batch_content[r]) for r in range(nb_resources)]

        # All tasks are assigned to a batch
        for r in range(nb_resources):
            model.constraint(model.partition(batch_content_arrays[r]))

        # Each batch must consist of tasks with the same type
        types_in_resource_array = model.array(types_in_resource)
        for r in range(nb_resources):
            resource_type_lambda = model.lambda_function(lambda i: types_in_resource_array[r][i])
            for batch in batch_content[r]:
                model.constraint(model.count( model.distinct( batch, resource_type_lambda ) ) <= 1)

        # Each batch cannot exceed the maximum capacity of the resource
        for r in range(nb_resources):
            for batch in batch_content[r]:
                model.constraint(model.count(batch) <= capacity[r])

        # Interval decisions: time range of each batch of tasks
        batch_interval = [[model.interval(0, time_horizon) 
                           for _ in range(nb_tasks_per_resource[r])]
                            for r in range(nb_resources)]
        batch_interval_arrays = [model.array(batch_interval[r]) for r in range(nb_resources)]
    
        # Non-overlap of batch intervals on the same resource
        for r in range(nb_resources):
            for b in range(1, nb_tasks_per_resource[r]):
                model.constraint(batch_interval[r][b-1] < batch_interval[r][b])
        
        # Interval decisions: time range of each task
        task_interval = [None for _ in range(nb_tasks)]
        for t in range(nb_tasks):
            # Retrieve the batch index and resource for this task
            r = resources[t]
            b = model.find( batch_content_arrays[r], task_index_in_resource[t] )
            # Task interval interval associated with task t
            task_interval[t] = batch_interval_arrays[r][b]

        # Task durations
        for t in range(nb_tasks):
            model.constraint(model.length(task_interval[t]) == duration[t])

        # Precedence constraints between tasks
        for t in range(nb_tasks):
            for s in successors[t]:
                model.constraint( task_interval[t] < task_interval[s])

        # Makespan: end of the last task
        makespan = model.max([model.end(model.at(batch_interval_arrays[r], i))
                                for i in range(nb_tasks_per_resource[r]) 
                                    for r in range(nb_resources)])
        model.minimize(makespan)
        model.close()

        # Parametrize the optimizer
        optimizer.param.time_limit = time_limit
        optimizer.solve()

        #
        # Write the solution in a file with the following format:
        #  - makespan
        #  - machine number
        #  - preceeding lines are the ordered intervals of the tasks 
        #    for the corresponding machine number
        if output_file is not None:
            with open(output_file, 'w') as f:
                f.write(str(makespan.value) + "\n")
                for r in range(nb_resources):
                    f.write(str(r) + "\n")
                    for b in range(nb_tasks_per_resource[r]):
                        t = tasks_in_resource[r][b]
                        line = str(t) + " " + str(task_interval[t].value.start()) + " " + str(task_interval[t].value.end())
                        f.write(line + "\n")
            print("Solution written in file ", output_file)


if __name__=="__main__":
    if len(sys.argv) < 2:
        print(
            "Usage: python batch_scheduling.py instance_file [output_file] [time_limit]")
        sys.exit(1)
    
    instance_file = sys.argv[1]
    output_file = sys.argv[2] if len(sys.argv) >= 3 else None
    time_limit = int(sys.argv[3]) if len(sys.argv) >= 4 else 60

    main(instance_file, output_file, time_limit)