diff --git a/mmseg/models/backbones/hrnet.py b/mmseg/models/backbones/hrnet.py
index 055fc985bb1d1841e40b5a9f0337b7a2fe2495e4..0f064cff7da23244417423a5e70b6ca81ba17067 100644
--- a/mmseg/models/backbones/hrnet.py
+++ b/mmseg/models/backbones/hrnet.py
@@ -230,6 +230,8 @@ class HRNet(BaseModule):
             and its variants only.
         with_cp (bool): Use checkpoint or not. Using checkpoint will save some
             memory while slowing down the training speed.
+        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+            -1 means not freezing any parameters. Default: -1.
         zero_init_residual (bool): whether to use zero init for last norm layer
             in resblocks to let them behave as identity.
         pretrained (str, optional): model pretrained path. Default: None
@@ -285,6 +287,7 @@ class HRNet(BaseModule):
                  norm_cfg=dict(type='BN', requires_grad=True),
                  norm_eval=False,
                  with_cp=False,
+                 frozen_stages=-1,
                  zero_init_residual=False,
                  pretrained=None,
                  init_cfg=None):
@@ -315,6 +318,7 @@ class HRNet(BaseModule):
         self.norm_cfg = norm_cfg
         self.norm_eval = norm_eval
         self.with_cp = with_cp
+        self.frozen_stages = frozen_stages
 
         # stem net
         self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
@@ -388,6 +392,8 @@ class HRNet(BaseModule):
         self.stage4, pre_stage_channels = self._make_stage(
             self.stage4_cfg, num_channels)
 
+        self._freeze_stages()
+
     @property
     def norm1(self):
         """nn.Module: the normalization layer named "norm1" """
@@ -534,6 +540,32 @@ class HRNet(BaseModule):
 
         return Sequential(*hr_modules), in_channels
 
+    def _freeze_stages(self):
+        """Freeze stages param and norm stats."""
+        if self.frozen_stages >= 0:
+
+            self.norm1.eval()
+            self.norm2.eval()
+            for m in [self.conv1, self.norm1, self.conv2, self.norm2]:
+                for param in m.parameters():
+                    param.requires_grad = False
+
+        for i in range(1, self.frozen_stages + 1):
+            if i == 1:
+                m = getattr(self, f'layer{i}')
+                t = getattr(self, f'transition{i}')
+            elif i == 4:
+                m = getattr(self, f'stage{i}')
+            else:
+                m = getattr(self, f'stage{i}')
+                t = getattr(self, f'transition{i}')
+            m.eval()
+            for param in m.parameters():
+                param.requires_grad = False
+            t.eval()
+            for param in t.parameters():
+                param.requires_grad = False
+
     def forward(self, x):
         """Forward function."""
 
@@ -575,6 +607,7 @@ class HRNet(BaseModule):
         """Convert the model into training mode will keeping the normalization
         layer freezed."""
         super(HRNet, self).train(mode)
+        self._freeze_stages()
         if mode and self.norm_eval:
             for m in self.modules():
                 # trick: eval have effect on BatchNorm only
diff --git a/tests/test_models/test_backbones/test_hrnet.py b/tests/test_models/test_backbones/test_hrnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..81611a0d115edfbd8e63f0c29e3fd5f505997989
--- /dev/null
+++ b/tests/test_models/test_backbones/test_hrnet.py
@@ -0,0 +1,63 @@
+from mmcv.utils.parrots_wrapper import _BatchNorm
+
+from mmseg.models.backbones import HRNet
+
+
+def test_hrnet_backbone():
+    # Test HRNET with two stage frozen
+
+    extra = dict(
+        stage1=dict(
+            num_modules=1,
+            num_branches=1,
+            block='BOTTLENECK',
+            num_blocks=(4, ),
+            num_channels=(64, )),
+        stage2=dict(
+            num_modules=1,
+            num_branches=2,
+            block='BASIC',
+            num_blocks=(4, 4),
+            num_channels=(32, 64)),
+        stage3=dict(
+            num_modules=4,
+            num_branches=3,
+            block='BASIC',
+            num_blocks=(4, 4, 4),
+            num_channels=(32, 64, 128)),
+        stage4=dict(
+            num_modules=3,
+            num_branches=4,
+            block='BASIC',
+            num_blocks=(4, 4, 4, 4),
+            num_channels=(32, 64, 128, 256)))
+    frozen_stages = 2
+    model = HRNet(extra, frozen_stages=frozen_stages)
+    model.init_weights()
+    model.train()
+    assert model.norm1.training is False
+
+    for layer in [model.conv1, model.norm1]:
+        for param in layer.parameters():
+            assert param.requires_grad is False
+    for i in range(1, frozen_stages + 1):
+        if i == 1:
+            layer = getattr(model, f'layer{i}')
+            transition = getattr(model, f'transition{i}')
+        elif i == 4:
+            layer = getattr(model, f'stage{i}')
+        else:
+            layer = getattr(model, f'stage{i}')
+            transition = getattr(model, f'transition{i}')
+
+        for mod in layer.modules():
+            if isinstance(mod, _BatchNorm):
+                assert mod.training is False
+        for param in layer.parameters():
+            assert param.requires_grad is False
+
+        for mod in transition.modules():
+            if isinstance(mod, _BatchNorm):
+                assert mod.training is False
+        for param in transition.parameters():
+            assert param.requires_grad is False