import { UnexpectedError } from '@sparelabs/error-types'
import { RandomClass } from '@sparelabs/random'
import { minBy, times, uniqWith } from 'lodash'
import { Geography } from './Geography'
import { IPoint } from './GeographyTypes'
import { WeightedDistributionSampler } from './WeightedDistributionSampler'
import { WeightedKmeansUtils } from './WeightedKmeansUtils'

interface IShortestDistanceFromCentroid {
  point: IPoint
  shortestDistance: number
}

/**
 * This class follows the Kmeans++ algorithm in choosing k number of initial centroids for the Kmeans clustering algorithm.
 * The algorithm goes as follows:
 * 1. Pick an initial centroid randomly from your list of points
 * 2. For each point x, find the distance, D(x), from x to the closest centroid that has already been chosen
 * 3. From a weighted probability distribution based on D(x), Choose one point x to be a new centroid
 * 4. repeat steps 2-3 until you have k centroids
 */
export class ClusterCenterInitializer {
  public static getInitialCenters(points: IPoint[], numCentroids: number, seed?: string): IPoint[] {
    if (points.length <= 0 || numCentroids <= 0) {
      return []
    }

    const uniquePoints = uniqWith(points, (cur, next) => Geography.arePointsEqual(cur, next, 30))

    const random = new RandomClass(seed)
    const totalNumCentroids = Math.min(uniquePoints.length, numCentroids)

    const centroids: IPoint[] = [random.chooseValue(uniquePoints)]
    for (const _ of times(totalNumCentroids - 1)) {
      const shortestDistancesFromCentroids: IShortestDistanceFromCentroid[] = []

      for (const point of uniquePoints) {
        const closestCentroid = minBy(centroids, (centroid) =>
          WeightedKmeansUtils.euclideanDistanceSquared(point, centroid)
        )
        if (!closestCentroid) {
          throw new UnexpectedError('Unable to find closest centroid', { point, centroids })
        }

        shortestDistancesFromCentroids.push({
          point,
          shortestDistance: WeightedKmeansUtils.euclideanDistanceSquared(point, closestCentroid),
        })
      }

      const newCentroidSampler = new WeightedDistributionSampler(
        shortestDistancesFromCentroids
          .map(({ point, shortestDistance }) => ({
            weight: shortestDistance,
            point,
          }))
          .sort((pointA, pointB) => pointA.weight - pointB.weight),
        seed
      )

      centroids.push(newCentroidSampler.sample().point)
    }

    return centroids
  }
}
