Skip to content

Commit 4718376

Browse files
committed
feat(core_text): enhance data validation in PromptedFilter and Text2MultiHopQAGenerator
1 parent e25d859 commit 4718376

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

dataflow/operators/core_text/filter/prompted_filter.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,21 @@ def __init__(self, llm_serving: LLMServingABC, system_prompt: str = "Please eval
1818
self.prompted_evaluator = PromptedEvaluator(llm_serving, system_prompt)
1919
self.min_score = min_score
2020
self.max_score = max_score
21+
22+
@staticmethod
23+
def _has_valid_content(value) -> bool:
24+
if value is None:
25+
return False
26+
if isinstance(value, str):
27+
return value.strip() != ""
28+
if isinstance(value, (list, tuple, set, dict)):
29+
return len(value) > 0
30+
try:
31+
if pd.isna(value):
32+
return False
33+
except TypeError:
34+
pass
35+
return bool(value)
2136

2237
@staticmethod
2338
def get_desc(lang: str = "zh"):
@@ -71,7 +86,7 @@ def run(self, storage: DataFlowStorage, input_key: str = "raw_content", output_k
7186
self.logger.info(f"Loading, number of rows: {len(dataframe)}")
7287

7388
# Drop rows where input_key is empty/null before evaluation
74-
valid_mask = dataframe[input_key].notna() & (dataframe[input_key].astype(str).str.strip() != '')
89+
valid_mask = dataframe[input_key].apply(self._has_valid_content)
7590
valid_dataframe = dataframe[valid_mask]
7691
self.logger.info(f"Skipping {(~valid_mask).sum()} rows with empty '{input_key}'")
7792

dataflow/operators/core_text/generate/text2multihopqa_generator.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,23 @@ def run(
225225
dataframe = storage.read("dataframe")
226226
self._validate_dataframe(dataframe)
227227
texts = dataframe[self.input_key].tolist()
228-
outputs=self.process_batch(texts)
229-
dataframe[self.output_key] = [
228+
outputs = self.process_batch(texts)
229+
qa_pairs_column = [
230230
output['qa_pairs'][:self.num_q] if len(output['qa_pairs']) >= self.num_q else output['qa_pairs']
231231
for output in outputs
232232
]
233+
metadata_column = [output['metadata'] for output in outputs]
234+
235+
dataframe = dataframe.copy()
236+
dataframe[self.output_key] = qa_pairs_column
237+
dataframe[self.output_meta_key] = metadata_column
238+
239+
valid_mask = dataframe[self.output_key].apply(lambda qa_pairs: isinstance(qa_pairs, list) and len(qa_pairs) > 0)
240+
filtered_count = int((~valid_mask).sum())
241+
if filtered_count:
242+
self.logger.info(f"Filtering out {filtered_count} rows with empty '{self.output_key}'")
243+
dataframe = dataframe[valid_mask].reset_index(drop=True)
233244

234-
dataframe[self.output_meta_key] = [output['metadata'] for output in outputs]
235245
output_file = storage.write(dataframe)
236246
self.logger.info(f"Results saved to {output_file}")
237247

0 commit comments

Comments
 (0)