@@ -62,10 +62,6 @@ def configure(self, model_cfg, model_ckpt, data_cfg, training=True, **kwargs):
6262 if training :
6363 self .configure_regularization (cfg )
6464
65- # Other hyper-parameters
66- if 'hyperparams' in cfg :
67- self .configure_hyperparams (cfg , training , ** kwargs )
68-
6965 # Hooks
7066 self .configure_hook (cfg )
7167
@@ -117,12 +113,15 @@ def configure_data(self, cfg, training, **kwargs):
117113 seed = cfg .seed
118114 )
119115 )
120- if 'dataset' in cfg .data .train :
121- train_cfg = self .get_train_data_cfg (cfg )
122- train_cfg .ote_dataset = cfg .data .train .pop ('ote_dataset' , None )
123- train_cfg .labels = cfg .data .train .get ('labels' , None )
124- train_cfg .data_classes = cfg .data .train .pop ('data_classes' , None )
125- train_cfg .new_classes = cfg .data .train .pop ('new_classes' , None )
116+ for subset in ("train" , "val" , "test" ):
117+ if 'dataset' in cfg .data [subset ]:
118+ subset_cfg = self .get_data_cfg (cfg , subset )
119+ subset_cfg .ote_dataset = cfg .data [subset ].pop ('ote_dataset' , None )
120+ subset_cfg .labels = cfg .data [subset ].get ('labels' , None )
121+ if 'data_classes' in cfg .data [subset ]:
122+ subset_cfg .data_classes = cfg .data [subset ].pop ('data_classes' )
123+ if 'new_classes' in cfg .data [subset ]:
124+ subset_cfg .new_classes = cfg .data [subset ].pop ('new_classes' )
126125
127126 def configure_task (self , cfg , training , ** kwargs ):
128127 """Adjust settings for task adaptation
@@ -200,7 +199,7 @@ def configure_task_classes(self, cfg, task_adapt_type, task_adapt_op):
200199
201200 def configure_task_data_pipeline (self , cfg , model_classes , data_classes ):
202201 # Trying to alter class indices of training data according to model class order
203- tr_data_cfg = self .get_train_data_cfg (cfg )
202+ tr_data_cfg = self .get_data_cfg (cfg , "train" )
204203 class_adapt_cfg = dict (type = 'AdaptClassLabels' , src_classes = data_classes , dst_classes = model_classes )
205204 pipeline_cfg = tr_data_cfg .pipeline
206205 for i , op in enumerate (pipeline_cfg ):
@@ -240,7 +239,7 @@ def configure_task_cls_incr(self, cfg, task_adapt_type, org_model_classes, model
240239 else :
241240 bbox_head = cfg .model .roi_head .bbox_head
242241 if task_adapt_type == 'mpa' :
243- tr_data_cfg = self .get_train_data_cfg (cfg )
242+ tr_data_cfg = self .get_data_cfg (cfg , "train" )
244243 if tr_data_cfg .type != 'MPADetDataset' :
245244 tr_data_cfg .img_ids_dict = self .get_img_ids_for_incr (cfg , org_model_classes , model_classes )
246245 tr_data_cfg .org_type = tr_data_cfg .type
@@ -311,7 +310,7 @@ def configure_task_cls_incr(self, cfg, task_adapt_type, org_model_classes, model
311310 ConfigDict (type = 'AdaptiveTrainSchedulingHook' , ** adaptive_validation_interval )
312311 )
313312 else :
314- src_data_cfg = Stage .get_train_data_cfg (cfg )
313+ src_data_cfg = Stage .get_data_cfg (cfg , "train" )
315314 src_data_cfg .pop ('old_new_indices' , None )
316315
317316 def configure_regularization (self , cfg ):
@@ -338,7 +337,7 @@ def get_img_ids_for_incr(cfg, org_model_classes, model_classes):
338337 new_classes = np .setdiff1d (model_classes , org_model_classes ).tolist ()
339338 old_classes = np .intersect1d (org_model_classes , model_classes ).tolist ()
340339
341- src_data_cfg = Stage .get_train_data_cfg (cfg )
340+ src_data_cfg = Stage .get_data_cfg (cfg , "train" )
342341
343342 ids_old , ids_new = [], []
344343 data_cfg = cfg .data .test .copy ()
@@ -366,17 +365,6 @@ def get_img_ids_for_incr(cfg, org_model_classes, model_classes):
366365 )
367366 return outputs
368367
369- def configure_hyperparams (self , cfg , training , ** kwargs ):
370- hyperparams = kwargs .get ('hyperparams' , None )
371- if hyperparams is not None :
372- bs = hyperparams .get ('bs' , None )
373- if bs is not None :
374- cfg .data .samples_per_gpu = bs
375-
376- lr = hyperparams .get ('lr' , None )
377- if lr is not None :
378- cfg .optimizer .lr = lr
379-
380368 @staticmethod
381369 def add_yolox_hooks (cfg ):
382370 update_or_add_custom_hook (
0 commit comments