Regularization and Hyperparameter Tuning using Fashion-MNIST¶
Author: Tianxiang (Adam) Gao
Course: CSC 383 / 483 – Applied Deep Learning
Description:
In this assignment, you will explore how neural networks can overfit training data and how regularization techniques such as dropout and weight decay can improve generalization.
You will also experiment with hyperparameter tuning, adjusting parameters like the learning rate and dropout rate to find combinations that lead to better model performance.
Setup¶
We will first import some useful libraries:
numpyfor numerical operations (e.g., arrays, random sampling).kerasfor loading the MNIST dataset and building deep learning models.keras.layersprovides the building blocks (dense layers, convolutional layers, activation functions, etc.) to design neural networks.matplotlibfor visualizing images and plotting graphs.sklearn.model_selectionfor splitting a validation set from the training data.
import numpy as np
import keras
from keras import layers, regularizers
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
Prepare the Data [0/0]¶
- Use
keras.datasets.fashion_mnist.load_data()to load the Fashion-MNIST training and test sets. - Normalize all pixel values from integers in the range [0, 255] to floating-point numbers between 0 and 1.
(x_train_full, y_train_full), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train_full = x_train_full.astype("float32") / 255.0
y_train_full = y_train_full.squeeze().astype("int32")
x_test = x_test.astype("float32") / 255.0
y_test = y_test.squeeze().astype("int32")
print("x_train shape:", x_train_full.shape)
num_classes, input_shape = 10, x_train_full.shape[1:]
print("num_classes:", num_classes)
print("input_shape:", input_shape)
x_train shape: (60000, 28, 28) num_classes: 10 input_shape: (28, 28)
Visualize the Data [0/0]¶
- Randomly select 9 images from the training set
x_train. Display them in a 3×3 grid using Matplotlib (plt.subplot). For each image, show its corresponding class label (fromy_train) as the subplot title.
indices = np.random.choice(len(x_train_full), 9, replace=False)
plt.figure(figsize=(6, 6))
for i, idx in enumerate(indices):
plt.subplot(3, 3, i + 1)
plt.imshow(x_train_full[idx], cmap="gray")
# plt.imshow(x_train_full[idx])
plt.title(f"Label: {y_train_full[idx]}")
plt.axis("off")
plt.tight_layout()
plt.show()
Validation Set [10/10]¶
- We will use
train_test_split()split 50% of the images from the training dataset to create a validation set,
which will be used to help us tune the hyperparameters during training.
x_train, x_val, y_train, y_val = train_test_split(
)
print("Train subset:", x_train.shape)
print("Validation subset:", x_val.shape)
print("Test set:", x_test.shape)
Train subset: (30000, 28, 28) Validation subset: (30000, 28, 28) Test set: (10000, 28, 28)
Build the Model [30/30]¶
Implement a helper function
make_model()
that returns a simple two-layer MLP built usingkeras.Sequentialwith the following layers:- Input layer: accepts images of shape
input_shape. - Flatten layer: converts each 2D image into a 1D vector.
- Dense layer: fully connected layer with
widthhidden units and"relu"activation. - Dropout layer: randomly drops a fraction of hidden activations (
dropout_rate) during training
to prevent overfitting. Skip this layer ifdropout_rate = 0. - Output layer: fully connected layer with
num_classesunits (one per class) and"softmax"activation.
- Input layer: accepts images of shape
Create a
base_modelusing your helper function and inspect the model by callingmodel.summary()to display the network architecture, output shapes, and number of parameters in each layer.Save the
initial_weightsof thebase_modelfor reuse in optimizer comparisons.
def make_model(num_classes, input_shape, width=128, dropout_rate=0.0):
model = keras.Sequential([
])
return model
base_model = make_model(num_classes, input_shape, width=128)
base_model.summary()
initial_weights = base_model.get_weights()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ flatten (Flatten) │ (None, 784) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense (Dense) │ (None, 128) │ 100,480 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ lambda (Lambda) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 10) │ 1,290 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 101,770 (397.54 KB)
Trainable params: 101,770 (397.54 KB)
Non-trainable params: 0 (0.00 B)
Define Optimizer Function [0/0]¶
To easily switch between different optimization algorithms, copy and paste the helper function
get_optimizer()from previous assignments.
This function should return the corresponding Keras optimizer object based on its name.Implement the function with the following behavior:
"sgd"→ standard stochastic gradient descent (SGD)."momentum"→ SGD with momentum (momentum=0.9)."rmsprop"→ RMSprop optimizer (adaptive learning rate)."adam"→ Adam optimizer (adaptive learning rate with momentum).- Raise a
ValueErrorif an unknown name is provided.
def get_optimizer(name, lr=1e-3):
if name == "sgd":
return keras.optimizers.SGD(learning_rate=lr)
if name == "momentum":
return keras.optimizers.SGD(learning_rate=lr, momentum=0.9)
if name == "rmsprop":
return keras.optimizers.RMSprop(learning_rate=lr)
if name == "adam":
return keras.optimizers.Adam(learning_rate=lr)
raise ValueError(f"Unknown optimizer: {name}")
Train and Compare Optimizers [20/20]¶
We will also define a helper function
train()to train the model using different optimization choices.
This function is adapted from previous assignments and includes additional input arguments such as
dropout_rate,lr, andwidth.The function should:
- Recreate a fresh model via
make_model(...). - Reset to the same
initial_weightsto ensure all optimizers start from the same point. - Build the optimizer using
get_optimizer(name, lr). - Compile the model with
loss="sparse_categorical_crossentropy"andmetrics=["accuracy"]. - Train the model on
(x_train, y_train)and evaluate it using the validation data(x_val, y_val). - Return the training
Historyobject for later comparison.
- Recreate a fresh model via
def train(name, dropout_rate=0.0, lr=1e-3, batch_size=64, width=128, epochs=100):
model =
model.set_weights(initial_weights)
opt =
model.compile(loss="sparse_categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
print(f"\n===Training with {name}===")
hist =
return hist
Experiment: Training with SGD [5/5]¶
- Now train your model using stochastic gradient descent (SGD) with a high learning rate (
lr=1e-1) and no dropout. This setup will intentionally cause overfitting or even unstable training, allowing you to observe how training and validation accuracy diverge.
hist = train("sgd", dropout_rate=0.0, lr=1e-1)
===Training with sgd=== Epoch 1/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 7s 12ms/step - accuracy: 0.6902 - loss: 0.9003 - val_accuracy: 0.7994 - val_loss: 0.5654 Epoch 2/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 6s 14ms/step - accuracy: 0.8273 - loss: 0.4948 - val_accuracy: 0.8408 - val_loss: 0.4528 Epoch 3/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8489 - loss: 0.4317 - val_accuracy: 0.8334 - val_loss: 0.4438 Epoch 4/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8567 - loss: 0.4073 - val_accuracy: 0.8327 - val_loss: 0.4501 Epoch 5/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8618 - loss: 0.3820 - val_accuracy: 0.8530 - val_loss: 0.4107 Epoch 6/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8668 - loss: 0.3741 - val_accuracy: 0.8638 - val_loss: 0.3811 Epoch 7/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.8734 - loss: 0.3577 - val_accuracy: 0.8589 - val_loss: 0.3862 Epoch 8/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8797 - loss: 0.3389 - val_accuracy: 0.8473 - val_loss: 0.4038 Epoch 9/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8780 - loss: 0.3377 - val_accuracy: 0.8610 - val_loss: 0.3851 Epoch 10/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8827 - loss: 0.3217 - val_accuracy: 0.8735 - val_loss: 0.3557 Epoch 11/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8883 - loss: 0.3082 - val_accuracy: 0.8681 - val_loss: 0.3670 Epoch 12/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.8942 - loss: 0.2920 - val_accuracy: 0.8764 - val_loss: 0.3430 Epoch 13/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8922 - loss: 0.3011 - val_accuracy: 0.8676 - val_loss: 0.3726 Epoch 14/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.8970 - loss: 0.2858 - val_accuracy: 0.8637 - val_loss: 0.3611 Epoch 15/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9004 - loss: 0.2795 - val_accuracy: 0.8786 - val_loss: 0.3373 Epoch 16/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9021 - loss: 0.2745 - val_accuracy: 0.8827 - val_loss: 0.3365 Epoch 17/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 5ms/step - accuracy: 0.8990 - loss: 0.2780 - val_accuracy: 0.8485 - val_loss: 0.4087 Epoch 18/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8996 - loss: 0.2658 - val_accuracy: 0.8813 - val_loss: 0.3329 Epoch 19/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9078 - loss: 0.2560 - val_accuracy: 0.8783 - val_loss: 0.3412 Epoch 20/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9109 - loss: 0.2473 - val_accuracy: 0.8789 - val_loss: 0.3398 Epoch 21/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9091 - loss: 0.2430 - val_accuracy: 0.8772 - val_loss: 0.3429 Epoch 22/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9160 - loss: 0.2347 - val_accuracy: 0.8845 - val_loss: 0.3239 Epoch 23/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9116 - loss: 0.2438 - val_accuracy: 0.8664 - val_loss: 0.3800 Epoch 24/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9168 - loss: 0.2332 - val_accuracy: 0.8833 - val_loss: 0.3274 Epoch 25/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9168 - loss: 0.2328 - val_accuracy: 0.8797 - val_loss: 0.3446 Epoch 26/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9168 - loss: 0.2275 - val_accuracy: 0.8855 - val_loss: 0.3225 Epoch 27/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9194 - loss: 0.2213 - val_accuracy: 0.8804 - val_loss: 0.3410 Epoch 28/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9215 - loss: 0.2187 - val_accuracy: 0.8834 - val_loss: 0.3325 Epoch 29/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9230 - loss: 0.2122 - val_accuracy: 0.8868 - val_loss: 0.3248 Epoch 30/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9232 - loss: 0.2119 - val_accuracy: 0.8711 - val_loss: 0.3699 Epoch 31/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9283 - loss: 0.2000 - val_accuracy: 0.8739 - val_loss: 0.3700 Epoch 32/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9277 - loss: 0.1979 - val_accuracy: 0.8859 - val_loss: 0.3392 Epoch 33/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9290 - loss: 0.1974 - val_accuracy: 0.8796 - val_loss: 0.3505 Epoch 34/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9303 - loss: 0.1911 - val_accuracy: 0.8819 - val_loss: 0.3528 Epoch 35/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - accuracy: 0.9309 - loss: 0.1896 - val_accuracy: 0.8849 - val_loss: 0.3400 Epoch 36/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9314 - loss: 0.1867 - val_accuracy: 0.8836 - val_loss: 0.3306 Epoch 37/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9353 - loss: 0.1842 - val_accuracy: 0.8623 - val_loss: 0.4070 Epoch 38/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 9ms/step - accuracy: 0.9362 - loss: 0.1802 - val_accuracy: 0.8862 - val_loss: 0.3358 Epoch 39/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9350 - loss: 0.1784 - val_accuracy: 0.8872 - val_loss: 0.3386 Epoch 40/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9371 - loss: 0.1758 - val_accuracy: 0.8851 - val_loss: 0.3455 Epoch 41/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9369 - loss: 0.1752 - val_accuracy: 0.8872 - val_loss: 0.3337 Epoch 42/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9385 - loss: 0.1690 - val_accuracy: 0.8849 - val_loss: 0.3399 Epoch 43/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 5s 10ms/step - accuracy: 0.9418 - loss: 0.1652 - val_accuracy: 0.8839 - val_loss: 0.3560 Epoch 44/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9378 - loss: 0.1700 - val_accuracy: 0.8857 - val_loss: 0.3435 Epoch 45/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9444 - loss: 0.1566 - val_accuracy: 0.8593 - val_loss: 0.4355 Epoch 46/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9432 - loss: 0.1593 - val_accuracy: 0.8881 - val_loss: 0.3371 Epoch 47/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9457 - loss: 0.1516 - val_accuracy: 0.8869 - val_loss: 0.3480 Epoch 48/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 5ms/step - accuracy: 0.9460 - loss: 0.1518 - val_accuracy: 0.8807 - val_loss: 0.3733 Epoch 49/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9445 - loss: 0.1516 - val_accuracy: 0.8820 - val_loss: 0.3519 Epoch 50/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9484 - loss: 0.1424 - val_accuracy: 0.8844 - val_loss: 0.3444 Epoch 51/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9478 - loss: 0.1479 - val_accuracy: 0.8826 - val_loss: 0.3670 Epoch 52/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9503 - loss: 0.1402 - val_accuracy: 0.8902 - val_loss: 0.3457 Epoch 53/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9483 - loss: 0.1430 - val_accuracy: 0.8843 - val_loss: 0.3596 Epoch 54/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9491 - loss: 0.1407 - val_accuracy: 0.8880 - val_loss: 0.3633 Epoch 55/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9518 - loss: 0.1367 - val_accuracy: 0.8900 - val_loss: 0.3496 Epoch 56/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9539 - loss: 0.1300 - val_accuracy: 0.8894 - val_loss: 0.3517 Epoch 57/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9545 - loss: 0.1264 - val_accuracy: 0.8802 - val_loss: 0.3805 Epoch 58/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9542 - loss: 0.1292 - val_accuracy: 0.8870 - val_loss: 0.3590 Epoch 59/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9565 - loss: 0.1277 - val_accuracy: 0.8890 - val_loss: 0.3638 Epoch 60/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9570 - loss: 0.1226 - val_accuracy: 0.8869 - val_loss: 0.3701 Epoch 61/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9559 - loss: 0.1255 - val_accuracy: 0.8827 - val_loss: 0.3996 Epoch 62/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9568 - loss: 0.1208 - val_accuracy: 0.8851 - val_loss: 0.3738 Epoch 63/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9602 - loss: 0.1128 - val_accuracy: 0.8890 - val_loss: 0.3637 Epoch 64/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9563 - loss: 0.1219 - val_accuracy: 0.8798 - val_loss: 0.3963 Epoch 65/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9610 - loss: 0.1125 - val_accuracy: 0.8845 - val_loss: 0.3819 Epoch 66/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9595 - loss: 0.1115 - val_accuracy: 0.8877 - val_loss: 0.3766 Epoch 67/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9587 - loss: 0.1157 - val_accuracy: 0.8857 - val_loss: 0.3765 Epoch 68/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9634 - loss: 0.1082 - val_accuracy: 0.8810 - val_loss: 0.3915 Epoch 69/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9617 - loss: 0.1084 - val_accuracy: 0.8899 - val_loss: 0.3700 Epoch 70/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 9ms/step - accuracy: 0.9645 - loss: 0.1018 - val_accuracy: 0.8779 - val_loss: 0.4225 Epoch 71/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9637 - loss: 0.1036 - val_accuracy: 0.8875 - val_loss: 0.3794 Epoch 72/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9671 - loss: 0.0961 - val_accuracy: 0.8785 - val_loss: 0.4006 Epoch 73/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9660 - loss: 0.0991 - val_accuracy: 0.8875 - val_loss: 0.3893 Epoch 74/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9676 - loss: 0.0940 - val_accuracy: 0.8874 - val_loss: 0.3882 Epoch 75/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 5s 10ms/step - accuracy: 0.9677 - loss: 0.0967 - val_accuracy: 0.8888 - val_loss: 0.3856 Epoch 76/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9658 - loss: 0.1007 - val_accuracy: 0.8675 - val_loss: 0.4644 Epoch 77/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9665 - loss: 0.0964 - val_accuracy: 0.8794 - val_loss: 0.4152 Epoch 78/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9656 - loss: 0.0967 - val_accuracy: 0.8876 - val_loss: 0.4112 Epoch 79/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9689 - loss: 0.0920 - val_accuracy: 0.8913 - val_loss: 0.4002 Epoch 80/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9710 - loss: 0.0835 - val_accuracy: 0.8841 - val_loss: 0.4123 Epoch 81/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9708 - loss: 0.0850 - val_accuracy: 0.8881 - val_loss: 0.4041 Epoch 82/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9704 - loss: 0.0881 - val_accuracy: 0.8822 - val_loss: 0.4268 Epoch 83/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9705 - loss: 0.0855 - val_accuracy: 0.8838 - val_loss: 0.4085 Epoch 84/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9686 - loss: 0.0896 - val_accuracy: 0.8824 - val_loss: 0.4316 Epoch 85/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 5ms/step - accuracy: 0.9716 - loss: 0.0824 - val_accuracy: 0.8884 - val_loss: 0.4095 Epoch 86/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9738 - loss: 0.0800 - val_accuracy: 0.8797 - val_loss: 0.4275 Epoch 87/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9721 - loss: 0.0815 - val_accuracy: 0.8876 - val_loss: 0.4213 Epoch 88/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9732 - loss: 0.0801 - val_accuracy: 0.8884 - val_loss: 0.4096 Epoch 89/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9748 - loss: 0.0773 - val_accuracy: 0.8801 - val_loss: 0.4408 Epoch 90/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9744 - loss: 0.0779 - val_accuracy: 0.8828 - val_loss: 0.4289 Epoch 91/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9751 - loss: 0.0741 - val_accuracy: 0.8877 - val_loss: 0.4249 Epoch 92/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9731 - loss: 0.0775 - val_accuracy: 0.8830 - val_loss: 0.4381 Epoch 93/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9761 - loss: 0.0704 - val_accuracy: 0.8862 - val_loss: 0.4295 Epoch 94/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9759 - loss: 0.0737 - val_accuracy: 0.8711 - val_loss: 0.4990 Epoch 95/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9762 - loss: 0.0700 - val_accuracy: 0.8887 - val_loss: 0.4202 Epoch 96/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9766 - loss: 0.0703 - val_accuracy: 0.8858 - val_loss: 0.4291 Epoch 97/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9768 - loss: 0.0668 - val_accuracy: 0.8829 - val_loss: 0.4602 Epoch 98/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9794 - loss: 0.0654 - val_accuracy: 0.8870 - val_loss: 0.4382 Epoch 99/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9779 - loss: 0.0658 - val_accuracy: 0.8898 - val_loss: 0.4308 Epoch 100/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9786 - loss: 0.0663 - val_accuracy: 0.8828 - val_loss: 0.4462
Plot Training and Validation Curves [5/5]¶
Define a helper function
plot_history(hist, log_scale=False)to visualize the training and validation loss curves from theHistoryobject returned bymodel.fit().- Plot both training loss and validation loss on the same graph.
- Add axis labels, a title, and a legend for clarity.
- Use the optional argument
log_scale=Trueto show the loss on a logarithmic scale.
def plot_history(hist, log_scale=False):
plt.figure(figsize=(8,5))
plt.plot(hist.history["loss"], color="blue", linestyle="-", label="train")
plt.plot(hist.history["val_loss"], color="red", linestyle="--", label="val")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training vs Validation Loss")
plt.legend()
plt.grid(True, which="both", ls=":")
if log_scale:
plt.yscale("log")
plt.ylabel("Loss (log scale)")
plt.show()
# Plot the result
Experiment: Training with SGD + Dropout [5/5]¶
- Now train the model again using SGD with the same learning rate (
lr=1e-1) but add dropout = 0.5 to apply regularization. Compare the results with the previous run to see how dropout helps reduce overfitting and improves validation performance.
hist =
===Training with sgd=== Epoch 1/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.6264 - loss: 1.0564 - val_accuracy: 0.8114 - val_loss: 0.5471 Epoch 2/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.7895 - loss: 0.6035 - val_accuracy: 0.8197 - val_loss: 0.4904 Epoch 3/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8099 - loss: 0.5358 - val_accuracy: 0.8420 - val_loss: 0.4408 Epoch 4/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8215 - loss: 0.5000 - val_accuracy: 0.8534 - val_loss: 0.4137 Epoch 5/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8340 - loss: 0.4696 - val_accuracy: 0.8457 - val_loss: 0.4211 Epoch 6/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.8338 - loss: 0.4648 - val_accuracy: 0.8549 - val_loss: 0.3986 Epoch 7/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8392 - loss: 0.4464 - val_accuracy: 0.8610 - val_loss: 0.3844 Epoch 8/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8420 - loss: 0.4339 - val_accuracy: 0.8580 - val_loss: 0.3853 Epoch 9/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.8457 - loss: 0.4255 - val_accuracy: 0.8560 - val_loss: 0.3884 Epoch 10/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.8499 - loss: 0.4162 - val_accuracy: 0.8629 - val_loss: 0.3767 Epoch 11/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.8519 - loss: 0.4051 - val_accuracy: 0.8623 - val_loss: 0.3720 Epoch 12/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8557 - loss: 0.3976 - val_accuracy: 0.8654 - val_loss: 0.3692 Epoch 13/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8552 - loss: 0.4015 - val_accuracy: 0.8657 - val_loss: 0.3651 Epoch 14/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8569 - loss: 0.3961 - val_accuracy: 0.8649 - val_loss: 0.3711 Epoch 15/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 9ms/step - accuracy: 0.8563 - loss: 0.3883 - val_accuracy: 0.8674 - val_loss: 0.3576 Epoch 16/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8592 - loss: 0.3819 - val_accuracy: 0.8699 - val_loss: 0.3563 Epoch 17/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8627 - loss: 0.3734 - val_accuracy: 0.8748 - val_loss: 0.3460 Epoch 18/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8647 - loss: 0.3704 - val_accuracy: 0.8738 - val_loss: 0.3494 Epoch 19/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.8666 - loss: 0.3615 - val_accuracy: 0.8744 - val_loss: 0.3446 Epoch 20/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.8681 - loss: 0.3638 - val_accuracy: 0.8695 - val_loss: 0.3578 Epoch 21/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8630 - loss: 0.3606 - val_accuracy: 0.8723 - val_loss: 0.3439 Epoch 22/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8694 - loss: 0.3529 - val_accuracy: 0.8753 - val_loss: 0.3465 Epoch 23/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8706 - loss: 0.3479 - val_accuracy: 0.8761 - val_loss: 0.3428 Epoch 24/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 9ms/step - accuracy: 0.8719 - loss: 0.3463 - val_accuracy: 0.8774 - val_loss: 0.3374 Epoch 25/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8684 - loss: 0.3498 - val_accuracy: 0.8769 - val_loss: 0.3379 Epoch 26/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.8748 - loss: 0.3378 - val_accuracy: 0.8714 - val_loss: 0.3507 Epoch 27/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8767 - loss: 0.3312 - val_accuracy: 0.8729 - val_loss: 0.3474 Epoch 28/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 9ms/step - accuracy: 0.8763 - loss: 0.3326 - val_accuracy: 0.8804 - val_loss: 0.3334 Epoch 29/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8792 - loss: 0.3210 - val_accuracy: 0.8819 - val_loss: 0.3281 Epoch 30/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8769 - loss: 0.3375 - val_accuracy: 0.8771 - val_loss: 0.3422 Epoch 31/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 9ms/step - accuracy: 0.8794 - loss: 0.3297 - val_accuracy: 0.8776 - val_loss: 0.3380 Epoch 32/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step - accuracy: 0.8780 - loss: 0.3245 - val_accuracy: 0.8753 - val_loss: 0.3470 Epoch 33/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.8836 - loss: 0.3156 - val_accuracy: 0.8770 - val_loss: 0.3419 Epoch 34/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8803 - loss: 0.3226 - val_accuracy: 0.8803 - val_loss: 0.3323 Epoch 35/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8789 - loss: 0.3141 - val_accuracy: 0.8789 - val_loss: 0.3328 Epoch 36/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 5s 10ms/step - accuracy: 0.8804 - loss: 0.3178 - val_accuracy: 0.8825 - val_loss: 0.3306 Epoch 37/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8832 - loss: 0.3106 - val_accuracy: 0.8744 - val_loss: 0.3428 Epoch 38/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.8844 - loss: 0.3089 - val_accuracy: 0.8814 - val_loss: 0.3308 Epoch 39/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.8821 - loss: 0.3108 - val_accuracy: 0.8816 - val_loss: 0.3359 Epoch 40/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 9ms/step - accuracy: 0.8826 - loss: 0.3098 - val_accuracy: 0.8832 - val_loss: 0.3268 Epoch 41/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.8834 - loss: 0.3092 - val_accuracy: 0.8814 - val_loss: 0.3328 Epoch 42/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.8848 - loss: 0.3051 - val_accuracy: 0.8820 - val_loss: 0.3300 Epoch 43/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8858 - loss: 0.3053 - val_accuracy: 0.8839 - val_loss: 0.3272 Epoch 44/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.8831 - loss: 0.3048 - val_accuracy: 0.8790 - val_loss: 0.3371 Epoch 45/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.8873 - loss: 0.2987 - val_accuracy: 0.8844 - val_loss: 0.3289 Epoch 46/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8905 - loss: 0.2918 - val_accuracy: 0.8826 - val_loss: 0.3276 Epoch 47/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8894 - loss: 0.2962 - val_accuracy: 0.8840 - val_loss: 0.3236 Epoch 48/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8905 - loss: 0.2935 - val_accuracy: 0.8850 - val_loss: 0.3258 Epoch 49/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.8876 - loss: 0.2933 - val_accuracy: 0.8831 - val_loss: 0.3285 Epoch 50/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.8954 - loss: 0.2828 - val_accuracy: 0.8847 - val_loss: 0.3285 Epoch 51/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8902 - loss: 0.2898 - val_accuracy: 0.8841 - val_loss: 0.3247 Epoch 52/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8926 - loss: 0.2867 - val_accuracy: 0.8819 - val_loss: 0.3341 Epoch 53/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8916 - loss: 0.2851 - val_accuracy: 0.8831 - val_loss: 0.3295 Epoch 54/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.8948 - loss: 0.2760 - val_accuracy: 0.8808 - val_loss: 0.3348 Epoch 55/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 5ms/step - accuracy: 0.8883 - loss: 0.2903 - val_accuracy: 0.8839 - val_loss: 0.3327 Epoch 56/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8941 - loss: 0.2833 - val_accuracy: 0.8820 - val_loss: 0.3448 Epoch 57/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8964 - loss: 0.2769 - val_accuracy: 0.8867 - val_loss: 0.3315 Epoch 58/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.8869 - loss: 0.2915 - val_accuracy: 0.8840 - val_loss: 0.3359 Epoch 59/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.8907 - loss: 0.2899 - val_accuracy: 0.8824 - val_loss: 0.3351 Epoch 60/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8921 - loss: 0.2853 - val_accuracy: 0.8817 - val_loss: 0.3390 Epoch 61/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.8981 - loss: 0.2720 - val_accuracy: 0.8845 - val_loss: 0.3315 Epoch 62/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.8951 - loss: 0.2763 - val_accuracy: 0.8838 - val_loss: 0.3403 Epoch 63/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.8948 - loss: 0.2799 - val_accuracy: 0.8857 - val_loss: 0.3278 Epoch 64/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8956 - loss: 0.2721 - val_accuracy: 0.8869 - val_loss: 0.3277 Epoch 65/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8952 - loss: 0.2742 - val_accuracy: 0.8869 - val_loss: 0.3275 Epoch 66/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8942 - loss: 0.2748 - val_accuracy: 0.8858 - val_loss: 0.3321 Epoch 67/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.8995 - loss: 0.2590 - val_accuracy: 0.8839 - val_loss: 0.3357 Epoch 68/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.8994 - loss: 0.2656 - val_accuracy: 0.8856 - val_loss: 0.3322 Epoch 69/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8971 - loss: 0.2731 - val_accuracy: 0.8831 - val_loss: 0.3397 Epoch 70/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8996 - loss: 0.2622 - val_accuracy: 0.8864 - val_loss: 0.3290 Epoch 71/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8974 - loss: 0.2652 - val_accuracy: 0.8877 - val_loss: 0.3294 Epoch 72/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9034 - loss: 0.2554 - val_accuracy: 0.8854 - val_loss: 0.3302 Epoch 73/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9004 - loss: 0.2595 - val_accuracy: 0.8876 - val_loss: 0.3312 Epoch 74/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9023 - loss: 0.2548 - val_accuracy: 0.8846 - val_loss: 0.3376 Epoch 75/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8987 - loss: 0.2546 - val_accuracy: 0.8834 - val_loss: 0.3303 Epoch 76/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9036 - loss: 0.2523 - val_accuracy: 0.8875 - val_loss: 0.3342 Epoch 77/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9017 - loss: 0.2577 - val_accuracy: 0.8847 - val_loss: 0.3412 Epoch 78/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9016 - loss: 0.2552 - val_accuracy: 0.8820 - val_loss: 0.3470 Epoch 79/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8981 - loss: 0.2621 - val_accuracy: 0.8878 - val_loss: 0.3276 Epoch 80/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.8998 - loss: 0.2564 - val_accuracy: 0.8876 - val_loss: 0.3338 Epoch 81/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9046 - loss: 0.2488 - val_accuracy: 0.8860 - val_loss: 0.3313 Epoch 82/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9002 - loss: 0.2582 - val_accuracy: 0.8862 - val_loss: 0.3392 Epoch 83/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9001 - loss: 0.2546 - val_accuracy: 0.8855 - val_loss: 0.3422 Epoch 84/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9070 - loss: 0.2453 - val_accuracy: 0.8882 - val_loss: 0.3361 Epoch 85/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9056 - loss: 0.2433 - val_accuracy: 0.8858 - val_loss: 0.3384 Epoch 86/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9053 - loss: 0.2458 - val_accuracy: 0.8873 - val_loss: 0.3347 Epoch 87/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9092 - loss: 0.2393 - val_accuracy: 0.8849 - val_loss: 0.3373 Epoch 88/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9061 - loss: 0.2453 - val_accuracy: 0.8855 - val_loss: 0.3410 Epoch 89/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9061 - loss: 0.2457 - val_accuracy: 0.8852 - val_loss: 0.3407 Epoch 90/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9079 - loss: 0.2458 - val_accuracy: 0.8865 - val_loss: 0.3372 Epoch 91/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9057 - loss: 0.2442 - val_accuracy: 0.8817 - val_loss: 0.3513 Epoch 92/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9074 - loss: 0.2381 - val_accuracy: 0.8888 - val_loss: 0.3285 Epoch 93/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9082 - loss: 0.2434 - val_accuracy: 0.8858 - val_loss: 0.3455 Epoch 94/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9057 - loss: 0.2424 - val_accuracy: 0.8851 - val_loss: 0.3461 Epoch 95/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9088 - loss: 0.2398 - val_accuracy: 0.8878 - val_loss: 0.3448 Epoch 96/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9051 - loss: 0.2418 - val_accuracy: 0.8856 - val_loss: 0.3442 Epoch 97/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9099 - loss: 0.2358 - val_accuracy: 0.8883 - val_loss: 0.3428 Epoch 98/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9092 - loss: 0.2373 - val_accuracy: 0.8872 - val_loss: 0.3461 Epoch 99/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.9064 - loss: 0.2416 - val_accuracy: 0.8865 - val_loss: 0.3458 Epoch 100/100 469/469 ━━━━━━━━━━━━━━━━━━━━ 5s 10ms/step - accuracy: 0.9042 - loss: 0.2392 - val_accuracy: 0.8868 - val_loss: 0.3574
plot_history(hist)
Hyperparameter Search: Learning Rate vs Dropout [25/25]¶
- Now we will perform a simple grid search over two important hyperparameters: Learning rate (
lr) and Dropout rate (dropout)
For each combination of learning rate and dropout rate:
- Build a new model using
make_model(...). - train the model for a fixed number of epochs (e.g. 20) using the SGD optimizer
- and record both the training accuracy and validation accuracy from the last epoch.
This will help us visualize which combinations lead to underfitting, overfitting, or the best generalization.
# Define search ranges
import random
learning_rates = [1e-4, 1e-3, 1e-2, 1e-1]
dropouts = [0.0, 0.2, 0.5, 0.7]
results = [] # store results in a list of dicts
trials = 10
for lr in learning_rates:
for drop in dropouts:
print(f"\nTraining with lr={lr}, dropout={drop}")
model =
model.compile(
optimizer=keras.optimizers.SGD(learning_rate=lr),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"]
)
hist =
val_acc =
train_acc =
print(f"train_acc={train_acc:.3f}, val_acc={val_acc:.3f}")
results.append([lr, drop, train_acc, val_acc])
Training with lr=0.0001, dropout=0.0 train_acc=0.569, val_acc=0.573 Training with lr=0.0001, dropout=0.2 train_acc=0.501, val_acc=0.569 Training with lr=0.0001, dropout=0.5 train_acc=0.394, val_acc=0.577 Training with lr=0.0001, dropout=0.7 train_acc=0.322, val_acc=0.538 Training with lr=0.001, dropout=0.0 train_acc=0.755, val_acc=0.760 Training with lr=0.001, dropout=0.2 train_acc=0.720, val_acc=0.750 Training with lr=0.001, dropout=0.5 train_acc=0.679, val_acc=0.734 Training with lr=0.001, dropout=0.7 train_acc=0.615, val_acc=0.718 Training with lr=0.01, dropout=0.0 train_acc=0.843, val_acc=0.842 Training with lr=0.01, dropout=0.2 train_acc=0.831, val_acc=0.840 Training with lr=0.01, dropout=0.5 train_acc=0.813, val_acc=0.835 Training with lr=0.01, dropout=0.7 train_acc=0.780, val_acc=0.827 Training with lr=0.1, dropout=0.0 train_acc=0.894, val_acc=0.862 Training with lr=0.1, dropout=0.2 train_acc=0.886, val_acc=0.880 Training with lr=0.1, dropout=0.5 train_acc=0.863, val_acc=0.869 Training with lr=0.1, dropout=0.7 train_acc=0.828, val_acc=0.858
print("\n=== Summary ===")
print(" lr\t dropout\t train_acc\t val_acc")
for r in results:
print(f"{r[0]:.0e}\t {r[1]:.1f}\t\t {r[2]:.3f}\t\t {r[3]:.3f}")
=== Summary === lr dropout train_acc val_acc 1e-04 0.0 0.569 0.573 1e-04 0.2 0.501 0.569 1e-04 0.5 0.394 0.577 1e-04 0.7 0.322 0.538 1e-03 0.0 0.755 0.760 1e-03 0.2 0.720 0.750 1e-03 0.5 0.679 0.734 1e-03 0.7 0.615 0.718 1e-02 0.0 0.843 0.842 1e-02 0.2 0.831 0.840 1e-02 0.5 0.813 0.835 1e-02 0.7 0.780 0.827 1e-01 0.0 0.894 0.862 1e-01 0.2 0.886 0.880 1e-01 0.5 0.863 0.869 1e-01 0.7 0.828 0.858
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
# Convert results to NumPy array
results = np.array(results, dtype=float)
lr = results[:, 0]
drop = results[:, 1]
train_acc = results[:, 2]
val_acc = results[:, 3]
# Normalize validation accuracy to [0, 1]
val_acc_norm = (val_acc - val_acc.min()) / (val_acc.max() - val_acc.min() + 1e-8)
# Exaggerate size differences for visibility
sizes = 200 + 4000 * (val_acc_norm ** 3)
plt.figure(figsize=(8, 6))
plt.scatter(
lr, drop,
s=sizes,
color="royalblue",
alpha=0.7,
edgecolors="black",
linewidths=0.5
)
plt.xscale("log")
plt.xlabel("Learning Rate (log scale)")
plt.ylabel("Dropout Rate")
plt.title("Validation Accuracy across Learning Rate and Dropout\n(dot size ∝ validation accuracy)")
plt.grid(True, ls="--", lw=0.5)
# Identify best point
best_idx = np.argmax(val_acc)
best_lr, best_drop, best_val = lr[best_idx], drop[best_idx], val_acc[best_idx]
# Draw dashed guide lines for the best point
plt.axvline(best_lr, color="red", linestyle="--", linewidth=1)
plt.axhline(best_drop, color="red", linestyle="--", linewidth=1)
# Mark the best point with a circle
plt.scatter(best_lr, best_drop, s=2500, facecolors="none", edgecolors="red", linewidths=2)
# Add annotation
plt.text(best_lr, best_drop + 0.05,
f"Best val_acc = {best_val:.3f}\n(lr={best_lr:.0e}, dropout={best_drop:.1f})",
color="red", fontsize=10, ha="center", va="bottom",
bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"))
# Create a custom line legend handle
custom_line = Line2D([0], [0], color="red", lw=1.5, linestyle="--", label="Best configuration")
# Add legend using the custom line
plt.legend(handles=[custom_line])
plt.show()