@@ -330,22 +330,47 @@ def eval_local_cmd():
330330def pdf2model_init (cache : Path = typer .Option (Path ("." ),
331331 help = "Cache dir" ),
332332 qa : str = typer .Option ("kbc" , help = "Which pipeline to init (vqa or kbc)" ),
333- model : Optional [str ] = typer .Option (None , help = "Base model name or path" )):
333+ model : Optional [str ] = typer .Option (None , help = "Base model name or path" ),
334+ train_backend : str = typer .Option (
335+ "base" ,
336+ "--train-backend" ,
337+ help = "With --qa kbc: 'base' (LlamaFactory) or a registered dataflex-* backend (see cli_pdf.DATAFLEX_BACKEND_SPECS). vqa only allows 'base'." ,
338+ )):
334339 if qa not in ["vqa" , "kbc" ]:
335340 _echo (f"Invalid qa type: { qa } . Must be 'vqa' or 'kbc'." , "red" )
336341 raise typer .Exit (code = 1 )
337-
342+ if qa == "vqa" :
343+ if train_backend != "base" :
344+ _echo ("vqa only supports --train-backend base." , "red" )
345+ raise typer .Exit (code = 1 )
346+ else :
347+ from dataflow .cli_funcs .cli_pdf import DATAFLEX_BACKEND_SPECS # type: ignore
348+
349+ allowed_kbc = {"base" , * DATAFLEX_BACKEND_SPECS .keys ()}
350+ if train_backend not in allowed_kbc :
351+ supported = ", " .join (sorted (allowed_kbc ))
352+ _echo (
353+ f"Invalid --train-backend={ train_backend !r} for --qa kbc. Supported: { supported } ." ,
354+ "red" ,
355+ )
356+ raise typer .Exit (code = 1 )
357+
338358 try :
339359 from dataflow .cli_funcs .cli_pdf import cli_pdf2model_init # type: ignore
340- cli_pdf2model_init (cache_path = str (cache ), qa_type = qa , model_name = model )
360+ cli_pdf2model_init (
361+ cache_path = str (cache ),
362+ qa_type = qa ,
363+ model_name = model ,
364+ pdf2model_train_backend = train_backend ,
365+ )
341366 except Exception as e :
342367 _echo (f"pdf2model init error: { e } " , "red" )
343368 raise typer .Exit (code = 1 )
344369
345370
346371@pdf_app .command ("train" )
347372def pdf2model_train (cache : Path = typer .Option (Path ("." ), help = "Cache dir" ),
348- lf_yaml : Optional [Path ] = typer .Option (None , help = "LlamaFactory yaml" )):
373+ lf_yaml : Optional [Path ] = typer .Option (None , help = "LlamaFactory yaml (base backend only) " )):
349374
350375 try :
351376 from dataflow .cli_funcs .cli_pdf import cli_pdf2model_train # type: ignore
0 commit comments