import hexaly.optimizer
import sys

def read_elem(filename):
    with open(filename) as f:
        return [str(elem) for elem in f.read().split()]

#
# Read instance data
#
def read_instance(filename):
    file_it = iter(read_elem(filename))

    # Data properties
    nb_observations = int(next(file_it))
    nb_dimensions = int(next(file_it))

    coordinates_data = [None] * nb_observations
    for o in range(nb_observations):
        coordinates_data[o] = [None] * (nb_dimensions)
        for d in range(nb_dimensions):
            coordinates_data[o][d] = float(next(file_it))
        next(file_it) # skip initial clusters

    return nb_observations, nb_dimensions, coordinates_data

def main(instance_file, output_file, time_limit, k):
    nb_observations, nb_dimensions, coordinates_data = read_instance(instance_file)

    with hexaly.optimizer.HexalyOptimizer() as optimizer:
        #
        # Declare the optimization model
        #
        model = optimizer.model

        # clusters[c] represents the points in cluster c
        clusters = [model.set(nb_observations) for c in range(k)]

        # Each point must be in one cluster and one cluster only
        model.constraint(model.partition(clusters))

        # Coordinates of points
        coordinates = model.array(coordinates_data)

        # Compute variances
        variances = []
        for cluster in clusters:
            size = model.count(cluster)

            # Compute centroid of cluster
            centroid = [0 for d in range(nb_dimensions)]
            for d in range(nb_dimensions):
                coordinate_lambda = model.lambda_function(
                    lambda i: model.at(coordinates, i, d))
                centroid[d] = model.iif(
                    size == 0,
                    0,
                    model.sum(cluster, coordinate_lambda) / size)

            # Compute variance of cluster
            variance = model.sum()
            for d in range(nb_dimensions):
                dimension_variance_lambda = model.lambda_function(lambda i:
                    model.pow(model.at(coordinates, i, d) - centroid[d], 2))
                dimension_variance = model.sum(cluster, dimension_variance_lambda)
                variance.add_operand(dimension_variance)
            variances.append(variance)

        # Minimize the total variance
        obj = model.sum(variances)
        model.minimize(obj)

        model.close()

        # Parameterize the optimizer
        optimizer.param.time_limit = time_limit

        optimizer.solve()

        #
        # Write the solution in a file in the following format:
        #  - objective value
        #  - k
        #  - for each cluster, a line with the elements in the cluster
        #    (separated by spaces)
        #
        if output_file != None:
            with open(output_file, 'w') as f:
                f.write("%f\n" % obj.value)
                f.write("%d\n" % k)
                for c in range(k):
                    for o in clusters[c].value:
                        f.write("%d " % o)
                    f.write("\n")

if __name__ == '__main__':
    if len(sys.argv) < 2:
        print("Usage: python kmeans.py inputFile [outputFile] [timeLimit] [k value]")
        sys.exit(1)

    instance_file = sys.argv[1]
    output_file = sys.argv[2] if len(sys.argv) >= 3 else None
    time_limit = int(sys.argv[3]) if len(sys.argv) >= 4 else 60
    k = int(sys.argv[4]) if len(sys.argv) >= 5 else 2
    main(instance_file, output_file, time_limit, k)