5353from deepmd .pt .utils .multi_task import (
5454 preprocess_shared_params ,
5555)
56- from deepmd .pt .utils .stat import (
57- make_stat_input ,
58- )
5956from deepmd .utils .argcheck import (
6057 normalize ,
6158)
@@ -104,36 +101,23 @@ def get_trainer(
104101 config ["model" ]["resuming" ] = (finetune_model is not None ) or (ckpt is not None )
105102
106103 def prepare_trainer_input_single (
107- model_params_single , data_dict_single , loss_dict_single , suffix = ""
104+ model_params_single , data_dict_single , loss_dict_single , suffix = "" , rank = 0
108105 ):
109106 training_dataset_params = data_dict_single ["training_data" ]
110107 type_split = False
111108 if model_params_single ["descriptor" ]["type" ] in ["se_e2_a" ]:
112109 type_split = True
113- validation_dataset_params = data_dict_single ["validation_data" ]
110+ validation_dataset_params = data_dict_single .get ("validation_data" , None )
111+ validation_systems = (
112+ validation_dataset_params ["systems" ] if validation_dataset_params else None
113+ )
114114 training_systems = training_dataset_params ["systems" ]
115- validation_systems = validation_dataset_params ["systems" ]
116-
117- # noise params
118- noise_settings = None
119- if loss_dict_single .get ("type" , "ener" ) == "denoise" :
120- noise_settings = {
121- "noise_type" : loss_dict_single .pop ("noise_type" , "uniform" ),
122- "noise" : loss_dict_single .pop ("noise" , 1.0 ),
123- "noise_mode" : loss_dict_single .pop ("noise_mode" , "fix_num" ),
124- "mask_num" : loss_dict_single .pop ("mask_num" , 8 ),
125- "mask_prob" : loss_dict_single .pop ("mask_prob" , 0.15 ),
126- "same_mask" : loss_dict_single .pop ("same_mask" , False ),
127- "mask_coord" : loss_dict_single .pop ("mask_coord" , False ),
128- "mask_type" : loss_dict_single .pop ("mask_type" , False ),
129- "max_fail_num" : loss_dict_single .pop ("max_fail_num" , 10 ),
130- "mask_type_idx" : len (model_params_single ["type_map" ]) - 1 ,
131- }
132- # noise_settings = None
133115
134116 # stat files
135117 stat_file_path_single = data_dict_single .get ("stat_file" , None )
136- if stat_file_path_single is not None :
118+ if rank != 0 :
119+ stat_file_path_single = None
120+ elif stat_file_path_single is not None :
137121 if Path (stat_file_path_single ).is_dir ():
138122 raise ValueError (
139123 f"stat_file should be a file, not a directory: { stat_file_path_single } "
@@ -144,71 +128,63 @@ def prepare_trainer_input_single(
144128 stat_file_path_single = DPPath (stat_file_path_single , "a" )
145129
146130 # validation and training data
147- validation_data_single = DpLoaderSet (
148- validation_systems ,
149- validation_dataset_params ["batch_size" ],
150- model_params_single ,
131+ validation_data_single = (
132+ DpLoaderSet (
133+ validation_systems ,
134+ validation_dataset_params ["batch_size" ],
135+ model_params_single ,
136+ )
137+ if validation_systems
138+ else None
151139 )
152140 if ckpt or finetune_model :
153141 train_data_single = DpLoaderSet (
154142 training_systems ,
155143 training_dataset_params ["batch_size" ],
156144 model_params_single ,
157145 )
158- sampled_single = None
159146 else :
160147 train_data_single = DpLoaderSet (
161148 training_systems ,
162149 training_dataset_params ["batch_size" ],
163150 model_params_single ,
164151 )
165- data_stat_nbatch = model_params_single .get ("data_stat_nbatch" , 10 )
166- sampled_single = make_stat_input (
167- train_data_single .systems ,
168- train_data_single .dataloaders ,
169- data_stat_nbatch ,
170- )
171- if noise_settings is not None :
172- train_data_single = DpLoaderSet (
173- training_systems ,
174- training_dataset_params ["batch_size" ],
175- model_params_single ,
176- )
177152 return (
178153 train_data_single ,
179154 validation_data_single ,
180- sampled_single ,
181155 stat_file_path_single ,
182156 )
183157
158+ rank = dist .get_rank () if dist .is_initialized () else 0
184159 if not multi_task :
185160 (
186161 train_data ,
187162 validation_data ,
188- sampled ,
189163 stat_file_path ,
190164 ) = prepare_trainer_input_single (
191- config ["model" ], config ["training" ], config ["loss" ]
165+ config ["model" ],
166+ config ["training" ],
167+ config ["loss" ],
168+ rank = rank ,
192169 )
193170 else :
194- train_data , validation_data , sampled , stat_file_path = {}, {}, {}, {}
171+ train_data , validation_data , stat_file_path = {}, {}, {}
195172 for model_key in config ["model" ]["model_dict" ]:
196173 (
197174 train_data [model_key ],
198175 validation_data [model_key ],
199- sampled [model_key ],
200176 stat_file_path [model_key ],
201177 ) = prepare_trainer_input_single (
202178 config ["model" ]["model_dict" ][model_key ],
203179 config ["training" ]["data_dict" ][model_key ],
204180 config ["loss_dict" ][model_key ],
205181 suffix = f"_{ model_key } " ,
182+ rank = rank ,
206183 )
207184
208185 trainer = training .Trainer (
209186 config ,
210187 train_data ,
211- sampled = sampled ,
212188 stat_file_path = stat_file_path ,
213189 validation_data = validation_data ,
214190 init_model = init_model ,
0 commit comments