#include "optimizer/hexalyoptimizer.h"
#include <algorithm>
#include <fstream>
#include <iostream>
#include <limits>
#include <numeric>
#include <vector>

using namespace hexaly;
using namespace std;

class BatchScheduling {
private:
    // Hexaly Optimizer
    HexalyOptimizer optimizer;
    // Decision variables: collection of sets (representing batches)
    vector<vector<HxExpression>> batchContent;
    // Decision variables: intervals of batches on a resource
    vector<vector<HxExpression>> batchInterval;
    // Decision variables: interval of a task
    vector<HxExpression> taskInterval;
    // Objective = minimize the makespan: end of the last task of the last job
    HxExpression makespan;

public:
    BatchScheduling() {}

    // Number of tasks
    int nbTasks;
    // Number of resources
    int nbResources;
    // Capacity of each resource
    vector<int> capacity;

    // Number of tasks assigned to each resource
    vector<int> nbTasksPerResource;
    // Types of tasks assigned to each resource
    vector<vector<int>> typesInResources;
    // Tasks assigned to each resource
    vector<vector<int>> tasksInResource;
    // Index of task i on resource[i]
    vector<int> taskIndexInResource;

    // Type of task i
    vector<int> type;
    // Resource required for task i
    vector<int> resource;
    // Duration of task i
    vector<int> duration;
    // Number of tasks that must succeed task i
    vector<int> nbSuccessors;
    // Task ids that must succeed task i
    vector<vector<int>> successors;

    // Longest possible time horizon
    int timeHorizon{0};

    // Read input file
    void readInstance(const string& fileName) {
        ifstream infile;
        infile.exceptions(ifstream::failbit | ifstream::badbit);
        infile.open(fileName.c_str());

        // first line has number of tasks, number of resources
        infile >> nbTasks;
        infile >> nbResources;

        // second line has the capacity of each resource
        capacity.resize(nbResources);
        for (int i = 0; i < nbResources; i++) {
            infile >> capacity[i];
        }
        // initalize
        nbTasksPerResource.resize(nbResources);
        for (int j = 0; j < nbResources; j++) {
            typesInResources.push_back(vector<int>());
            tasksInResource.push_back(vector<int>());
            nbTasksPerResource[j] = 0;
        }
        type.resize(nbTasks);
        resource.resize(nbTasks);
        duration.resize(nbTasks);
        nbSuccessors.resize(nbTasks);
        taskIndexInResource.resize(nbTasks);

        // Task information: [type, machine, duration, nbSuccessors, [successors]]
        for (int i = 0; i < nbTasks; i++) {
            infile >> type[i];
            infile >> resource[i];
            infile >> duration[i];
            infile >> nbSuccessors[i];
            // collect which tasks that must succeed task i
            successors.push_back(vector<int>());
            for (int j = 0; j < nbSuccessors[i]; j++) {
                successors[i].push_back(j);
                infile >> successors[i][j];
            }

            // Index of task i on resource[i]
            taskIndexInResource[i] = nbTasksPerResource[resource[i]];
            // Map from name of task i on resource[i] to task i type
            typesInResources[resource[i]].push_back(type[i]);
            // Map from name of task i on resource[i] to task i
            tasksInResource[resource[i]].push_back(i);
            // Incremenet number of tasks required by this resource
            nbTasksPerResource[resource[i]] += 1;

            // Add task time to the overall trivial time horizon
            timeHorizon += duration[i];
        }
        
        infile.close();
    }

    void solve(int timeLimit) {
        // Declare the optimization model
        HxModel model = optimizer.getModel();

        // For each resource, the contents of each batch of tasks performed
        batchContent.resize(nbResources);
        for (unsigned int r = 0; r < nbResources; r++) {
            batchContent[r].resize(nbTasksPerResource[r]);
            for (unsigned int b = 0; b < nbTasksPerResource[r]; b++) {
                batchContent[r][b] = model.setVar(nbTasksPerResource[r]);
            }
        }

        // Create HexalyOptimizer arrays in order to be able to access them with "at" operators
        vector<HxExpression> batchContentArray = vector<HxExpression>(nbResources);
        for (unsigned int r = 0; r < nbResources; r++) {
            batchContentArray[r] = model.array();
            for (unsigned int b = 0; b < nbTasksPerResource[r]; b++) {
                batchContentArray[r].addOperand(batchContent[r][b]);
            }
        }

        // All tasks are assigned to a batch
        for (unsigned int r = 0; r < nbResources; r++) {
            model.constraint(model.partition(batchContentArray[r]));
        }

        // Create HexalyOptimizer arrays in order to be able to access them with "at" operators
        vector<HxExpression> typesInResourcesArray = vector<HxExpression>(nbResources);
        for (unsigned int r = 0; r < nbResources; r++) {
            typesInResourcesArray[r] = model.array();
            for (unsigned int i = 0; i < nbTasksPerResource[r]; i++) {
                typesInResourcesArray[r].addOperand(typesInResources[r][i]);
            }
        }

        // Each batch must consist of tasks with the same type
        for (unsigned int r = 0; r < nbResources; r++) {
            HxExpression resourceTypeLambda = model.createLambdaFunction([&](HxExpression i) {
                return typesInResourcesArray[r][i];
            });
            for (unsigned int b = 0; b < nbTasksPerResource[r]; b++) {
                HxExpression batch = batchContent[r][b];
                model.constraint(model.count(model.distinct(batch, resourceTypeLambda)) <= 1);
            }
        }

        // Each batch cannot exceed the maximum capacity of the resource
        for (unsigned int r = 0; r < nbResources; r++) {
            for (unsigned int b = 0; b < nbTasksPerResource[r]; b++) {
                model.constraint(model.count(batchContent[r][b]) <= capacity[r]);
            }           
        }

        // Interval decisions: time range of each batch of tasks
        batchInterval.resize(nbResources);
        for (unsigned int r = 0; r < nbResources; r++) {
            batchInterval[r].resize(nbTasksPerResource[r]);
            for (unsigned int b = 0; b < nbTasksPerResource[r]; b++) {
                batchInterval[r][b] = model.intervalVar(0, timeHorizon);
            }
        }

        // Create HexalyOptimizer arrays in order to be able to access them with "at" operators
        vector<HxExpression> batchIntervalArray = vector<HxExpression>(nbResources);
        for (unsigned int r = 0; r < nbResources; r++) {
            batchIntervalArray[r] = model.array();
            for (unsigned int b = 0; b < nbTasksPerResource[r]; b++) {
                batchIntervalArray[r].addOperand(batchInterval[r][b]);
            }
        }

        // Non-overlap of batch intervals on the same resource
        for (unsigned int r = 0; r < nbResources; r++) {
            for (unsigned int b = 1; b < nbTasksPerResource[r]; b++) {
                model.constraint(batchInterval[r][b-1] < batchInterval[r][b]);
            }
        }

        // Interval decisions: time range of each task
        taskInterval.resize(nbTasks);
        for (unsigned int t = 0; t < nbTasks; t++) {
            // Retrieve the batch index and resource for this task
            int r = resource[t];
            HxExpression b = model.find(batchContentArray[r], taskIndexInResource[t]);
            // Task interval associated with task t
            taskInterval[t] = batchIntervalArray[r][b];
        }

        // Task durations
        for (unsigned int t = 0; t < nbTasks; t++) {
            model.constraint(model.length(taskInterval[t]) == duration[t]);
        }

        // Precedence constraints between tasks
        for (unsigned int t = 0; t < nbTasks; t++) {
            for (unsigned int s = 0; s < nbSuccessors[t]; s++) {
                model.constraint(taskInterval[t] < taskInterval[successors[t][s]]);
            }
        }

        // Makespan: end of the last task
        makespan = model.max();
        for (unsigned int t = 0; t < nbTasks; t++) {
            makespan.addOperand(model.end(taskInterval[t]));
        }
        model.minimize(makespan);
        model.close();

        // Parameterize the optimizer
        optimizer.getParam().setTimeLimit(timeLimit);
        optimizer.solve();
    }

    /*
     Write the solution in a file with the following format:
      - makespan
      - machine number
      - preceeding lines are the ordered intervals of the tasks 
        for the corresponding machine number
    */
    void writeSolution(const string& fileName) {
        
        ofstream outfile;
        outfile.exceptions(ofstream::failbit | ofstream::badbit);
        outfile.open(fileName.c_str());
        
        outfile << makespan.getValue() << endl;

        for (unsigned int r = 0; r < nbResources; r++) {
            outfile << r << endl;
            for (unsigned int b = 0; b < nbTasksPerResource[r]; b++) {
                int t = tasksInResource[r][b];
                int start = taskInterval[t].getIntervalValue().getStart();
                int end = taskInterval[t].getIntervalValue().getEnd();
                outfile << t << " " << start << " " << end << endl;
            }
        }
        cout << "Solution written in file " << fileName << endl;
        outfile.close();
    }
    
};

int main(int argc, char** argv) {
    if (argc < 2) {
        cout << "Usage: batch_scheduling instanceFile [outputFile] [timeLimit]" << endl;
        exit(1);
    }

    const char* instanceFile = argv[1];
    const char* outputFile = argc > 2 ? argv[2] : NULL;
    const char* strTimeLimit = argc > 3 ? argv[3] : "60";

    BatchScheduling model;
    try {
        model.readInstance(instanceFile);
        const int timeLimit = atoi(strTimeLimit);
        model.solve(timeLimit);
        if (outputFile != NULL)
            model.writeSolution(outputFile);
        return 0;
    } catch (const exception& e) {
        cerr << "An error occurred: " << e.what() << endl;
        return 1;
    }
}