diff --git a/train.py b/train.py index 1d3404ffc414..d4a5495d3b3b 100644 --- a/train.py +++ b/train.py @@ -34,7 +34,7 @@ from utils.datasets import create_dataloader from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \ - check_requirements, print_mutation, set_logging, one_cycle, colorstr + check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods from utils.downloads import attempt_download from utils.loss import ComputeLoss from utils.plots import plot_labels, plot_evolution @@ -42,6 +42,7 @@ from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.metrics import fitness from utils.loggers import Loggers +from utils.callbacks import Callbacks LOGGER = logging.getLogger(__name__) LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html @@ -52,6 +53,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary opt, device, + callbacks=Callbacks() ): save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \ @@ -77,12 +79,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Loggers if RANK in [-1, 0]: - loggers = Loggers(save_dir, weights, opt, hyp, LOGGER).start() # loggers dict + loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance if loggers.wandb: data_dict = loggers.wandb.data_dict if resume: weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp + # Register actions + for k in methods(loggers): + callbacks.register_action(k, callback=getattr(loggers, k)) + # Config plots = not evolve # create plots cuda = device.type != 'cpu' @@ -215,13 +221,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency # model._initialize_biases(cf.to(device)) if plots: - plot_labels(labels, names, save_dir, loggers) + plot_labels(labels, names, save_dir) # Anchors if not opt.noautoanchor: check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) model.half().float() # pre-reduce anchor precision + callbacks.on_pretrain_routine_end() + # DDP mode if cuda and RANK != -1: model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) @@ -329,8 +337,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % ( f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) - loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots) - + callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots) # end batch ------------------------------------------------------------------------------------------------ # Scheduler @@ -339,7 +346,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if RANK in [-1, 0]: # mAP - loggers.on_train_epoch_end(epoch) + callbacks.on_train_epoch_end(epoch=epoch) ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights']) final_epoch = epoch + 1 == epochs if not noval or final_epoch: # Calculate mAP @@ -353,14 +360,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary save_json=is_coco and final_epoch, verbose=nc < 50 and final_epoch, plots=plots and final_epoch, - loggers=loggers, + callbacks=callbacks, compute_loss=compute_loss) # Update best mAP fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] if fi > best_fitness: best_fitness = fi - loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi) + callbacks.on_fit_epoch_end(mloss, results, lr, epoch, best_fitness, fi) # Save model if (not nosave) or (final_epoch and not evolve): # if save @@ -377,7 +384,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if best_fitness == fi: torch.save(ckpt, best) del ckpt - loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi) + callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi) # end epoch ---------------------------------------------------------------------------------------------------- # end training ----------------------------------------------------------------------------------------------------- @@ -400,7 +407,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary for f in last, best: if f.exists(): strip_optimizer(f) # strip optimizers - loggers.on_train_end(last, best, plots) + callbacks.on_train_end(last, best, plots, epoch) + LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") torch.cuda.empty_cache() return results @@ -448,6 +456,7 @@ def parse_opt(known=False): def main(opt): + # Checks set_logging(RANK) if RANK in [-1, 0]: print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items())) diff --git a/utils/callbacks.py b/utils/callbacks.py new file mode 100644 index 000000000000..f23d57a6c043 --- /dev/null +++ b/utils/callbacks.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python + +class Callbacks: + """" + Handles all registered callbacks for YOLOv5 Hooks + """ + + _callbacks = { + 'on_pretrain_routine_start': [], + 'on_pretrain_routine_end': [], + + 'on_train_start': [], + 'on_train_epoch_start': [], + 'on_train_batch_start': [], + 'optimizer_step': [], + 'on_before_zero_grad': [], + 'on_train_batch_end': [], + 'on_train_epoch_end': [], + + 'on_val_start': [], + 'on_val_batch_start': [], + 'on_val_image_end': [], + 'on_val_batch_end': [], + 'on_val_end': [], + + 'on_fit_epoch_end': [], # fit = train + val + 'on_model_save': [], + 'on_train_end': [], + + 'teardown': [], + } + + def __init__(self): + return + + def register_action(self, hook, name='', callback=None): + """ + Register a new action to a callback hook + + Args: + hook The callback hook name to register the action to + name The name of the action + callback The callback to fire + """ + assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" + assert callable(callback), f"callback '{callback}' is not callable" + self._callbacks[hook].append({'name': name, 'callback': callback}) + + def get_registered_actions(self, hook=None): + """" + Returns all the registered actions by callback hook + + Args: + hook The name of the hook to check, defaults to all + """ + if hook: + return self._callbacks[hook] + else: + return self._callbacks + + @staticmethod + def run_callbacks(register, *args, **kwargs): + """ + Loop through the registered actions and fire all callbacks + """ + for logger in register: + # print(f"Running callbacks.{logger['callback'].__name__}()") + logger['callback'](*args, **kwargs) + + def on_pretrain_routine_start(self, *args, **kwargs): + """ + Fires all registered callbacks at the start of each pretraining routine + """ + self.run_callbacks(self._callbacks['on_pretrain_routine_start'], *args, **kwargs) + + def on_pretrain_routine_end(self, *args, **kwargs): + """ + Fires all registered callbacks at the end of each pretraining routine + """ + self.run_callbacks(self._callbacks['on_pretrain_routine_end'], *args, **kwargs) + + def on_train_start(self, *args, **kwargs): + """ + Fires all registered callbacks at the start of each training + """ + self.run_callbacks(self._callbacks['on_train_start'], *args, **kwargs) + + def on_train_epoch_start(self, *args, **kwargs): + """ + Fires all registered callbacks at the start of each training epoch + """ + self.run_callbacks(self._callbacks['on_train_epoch_start'], *args, **kwargs) + + def on_train_batch_start(self, *args, **kwargs): + """ + Fires all registered callbacks at the start of each training batch + """ + self.run_callbacks(self._callbacks['on_train_batch_start'], *args, **kwargs) + + def optimizer_step(self, *args, **kwargs): + """ + Fires all registered callbacks on each optimizer step + """ + self.run_callbacks(self._callbacks['optimizer_step'], *args, **kwargs) + + def on_before_zero_grad(self, *args, **kwargs): + """ + Fires all registered callbacks before zero grad + """ + self.run_callbacks(self._callbacks['on_before_zero_grad'], *args, **kwargs) + + def on_train_batch_end(self, *args, **kwargs): + """ + Fires all registered callbacks at the end of each training batch + """ + self.run_callbacks(self._callbacks['on_train_batch_end'], *args, **kwargs) + + def on_train_epoch_end(self, *args, **kwargs): + """ + Fires all registered callbacks at the end of each training epoch + """ + self.run_callbacks(self._callbacks['on_train_epoch_end'], *args, **kwargs) + + def on_val_start(self, *args, **kwargs): + """ + Fires all registered callbacks at the start of the validation + """ + self.run_callbacks(self._callbacks['on_val_start'], *args, **kwargs) + + def on_val_batch_start(self, *args, **kwargs): + """ + Fires all registered callbacks at the start of each validation batch + """ + self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs) + + def on_val_image_end(self, *args, **kwargs): + """ + Fires all registered callbacks at the end of each val image + """ + self.run_callbacks(self._callbacks['on_val_image_end'], *args, **kwargs) + + def on_val_batch_end(self, *args, **kwargs): + """ + Fires all registered callbacks at the end of each validation batch + """ + self.run_callbacks(self._callbacks['on_val_batch_end'], *args, **kwargs) + + def on_val_end(self, *args, **kwargs): + """ + Fires all registered callbacks at the end of the validation + """ + self.run_callbacks(self._callbacks['on_val_end'], *args, **kwargs) + + def on_fit_epoch_end(self, *args, **kwargs): + """ + Fires all registered callbacks at the end of each fit (train+val) epoch + """ + self.run_callbacks(self._callbacks['on_fit_epoch_end'], *args, **kwargs) + + def on_model_save(self, *args, **kwargs): + """ + Fires all registered callbacks after each model save + """ + self.run_callbacks(self._callbacks['on_model_save'], *args, **kwargs) + + def on_train_end(self, *args, **kwargs): + """ + Fires all registered callbacks at the end of training + """ + self.run_callbacks(self._callbacks['on_train_end'], *args, **kwargs) + + def teardown(self, *args, **kwargs): + """ + Fires all registered callbacks before teardown + """ + self.run_callbacks(self._callbacks['teardown'], *args, **kwargs) diff --git a/utils/general.py b/utils/general.py index a414b391d24e..ed028d2b3765 100755 --- a/utils/general.py +++ b/utils/general.py @@ -67,6 +67,11 @@ def handler(*args, **kwargs): return handler +def methods(instance): + # Get class/instance methods + return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")] + + def set_logging(rank=-1, verbose=True): logging.basicConfig( format="%(message)s", diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index 06d562d60f99..5d4377d54155 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -29,10 +29,12 @@ def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, self.hyp = hyp self.logger = logger # for printing results to console self.include = include + self.keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss + 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics + 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss + 'x/lr0', 'x/lr1', 'x/lr2'] # params for k in LOGGERS: setattr(self, k, None) # init empty logger dictionary - - def start(self): self.csv = True # always log to csv # Message @@ -57,7 +59,11 @@ def start(self): else: self.wandb = None - return self + def on_pretrain_routine_end(self): + # Callback runs on pre-train routine end + paths = self.save_dir.glob('*labels*.jpg') # training labels + if self.wandb: + self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]}) def on_train_batch_end(self, ni, model, imgs, targets, paths, plots): # Callback runs on train batch end @@ -78,8 +84,8 @@ def on_train_epoch_end(self, epoch): if self.wandb: self.wandb.current_epoch = epoch + 1 - def on_val_batch_end(self, pred, predn, path, names, im): - # Callback runs on train batch end + def on_val_image_end(self, pred, predn, path, names, im): + # Callback runs on val image end if self.wandb: self.wandb.val_one_image(pred, predn, path, names, im) @@ -89,25 +95,20 @@ def on_val_end(self): files = sorted(self.save_dir.glob('val*.jpg')) self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]}) - def on_train_val_end(self, mloss, results, lr, epoch, best_fitness, fi): - # Callback runs on val end during training + def on_fit_epoch_end(self, mloss, results, lr, epoch, best_fitness, fi): + # Callback runs at the end of each fit (train+val) epoch vals = list(mloss) + list(results) + lr - keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss - 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics - 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss - 'x/lr0', 'x/lr1', 'x/lr2'] # params - x = {k: v for k, v in zip(keys, vals)} # dict - + x = {k: v for k, v in zip(self.keys, vals)} # dict if self.csv: file = self.save_dir / 'results.csv' n = len(x) + 1 # number of cols - s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # add header + s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + self.keys)).rstrip(',') + '\n') # add header with open(file, 'a') as f: f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n') if self.tb: for k, v in x.items(): - self.tb.add_scalar(k, v, epoch) # TensorBoard + self.tb.add_scalar(k, v, epoch) if self.wandb: self.wandb.log(x) @@ -119,20 +120,22 @@ def on_model_save(self, last, epoch, final_epoch, best_fitness, fi): if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1: self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi) - def on_train_end(self, last, best, plots): + def on_train_end(self, last, best, plots, epoch): # Callback runs on training end if plots: plot_results(dir=self.save_dir) # save results.png files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]] files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter + + if self.tb: + from PIL import Image + import numpy as np + for f in files: + self.tb.add_image(f.stem, np.asarray(Image.open(f)), epoch, dataformats='HWC') + if self.wandb: wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]}) wandb.log_artifact(str(best if best.exists() else last), type='model', name='run_' + self.wandb.wandb_run.id + '_model', aliases=['latest', 'best', 'stripped']) self.wandb.finish_run() - - def log_images(self, paths): - # Log images - if self.wandb: - self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]}) diff --git a/utils/plots.py b/utils/plots.py index e13e316314dd..252e128168ee 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -281,7 +281,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx plt.savefig(str(Path(path).name) + '.png', dpi=300) -def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): +def plot_labels(labels, names=(), save_dir=Path('')): # plot dataset labels print('Plotting labels... ') c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes @@ -324,10 +324,6 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): matplotlib.use('Agg') plt.close() - # loggers - if loggers: - loggers.log_images(save_dir.glob('*labels*.jpg')) - def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() # Plot hyperparameter evolution results in evolve.txt diff --git a/val.py b/val.py index 86439b1380dc..58e8170da86c 100644 --- a/val.py +++ b/val.py @@ -25,7 +25,7 @@ from utils.metrics import ap_per_class, ConfusionMatrix from utils.plots import plot_images, output_to_target, plot_study_txt from utils.torch_utils import select_device, time_sync -from utils.loggers import Loggers +from utils.callbacks import Callbacks def save_one_txt(predn, save_conf, shape, file): @@ -97,7 +97,7 @@ def run(data, dataloader=None, save_dir=Path(''), plots=True, - loggers=Loggers(), + callbacks=Callbacks(), compute_loss=None, ): # Initialize/load model and set device @@ -213,7 +213,7 @@ def run(data, save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt')) if save_json: save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary - loggers.on_val_batch_end(pred, predn, path, names, img[si]) + callbacks.on_val_image_end(pred, predn, path, names, img[si]) # Plot images if plots and batch_i < 3: @@ -250,7 +250,7 @@ def run(data, # Plots if plots: confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) - loggers.on_val_end() + callbacks.on_val_end() # Save JSON if save_json and len(jdict): @@ -282,7 +282,7 @@ def run(data, model.float() # for training if not training: s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' - print(f"Results saved to {save_dir}{s}") + print(f"Results saved to {colorstr('bold', save_dir)}{s}") maps = np.zeros(nc) + map for i, c in enumerate(ap_class): maps[c] = ap[i]