import hexaly.optimizer
import sys


def main(instance_file, output_file, time_limit):
    nb_tasks, nb_renewable_resources, nb_resources, nb_modes, capacity, duration, weight, \
        nb_successors, successors, horizon = read_input(instance_file)

    with hexaly.optimizer.HexalyOptimizer() as optimizer:
        # Declare the optimization model
        model = optimizer.model

        # Optional interval decisions: time range for each task and mode
        tasks_in_mode = [[model.optional_interval(0, horizon) for m in range(nb_modes[i])] for i in range(nb_tasks)]
        present_mode = [[model.presence(tasks_in_mode[i][m]) for m in range(nb_modes[i])] for i in range(nb_tasks)]

        # Hull
        tasks = [model.hull(tasks_in_mode[task][mode] for mode in range(nb_modes[task])) for task in range(nb_tasks)]

        # Constraints: Task duration
        for task in range(nb_tasks):
            for mode in range(nb_modes[task]):
                model.constraint(model.iif(
                        present_mode[task][mode],
                        model.eq(model.length(tasks_in_mode[task][mode]), duration[task][mode]),
                        1
                ))

        # Constraints: Precedence between tasks
        for task in range(nb_tasks):
            for s in range(nb_successors[task]):
                model.constraint(model.lt(tasks[task], tasks[successors[task][s]]))

        # Constraints: Exactly one active mode for each task
        for task in range(nb_tasks):
            model.constraint(model.eq(sum(present_mode[task]), 1))

        # Makespan: end of the last task
        makespan = model.max([model.end(tasks[task]) for task in range(nb_tasks)])

        # Constraints: Renewable resources
        def capacity_respected(resource, time):
            total_weight = model.sum()
            for task in range(nb_tasks):
                for mode in range(nb_modes[task]):
                    total_weight.add_operand(
                        weight[task][resource][mode]
                        * model.contains(tasks_in_mode[task][mode], time)
                    )
            return model.leq(total_weight, capacity[resource])

        for resource in range(nb_renewable_resources):
            capacity_respected_lambda = model.lambda_function(lambda time: capacity_respected(resource, time))
            model.constraint(model.and_(model.range(makespan), capacity_respected_lambda))

        # Constraints: Non-renewable resources
        for resource in range(nb_renewable_resources, nb_resources):
            total_weight = model.sum()
            for task in range(nb_tasks):
                for mode in range(nb_modes[task]):
                    total_weight.add_operand(weight[task][resource][mode] * present_mode[task][mode])
            
            model.constraint(model.leq(total_weight, capacity[resource]))

        # Objective: Minimize the makespan
        model.minimize(makespan)

        model.close()

        # Parameterize the optimizer
        optimizer.param.time_limit = time_limit

        optimizer.solve()

        #
        # Write the solution in a file with the following format:
        # - total makespan
        # - for each task, the task ID, the mode ID, the start and end times
        #
        if output_file != None:
            with open(output_file, "w") as file:
                print("Solution written in file", output_file)
                file.write(str(makespan.value) + "\n")
                for task in range(nb_tasks):
                    activeModeId = -1
                    for mode in range(nb_modes[task]):
                        if (present_mode[task][mode].value):
                            activeModeId = mode
                            break
                    file.write(str(task + 1) + " " + str(activeModeId + 1) + " "
                            + str(tasks[task].value.start()) + " "
                            + str(tasks[task].value.end()))
                    file.write("\n")

def read_input(filename):
    with open(filename) as file:
        lines = file.readlines()

    ## Parse number of tasks
    line = lines[5].split(":")
    nb_tasks = int(line[1])

    ## Parse number of resources
    line = lines[8].split(":")
    nb_renewable_resources = int(line[1].split()[0])

    line = lines[9].split(":")
    nb_non_renewable_resources = int(line[1].split()[0])
    nb_resources = nb_renewable_resources + nb_non_renewable_resources

    # Number of available modes for each task
    nb_modes = [0 for i in range(nb_tasks)]
    
    # Number of successors of each task
    nb_successors = [0 for i in range(nb_tasks)]

    # Successors of each task
    successors = [[] for i in range(nb_tasks)]

    ## Parse successors of each task
    line_index = 18
    line = lines[line_index].split()
    task_id = int(line[0]) - 1

    while True:
        nb_modes[task_id] = int(line[1])
        nb_successors[task_id] = int(line[2])
        successors[task_id] = [int(line[3 + s]) - 1 for s in range(nb_successors[task_id])]
        if task_id + 1 == nb_tasks:
            break
        line_index = line_index + 1
        line = lines[line_index].split()
        task_id = int(line[0]) - 1
    
    ## Parse tasks durations per mode AND consumed resource weight per mode for each task
    line_index = line_index + 5
    
    # Duration of each task and mode
    duration = [[] for i in range(nb_tasks)]

    # Required resource weight for each task and mode
    weight = [[[]] for i in range(nb_tasks)]

    for task in range(nb_tasks):
        weight[task] = [[0 for m in range(nb_modes[task])] for r in range(nb_resources)]
        duration[task] = [0 for m in range(nb_modes[task])]
        for mode in range(nb_modes[task]):
            line = lines[line_index].split()
            base_index = 1 if mode == 0 else 0
            duration[task][mode] = int(line[base_index + 1])
            for resource in range(nb_resources):
                weight[task][resource][mode] = int(line[base_index + 2 + resource])
            line_index = line_index + 1

    # Maximum capacity of each resource
    capacity = [int(lines[line_index + 3].split()[r]) for r in range(nb_resources)]

    # Trivial upper bound for the start times of the tasks
    horizon = sum(max(duration[task][mode] for mode in range(nb_modes[task])) for task in range(nb_tasks))

    return (nb_tasks, nb_renewable_resources, nb_resources, nb_modes, capacity, duration, weight, nb_successors, successors, horizon)

if __name__ == '__main__':
    if len(sys.argv) < 2:
        print("Usage: python rcpsp_multi_mode.py instance_file [output_file] [time_limit]")
        sys.exit(1)

    instance_file = sys.argv[1]
    output_file = sys.argv[2] if len(sys.argv) >= 3 else None
    time_limit = int(sys.argv[3]) if len(sys.argv) >= 4 else 60
    main(instance_file, output_file, time_limit)
