Helper class that provides a standard way to create an ABC using inheritance.
LR Finder Helper
# use LRFinder pythonm module (other version with lightning)def find_optimal_lr(model, train_loader, criterion=None, optimizer=None, device='cuda'):# If no criterion provided, use default CrossEntropyLossif criterion isNone: criterion = nn.CrossEntropyLoss()# If no optimizer provided, use Adamif optimizer isNone: optimizer = torch.optim.Adam(model.parameters(), lr=1e-7, weight_decay=1e-2)# Initialize LR Finder lr_finder = LRFinder(model, optimizer, criterion, device=device)# Run LR range test lr_finder.range_test( train_loader, start_lr=1e-7, # Very small starting learning rate end_lr=10, # Large ending learning rate num_iter=100, # Number of iterations to test smooth_f=0.05# Smoothing factor for the loss )# Plot the learning rate vs loss lr_finder.plot(log_lr=True)# Suggest optimal learning rate suggested_lr = lr_finder.reset()print(f"Suggested Learning Rate: {suggested_lr}")return suggested_lr
[21:54:44] INFO - Init ImageDataModule for fashion_mnist
[21:54:46] INFO - loading dataset fashion_mnist with args () from split train
[21:54:46] INFO - loading dataset fashion_mnist from split train
Overwrite dataset info from restored data version if exists.
[21:54:48] INFO - Overwrite dataset info from restored data version if exists.
Loading Dataset info from ../data/image/fashion_mnist/fashion_mnist/0.0.0/531be5e2ccc9dba0c201ad3ae567a4f3d16ecdd2
[21:54:48] INFO - Loading Dataset info from ../data/image/fashion_mnist/fashion_mnist/0.0.0/531be5e2ccc9dba0c201ad3ae567a4f3d16ecdd2
Found cached dataset fashion_mnist (/user/s/slegroux/Projects/nimrod/nbs/../data/image/fashion_mnist/fashion_mnist/0.0.0/531be5e2ccc9dba0c201ad3ae567a4f3d16ecdd2)
[21:54:48] INFO - Found cached dataset fashion_mnist (/user/s/slegroux/Projects/nimrod/nbs/../data/image/fashion_mnist/fashion_mnist/0.0.0/531be5e2ccc9dba0c201ad3ae567a4f3d16ecdd2)
Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/fashion_mnist/fashion_mnist/0.0.0/531be5e2ccc9dba0c201ad3ae567a4f3d16ecdd2
[21:54:48] INFO - Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/fashion_mnist/fashion_mnist/0.0.0/531be5e2ccc9dba0c201ad3ae567a4f3d16ecdd2
[21:54:52] INFO - loading dataset fashion_mnist with args () from split test
[21:54:52] INFO - loading dataset fashion_mnist from split test
Overwrite dataset info from restored data version if exists.
[21:54:53] INFO - Overwrite dataset info from restored data version if exists.
Loading Dataset info from ../data/image/fashion_mnist/fashion_mnist/0.0.0/531be5e2ccc9dba0c201ad3ae567a4f3d16ecdd2
[21:54:53] INFO - Loading Dataset info from ../data/image/fashion_mnist/fashion_mnist/0.0.0/531be5e2ccc9dba0c201ad3ae567a4f3d16ecdd2
Found cached dataset fashion_mnist (/user/s/slegroux/Projects/nimrod/nbs/../data/image/fashion_mnist/fashion_mnist/0.0.0/531be5e2ccc9dba0c201ad3ae567a4f3d16ecdd2)
[21:54:53] INFO - Found cached dataset fashion_mnist (/user/s/slegroux/Projects/nimrod/nbs/../data/image/fashion_mnist/fashion_mnist/0.0.0/531be5e2ccc9dba0c201ad3ae567a4f3d16ecdd2)
Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/fashion_mnist/fashion_mnist/0.0.0/531be5e2ccc9dba0c201ad3ae567a4f3d16ecdd2
[21:54:53] INFO - Loading Dataset info from /user/s/slegroux/Projects/nimrod/nbs/../data/image/fashion_mnist/fashion_mnist/0.0.0/531be5e2ccc9dba0c201ad3ae567a4f3d16ecdd2
[21:54:53] INFO - split train into train/val [0.8, 0.2]
[21:54:53] INFO - train: 48000 val: 12000, test: 10000
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[21:54:54] INFO - ConvNetX: init
[21:54:54] INFO - Classifier: init
/user/s/slegroux/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'nnet' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['nnet'])`.
/user/s/slegroux/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /user/s/slegroux/Projects/nimrod/nbs/checkpoints/FASHION-MNIST-Classifier/ConvNetX-bs:128-epochs:1 exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[21:54:54] INFO - Optimizer: <class 'torch.optim.adamw.AdamW'>
[21:54:54] INFO - Scheduler: <class 'torch.optim.lr_scheduler.OneCycleLR'>
| Name | Type | Params | Mode
------------------------------------------------------------
0 | nnet | ConvNet | 110 K | train
1 | loss | CrossEntropyLoss | 0 | train
2 | train_acc | MulticlassAccuracy | 0 | train
3 | val_acc | MulticlassAccuracy | 0 | train
4 | test_acc | MulticlassAccuracy | 0 | train
5 | train_loss | MeanMetric | 0 | train
6 | val_loss | MeanMetric | 0 | train
7 | test_loss | MeanMetric | 0 | train
8 | val_acc_best | MaxMetric | 0 | train
------------------------------------------------------------
110 K Trainable params
0 Non-trainable params
110 K Total params
0.440 Total estimated model params size (MB)
41 Modules in train mode
0 Modules in eval mode
/user/s/slegroux/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
/user/s/slegroux/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=1` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/user/s/slegroux/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.