# -*- coding: utf-8 -*-
# Created on Sat Jan 10 08:15:00 2026
# @author: mhebding
# TP5 - IA1



import numpy as np
import matplotlib.pyplot as plt
import random

## 1 - k-moyennes sur un exemple de donnees bidimensionnelles

## Chargement des données ##
with open("PT_TP5b_data.txt") as fichier:
    listeLignes = fichier.readlines()

X = np.empty([600, 2], dtype=float)

for k in range(600):
    X[k][0] = float(listeLignes[k].split(";")[0])
    X[k][1] = float(listeLignes[k].split(";")[1])
    
## Affichage des données ##
plt.plot(X[:, 0], X[:, 1], 'ko')
plt.show()

# 1.
# k = 3 semble pertinent.

# 2.
def distance(d1, d2):
    return(((d2[0]-d1[0])**2 + (d2[1]-d1[1])**2)**(1/2))

# 3.
def initialisation(k):
    return(np.array([random.choice(X) for _ in range(k)]))

# 4.
def plusProcheCentre(d, tabCentres):
    indMin = 0
    for i in range(len(tabCentres)):
        if distance(d, tabCentres[i]) < distance(d, tabCentres[indMin]):
            indMin = i
    return(indMin)

# 5.
def kMoyennes(k):
    tabCentres = initialisation(k)
    Y = np.empty(600, dtype=int)

    while True:
        tabNouveauxCentres = np.zeros([k, 2], dtype = float)
        nbElementsClasses = [0] * k
        
        for i in range(600):
            Y[i] = plusProcheCentre(X[i], tabCentres)
            tabNouveauxCentres[Y[i]] += X[i]
            nbElementsClasses[Y[i]] += 1

        for l in range(k):
            tabNouveauxCentres[l] /= nbElementsClasses[l]
        
        if (tabNouveauxCentres == tabCentres).all():
            return(Y)
        else:
            tabCentres = tabNouveauxCentres

# 6.
def afficher(k):
    Y = kMoyennes(k)
    plt.scatter(X[:, 0], X[:, 1], marker="o", c=Y, s=25, edgecolor="k")
    plt.show()

# afficher(3)



## 2 - Reduction du nombre de couleurs dans une image

from PIL import Image

# Chargement d'une image d'un format classique vers un tableau numpy
imgpil = Image.open("PT_TP5b_parrot.jpg")
img = np.array(imgpil)

# 7.
def nbCouleurs(img):
    dico = {}
    H, L = np.shape(img)[0:2]
    
    for i in range(H):
        for j in range(L):
            if tuple(img[i,j]) not in dico:
                dico[tuple(img[i,j])] = None

    return(len(dico))

# 8.
def dist(pix1, pix2):
    return(((pix2[0]-pix1[0])**2 + (pix2[1]-pix1[1])**2 + (pix2[2]-pix1[2])**2)**(1/2))

def kMoyennes2(img, k, N):
    H, L = np.shape(img)[0:2]
    tabCentres = np.array([random.choice(random.choice(img)) for _ in range(k)], dtype=int)
    Y = np.empty([H, L], dtype=int)

    for _ in range(N):
        tabNouveauxCentres = np.zeros([k, 3], dtype=float)
        nbElementsClasses = [0] * k
        
        for i in range(H):
            for j in range(L):
                pix = np.asarray(img[i,j], dtype=int)
                Y[i,j] = plusProcheCentre(pix, tabCentres)
                tabNouveauxCentres[Y[i,j]] += pix
                nbElementsClasses[Y[i,j]] += 1

        for l in range(k):
            tabNouveauxCentres[l] /= nbElementsClasses[l]
        
        #imgRes = np.copy(img)
        #for i in range(H):
        #    for j in range(L):
        #        imgRes[i,j] = tabCentres[Y[i,j]]
        #plt.imshow(imgRes)
        #plt.show()

    return(tabNouveauxCentres)
            
# 9.
def reduire(img, k, N):
    H, L = np.shape(img)[0:2]
    tabCentres = kMoyennes2(img, k, N)

    imgRes = np.empty([H, L, 3], dtype=np.uint8)
    
    for i in range(H):
        for j in range(L):
            imgRes[i,j] = tabCentres[plusProcheCentre(img[i,j], tabCentres)]

    plt.imshow(imgRes)
    plt.show()  

# reduire(img, 10, 5)