'load_from_checkpoint fails after transfer learning a LightningModule
I try to transfer learn a LightningModule. The relevant part of the code is this:
class DeepFilteringTransferLearning(pl.LightningModule):
def __init__(self, chk_path = None):
super().__init__()
#init class members
self.prediction = []
self.label = []
self.loss = MSELoss()
#init pretrained model
self.chk_path = chk_path
model = DeepFiltering.load_from_checkpoint(chk_path)
backbone = model.sequential
layers = list(backbone.children())[:-1]
self.groundModel = Sequential(*layers)
#use the pretrained modell the same way to regress Lshall and neq
self.regressor = nn.Linear(64,2)
def forward(self, x):
self.groundModel.eval()
with torch.no_grad():
groundOut = self.groundModel(x)
yPred = self.regressor(groundOut)
return yPred
I save my model in a different, main file which relevant part is:
#callbacks
callbacks = [
ModelCheckpoint(
dirpath = "checkpoints/maxPooling16StandardizedL2RegularizedReproduceableSeeded42Ampl1ConvTransferLearned",
save_top_k=5,
monitor="val_loss",
),
]
#trainer
trainer = pl.Trainer(gpus=[1,2],strategy="dp",max_epochs=150,logger=wandb_logger,callbacks=callbacks,precision=32,deterministic=True)
trainer.fit(model,train_dataloaders=trainDl,val_dataloaders=valDl)
After try to load the modell from checkpoint like this:
chk_patH = "path/to/transfer_learned/model"
standardizedL2RegularizedL1 = DeepFilteringTransferLearning("path/to/model/trying/to/use/for/transfer_learning").load_from_checkpoint(chk_patH)
I got the following error:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/torch/serialization.py in _check_seekable(f)
307 try:
--> 308 f.seek(f.tell())
309 return True
AttributeError: 'NoneType' object has no attribute 'seek'
During handling of the above exception, another exception occurred:
AttributeError Traceback (most recent call last)
<ipython-input-6-13f5fd0c7b85> in <module>
1 chk_patH = "checkpoints/maxPooling16StandardizedL2RegularizedReproduceableSeeded42Ampl1/epoch=4-step=349.ckpt"
----> 2 standardizedL2RegularizedL1 = DeepFilteringTransferLearning("checkpoints/maxPooling16StandardizedL2RegularizedReproduceableSeeded42Ampl2/epoch=145-step=10219.ckpt").load_from_checkpoint(chk_patH)
~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
154 checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
155
--> 156 model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
157 return model
158
~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/pytorch_lightning/core/saving.py in _load_model_state(cls, checkpoint, strict, **cls_kwargs_new)
196 _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}
197
--> 198 model = cls(**_cls_kwargs)
199
200 # give model a chance to load something
~/whistlerProject/gitHub/whistler/mathe/gwInspired/deepFilteringTransferLearning.py in __init__(self, chk_path)
34 #init pretrained model
35 self.chk_path = chk_path
---> 36 model = DeepFiltering.load_from_checkpoint(chk_path)
37 backbone = model.sequential
38 layers = list(backbone.children())[:-1]
~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
132 checkpoint = pl_load(checkpoint_path, map_location=map_location)
133 else:
--> 134 checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
135
136 if hparams_file is not None:
~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/pytorch_lightning/utilities/cloud_io.py in load(path_or_url, map_location)
31 if not isinstance(path_or_url, (str, Path)):
32 # any sort of BytesIO or similiar
---> 33 return torch.load(path_or_url, map_location=map_location)
34 if str(path_or_url).startswith("http"):
35 return torch.hub.load_state_dict_from_url(str(path_or_url), map_location=map_location)
~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
579 pickle_load_args['encoding'] = 'utf-8'
580
--> 581 with _open_file_like(f, 'rb') as opened_file:
582 if _is_zipfile(opened_file):
583 # The zipfile reader is going to advance the current file position.
~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/torch/serialization.py in _open_file_like(name_or_buffer, mode)
233 return _open_buffer_writer(name_or_buffer)
234 elif 'r' in mode:
--> 235 return _open_buffer_reader(name_or_buffer)
236 else:
237 raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/torch/serialization.py in __init__(self, buffer)
218 def __init__(self, buffer):
219 super(_open_buffer_reader, self).__init__(buffer)
--> 220 _check_seekable(buffer)
221
222
~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/torch/serialization.py in _check_seekable(f)
309 return True
310 except (io.UnsupportedOperation, AttributeError) as e:
--> 311 raise_err_msg(["seek", "tell"], e)
312 return False
313
~/anaconda3/envs/skimageTrial/lib/python3.6/site-packages/torch/serialization.py in raise_err_msg(patterns, e)
302 + " Please pre-load the data into a buffer like io.BytesIO and"
303 + " try to load from it instead.")
--> 304 raise type(e)(msg)
305 raise e
306
AttributeError: 'NoneType' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.
which I can't resolve. I try to this according to the available tutorials on the official page of pytorch lightning here. I can't figure it out what I miss.
Could somebody point me in the right direction?
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|