'Access all batch outputs at the end of epoch in callback with pytorch lightning
The documentation for the on_train_epoch_end
, https://pytorch-lightning.readthedocs.io/en/stable/extensions/callbacks.html#on-train-epoch-end, states:
To access all batch outputs at the end of the epoch, either:
- Implement training_epoch_end in the LightningModule and access outputs via the module OR
- Cache data across train batch hooks inside the callback implementation to post-process in this hook.
I am trying to use the first alternative with the following LightningModule and Callback setup:
import pytorch_lightning as pl
from pytorch_lightning import Callback
class LightningModule(pl.LightningModule):
def __init__(self, *args):
super().__init__()
self.automatic_optimization = False
def training_step(self, batch, batch_idx):
return {'batch': batch}
def training_epoch_end(self, training_step_outputs):
# training_step_outputs has all my batches
return
class MyCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
# pl_module.batch ???
return
How do I access the outputs via the pl_module
in the callback? What is the recommended way of getting access to training_step_outputs
in my callback?
Solution 1:[1]
You can store the outputs of each training batch in a state and access it at the end of the training epoch. Here is an example -
from pytorch_lightning import Callback
class MyCallback(Callback):
def __init__(self):
super().__init__()
self.state = []
def on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, unused=0):
self.state.append(outputs)
def on_train_epoch_end(self, trainer, pl_module):
# access output using state
all_outputs = self.state
Hope this helps you! ?
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Aniket Maurya |