'How to use gradient accumulation in detectron2
def do_train(cfg, model, resume=False):
model.train()
optimizer = my_build_optimizer(cfg, model)
scheduler = build_lr_scheduler(cfg, optimizer)
checkpointer = DetectionCheckpointer(
model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler
)
start_iter = (
checkpointer.resume_or_load(
cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
)
max_iter = cfg.SOLVER.MAX_ITER
periodic_checkpointer = PeriodicCheckpointer(
checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
)
writers = (
[
CommonMetricPrinter(max_iter),
JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
TensorboardXWriter(cfg.OUTPUT_DIR),
]
if comm.is_main_process()
else []
)
# compared to "train_net.py", we do not support accurate timing and
# precise BN here, because they are not trivial to implement in a small training loop
data_loader = build_detection_train_loader(cfg)
logger.info("Starting training from iteration {}".format(start_iter))
with EventStorage(start_iter) as storage:
for data, iteration in zip(data_loader, range(start_iter, max_iter)):
storage.iter = iteration
loss_dict = model(data)
losses = sum(loss_dict.values())
assert torch.isfinite(losses).all(), loss_dict
loss_dict_reduced = {k: v.item()
for k, v in comm.reduce_dict(loss_dict).items()}
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
if comm.is_main_process():
storage.put_scalars(
total_loss=losses_reduced, **loss_dict_reduced)
optimizer.zero_grad()
losses.backward()
if cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" and cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE > 0.0:
norm = torch.nn.utils.clip_grad_norm_(
teacher.parameters(), cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE)
storage.put_scalar('log_grad_norm', norm)
optimizer.step()
storage.put_scalar(
"lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
scheduler.step()
if (
cfg.TEST.EVAL_PERIOD > 0
and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0
and iteration != max_iter - 1
):
do_test(cfg, model)
# Compared to "train_net.py", the test results are not dumped to EventStorage
comm.synchronize()
if iteration - start_iter > 5 and (
(iteration + 1) % 20 == 0 or iteration == max_iter - 1
):
for writer in writers:
writer.write()
periodic_checkpointer.step(iteration)
strong text I tried to modify the above code, but I can't achieve the normal effect. I would like to know how to use gradient accumulation in detectron2. I learned about gradient accumulation in the link below, but I can't get it to work well in detectron2. https://cowarder.site/2019/10/29/Gradient-Accumulation/
Solution 1:[1]
You should consider modifying grad_zero() to enable gradient accumulation. This is more of a pytorch related problem and so you probably should check out pytorch documentation for such implementation.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Michael Koo |