'load_from_checkpoint fails after transfer learning a LightningModule

I try to transfer learn a LightningModule. The relevant part of the code is this:

class DeepFilteringTransferLearning(pl.LightningModule):
    def __init__(self, chk_path = None):
        super().__init__()
        
        #init class members
        self.prediction = []
        self.label = []
        self.loss = MSELoss()
        
        #init pretrained model
        self.chk_path = chk_path
        model = DeepFiltering.load_from_checkpoint(chk_path)
        backbone = model.sequential
        layers = list(backbone.children())[:-1]
        self.groundModel = Sequential(*layers)
        
        #use the pretrained modell the same way to regress Lshall and neq
        self.regressor = nn.Linear(64,2)

    def forward(self, x):
        self.groundModel.eval()
        with torch.no_grad():
            groundOut = self.groundModel(x)
        yPred = self.regressor(groundOut)
        return yPred

I save my model in a different, main file which relevant part is:

#callbacks
callbacks = [
        ModelCheckpoint(
            dirpath = "checkpoints/maxPooling16StandardizedL2RegularizedReproduceableSeeded42Ampl1ConvTransferLearned",
            save_top_k=5,
            monitor="val_loss",
        ),
    ]

#trainer
trainer = pl.Trainer(gpus=[1,2],strategy="dp",max_epochs=150,logger=wandb_logger,callbacks=callbacks,precision=32,deterministic=True)
trainer.fit(model,train_dataloaders=trainDl,val_dataloaders=valDl)

After try to load the modell from checkpoint like this:

chk_patH = "path/to/transfer_learned/model"
standardizedL2RegularizedL1 = DeepFilteringTransferLearning("path/to/model/trying/to/use/for/transfer_learning").load_from_checkpoint(chk_patH)

I got the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/torch/serialization.py in _check_seekable(f)
    307     try:
--> 308         f.seek(f.tell())
    309         return True

AttributeError: 'NoneType' object has no attribute 'seek'

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
<ipython-input-6-13f5fd0c7b85> in <module>
      1 chk_patH = "checkpoints/maxPooling16StandardizedL2RegularizedReproduceableSeeded42Ampl1/epoch=4-step=349.ckpt"
----> 2 standardizedL2RegularizedL1 = DeepFilteringTransferLearning("checkpoints/maxPooling16StandardizedL2RegularizedReproduceableSeeded42Ampl2/epoch=145-step=10219.ckpt").load_from_checkpoint(chk_patH)

~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
    154         checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
    155 
--> 156         model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
    157         return model
    158 

~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/pytorch_lightning/core/saving.py in _load_model_state(cls, checkpoint, strict, **cls_kwargs_new)
    196             _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}
    197 
--> 198         model = cls(**_cls_kwargs)
    199 
    200         # give model a chance to load something

~/whistlerProject/gitHub/whistler/mathe/gwInspired/deepFilteringTransferLearning.py in __init__(self, chk_path)
     34         #init pretrained model
     35         self.chk_path = chk_path
---> 36         model = DeepFiltering.load_from_checkpoint(chk_path)
     37         backbone = model.sequential
     38         layers = list(backbone.children())[:-1]

~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
    132                 checkpoint = pl_load(checkpoint_path, map_location=map_location)
    133             else:
--> 134                 checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
    135 
    136         if hparams_file is not None:

~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/pytorch_lightning/utilities/cloud_io.py in load(path_or_url, map_location)
     31     if not isinstance(path_or_url, (str, Path)):
     32         # any sort of BytesIO or similiar
---> 33         return torch.load(path_or_url, map_location=map_location)
     34     if str(path_or_url).startswith("http"):
     35         return torch.hub.load_state_dict_from_url(str(path_or_url), map_location=map_location)

~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
    579         pickle_load_args['encoding'] = 'utf-8'
    580 
--> 581     with _open_file_like(f, 'rb') as opened_file:
    582         if _is_zipfile(opened_file):
    583             # The zipfile reader is going to advance the current file position.

~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/torch/serialization.py in _open_file_like(name_or_buffer, mode)
    233             return _open_buffer_writer(name_or_buffer)
    234         elif 'r' in mode:
--> 235             return _open_buffer_reader(name_or_buffer)
    236         else:
    237             raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")

~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/torch/serialization.py in __init__(self, buffer)
    218     def __init__(self, buffer):
    219         super(_open_buffer_reader, self).__init__(buffer)
--> 220         _check_seekable(buffer)
    221 
    222 

~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/torch/serialization.py in _check_seekable(f)
    309         return True
    310     except (io.UnsupportedOperation, AttributeError) as e:
--> 311         raise_err_msg(["seek", "tell"], e)
    312     return False
    313 

~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/torch/serialization.py in raise_err_msg(patterns, e)
    302                                 + " Please pre-load the data into a buffer like io.BytesIO and"
    303                                 + " try to load from it instead.")
--> 304                 raise type(e)(msg)
    305         raise e
    306 

AttributeError: 'NoneType' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

which I can't resolve. I try to this according to the available tutorials on the official page of pytorch lightning here. I can't figure it out what I miss.

Could somebody point me in the right direction?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source