import { indexOf, minBy, sumBy } from 'lodash'
import { ClusterCenterInitializer } from './ClusterCenterInitializer'
import { Geography } from './Geography'
import { IPoint } from './GeographyTypes'
import { buildPoint } from './GeometryBuilder'
import { WeightedKmeansUtils } from './WeightedKmeansUtils'

interface IClusteredPoints {
  centroid: IPoint
  points: IPoint[]
}

// todo remove
interface IHeatmapPoint {
  point: IPoint
  value: number // not normalized in general
}

export class WeightedKmeans {
  private static readonly maxIterations = 10000

  public static cluster(
    data: IHeatmapPoint[],
    numCentroids: number,
    maxIterations?: number,
    seed?: string
  ): IClusteredPoints[] {
    if (numCentroids <= 0) {
      return []
    }
    let centroids = ClusterCenterInitializer.getInitialCenters(
      data.map(({ point }) => point),
      numCentroids,
      seed
    )
    let converging = false

    const centroidToPointsMap: Record<number, IHeatmapPoint[]> = {}
    for (
      let remainingIterations = maxIterations ?? this.maxIterations;
      !converging && remainingIterations > 0;
      remainingIterations--
    ) {
      converging = true

      for (const [index] of centroids.entries()) {
        centroidToPointsMap[index] = []
      }

      for (const heatmapPoint of data) {
        const closestCentroidIndex = this.getClosestCentroidIndex(heatmapPoint.point, centroids)
        centroidToPointsMap[closestCentroidIndex].push(heatmapPoint)
      }

      centroids = centroids.map((centroid, index) => {
        const newCentroid = this.calculateWeightedCentroid(centroidToPointsMap[index])
        if (!Geography.arePointsEqual(centroid, newCentroid, 30)) {
          converging = false
        }
        return newCentroid
      })
    }

    return centroids.map((centroid, index) => ({
      centroid,
      points: centroidToPointsMap[index].map(({ point }) => point),
    }))
  }

  private static getClosestCentroidIndex(point: IPoint, centroids: IPoint[]): number {
    const closestCentroid = minBy(centroids, (centroid) =>
      WeightedKmeansUtils.euclideanDistanceSquared(point, centroid)
    )
    const closestCentroidIndex = indexOf(centroids, closestCentroid)

    return closestCentroidIndex
  }

  private static calculateWeightedCentroid(points: IHeatmapPoint[]): IPoint {
    const totalWeight = sumBy(points, ({ value }) => value)
    const weightedSumOfLongitudes = sumBy(points, ({ point, value }) => point.coordinates[0] * value)
    const weightedSumOfLatitudes = sumBy(points, ({ point, value }) => point.coordinates[1] * value)

    return buildPoint(weightedSumOfLongitudes / totalWeight, weightedSumOfLatitudes / totalWeight)
  }
}
