diff --git a/README.md b/README.md index 82a4fcc2a..4aa7d5b1d 100644 --- a/README.md +++ b/README.md @@ -54,14 +54,42 @@ Some examples from our YOLACT base model (33.5 fps on a Titan Xp and 29.8 mAP on ```Shell sh data/scripts/COCO_test.sh ``` - - If you want to use YOLACT++, compile deformable convolutional layers (from [DCNv2](https://github.com/CharlesShang/DCNv2/tree/pytorch_1.0)). + - If you want to use YOLACT++, compile deformable convolutional layers. Make sure you have the latest CUDA toolkit installed from [NVidia's Website](https://developer.nvidia.com/cuda-toolkit). + + Case 1: Only need to use pytorch code ```Shell - cd external/DCNv2 - python setup.py build develop + pip install mmcv-full ``` + Case 2: If you'd like to inference using onnx models. Follow - https://mmcv.readthedocs.io/en/latest/deployment/onnxruntime_op.html + ```Shell + wget https://github.com/microsoft/onnxruntime/releases/download/v1.8.1/onnxruntime-linux-x64-1.8.1.tgz + tar -zxvf onnxruntime-linux-x64-1.8.1.tgz + cd onnxruntime-linux-x64-1.8.1 + export ONNXRUNTIME_DIR=$(pwd) + export LD_LIBRARY_PATH=$ONNXRUNTIME_DIR/lib:$LD_LIBRARY_PATH + ``` + ```Shell + cd .. + git clone https://github.com/open-mmlab/mmcv.git + cd mmcv ## to MMCV root directory + MMCV_WITH_OPS=1 MMCV_WITH_ORT=1 python setup.py develop + ``` + ```Shell + pip install onnxruntime==1.8.1 + ``` +## ONNX export and Inference (tested for yolact_plus for image with size 550, supports batch-size=1) +For pytorch model to ONNX conversion +```Shell +python3 yolact2onnx.py --config yolact_plus_base_config --ckpt_path yolact_plus_base_159_180000.pth --onnx_paths yolact_plus.onnx maskiou_net.onnx --score_threshold 0.5 +``` + +For ONNX inference +```Shell +python3 onnx_inference.py --config yolact_plus_base_config --img_path 15387869.jpg --onnx_paths yolact_plus.onnx maskiou_net.onnx --score_threshold 0.5 +``` # Evaluation Here are our YOLACT models (released on April 5th, 2019) along with their FPS on a Titan Xp and mAP on `test-dev`: diff --git a/backbone.py b/backbone.py index 4df59d023..1a601a137 100644 --- a/backbone.py +++ b/backbone.py @@ -5,7 +5,7 @@ from collections import OrderedDict try: - from dcn_v2 import DCN + from mmcv.ops import ModulatedDeformConv2dPack as DCN except ImportError: def DCN(*args, **kwdargs): raise Exception('DCN could not be imported. If you want to use YOLACT++ models, compile DCN. Check the README for instructions.') @@ -21,9 +21,6 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=nn.Ba if use_dcn: self.conv2 = DCN(planes, planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, deformable_groups=1) - self.conv2.bias.data.zero_() - self.conv2.conv_offset_mask.weight.data.zero_() - self.conv2.conv_offset_mask.bias.data.zero_() else: self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) diff --git a/data/config.py b/data/config.py index 91b4c82ea..39d57d636 100644 --- a/data/config.py +++ b/data/config.py @@ -625,6 +625,7 @@ def print(self): # Use command-line arguments to set this. 'no_jit': False, + 'export_onnx':False 'backbone': None, 'name': 'base_config', diff --git a/layers/functions/detection.py b/layers/functions/detection.py index 4e5fd068c..c0520bcc0 100644 --- a/layers/functions/detection.py +++ b/layers/functions/detection.py @@ -29,7 +29,7 @@ def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh): self.use_cross_class_nms = False self.use_fast_nms = False - def __call__(self, predictions, net): + def __call__(self, predictions, net=None): """ Args: loc_data: (tensor) Loc preds from loc layers @@ -72,8 +72,10 @@ def __call__(self, predictions, net): if result is not None and proto_data is not None: result['proto'] = proto_data[batch_idx] - - out.append({'detection': result, 'net': net}) + if net is not None: + out.append({'detection': result, 'net': net}) + else: + out.append({'detection': result}) return out diff --git a/layers/output_utils.py b/layers/output_utils.py index 27efac935..d1488ae5a 100644 --- a/layers/output_utils.py +++ b/layers/output_utils.py @@ -12,7 +12,7 @@ from utils import timer from .box_utils import crop, sanitize_coordinates -def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear', +def postprocess(det_output, w, h, maskiou_net=None, batch_idx=0, interpolation_mode='bilinear', visualize_lincomb=False, crop_masks=True, score_threshold=0): """ Postprocesses the output of Yolact on testing mode into a format that makes sense, @@ -31,10 +31,12 @@ def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear', - boxes [num_det, 4]: The bounding box for each detection in absolute point form. - masks [num_det, h, w]: Full image masks for each detection. """ - dets = det_output[batch_idx] - net = dets['net'] - dets = dets['detection'] + if not maskiou_net: + net = dets['net'] + dets = dets['detection'] + else: + dets = dets['detection'] if dets is None: return [torch.Tensor()] * 4 # Warning, this is 4 copies of the same thing @@ -79,7 +81,11 @@ def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear', if cfg.use_maskiou: with timer.env('maskiou_net'): with torch.no_grad(): - maskiou_p = net.maskiou_net(masks.unsqueeze(1)) + if maskiou_net is not None: + maskiou_p = maskiou_net.run(None, {"input": masks.unsqueeze(1).numpy()}) + maskiou_p = torch.from_numpy(maskiou_p[0]) + else: + maskiou_p = net.maskiou_net(masks.unsqueeze(1)) maskiou_p = torch.gather(maskiou_p, dim=1, index=classes.unsqueeze(1)).squeeze(1) if cfg.rescore_mask: if cfg.rescore_bbox: diff --git a/onnx_inference.py b/onnx_inference.py new file mode 100644 index 000000000..23c075b26 --- /dev/null +++ b/onnx_inference.py @@ -0,0 +1,318 @@ +from data import cfg, MEANS, STD, set_cfg, mask_type +from layers.box_utils import mask_iou +import torch.nn.functional as F +import torch +import cv2 +import os +from layers.output_utils import postprocess + +import colorsys +import random + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib import patches +from matplotlib.patches import Polygon +from skimage.measure import find_contours +import onnxruntime as ort +from layers import Detect +from yolact import FastMaskIoUNet +from mmcv.ops import get_onnxruntime_op_path + +############################################################ +# Visualization +############################################################ + + +def random_colors(N, bright=True): + """ + Generate random colors. + To get visually distinct colors, generate them in HSV space then + convert to RGB. + """ + brightness = 1.0 if bright else 0.7 + hsv = [(i / N, 1, brightness) for i in range(N)] + colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) + random.shuffle(colors) + return colors + + +def apply_mask(image, mask, color, alpha=0.5): + """Apply the given mask to the image.""" + for c in range(3): + image[:, :, c] = np.where( + mask == 1, + image[:, :, c] * (1 - alpha) + alpha * color[c] * 255, + image[:, :, c], + ) + # cv2.imwrite('maskapplied.png',image) + return image + + +def display_instances( + image, + boxes, + masks, + class_ids, + class_names, + scores=None, + title="", + figsize=(8, 8), + show_mask=True, + show_bbox=True, + colors=None, + captions=None, + plot_path=None, +): + """ + boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates. + masks: [height, width, num_instances] + class_ids: [num_instances] + class_names: list of class names of the dataset + scores: (optional) confidence scores for each box + title: (optional) Figure title + show_mask, show_bbox: To show masks and bounding boxes or not + figsize: (optional) the size of the image + colors: (optional) An array or colors to use with each object + captions: (optional) A list of strings to use as captions for each object + """ + # Number of instances + N = boxes.shape[0] + # print(image.shape, masks.shape) + if not N: + print("\n*** No instances to display *** \n") + else: + assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0] + + # If no axis is passed, create one and automatically call show() + fig, ax = plt.subplots( + 1, + figsize=figsize, + ) + + # Generate random colors + colors = colors or random_colors(N) + + # Show area outside image boundaries. + height, width = image.shape[:2] + ax.set_ylim(height + 10, -10) + ax.set_xlim(-10, width + 10) + ax.axis("off") + ax.set_title(title) + + masked_image = image.astype(np.uint32).copy() + # print(N, '-------') + for i in range(N): + color = colors[i] + + # Bounding box + if not np.any(boxes[i]): + + # Skip this instance. Has no bbox. Likely lost in image cropping. + continue + y1, x1, y2, x2 = boxes[i] + if show_bbox: + p = patches.Rectangle( + (x1, y1), + x2 - x1, + y2 - y1, + linewidth=2, + alpha=0.7, + linestyle="dashed", + edgecolor=color, + facecolor="none", + ) + ax.add_patch(p) + + # Label + if not captions: + class_id = class_ids[i] + score = scores[i] if scores is not None else None + label = class_names[class_id] + caption = "{} {:.3f}".format(label, score) if score else label + else: + caption = captions[i] + ax.text(x1, y1 + 8, caption, color="w", size=11, backgroundcolor="none") + + # Mask + mask = masks[:, :, i] + # print(show_mask,'---------------') + if show_mask: + masked_image = apply_mask(masked_image, mask, color, alpha=0.2) + + # Mask Polygon + # Pad to ensure proper polygons for masks that touch image edges. + padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8) + padded_mask[1:-1, 1:-1] = mask + contours = find_contours(padded_mask, 0.5) + for verts in contours: + # Subtract the padding and flip (y, x) to (x, y) + verts = np.fliplr(verts) - 1 + p = Polygon(verts, facecolor="none", edgecolor=color) + ax.add_patch(p) + ax.imshow(masked_image.astype(np.uint8)) + if plot_path is not None: + fig.subplots_adjust(left=0, bottom=0, right=1, top=1) + fig.savefig(plot_path) + plt.close(fig) + + +#preprocess +class FastBaseTransform(torch.nn.Module): + """ + Transform that does all operations on the GPU for super speed. + This doesn't suppport a lot of config settings and should only be used for production. + Maintain this as necessary. + """ + + def __init__(self): + super().__init__() + + self.mean = torch.Tensor(MEANS).float() + self.std = torch.Tensor( STD ).float() + + if torch.cuda.is_available(): + self.mean = self.mean.cuda() + self.std = self.std.cuda() + self.mean = self.mean[None, :, None, None] + self.std = self.std[None, :, None, None] + self.transform = cfg.backbone.transform + + def forward(self, img): + self.mean = self.mean.to(img.device) + self.std = self.std.to(img.device) + + # img assumed to be a pytorch BGR image with channel order [n, h, w, c] + # if cfg.preserve_aspect_ratio: + # raise NotImplementedError + + img = img.permute(0, 3, 1, 2).contiguous() + img = F.interpolate(img, (cfg.max_size, cfg.max_size), mode='bilinear', align_corners=False) + + if self.transform.normalize: + img = (img - self.mean) / self.std + elif self.transform.subtract_means: + img = (img - self.mean) + elif self.transform.to_float: + img = img / 255 + + if self.transform.channel_order != 'RGB': + raise NotImplementedError + + img = img[:, (2, 1, 0), :, :].contiguous() + + # Return value is in channel order [n, c, h, w] and RGB + return img + +def load_image(path: str): + img =cv2.imread(path) + img =cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + h,w,_ = img.shape + frame = torch.from_numpy(img).float() + batch = FastBaseTransform()(frame.unsqueeze(0)) + return batch.numpy(), (h,w) + +def PostProcess(preds, orig_shapes, maskiou_sess, score_threshold: float): + h, w = orig_shapes + classes, scores, boxes, masks = postprocess(preds, w, h, maskiou_net=maskiou_sess, crop_masks=True, score_threshold=score_threshold) + # if classes.size(0) == 0: + # return + + classes = list(classes.cpu().numpy().astype(int)) + if isinstance(scores, list): + box_scores = list(scores[0].cpu().detach().numpy().astype(float)) + mask_scores = list(scores[1].cpu().detach().numpy().astype(float)) + else: + scores = list(scores.detach().cpu().detach().numpy().astype(float)) + box_scores = scores + mask_scores = scores + masks = masks.view(-1, h*w) + boxes = boxes.cpu().numpy() + masks = masks.view(-1, h, w).detach().cpu().numpy() + + structure_bbox_list =[] + mask_list=[] + for i in range(masks.shape[0]): + # Make sure that the bounding box actually makes sense and a mask was produced + if (boxes[i, 3] - boxes[i, 1]) * (boxes[i, 2] - boxes[i, 0]) > 0: + + if mask_scores[i]>0.4: + bbox = boxes[i,:] + bbox = [round(float(x)*10)/10 for x in bbox] + structure_bbox_list.append(bbox) + + mask_list.append(masks[i,:,:].astype(np.uint8)) + + return structure_bbox_list, mask_list, classes, mask_scores + +def main(args): + #load model + config = args.config + set_cfg(config) + cfg.mask_proto_debug = False + + detect_layer = Detect(cfg.num_classes, bkg_label=0, top_k=cfg.nms_top_k, + conf_thresh=cfg.nms_conf_thresh, nms_thresh=cfg.nms_thresh) + if cfg.use_maskiou: + maskiou_net_sess = ort.InferenceSession(args.onnx_paths[1], providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + maskiou_net_sess_input_names = [inp.name for inp in maskiou_net_sess.get_inputs()] + print('Input Names:', maskiou_net_sess_input_names) + maskiou_net_sess_output_names = [out.name for out in maskiou_net_sess.get_outputs()] + print(maskiou_net_sess_output_names) + + img_path = args.img_path + img, orig_shapes =load_image(img_path) + + ## exported ONNX model with custom operators + ort_custom_op_path = get_onnxruntime_op_path() + assert os.path.exists(ort_custom_op_path) + session_options = ort.SessionOptions() + session_options.register_custom_ops_library(ort_custom_op_path) + + session = ort.InferenceSession(args.onnx_paths[0], session_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + input_names = [inp.name for inp in session.get_inputs()] + print('Input Names:', input_names) + output_names = [out.name for out in session.get_outputs()] + print(output_names) + onnx_results = session.run(None, {input_names[0] : img}) + detect_input = {"loc": torch.from_numpy(onnx_results[0]), + "conf": torch.from_numpy(onnx_results[1]), + "mask": torch.from_numpy(onnx_results[2]), + "priors": torch.from_numpy(onnx_results[3]), + "proto": torch.from_numpy(onnx_results[4]), + } + + detect_layer.use_fast_nms = True + detect_layer.use_cross_class_nms = False + + final_out = detect_layer(detect_input) + + structure_bbox_list, mask_list, classes, scores = PostProcess(final_out, orig_shapes, maskiou_net_sess, args.score_threshold) + + display_instances( + cv2.imread(img_path), + np.array([[y1, x1, y2, x2] for x1, y1, x2, y2 in structure_bbox_list]), + np.stack(mask_list, axis=2), + np.arange(len(structure_bbox_list)), + [str(i) for i in range(len(structure_bbox_list))], + show_mask=True, + show_bbox=True, + plot_path=img_path.split(".jpg")[0]+"_ort_out.png", + figsize=(16, 16), + ) + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Yolact onnx inference') + parser.add_argument('--config', type=str, default='yolact_plus_base_config', + help='The config object to use.') + parser.add_argument('--img_path', type=str, default="", + help='Give the path to image for inference') + parser.add_argument('--score_threshold', type=float, default=0.5, + help='Give the path to exported ONNX weights') + parser.add_argument('--onnx_paths', type=str, default="", nargs='+', + help='Give the path to exported ONNX weights') + + args = parser.parse_args() + main(args) + \ No newline at end of file diff --git a/yolact.py b/yolact.py index d83703bb7..bd26c22b0 100644 --- a/yolact.py +++ b/yolact.py @@ -22,7 +22,8 @@ torch.cuda.current_device() # As of March 10, 2019, Pytorch DataParallel still doesn't support JIT Script Modules -use_jit = torch.cuda.device_count() <= 1 + +use_jit = not cfg.export_onnx if not use_jit: print('Multiple GPUs detected! Turning off JIT.') @@ -370,8 +371,10 @@ def __init__(self): def forward(self, x): x = self.maskiou_net(x) - maskiou_p = F.max_pool2d(x, kernel_size=x.size()[2:]).squeeze(-1).squeeze(-1) - + if not cfg.export_onnx: + maskiou_p = F.max_pool2d(x, kernel_size=x.size()[2:]).squeeze(-1).squeeze(-1) + else: + maskiou_p = F.max_pool2d(x, kernel_size=(3,3)).squeeze(-1).squeeze(-1) return maskiou_p @@ -477,12 +480,16 @@ def save_weights(self, path): def load_weights(self, path): """ Loads weights from a compressed save file. """ state_dict = torch.load(path) - + # For backward compatability, remove these (the new variable is called layers) for key in list(state_dict.keys()): if key.startswith('backbone.layer') and not key.startswith('backbone.layers'): del state_dict[key] - + + if "conv2.conv_offset_mask" in key: + new_key = key.replace("conv2.conv_offset_mask", "conv2.conv_offset", 1) + state_dict[new_key] = state_dict[key] + del state_dict[key] # Also for backward compatibility with v1.0 weights, do this check if key.startswith('fpn.downsample_layers.'): if cfg.fpn is not None and int(key.split('.')[2]) >= cfg.fpn.num_downsample: @@ -672,11 +679,10 @@ def forward(self, x): else: pred_outs['conf'] = F.softmax(pred_outs['conf'], -1) - - return self.detect(pred_outs, self) - - - + if cfg.export_onnx: + return pred_outs + else: + return self.detect(pred_outs, self) # Some testing code if __name__ == '__main__': diff --git a/yolact2onnx.py b/yolact2onnx.py new file mode 100644 index 000000000..9c76f00f5 --- /dev/null +++ b/yolact2onnx.py @@ -0,0 +1,46 @@ +from data import cfg, MEANS, STD, set_cfg, mask_type +import torch.nn.functional as F +import torch +import numpy as np + + +def main(args): + config = args.config + trained_model = args.ckpt_path + + #load model + set_cfg(config) + cfg.mask_proto_debug = False + cfg.export_onnx=True + #torch.set_default_tensor_type('torch.cuda.FloatTensor') + + print('Loading model...', end='') + from yolact import Yolact + net = Yolact() + net.load_weights(trained_model) + net.eval() + dummy = torch.randn(1,3,550,550) + torch.onnx.export(net, dummy, "yolact_plus.onnx", opset_version=13) + net.detect.use_fast_nms = True + net.detect.use_cross_class_nms = False + if cfg.use_maskiou: + dummy2 = torch.randn(1,1,138,138) + torch.onnx.export(net.maskiou_net, dummy2, "maskiou_net.onnx", + input_names = ['input'], # the model's input names + output_names = ['output'], + dynamic_axes={'input' : {0 : 'batch_size'}, 'output' : {0 : 'batch_size'}}, + opset_version=13) + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Yolact onnx export') + parser.add_argument('--config', type=str, default='yolact_plus_base_config', + help='The config object to use.') + parser.add_argument('--ckpt_path', type=str, default="", + help='Give the path to trained pytorch weights') + parser.add_argument('--onnx_paths', type=str, default="", nargs='+', + help='Give the path to exported ONNX weights') + + args = parser.parse_args() + main(args) + \ No newline at end of file