Skip to content

Commit f05e48f

Browse files
committed
Merge remote-tracking branch 'origin/develop' into vs/fix_yolo_bfp16
2 parents 116503a + ead8140 commit f05e48f

73 files changed

Lines changed: 591 additions & 923 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/otx/algorithms/classification/configs/base/data/data_pipeline.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,7 @@
11
"""Data Pipeline of Class-Incr model for Classification Task."""
22

3-
# Copyright (C) 2022 Intel Corporation
4-
#
5-
# Licensed under the Apache License, Version 2.0 (the "License");
6-
# you may not use this file except in compliance with the License.
7-
# You may obtain a copy of the License at
8-
#
9-
# http://www.apache.org/licenses/LICENSE-2.0
10-
#
11-
# Unless required by applicable law or agreed to in writing,
12-
# software distributed under the License is distributed on an "AS IS" BASIS,
13-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14-
# See the License for the specific language governing permissions
15-
# and limitations under the License.
3+
# Copyright (C) 2022-2023 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
165

176
# pylint: disable=invalid-name
187

src/otx/algorithms/classification/configs/base/data/selfsl/data_pipeline.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,21 @@
11
"""Data Pipeline of Self-SL model for Classification Task."""
22

3-
# Copyright (C) 2022 Intel Corporation
4-
#
5-
# Licensed under the Apache License, Version 2.0 (the "License");
6-
# you may not use this file except in compliance with the License.
7-
# You may obtain a copy of the License at
8-
#
9-
# http://www.apache.org/licenses/LICENSE-2.0
10-
#
11-
# Unless required by applicable law or agreed to in writing,
12-
# software distributed under the License is distributed on an "AS IS" BASIS,
13-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14-
# See the License for the specific language governing permissions
15-
# and limitations under the License.
3+
# Copyright (C) 2022-2023 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
165

176
# pylint: disable=invalid-name
187

198
__img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
209
__img_size = 224
2110

2211
__train_pipeline_v0 = [
23-
dict(type="LoadImageFromOTXDataset"),
12+
dict(
13+
type="LoadResizeDataFromOTXDataset",
14+
resize_cfg=dict(type="Resize", size=__img_size, downscale_only=True),
15+
# To be resized in this op only if input is larger than expected size
16+
# for speed & cache memory efficiency.
17+
enable_memcache=True, # Cache after resizing image
18+
),
2419
dict(type="RandomResizedCrop", size=__img_size),
2520
dict(type="RandomFlip"),
2621
dict(
@@ -37,7 +32,13 @@
3732
dict(type="Collect", keys=["img"]),
3833
]
3934
__train_pipeline_v1 = [
40-
dict(type="LoadImageFromOTXDataset"),
35+
dict(
36+
type="LoadResizeDataFromOTXDataset",
37+
resize_cfg=dict(type="Resize", size=__img_size, downscale_only=True),
38+
# To be resized in this op only if input is larger than expected size
39+
# for speed & cache memory efficiency.
40+
enable_memcache=True, # Cache after resizing image
41+
),
4142
dict(type="RandomResizedCrop", size=__img_size),
4243
dict(type="RandomFlip"),
4344
dict(

src/otx/algorithms/classification/configs/base/data/semisl/data_pipeline.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818
__resize_target_size = 224
1919

2020
__common_pipeline = [
21-
dict(type="LoadImageFromOTXDataset"),
22-
dict(type="Resize", size=__resize_target_size),
21+
dict(
22+
type="LoadResizeDataFromOTXDataset",
23+
resize_cfg=dict(type="Resize", size=__resize_target_size, downscale_only=False),
24+
enable_memcache=True, # Cache after resizing image
25+
),
2326
dict(type="RandomFlip", flip_prob=0.5, direction="horizontal"),
2427
dict(type="AugMixAugment", config_str="augmix-m5-w3"),
2528
dict(type="RandomRotate", p=0.35, angle=(-10, 10)),
@@ -48,6 +51,17 @@
4851
dict(type="Collect", keys=["img", "img_strong"]),
4952
]
5053

54+
__val_pipeline = [
55+
dict(
56+
type="LoadResizeDataFromOTXDataset",
57+
resize_cfg=dict(type="Resize", size=__resize_target_size, downscale_only=False),
58+
enable_memcache=True, # Cache after resizing image
59+
),
60+
dict(type="Normalize", **__img_norm_cfg),
61+
dict(type="ImageToTensor", keys=["img"]),
62+
dict(type="Collect", keys=["img"]),
63+
]
64+
5165
__test_pipeline = [
5266
dict(type="LoadImageFromOTXDataset"),
5367
dict(type="Resize", size=__resize_target_size),
@@ -64,6 +78,6 @@
6478
type=__dataset_type,
6579
pipeline=__unlabeled_pipeline,
6680
),
67-
val=dict(type=__dataset_type, test_mode=True, pipeline=__test_pipeline),
81+
val=dict(type=__dataset_type, test_mode=True, pipeline=__val_pipeline),
6882
test=dict(type=__dataset_type, test_mode=True, pipeline=__test_pipeline),
6983
)
Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,7 @@
11
"""Data Pipeline of SupCon model for Classification Task."""
22

33
# Copyright (C) 2023 Intel Corporation
4-
#
5-
# Licensed under the Apache License, Version 2.0 (the "License");
6-
# you may not use this file except in compliance with the License.
7-
# You may obtain a copy of the License at
8-
#
9-
# http://www.apache.org/licenses/LICENSE-2.0
10-
#
11-
# Unless required by applicable law or agreed to in writing,
12-
# software distributed under the License is distributed on an "AS IS" BASIS,
13-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14-
# See the License for the specific language governing permissions
15-
# and limitations under the License.
4+
# SPDX-License-Identifier: Apache-2.0
165

176
# pylint: disable=invalid-name
187

@@ -21,11 +10,14 @@
2110

2211

2312
__train_pipeline = [
24-
dict(type="LoadImageFromOTXDataset"),
13+
dict(
14+
type="LoadResizeDataFromOTXDataset",
15+
resize_cfg=dict(type="Resize", size=__resize_target_size, downscale_only=False),
16+
enable_memcache=True, # Cache after resizing image
17+
),
2518
dict(
2619
type="TwoCropTransform",
2720
pipeline=[
28-
dict(type="Resize", size=__resize_target_size),
2921
dict(type="RandomFlip", flip_prob=0.5, direction="horizontal"),
3022
dict(type="AugMixAugment", config_str="augmix-m5-w3"),
3123
dict(type="RandomRotate", p=0.35, angle=(-10, 10)),
@@ -38,6 +30,17 @@
3830
),
3931
]
4032

33+
__val_pipeline = [
34+
dict(
35+
type="LoadResizeDataFromOTXDataset",
36+
resize_cfg=dict(type="Resize", size=__resize_target_size, downscale_only=False),
37+
enable_memcache=True, # Cache after resizing image
38+
),
39+
dict(type="Normalize", **__img_norm_cfg),
40+
dict(type="ImageToTensor", keys=["img"]),
41+
dict(type="Collect", keys=["img"]),
42+
]
43+
4144
__test_pipeline = [
4245
dict(type="LoadImageFromOTXDataset"),
4346
dict(type="Resize", size=__resize_target_size),
@@ -50,6 +53,6 @@
5053

5154
data = dict(
5255
train=dict(type=__dataset_type, pipeline=__train_pipeline),
53-
val=dict(type=__dataset_type, test_mode=True, pipeline=__test_pipeline),
56+
val=dict(type=__dataset_type, test_mode=True, pipeline=__val_pipeline),
5457
test=dict(type=__dataset_type, test_mode=True, pipeline=__test_pipeline),
5558
)

src/otx/algorithms/common/adapters/mmcv/hooks/dual_model_ema_hook.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,17 @@ def before_run(self, runner):
7171
self.src_model = getattr(model, self.src_model_name, None)
7272
self.dst_model = getattr(model, self.dst_model_name, None)
7373
if self.src_model and self.dst_model:
74+
self.enabled = True
7475
self.src_params = self.src_model.state_dict(keep_vars=True)
7576
self.dst_params = self.dst_model.state_dict(keep_vars=True)
76-
77-
def before_train_epoch(self, runner):
78-
"""Momentum update."""
79-
if runner.epoch + 1 == self.start_epoch:
80-
self._copy_model()
81-
self.enabled = True
82-
83-
if self.epoch_momentum > 0.0 and self.enabled:
77+
if runner.epoch == 0 and runner.iter == 0:
78+
self._copy_model(sync_model=True)
79+
logger.info("Initialized student model by teacher model")
80+
logger.info(f"model_s model_t diff: {self._diff_model()}")
81+
82+
def before_epoch(self, runner):
83+
"""Compute adaptive EMA momentum."""
84+
if self.epoch_momentum > 0.0:
8485
iter_per_epoch = len(runner.data_loader)
8586
epoch_decay = 1 - self.epoch_momentum
8687
iter_decay = math.pow(epoch_decay, self.interval / iter_per_epoch)
@@ -91,6 +92,12 @@ def before_train_epoch(self, runner):
9192
def after_train_iter(self, runner):
9293
"""Update ema parameter every self.interval iterations."""
9394
if not self.enabled or (runner.iter % self.interval != 0):
95+
# Skip update
96+
return
97+
98+
if runner.epoch + 1 < self.start_epoch:
99+
# Just copy parameters before start epoch
100+
self._copy_model()
94101
return
95102

96103
# EMA
@@ -107,12 +114,15 @@ def _get_model(self, runner):
107114
model = model.module
108115
return model
109116

110-
def _copy_model(self):
117+
def _copy_model(self, sync_model=False):
111118
with torch.no_grad():
112119
for name, src_param in self.src_params.items():
113120
if not name.startswith("ema_"):
114121
dst_param = self.dst_params[name]
115-
dst_param.data.copy_(src_param.data)
122+
if sync_model:
123+
src_param.data.copy_(dst_param.data)
124+
else:
125+
dst_param.data.copy_(src_param.data)
116126

117127
def _ema_model(self):
118128
momentum = min(self.momentum, 1.0)

src/otx/algorithms/common/adapters/mmcv/hooks/mean_teacher_hook.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def before_train_epoch(self, runner):
3434
logger.info(f"avr_ps_ratio: {average_pseudo_label_ratio}")
3535
self._get_model(runner).enable_unlabeled_loss(True)
3636
self.unlabeled_loss_enabled = True
37-
logger.info("---------- Enabled unlabeled loss and EMA smoothing")
37+
logger.info("---------- Enabled unlabeled loss and EMA smoothing ----------")
3838

3939
def after_train_iter(self, runner):
4040
"""Update ema parameter every self.interval iterations."""
@@ -44,6 +44,7 @@ def after_train_iter(self, runner):
4444

4545
if runner.epoch + 1 < self.start_epoch or self.unlabeled_loss_enabled is False:
4646
# Just copy parameters before enabled
47+
self._copy_model()
4748
return
4849

4950
# EMA

0 commit comments

Comments
 (0)