#include "optimizer/hexalyoptimizer.h"
#include <fstream>
#include <iostream>
#include <limits>
#include <vector>

using namespace hexaly;
using namespace std;

class Kmeans {
public:
    // Data properties
    int nbObservations;
    int nbDimensions;
    int k;

    vector<vector<double>> coordinatesData;

    // Hexaly Optimizer
    HexalyOptimizer optimizer;

    // Decisions
    vector<HxExpression> clusters;

    // Objective
    HxExpression obj;

    Kmeans(int k) : k(k) {}

    // Read instance data
    void readInstance(const string& fileName) {
        ifstream infile;
        infile.exceptions(ifstream::failbit | ifstream::badbit);
        infile.open(fileName.c_str());

        infile >> nbObservations;
        infile >> nbDimensions;

        coordinatesData.resize(nbObservations);
        string tmp;
        for (int o = 0; o < nbObservations; ++o) {
            coordinatesData[o].resize(nbDimensions);
            for (int d = 0; d < nbDimensions; ++d) {
                infile >> coordinatesData[o][d];
            }
            infile >> tmp; // skip initial clusters
        }
    }

    void solve(int limit) {
        // Declare the optimization model
        HxModel model = optimizer.getModel();

        // Set decisions: clusters[c] represents the points in cluster c
        clusters.resize(k);
        for (int c = 0; c < k; ++c) {
            clusters[c] = model.setVar(nbObservations);
        }

        // Each point must be in one cluster and one cluster only
        model.constraint(model.partition(clusters.begin(), clusters.end()));

        // Coordinates of points
        HxExpression coordinates = model.array();
        for (int o = 0; o < nbObservations; ++o) {
            coordinates.addOperand(model.array(coordinatesData[o].begin(), coordinatesData[o].end()));
        }

        // Compute variances
        vector<HxExpression> variances;
        variances.resize(k);
        for (int c = 0; c < k; ++c) {
            HxExpression cluster = clusters[c];
            HxExpression size = model.count(cluster);

            // Compute the centroid of the cluster
            HxExpression centroid = model.array();
            for (int d = 0; d < nbDimensions; ++d) {
                HxExpression coordinateLambda =
                    model.createLambdaFunction([&](HxExpression o) { return model.at(coordinates, o, d); });
                centroid.addOperand(model.iif(size == 0, 0, model.sum(cluster, coordinateLambda) / size));
            }

            // Compute the variance of the cluster
            HxExpression variance = model.sum();
            for (int d = 0; d < nbDimensions; ++d) {
                HxExpression dimensionVarianceLambda = model.createLambdaFunction(
                    [&](HxExpression o) { return model.pow(model.at(coordinates, o, d) - model.at(centroid, d), 2); });
                HxExpression dimensionVariance = model.sum(cluster, dimensionVarianceLambda);
                variance.addOperand(dimensionVariance);
            }
            variances[c] = variance;
        }

        // Minimize the total variance
        obj = model.sum(variances.begin(), variances.end());
        model.minimize(obj);

        model.close();

        // Parametrize the optimizer
        optimizer.getParam().setTimeLimit(limit);

        optimizer.solve();
    }

    /* Write the solution in a file in the following format:
     *  - objective value
     *  - k
     *  - for each cluster, a line with the elements in the cluster (separated by spaces) */
    void writeSolution(const string& fileName) {
        ofstream outfile;
        outfile.exceptions(ofstream::failbit | ofstream::badbit);
        outfile.open(fileName.c_str());

        outfile << obj.getDoubleValue() << endl;
        outfile << k << endl;
        for (int c = 0; c < k; ++c) {
            HxCollection clusterCollection = clusters[c].getCollectionValue();
            for (int i = 0; i < clusterCollection.count(); ++i) {
                outfile << clusterCollection[i] << " ";
            }
            outfile << endl;
        }
    }
};

int main(int argc, char** argv) {
    if (argc < 2) {
        cerr << "Usage: kmeans inputFile [outputFile] [timeLimit] [k value]" << endl;
        return 1;
    }

    const char* instanceFile = argv[1];
    const char* solFile = argc > 2 ? argv[2] : NULL;
    const char* strTimeLimit = argc > 3 ? argv[3] : "5";
    const char* k = argc > 4 ? argv[4] : "2";

    try {
        Kmeans model(atoi(k));
        model.readInstance(instanceFile);
        model.solve(atoi(strTimeLimit));
        if (solFile != NULL)
            model.writeSolution(solFile);
        return 0;
    } catch (const exception& e) {
        cerr << "An error occurred: " << e.what() << endl;
        return 1;
    }
}
