Skip to content

Commit 962a2e7

Browse files
authored
[pipeline & req]: implement pdf2model with DataFlex training backend; add pdf2model-dataflex dependency group (#502)
* [pipeline & req]: implement pdf2model with DataFlex; add pdf2model-dataflex dependency group * delete shim
1 parent 6d29b5c commit 962a2e7

File tree

4 files changed

+477
-64
lines changed

4 files changed

+477
-64
lines changed

dataflow/cli.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,22 +330,47 @@ def eval_local_cmd():
330330
def pdf2model_init(cache: Path = typer.Option(Path("."),
331331
help = "Cache dir"),
332332
qa: str = typer.Option("kbc", help="Which pipeline to init (vqa or kbc)"),
333-
model: Optional[str] = typer.Option(None, help="Base model name or path")):
333+
model: Optional[str] = typer.Option(None, help="Base model name or path"),
334+
train_backend: str = typer.Option(
335+
"base",
336+
"--train-backend",
337+
help="With --qa kbc: 'base' (LlamaFactory) or a registered dataflex-* backend (see cli_pdf.DATAFLEX_BACKEND_SPECS). vqa only allows 'base'.",
338+
)):
334339
if qa not in ["vqa", "kbc"]:
335340
_echo(f"Invalid qa type: {qa}. Must be 'vqa' or 'kbc'.", "red")
336341
raise typer.Exit(code=1)
337-
342+
if qa == "vqa":
343+
if train_backend != "base":
344+
_echo("vqa only supports --train-backend base.", "red")
345+
raise typer.Exit(code=1)
346+
else:
347+
from dataflow.cli_funcs.cli_pdf import DATAFLEX_BACKEND_SPECS # type: ignore
348+
349+
allowed_kbc = {"base", *DATAFLEX_BACKEND_SPECS.keys()}
350+
if train_backend not in allowed_kbc:
351+
supported = ", ".join(sorted(allowed_kbc))
352+
_echo(
353+
f"Invalid --train-backend={train_backend!r} for --qa kbc. Supported: {supported}.",
354+
"red",
355+
)
356+
raise typer.Exit(code=1)
357+
338358
try:
339359
from dataflow.cli_funcs.cli_pdf import cli_pdf2model_init # type: ignore
340-
cli_pdf2model_init(cache_path=str(cache), qa_type=qa, model_name=model)
360+
cli_pdf2model_init(
361+
cache_path=str(cache),
362+
qa_type=qa,
363+
model_name=model,
364+
pdf2model_train_backend=train_backend,
365+
)
341366
except Exception as e:
342367
_echo(f"pdf2model init error: {e}", "red")
343368
raise typer.Exit(code=1)
344369

345370

346371
@pdf_app.command("train")
347372
def pdf2model_train(cache: Path = typer.Option(Path("."), help="Cache dir"),
348-
lf_yaml: Optional[Path] = typer.Option(None, help="LlamaFactory yaml")):
373+
lf_yaml: Optional[Path] = typer.Option(None, help="LlamaFactory yaml (base backend only)")):
349374

350375
try:
351376
from dataflow.cli_funcs.cli_pdf import cli_pdf2model_train # type: ignore

0 commit comments

Comments
 (0)