Skip to content
Snippets Groups Projects
Commit 02b5d768 authored by Yinhao Li's avatar Yinhao Li Committed by GitHub
Browse files

[feature]: Able to use save_best option (#575)


* Add  save_best option in eval_hook.

* Update meta to fix best model can not test bug

* refactor with _do_evaluate

* remove redundent

* add meta

Co-authored-by: default avatarJiarui XU <xvjiarui0826@gmail.com>
parent 725d5aa0
No related branches found
No related tags found
No related merge requests found
import os.path as osp import os.path as osp
import torch.distributed as dist
from mmcv.runner import DistEvalHook as _DistEvalHook from mmcv.runner import DistEvalHook as _DistEvalHook
from mmcv.runner import EvalHook as _EvalHook from mmcv.runner import EvalHook as _EvalHook
from torch.nn.modules.batchnorm import _BatchNorm
class EvalHook(_EvalHook): class EvalHook(_EvalHook):
...@@ -23,33 +25,17 @@ class EvalHook(_EvalHook): ...@@ -23,33 +25,17 @@ class EvalHook(_EvalHook):
super().__init__(*args, by_epoch=by_epoch, **kwargs) super().__init__(*args, by_epoch=by_epoch, **kwargs)
self.efficient_test = efficient_test self.efficient_test = efficient_test
def after_train_iter(self, runner): def _do_evaluate(self, runner):
"""After train epoch hook. """perform evaluation and save ckpt."""
if not self._should_evaluate(runner):
Override default ``single_gpu_test``.
"""
if self.by_epoch or not self.every_n_iters(runner, self.interval):
return return
from mmseg.apis import single_gpu_test
runner.log_buffer.clear()
results = single_gpu_test(
runner.model,
self.dataloader,
show=False,
efficient_test=self.efficient_test)
self.evaluate(runner, results)
def after_train_epoch(self, runner):
"""After train epoch hook.
Override default ``single_gpu_test``.
"""
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
return
from mmseg.apis import single_gpu_test from mmseg.apis import single_gpu_test
runner.log_buffer.clear()
results = single_gpu_test(runner.model, self.dataloader, show=False) results = single_gpu_test(runner.model, self.dataloader, show=False)
self.evaluate(runner, results) runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
self._save_ckpt(runner, key_score)
class DistEvalHook(_DistEvalHook): class DistEvalHook(_DistEvalHook):
...@@ -71,39 +57,38 @@ class DistEvalHook(_DistEvalHook): ...@@ -71,39 +57,38 @@ class DistEvalHook(_DistEvalHook):
super().__init__(*args, by_epoch=by_epoch, **kwargs) super().__init__(*args, by_epoch=by_epoch, **kwargs)
self.efficient_test = efficient_test self.efficient_test = efficient_test
def after_train_iter(self, runner): def _do_evaluate(self, runner):
"""After train epoch hook. """perform evaluation and save ckpt."""
# Synchronization of BatchNorm's buffer (running_mean
Override default ``multi_gpu_test``. # and running_var) is not supported in the DDP of pytorch,
""" # which may cause the inconsistent performance of models in
if self.by_epoch or not self.every_n_iters(runner, self.interval): # different ranks, so we broadcast BatchNorm's buffers
# of rank 0 to other ranks to avoid this.
if self.broadcast_bn_buffer:
model = runner.model
for name, module in model.named_modules():
if isinstance(module,
_BatchNorm) and module.track_running_stats:
dist.broadcast(module.running_var, 0)
dist.broadcast(module.running_mean, 0)
if not self._should_evaluate(runner):
return return
from mmseg.apis import multi_gpu_test
runner.log_buffer.clear()
results = multi_gpu_test(
runner.model,
self.dataloader,
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
gpu_collect=self.gpu_collect,
efficient_test=self.efficient_test)
if runner.rank == 0:
print('\n')
self.evaluate(runner, results)
def after_train_epoch(self, runner): tmpdir = self.tmpdir
"""After train epoch hook. if tmpdir is None:
tmpdir = osp.join(runner.work_dir, '.eval_hook')
Override default ``multi_gpu_test``.
"""
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
return
from mmseg.apis import multi_gpu_test from mmseg.apis import multi_gpu_test
runner.log_buffer.clear()
results = multi_gpu_test( results = multi_gpu_test(
runner.model, runner.model,
self.dataloader, self.dataloader,
tmpdir=osp.join(runner.work_dir, '.eval_hook'), tmpdir=tmpdir,
gpu_collect=self.gpu_collect) gpu_collect=self.gpu_collect)
if runner.rank == 0: if runner.rank == 0:
print('\n') print('\n')
self.evaluate(runner, results) runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
self._save_ckpt(runner, key_score)
...@@ -122,8 +122,16 @@ def main(): ...@@ -122,8 +122,16 @@ def main():
if fp16_cfg is not None: if fp16_cfg is not None:
wrap_fp16_model(model) wrap_fp16_model(model)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
model.CLASSES = checkpoint['meta']['CLASSES'] if 'CLASSES' in checkpoint.get('meta', {}):
model.PALETTE = checkpoint['meta']['PALETTE'] model.CLASSES = checkpoint['meta']['CLASSES']
else:
print('"CLASSES" not found in meta, use dataset.CLASSES instead')
model.CLASSES = dataset.CLASSES
if 'PALETTE' in checkpoint.get('meta', {}):
model.PALETTE = checkpoint['meta']['PALETTE']
else:
print('"PALETTE" not found in meta, use dataset.PALETTE instead')
model.PALETTE = dataset.PALETTE
efficient_test = False efficient_test = False
if args.eval_options is not None: if args.eval_options is not None:
......
...@@ -149,6 +149,8 @@ def main(): ...@@ -149,6 +149,8 @@ def main():
PALETTE=datasets[0].PALETTE) PALETTE=datasets[0].PALETTE)
# add an attribute for visualization convenience # add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES model.CLASSES = datasets[0].CLASSES
# passing checkpoint meta for saving best checkpoint
meta.update(cfg.checkpoint_config.meta)
train_segmentor( train_segmentor(
model, model,
datasets, datasets,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment