In [3]:
import matplotlib
from matplotlib import pyplot as plt
%matplotlib inline
matplotlib.rc("image", cmap="gray")

1. Convolution d'une image par un filtre dans une direction

Pour calculer la transformée en ondelettes, on a besoin de filtrer suivant les « colonnes » ou suivant les « lignes » de l'image. On définit la fonction convolve_2d(image, response, axis, reverse=False).

In [4]:
import numpy as np

def convolve_2d(image, response, axis, reverse=False):
    """Provided a discrete signal and a response, computes the convolution.
    Symmetric boundary is assumed.

    Parameters
    ----------
    image : array-like, shape (M, N)
        The image to be filtered.
    response : array-like, shape (P, )
        The impulse response of the filter to apply; note that the filter will
        only be applied in one direction (see the `axis` option).
    axis : int, either 0 or 1
        The image axis along which the filter will be applied.
    reverse : boolean
        If true, the convolution will be computed using 
        `filter[p] = response[-p]`, for p = -P+1, ..., 0

    Return
    ------
    filtered_signal : array-like, shape (M, N)
    """
    M, N = image.shape
    filtered_image = np.zeros((M, N))
    P = response.shape[0]
    if reverse:
        sign = -1
    else:
        sign = 1
    for p in range(P):
        if axis == 0:
            indices = (np.arange(M) - sign * p) % M
            filtered_image += image[indices] * response[p]
        if axis == 1:
            indices = (np.arange(N) - sign * p) % N
            filtered_image += image[:, indices] * response[p]
    return filtered_image

On teste la convolution pour un filtre qui moyenne les P pixels adjacents

In [9]:
from imageio import imread
P = 20
response = np.ones(P) / P

image = imread("boat.tiff")
filtered_v = convolve_2d(image, response, 0)
filtered_h = convolve_2d(image, response, 1)

fig = plt.figure(figsize=(12, 4))
ax1 = fig.add_subplot(1, 3, 1)
plt.imshow(image)
plt.axis("off")
plt.title("Image originale")

ax2 = fig.add_subplot(1, 3, 2)
plt.imshow(filtered_v)
plt.axis("off")
plt.title("Image filtrée selon l'axe vertical")

ax3 = fig.add_subplot(1, 3, 3)
plt.imshow(filtered_h)
plt.axis("off")
plt.title("Image filtrée selon l'axe horizontal");

2, 3. Sous-échantillonnage et sur-échantillonnage d'une image

In [11]:
def subsample_2d(image, axis):
    """
    Returns the image by selecting over two pixels in the direction of `axis`.
    """
    if axis == 0:
        return np.copy(image[::2])
    if axis == 1:
        return np.copy(image[:, ::2])
    raise(ValueError("`axis` should be either 0 or 1."))


def upsample_2d(image, axis):
    """
    Returns the image constructed by inserting a `0` between every 2 pixels 
    along the direction given by `axis`.
    """
    M, N = image.shape
    if axis == 0:
        expanded = np.zeros((2 * M, N))
        expanded[::2] = image
    elif axis == 1:
        expanded = np.zeros((M, 2 * N))
        expanded[:, ::2] = image
    else:
        raise(ValueError("`axis` should be either 0 or 1."))
    return expanded

On teste les fonctions ci-dessus sur une image ; pour bien voir le résultat on se limite à une petite partie de l'image originale.

In [25]:
M, N = image.shape
K, L = M // 16, N // 16
patch = image[M//2 - K//2:M//2 + K//2, N//2 - L//2:N//2 + L//2]

fig = plt.figure(figsize=(15, 3))
ax1 = fig.add_subplot(1, 5, 1)
plt.imshow(patch)
plt.axis("off")
plt.title("Image originale")

ax2 = fig.add_subplot(1, 5, 2)
plt.imshow(subsample_2d(patch, 0))
plt.axis("off")
plt.title("Sous-échantillonnage\n vertical")

ax3 = fig.add_subplot(1, 5, 3)
plt.imshow(subsample_2d(patch, 1))
plt.axis("off")
plt.title("Sous-échantillonnage\n horizontal")

ax4 = fig.add_subplot(1, 5, 4)
plt.imshow(upsample_2d(patch, 0))
plt.axis("off")
plt.title("Sur-échantillonnage\n vertical")

ax5 = fig.add_subplot(1, 5, 5)
plt.imshow(upsample_2d(patch, 1))
plt.axis("off")
plt.title("Sur-échantillonnage\n horizontal");

4, 5 Transformée en ondelettes discrète (directe et inverse)

In [47]:
def dwt_2d(image, h, g, level):
    """
    Computes the discrete wavelet transform, applying filters `h` and `g`
    recursively `level` times.

    Parameters
    ----------
    image : array-like, shape (M, N)
    h : low-pass filter, shape (P, )
    g : high-pass filter, shape (P, )
    level : int
        level of recursion.
    """
    M, N = image.shape
    transform = np.zeros((M, N), dtype=np.double)
    transform[:] = image

    K, L = M, N
    for i in range(level):
        # We define 4 sub-images, `ll`, `lh`, `hl`, `hh`
        # where `l` stands for low-pass and `h` stands for high-pass
        axis = 0
        img_l = subsample_2d(convolve_2d(transform[:K, :L], h, axis, 
                                         reverse=True), axis)
        img_h = subsample_2d(convolve_2d(transform[:K, :L], g, axis, 
                                         reverse=True), axis)
        axis = 1
        img_ll = subsample_2d(convolve_2d(img_l, h, axis, reverse=True), axis)
        img_lh = subsample_2d(convolve_2d(img_l, g, axis, reverse=True), axis)
        img_hl = subsample_2d(convolve_2d(img_h, h, axis, reverse=True), axis)
        img_hh = subsample_2d(convolve_2d(img_h, g, axis, reverse=True), axis)

        transform[        :(K // 2),         :(L // 2)] = img_ll
        transform[(K // 2):K       ,         :(L // 2)] = img_hl
        transform[        :(K // 2), (L // 2):L       ] = img_lh
        transform[(K // 2):K       , (L // 2):L       ] = img_hh

        K, L =  K // 2, L // 2
    return transform


def idwt_2d(transform, h, g, level):
    """
    Computes the inverse discrete wavelet transform.

    Parameters
    ----------
    transform : array-like, shape (M, N)
        Result of applying the DWT to the original image.
    h : low-pass filter, shape (P, )
    g : high-pass filter, shape (P, )
    level : int
        level of recursion.
    """
    M, N = transform.shape
    image = np.zeros((M, N), dtype=np.double)
    image[:] = transform
    for i in range(level - 1, -1, -1):
        K, L = M // 2**i, N // 2**i
        img_ll = image[        :(K // 2),         :(L // 2)]
        img_hl = image[(K // 2):K       ,         :(L // 2)]
        img_lh = image[        :(K // 2), (L // 2):L       ]
        img_hh = image[(K // 2):K       , (L // 2):L       ]

        axis = 1        
        img_ll_u = upsample_2d(img_ll, axis)
        img_hl_u = upsample_2d(img_hl, axis)
        img_lh_u = upsample_2d(img_lh, axis)
        img_hh_u = upsample_2d(img_hh, axis)

        img_h = convolve_2d(img_hl_u, h, axis) + convolve_2d(img_hh_u, g, axis)
        img_l = convolve_2d(img_ll_u, h, axis) + convolve_2d(img_lh_u, g, axis)

        axis = 0
        img_h_u = upsample_2d(img_h, axis)
        img_l_u = upsample_2d(img_l, axis)
        img = convolve_2d(img_h_u, g, axis) + convolve_2d(img_l_u, h, axis)
        image[:K, :L] = img

    return image

On teste la transformée et son inverse sur une image en utilisant l'ondelette de Haar, dont les filtres $h$ et $g$ correspondant sont : $$ h = [\frac{1}{\sqrt{2}}\ \frac{1}{\sqrt{2}}],\qquad g = [\frac{1}{\sqrt{2}}\ -\frac{1}{\sqrt{2}}],\qquad $$

In [42]:
h = np.array([1, 1]) / np.sqrt(2)
g = np.array([1, -1]) / np.sqrt(2)

On applique la transformée en ondelettes avec un niveau de récursion level = 3.

In [82]:
level = 3
image_transformed = dwt_2d(image, h, g, level)
image_reconstructed = idwt_2d(image_transformed, h, g, level)

fig = plt.figure(figsize=(12, 4))
fig.add_subplot(1, 3, 1)
plt.imshow(image)
plt.axis("off")
plt.title("Image originale")

fig.add_subplot(1, 3, 2)
plt.imshow(image_transformed)
plt.axis("off")
plt.title("DWT (après %d itérations)" % level)

fig.add_subplot(1, 3, 3)
plt.imshow(image_reconstructed)
plt.axis("off")
plt.title("iDWT (après %d itérations)" % level);

On vérifie qu'on retrouve bien l'image originale en calculant la différence.

In [83]:
difference = image - image_reconstructed
print("L'erreur la plus grande (en valeur absolue) commise est %.3e" % np.max(np.abs(difference)))
L'erreur la plus grande (en valeur absolue) commise est 4.263e-13

7, 8, 9 Utilisation de la transformée en ondelettes pour le débruitage

On s'aperçoit que la transformée en ondelettes d'une image naturelle contient beaucoup de coefficients nuls (ou presque nuls) ; on peut utiliser cette propriété pour débruiter une image. Regardons la répartition des coefficients et d'une « image » aléatoire constituée uniquement de bruit. On considère un rapport signal à bruit de 3dB.

In [110]:
SNR = 3
sigma = np.sqrt(np.mean(image ** 2) / 10 ** (SNR / 10))

noise = np.random.normal(0, sigma, size=(M, N))
noise_transformed = dwt_2d(noise, h, g, level)

fig = plt.figure(figsize=(8, 4))
fig.add_subplot(1, 2, 1)
plt.imshow(noise)
plt.axis("off")
plt.title("Une image de bruit")

fig.add_subplot(1, 2, 2)
plt.imshow(noise_transformed)
plt.axis("off")
plt.title("Sa transformée en ondelettes");

Pour mieux se rendre compte, on va comparer les histogrammes.

In [112]:
fig = plt.figure(figsize=(16, 4))

ax1 = fig.add_subplot(1, 4, 1)
plt.hist(np.ravel(image), bins=100)
plt.title("Histogramme\n d'une image naturelle")

ax2 = fig.add_subplot(1, 4, 2)
plt.hist(np.ravel(image_transformed), bins=100)
plt.title("Histogramme\n de sa DWT")

ax3 = fig.add_subplot(1, 4, 3)
plt.hist(np.ravel(noise), bins=100)
plt.title("Histogramme\n d'une image de bruit")

ax4 = fig.add_subplot(1, 4, 4)
plt.hist(np.ravel(noise_transformed), bins=100)
plt.title("Histogramme\n de sa DWT");

On peut par ailleurs vérifier sur cet exemple que la transformée ondelettes est linéaire.

In [104]:
# We need to take care of the data type
noisy_image = np.asarray(image, dtype=np.double) + noise
noisy_image_transformed = dwt_2d(noisy_image, h, g, level)
if not np.allclose(noisy_image_transformed, image_transformed + noise_transformed):
    print("DWT(image + noise) != DWT(image) + DWT(noise)")
else:
    print("DWT(image + noise) = DWT(image) + DWT(noise)")
DWT(image + noise) = DWT(image) + DWT(noise)

On propose de filtrer le bruit en seuillant les coefficients de l'image.

In [113]:
threshold = 3 * sigma
noisy_image_transformed_thr = np.copy(noisy_image_transformed)
noisy_image_transformed_thr[np.abs(noisy_image_transformed_thr) < threshold] = 0
noisy_image_filtered = idwt_2d(noisy_image_transformed_thr, h, g, level)

fig = plt.figure(figsize=(12, 4))
ax1 = fig.add_subplot(1, 3, 1)
plt.imshow(image)
plt.axis("off")
plt.title("Image originale")

ax2 = fig.add_subplot(1, 3, 2)
plt.imshow(noisy_image)
plt.axis("off")
plt.title("Image bruitée")

ax3 = fig.add_subplot(1, 3, 3)
plt.imshow(noisy_image_filtered)
plt.axis("off")
plt.title("Image bruitée filtrée");