середу, 2 серпня 2023 р.

Візуалізація фільтрів згорткових шарів

Згорткові нейронні мережі (CNN) використовуються для багатьох завдань машинного навчання, особливо для обробки зображень. Одним з ключових елементів CNN є згорткові шари, які використовують набори ваг (або фільтрів) для сканування вхідного зображення та виявлення різних характеристик, таких як краї, текстури та кольори.

Ми можемо візуалізувати ці фільтри, щоб отримати краще уявлення про те, як CNN "бачить" зображення. В цьому пості ми покажемо, як це зробити за допомогою TensorFlow і Keras.

Спочатку, давайте навчимо модель класифікації на наборі MNIST:

import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

mnist = tf.keras.datasets.mnist.load_data()
(X_train_full, y_train_full), (X_test, y_test) = mnist
X_train_full = X_train_full / 255.
X_test = X_test / 255.
X_train, X_valid = X_train_full[:-5000], X_train_full[-5000:]
y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:]

X_train = X_train[..., np.newaxis]
X_valid = X_valid[..., np.newaxis]
X_test = X_test[..., np.newaxis]

model= tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, kernel_size=13, padding="same",
                           activation="relu", kernel_initializer="he_normal"),
    tf.keras.layers.Conv2D(64, kernel_size=9, padding="same",
                           activation="relu", kernel_initializer="he_normal"),
    tf.keras.layers.MaxPool2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.25),
    tf.keras.layers.Dense(128, activation="relu",
                          kernel_initializer="he_normal"),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10, activation="softmax")
])

model.compile(loss="sparse_categorical_crossentropy", optimizer="nadam",
              metrics=["accuracy"])

model.fit(X_train, y_train, epochs=4, validation_data=(X_valid, y_valid))
model.evaluate(X_test, y_test)

Після тренування моделі ми можемо отримати ваги першого згорткового шару:

filters, biases = model.layers[0].get_weights()

Це дає нам масив фільтрів, які можна візуалізувати за допомогою Matplotlib:

fig, ax = plt.subplots(nrows=4, ncols=8, figsize=(12, 6))
for i in range(4):
    for j in range(8):
        ax[i][j].imshow(filters[:, :, 0, i*8+j], cmap='cividis')
        ax[i][j].axis('off')
plt.show()

На виході отримаємо 32 візуалізованих фільтри, які показують, як кожен фільтр реагує на вхідні дані. Це може бути корисним інструментом для розуміння того, як CNN працює.

Для виводу фільтрів другого шару використаємо такий код:

filters, biases = model.layers[1].get_weights()
fig, ax = plt.subplots(nrows=8, ncols=8, figsize=(12, 12))
for i in range(8):
    for j in range(8):
        ax[i][j].imshow(filters[:, :, 0, i*8+j], cmap='cividis')
        ax[i][j].axis('off')
plt.show()
Розміри ядер було вибрано великими для кращої візуалізації.

Немає коментарів:

Дописати коментар