-
Notifications
You must be signed in to change notification settings - Fork 465
Enable DINO to OTX - Step 1. Enable Deformable DETR to OTX #2249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
43677e2
Add Deformable DETR as a experimental model
jaegukhyun 76a0b0d
Change CHANGELOG.md
jaegukhyun d189af3
Add unit test
jaegukhyun 32e0bf8
Add intg test
jaegukhyun 6559c5a
Revert e2e test changes
jaegukhyun 7b51b56
Reflect Reviews
jaegukhyun File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| """Initial file for mmcv ops.""" | ||
| # Copyright (C) 2023 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
|
|
||
| from .multi_scale_deformable_attn_pytorch import multi_scale_deformable_attn_pytorch | ||
|
|
||
| __all__ = ["multi_scale_deformable_attn_pytorch"] |
140 changes: 140 additions & 0 deletions
140
otx/algorithms/common/adapters/mmcv/ops/multi_scale_deformable_attn_pytorch.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| """Custom patch of multi_scale_deformable_attn_pytorch for openvino export.""" | ||
|
|
||
| # Copyright (C) 2023 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from mmcv.ops import multi_scale_deform_attn | ||
|
|
||
|
|
||
| def multi_scale_deformable_attn_pytorch( | ||
| value: torch.Tensor, | ||
| value_spatial_shapes: torch.Tensor, | ||
| sampling_locations: torch.Tensor, | ||
| attention_weights: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """Custom patch for multi_scale_deformable_attn_pytorch function. | ||
|
|
||
| Original implementation in mmcv.ops use torch.nn.functional.grid_sample. | ||
| It raises errors during inference with OpenVINO exported model. | ||
| Therefore this function change grid_sample function to _custom_grid_sample function. | ||
| """ | ||
|
|
||
| bs, _, num_heads, embed_dims = value.shape | ||
| _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape | ||
| value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) | ||
| sampling_grids = 2 * sampling_locations - 1 | ||
| sampling_value_list = [] | ||
| for level, (H_, W_) in enumerate(value_spatial_shapes): | ||
| # bs, H_*W_, num_heads, embed_dims -> | ||
| # bs, H_*W_, num_heads*embed_dims -> | ||
| # bs, num_heads*embed_dims, H_*W_ -> | ||
| # bs*num_heads, embed_dims, H_, W_ | ||
| value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_) | ||
| # bs, num_queries, num_heads, num_points, 2 -> | ||
| # bs, num_heads, num_queries, num_points, 2 -> | ||
| # bs*num_heads, num_queries, num_points, 2 | ||
| sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) | ||
| # bs*num_heads, embed_dims, num_queries, num_points | ||
| sampling_value_l_ = _custom_grid_sample( | ||
| value_l_, | ||
| sampling_grid_l_, | ||
| # mode='bilinear', | ||
| # padding_mode='zeros', | ||
| align_corners=False, | ||
| ) | ||
| sampling_value_list.append(sampling_value_l_) | ||
| # (bs, num_queries, num_heads, num_levels, num_points) -> | ||
| # (bs, num_heads, num_queries, num_levels, num_points) -> | ||
| # (bs, num_heads, 1, num_queries, num_levels*num_points) | ||
| attention_weights = attention_weights.transpose(1, 2).reshape( | ||
| bs * num_heads, 1, num_queries, num_levels * num_points | ||
| ) | ||
| output = ( | ||
| (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) | ||
| .sum(-1) | ||
| .view(bs, num_heads * embed_dims, num_queries) | ||
| ) | ||
| return output.transpose(1, 2).contiguous() | ||
|
|
||
|
|
||
| def _custom_grid_sample(im: torch.Tensor, grid: torch.Tensor, align_corners: bool = False) -> torch.Tensor: | ||
| """Custom patch for mmcv.ops.point_sample.bilinear_grid_sample. | ||
|
|
||
| This function is almost same with mmcv.ops.point_sample.bilinear_grid_sample. | ||
| The only difference is this function use reshape instead of view. | ||
|
|
||
| Args: | ||
| im (torch.Tensor): Input feature map, shape (N, C, H, W) | ||
| grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2) | ||
| align_corners (bool): If set to True, the extrema (-1 and 1) are | ||
| considered as referring to the center points of the input’s | ||
| corner pixels. If set to False, they are instead considered as | ||
| referring to the corner points of the input’s corner pixels, | ||
| making the sampling more resolution agnostic. | ||
|
|
||
| Returns: | ||
| torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg) | ||
| """ | ||
| n, c, h, w = im.shape | ||
| gn, gh, gw, _ = grid.shape | ||
| assert n == gn | ||
|
|
||
| x = grid[:, :, :, 0] | ||
| y = grid[:, :, :, 1] | ||
|
|
||
| if align_corners: | ||
| x = ((x + 1) / 2) * (w - 1) | ||
| y = ((y + 1) / 2) * (h - 1) | ||
| else: | ||
| x = ((x + 1) * w - 1) / 2 | ||
| y = ((y + 1) * h - 1) / 2 | ||
|
|
||
| x = x.reshape(n, -1) | ||
| y = y.reshape(n, -1) | ||
|
|
||
| x0 = torch.floor(x).long() | ||
| y0 = torch.floor(y).long() | ||
| x1 = x0 + 1 | ||
| y1 = y0 + 1 | ||
|
|
||
| wa = ((x1 - x) * (y1 - y)).unsqueeze(1) | ||
| wb = ((x1 - x) * (y - y0)).unsqueeze(1) | ||
| wc = ((x - x0) * (y1 - y)).unsqueeze(1) | ||
| wd = ((x - x0) * (y - y0)).unsqueeze(1) | ||
|
|
||
| # Apply default for grid_sample function zero padding | ||
| im_padded = F.pad(im, pad=[1, 1, 1, 1], mode="constant", value=0) | ||
| padded_h = h + 2 | ||
| padded_w = w + 2 | ||
| # save points positions after padding | ||
| x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1 | ||
|
|
||
| # Clip coordinates to padded image size | ||
| x0 = torch.where(x0 < 0, torch.tensor(0), x0) | ||
| x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0) | ||
| x1 = torch.where(x1 < 0, torch.tensor(0), x1) | ||
| x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1) | ||
| y0 = torch.where(y0 < 0, torch.tensor(0), y0) | ||
| y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0) | ||
| y1 = torch.where(y1 < 0, torch.tensor(0), y1) | ||
| y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1) | ||
|
|
||
| im_padded = im_padded.view(n, c, -1) | ||
|
|
||
| x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) | ||
| x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) | ||
| x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) | ||
| x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) | ||
|
|
||
| Ia = torch.gather(im_padded, 2, x0_y0) | ||
| Ib = torch.gather(im_padded, 2, x0_y1) | ||
| Ic = torch.gather(im_padded, 2, x1_y0) | ||
| Id = torch.gather(im_padded, 2, x1_y1) | ||
|
|
||
| return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw) | ||
|
|
||
|
|
||
| multi_scale_deform_attn.multi_scale_deformable_attn_pytorch = multi_scale_deformable_attn_pytorch | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,12 @@ | ||
| """Adapters for mmdeploy.""" | ||
| # Copyright (C) 2023 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # SPDX-License-Identifier: MIT | ||
|
|
||
| from .ops import squeeze__default | ||
| from .utils.mmdeploy import is_mmdeploy_enabled | ||
|
|
||
| __all__ = [ | ||
| "squeeze__default", | ||
| "is_mmdeploy_enabled", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| """Initial file for mmdeploy ops.""" | ||
| # Copyright (C) 2023 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
|
|
||
| from .custom_ops import squeeze__default | ||
|
|
||
| __all__ = ["squeeze__default"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| """Custom patch of mmdeploy ops for openvino export.""" | ||
| # Copyright (C) 2023 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
|
|
||
| import torch | ||
| from mmdeploy.core import SYMBOLIC_REWRITER | ||
| from mmdeploy.utils import get_ir_config | ||
| from torch.onnx import symbolic_helper | ||
|
|
||
| # Remove previous registered symbolic | ||
| SYMBOLIC_REWRITER._registry._rewrite_records["squeeze"] = list() | ||
|
|
||
|
|
||
| @SYMBOLIC_REWRITER.register_symbolic("squeeze", is_pytorch=True) | ||
| def squeeze__default(ctx, g, self, dim=None): | ||
harimkang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Register default symbolic function for `squeeze`. | ||
|
|
||
| squeeze might be exported with IF node in ONNX, which is not supported in | ||
| lots of backend. | ||
|
|
||
| mmdeploy 0.x version do not support opset13 version squeeze, therefore this function is for | ||
| custom patch for supporting opset13 version squeeze. | ||
|
|
||
| If we adapt mmdeploy1.x version, then this function is no longer needed. | ||
| """ | ||
| if dim is None: | ||
| dims = [] | ||
| for i, size in enumerate(self.type().sizes()): | ||
| if size == 1: | ||
| dims.append(i) | ||
| else: | ||
| dims = [symbolic_helper._get_const(dim, "i", "dim")] | ||
|
|
||
| if get_ir_config(ctx.cfg).get("opset_version", 11) >= 13: | ||
| axes = g.op("Constant", value_t=torch.tensor(dims, dtype=torch.long)) | ||
| return g.op("Squeeze", self, axes) | ||
|
|
||
| return g.op("Squeeze", self, axes_i=dims) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
51 changes: 51 additions & 0 deletions
51
otx/algorithms/detection/adapters/mmdet/models/detectors/custom_deformable_detr_detector.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| """OTX Deformable DETR Class for mmdetection detectors.""" | ||
|
|
||
| # Copyright (C) 2023 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
|
|
||
| from mmdet.models.builder import DETECTORS | ||
| from mmdet.models.detectors.deformable_detr import DeformableDETR | ||
|
|
||
| from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import ( | ||
| ActivationMapHook, | ||
| FeatureVectorHook, | ||
| ) | ||
| from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled | ||
|
|
||
|
|
||
| @DETECTORS.register_module() | ||
| class CustomDeformableDETR(DeformableDETR): | ||
| """Custom Deformable DETR with task adapt. | ||
|
|
||
| Deformable DETR does not support task adapt, so it just take task_adpat argument. | ||
| """ | ||
|
|
||
| def __init__(self, *args, task_adapt=None, **kwargs): | ||
jaegukhyun marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| super().__init__(*args, **kwargs) | ||
| self.task_adapt = task_adapt | ||
|
|
||
|
|
||
| if is_mmdeploy_enabled(): | ||
| from mmdeploy.core import FUNCTION_REWRITER | ||
|
|
||
| @FUNCTION_REWRITER.register_rewriter( | ||
| "otx.algorithms.detection.adapters.mmdet.models.detectors.custom_deformable_detr_detector.CustomDeformableDETR.simple_test" | ||
| ) | ||
| def custom_deformable_detr__simple_test(ctx, self, img, img_metas, **kwargs): | ||
| """Function for custom_deformable_detr__simple_test.""" | ||
| height = int(img_metas[0]["img_shape"][0]) | ||
| width = int(img_metas[0]["img_shape"][1]) | ||
| img_metas[0]["batch_input_shape"] = (height, width) | ||
| img_metas[0]["img_shape"] = (height, width, 3) | ||
| feat = self.extract_feat(img) | ||
| outs = self.bbox_head(feat, img_metas) | ||
| bbox_results = self.bbox_head.get_bboxes(*outs, img_metas=img_metas, **kwargs) | ||
|
|
||
| if ctx.cfg["dump_features"]: | ||
| feature_vector = FeatureVectorHook.func(feat) | ||
| cls_scores = outs[0] | ||
| saliency_map = ActivationMapHook.func(cls_scores) | ||
| return (*bbox_results, feature_vector, saliency_map) | ||
|
|
||
| return bbox_results | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.