# Cluster2.py

from gpanel import *
import math

# 4 clusters with id: 0, 1, 2, 3
# Samples: [x, y, id]

datafile = "samples.dat"
nbClusters = 4

def euklidian(pt1, pt2):
    return math.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0]) + \
                     (pt1[1] - pt2[1]) * (pt1[1] - pt2[1]))

def loadData(fileName):
    try:
        fData = open(fileName, 'r')
    except:
        return []
    out = []
    for line in fData:
        line = line[:-1]  # remove \n
        if len(line) == 0:  # empty line
            continue
        li = [i for i in line.split(",")]
        out.append(li)
    fData.close()
    return out

def onMousePressed(x, y):
    global nbClicks
    nbClicks += 1
    if nbClicks <= nbClusters:
        centroid = [x, y]
        pos(centroid)
        fillCircle(2)
        centroids.append(centroid)
        if nbClicks == nbClusters:
            title("Left click to start iteration.")
    if nbClicks > nbClusters:
        drawCentroids()
        iterate()
        drawCentroids()
        title("# Iterations: " + str(nbClicks - 3) + ". - Click for next.")

def drawCentroids():
    for centroid in centroids:
        pos(centroid)
        fillCircle(2)

def iterate():
    # determine affiliation as shortest distance to centroid
    for pt in X:
        distances = [] # distances to centroid 0, 1, 2, 3
        for k in range(nbClusters):
            distance = euklidian(pt, centroids[k])
            distances.append(distance)
        min_value = min(distances)
        min_index = distances.index(min_value)
        pt[2] = min_index  # set the cluster id

    # determine new centroids
    xSums = [0] * nbClusters
    ySums = [0] * nbClusters
    nSums = [0] * nbClusters
    for pt in X: # each sample
        for k in range(nbClusters):  # each old centroid index
            if pt[2] == k:  # sample belongs to centroid k
                xSums[k] += pt[0]  # add x coordinate
                ySums[k] += pt[1]  # add y coordinate
                nSums[k] += 1
    for k in range(nbClusters):
        centroids[k][0] = xSums[k] / nSums[k] # mean of x
        centroids[k][1] = ySums[k] / nSums[k] # mean of y

makeGPanel(-10, 110, -10, 110, mousePressed = onMousePressed)
drawGrid(0, 100, 0, 100, "gray")
text(98, -8, "Age",)
text(-8, 105, "Income (in 1000)",)
data = loadData(datafile)
X = []
centroids = []
nbClicks = 0
for sample in data:
    # pt = [x-coord, y-coord, cluster_id]
    # initialize cluster_id to -1
    pt = [float(sample[0]), float(sample[1]), -1]
    X.append(pt)
    pos(pt)
    fillCircle(0.5)
setColor("green")
setXORMode("white")
title("Click to set 3 centroids")
keep()
