# Cluster1.py

from gpanel import *
import random
import math

# 3 clusters with id: 0, 1, 2 
# Samples: [x, y, id]

N = 200 # number per cluster
spread = 5  # cluster spread

def euklidian(pt1, pt2):
    return math.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0]) + \
                     (pt1[1] - pt2[1]) * (pt1[1] - pt2[1]))

def onMousePressed(x, y):
    global nbClicks
    nbClicks += 1
    if nbClicks <= 3:
        centroid = [x, y]
        pos(centroid)
        fillCircle(4)
        centroids.append(centroid)
        if nbClicks == 3:
            title("Click to start iteration.") 
    if nbClicks > 3:
        drawCentroids()
        iterate()
        drawCentroids()
        title("# Iterations: " + str(nbClicks - 3) + ". - Click for next") 

def drawCentroids():
    for centroid in centroids:    
        pos(centroid)
        fillCircle(4)
                    
def iterate():
    # determine affiliation as shortest distance to centroid
    for pt in X:
        distances = [] # distances to centroid 0, 1, 2
        for k in range(3):
            distance = euklidian(pt, centroids[k])
            distances.append(distance)
        min_value = min(distances)    
        min_index = distances.index(min_value)
        pt[2] = min_index  # set the cluster id
    
    # determine new centroids
    xSums = [0, 0, 0]
    ySums = [0, 0, 0]
    nSums = [0, 0, 0]
    for pt in X: # each sample
        for k in range(3):  # each old centroid index
            if pt[2] == k:  # sample belongs to centroid k
                xSums[k] += pt[0]  # add x coordinate
                ySums[k] += pt[1]  # add y coordinate 
                nSums[k] += 1

    for k in range(3):        
        centroids[k][0] = xSums[k] / nSums[k] # mean of x      
        centroids[k][1] = ySums[k] / nSums[k] # mean of y      
        
makeGPanel(-10, 110, -10, 110, mousePressed = onMousePressed)
drawGrid(0, 100, 0, 100, "gray")
X = [0] * 3 * N
centroids = []
nbClicks = 0

z = [[30, 40], [50, 70], [70, 50]]
setColor("red")
for pt in z:
    pos(pt)
    fillCircle(1)
setColor("black")
    
for k in range(3):
    for i in range(N):
        rx = random.gauss(0, spread)
        ry = random.gauss(0, spread)
        x = rx + z[k][0]
        y = ry + z[k][1]
        pos(x, y)
        fillCircle(0.5)
        X[N * k + i] = [x, y, -1]
setColor("green")
setXORMode("white")
title("Click To Set 3 Centroids")
keep()
