---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[12], line 3
1 #| notest
2 tuner = Tuner(trainer)
----> 3 lr_finder = tuner.lr_find(
4 model,
5 datamodule=dm,
6 min_lr=1e-6,
7 max_lr=1.0,
8 num_training=100, # number of iterations
9 # attr_name="optimizer.lr",
10 )
11 fig = lr_finder.plot(suggest=True)
12 plt.show()
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/tuner/tuning.py:180, in Tuner.lr_find(self, model, train_dataloaders, val_dataloaders, dataloaders, datamodule, method, min_lr, max_lr, num_training, mode, early_stop_threshold, update_attr, attr_name)
177 lr_finder_callback._early_exit = True
178 self._trainer.callbacks = [lr_finder_callback] + self._trainer.callbacks
--> 180 self._trainer.fit(model, train_dataloaders, val_dataloaders, datamodule)
182 self._trainer.callbacks = [cb for cb in self._trainer.callbacks if cb is not lr_finder_callback]
184 return lr_finder_callback.optimal_lr
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:539, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
537 self.state.status = TrainerStatus.RUNNING
538 self.training = True
--> 539 call._call_and_handle_interrupt(
540 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
541 )
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
45 if trainer.strategy.launcher is not None:
46 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47 return trainer_fn(*args, **kwargs)
49 except _TunerExitException:
50 _call_teardown_hook(trainer)
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:575, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
568 assert self.state.fn is not None
569 ckpt_path = self._checkpoint_connector._select_ckpt_path(
570 self.state.fn,
571 ckpt_path,
572 model_provided=True,
573 model_connected=self.lightning_module is not None,
574 )
--> 575 self._run(model, ckpt_path=ckpt_path)
577 assert self.state.stopped
578 self.training = False
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:962, in Trainer._run(self, model, ckpt_path)
960 # hook
961 if self.state.fn == TrainerFn.FITTING:
--> 962 call._call_callback_hooks(self, "on_fit_start")
963 call._call_lightning_module_hook(self, "on_fit_start")
965 _log_hyperparams(self)
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:222, in _call_callback_hooks(trainer, hook_name, monitoring_callbacks, *args, **kwargs)
220 if callable(fn):
221 with trainer.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"):
--> 222 fn(trainer, trainer.lightning_module, *args, **kwargs)
224 if pl_module:
225 # restore current_fx when nested context
226 pl_module._current_fx_name = prev_fx_name
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/callbacks/lr_finder.py:130, in LearningRateFinder.on_fit_start(self, trainer, pl_module)
128 @override
129 def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
--> 130 self.lr_find(trainer, pl_module)
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/callbacks/lr_finder.py:113, in LearningRateFinder.lr_find(self, trainer, pl_module)
111 def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
112 with isolate_rng():
--> 113 self.optimal_lr = _lr_find(
114 trainer,
115 pl_module,
116 min_lr=self._min_lr,
117 max_lr=self._max_lr,
118 num_training=self._num_training_steps,
119 mode=self._mode,
120 early_stop_threshold=self._early_stop_threshold,
121 update_attr=self._update_attr,
122 attr_name=self._attr_name,
123 )
125 if self._early_exit:
126 raise _TunerExitException()
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/tuner/lr_finder.py:278, in _lr_find(trainer, model, min_lr, max_lr, num_training, mode, early_stop_threshold, update_attr, attr_name)
275 lr_finder._exchange_scheduler(trainer)
277 # Fit, lr & loss logged in callback
--> 278 _try_loop_run(trainer, params)
280 # Prompt if we stopped early
281 if trainer.global_step != num_training + start_steps:
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/tuner/lr_finder.py:523, in _try_loop_run(trainer, params)
521 loop.load_state_dict(deepcopy(params["loop_state_dict"]))
522 loop.restarting = False
--> 523 loop.run()
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:216, in _FitLoop.run(self)
214 try:
215 self.on_advance_start()
--> 216 self.advance()
217 self.on_advance_end()
218 except StopIteration:
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:455, in _FitLoop.advance(self)
453 with self.trainer.profiler.profile("run_training_epoch"):
454 assert self._data_fetcher is not None
--> 455 self.epoch_loop.run(self._data_fetcher)
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py:150, in _TrainingEpochLoop.run(self, data_fetcher)
148 while not self.done:
149 try:
--> 150 self.advance(data_fetcher)
151 self.on_advance_end(data_fetcher)
152 except StopIteration:
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py:322, in _TrainingEpochLoop.advance(self, data_fetcher)
320 batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
321 else:
--> 322 batch_output = self.manual_optimization.run(kwargs)
324 self.batch_progress.increment_processed()
326 # update non-plateau LR schedulers
327 # update epoch-interval ones only when we are at the end of training epoch
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/manual.py:94, in _ManualOptimization.run(self, kwargs)
92 self.on_run_start()
93 with suppress(StopIteration): # no loop to break at this level
---> 94 self.advance(kwargs)
95 self._restarting = False
96 return self.on_run_end()
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/manual.py:114, in _ManualOptimization.advance(self, kwargs)
111 trainer = self.trainer
113 # manually capture logged metrics
--> 114 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
115 del kwargs # release the batch from memory
116 self.trainer.strategy.post_training_step() # unused hook - call anyway for backward compatibility
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:323, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
320 return None
322 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 323 output = fn(*args, **kwargs)
325 # restore current_fx when nested context
326 pl_module._current_fx_name = prev_fx_name
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py:391, in Strategy.training_step(self, *args, **kwargs)
389 if self.model != self.lightning_module:
390 return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
--> 391 return self.lightning_module.training_step(*args, **kwargs)
File ~/Projects/nimrod/nimrod/models/core.py:186, in Classifier.training_step(self, batch, batch_idx)
183 sched.step() #reduce plateau sched is updated at end of epoch only instead TODO: should it be applied to val loop by default?
185 self.train_loss(loss)
--> 186 self.train_acc(preds, y)
187 metrics = {"train/loss": self.train_loss, "train/acc": self.train_acc}
188 self.log_dict(metrics, on_epoch=True, on_step=True, prog_bar=True)# Pass the validation loss to the scheduler
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/torchmetrics/metric.py:316, in Metric.forward(self, *args, **kwargs)
314 self._forward_cache = self._forward_full_state_update(*args, **kwargs)
315 else:
--> 316 self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
318 return self._forward_cache
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/torchmetrics/metric.py:385, in Metric._forward_reduce_state_update(self, *args, **kwargs)
382 self._enable_grad = True # allow grads for batch computation
384 # calculate batch state and compute batch value
--> 385 self.update(*args, **kwargs)
386 batch_val = self.compute()
388 # reduce batch and global state
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/torchmetrics/metric.py:560, in Metric._wrap_update.<locals>.wrapped_func(*args, **kwargs)
552 if "Expected all tensors to be on" in str(err):
553 raise RuntimeError(
554 "Encountered different devices in metric calculation (see stacktrace for details)."
555 " This could be due to the metric class not being on the same device as input."
(...)
558 " device corresponds to the device of the input."
559 ) from err
--> 560 raise err
562 if self.compute_on_cpu:
563 self._move_list_states_to_cpu()
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/torchmetrics/metric.py:550, in Metric._wrap_update.<locals>.wrapped_func(*args, **kwargs)
548 with torch.set_grad_enabled(self._enable_grad):
549 try:
--> 550 update(*args, **kwargs)
551 except RuntimeError as err:
552 if "Expected all tensors to be on" in str(err):
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/torchmetrics/classification/stat_scores.py:339, in MulticlassStatScores.update(self, preds, target)
337 """Update state with predictions and targets."""
338 if self.validate_args:
--> 339 _multiclass_stat_scores_tensor_validation(
340 preds, target, self.num_classes, self.multidim_average, self.ignore_index
341 )
342 preds, target = _multiclass_stat_scores_format(preds, target, self.top_k)
343 tp, fp, tn, fn = _multiclass_stat_scores_update(
344 preds, target, self.num_classes, self.top_k, self.average, self.multidim_average, self.ignore_index
345 )
File ~/miniconda3/envs/nimrod/lib/python3.11/site-packages/torchmetrics/functional/classification/stat_scores.py:318, in _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index)
316 num_unique_values = len(torch.unique(t, dim=None))
317 if num_unique_values > check_value:
--> 318 raise RuntimeError(
319 f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found"
320 f" {num_unique_values} in `{name}`. Found values: {torch.unique(t, dim=None)}."
321 )
RuntimeError: Detected more unique values in `preds` than expected. Expected only 10 but found 30 in `preds`. Found values: tensor([ 0, 1, 3, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 19, 20, 21, 22,
23, 25, 27, 28, 29, 30, 31, 32, 33, 34, 38, 39], device='cuda:0').