diff --git a/doc/progress.rst b/doc/progress.rst index 04a036f64..a000890a8 100644 --- a/doc/progress.rst +++ b/doc/progress.rst @@ -9,6 +9,9 @@ Changelog next ~~~~~~ + * ADD #1335: Improve MinIO support. + * Add progress bar for downloading MinIO files. Enable it with setting `show_progress` to true on either `openml.config` or the configuration file. + * When using `download_all_files`, files are only downloaded if they do not yet exist in the cache. * MAINT #1340: Add Numpy 2.0 support. Update tests to work with scikit-learn <= 1.5. * ADD #1342: Add HTTP header to requests to indicate they are from openml-python. diff --git a/examples/20_basic/simple_datasets_tutorial.py b/examples/20_basic/simple_datasets_tutorial.py index c525a3ef9..35b325fd9 100644 --- a/examples/20_basic/simple_datasets_tutorial.py +++ b/examples/20_basic/simple_datasets_tutorial.py @@ -50,6 +50,15 @@ X, y, categorical_indicator, attribute_names = dataset.get_data( dataset_format="dataframe", target=dataset.default_target_attribute ) + +############################################################################ +# Tip: you can get a progress bar for dataset downloads, simply set it in +# the configuration. Either in code or in the configuration file +# (see also the introduction tutorial) + +openml.config.show_progress = True + + ############################################################################ # Visualize the dataset # ===================== diff --git a/openml/_api_calls.py b/openml/_api_calls.py index 0aa5ba635..994f52b8b 100644 --- a/openml/_api_calls.py +++ b/openml/_api_calls.py @@ -1,6 +1,7 @@ # License: BSD 3-Clause from __future__ import annotations +import contextlib import hashlib import logging import math @@ -26,6 +27,7 @@ OpenMLServerException, OpenMLServerNoResult, ) +from .utils import ProgressBar _HEADERS = {"user-agent": f"openml-python/{__version__}"} @@ -161,12 +163,12 @@ def _download_minio_file( proxy_client = ProxyManager(proxy) if proxy else None client = minio.Minio(endpoint=parsed_url.netloc, secure=False, http_client=proxy_client) - try: client.fget_object( bucket_name=bucket, object_name=object_name, file_path=str(destination), + progress=ProgressBar() if config.show_progress else None, request_headers=_HEADERS, ) if destination.is_file() and destination.suffix == ".zip": @@ -206,11 +208,12 @@ def _download_minio_bucket(source: str, destination: str | Path) -> None: if file_object.object_name is None: raise ValueError("Object name is None.") - _download_minio_file( - source=source.rsplit("/", 1)[0] + "/" + file_object.object_name.rsplit("/", 1)[1], - destination=Path(destination, file_object.object_name.rsplit("/", 1)[1]), - exists_ok=True, - ) + with contextlib.suppress(FileExistsError): # Simply use cached version instead + _download_minio_file( + source=source.rsplit("/", 1)[0] + "/" + file_object.object_name.rsplit("/", 1)[1], + destination=Path(destination, file_object.object_name.rsplit("/", 1)[1]), + exists_ok=False, + ) def _download_text_file( diff --git a/openml/config.py b/openml/config.py index 1af8a7456..6a37537dc 100644 --- a/openml/config.py +++ b/openml/config.py @@ -28,6 +28,7 @@ class _Config(TypedDict): avoid_duplicate_runs: bool retry_policy: Literal["human", "robot"] connection_n_retries: int + show_progress: bool def _create_log_handlers(create_file_handler: bool = True) -> None: # noqa: FBT001, FBT002 @@ -111,6 +112,7 @@ def set_file_log_level(file_output_level: int) -> None: "avoid_duplicate_runs": True, "retry_policy": "human", "connection_n_retries": 5, + "show_progress": False, } # Default values are actually added here in the _setup() function which is @@ -131,6 +133,7 @@ def get_server_base_url() -> str: apikey: str = _defaults["apikey"] +show_progress: bool = _defaults["show_progress"] # The current cache directory (without the server name) _root_cache_directory = Path(_defaults["cachedir"]) avoid_duplicate_runs = _defaults["avoid_duplicate_runs"] @@ -238,6 +241,7 @@ def _setup(config: _Config | None = None) -> None: global server # noqa: PLW0603 global _root_cache_directory # noqa: PLW0603 global avoid_duplicate_runs # noqa: PLW0603 + global show_progress # noqa: PLW0603 config_file = determine_config_file_path() config_dir = config_file.parent @@ -255,6 +259,7 @@ def _setup(config: _Config | None = None) -> None: avoid_duplicate_runs = config["avoid_duplicate_runs"] apikey = config["apikey"] server = config["server"] + show_progress = config["show_progress"] short_cache_dir = Path(config["cachedir"]) n_retries = int(config["connection_n_retries"]) @@ -328,11 +333,11 @@ def _parse_config(config_file: str | Path) -> _Config: logger.info("Error opening file %s: %s", config_file, e.args[0]) config_file_.seek(0) config.read_file(config_file_) - if isinstance(config["FAKE_SECTION"]["avoid_duplicate_runs"], str): - config["FAKE_SECTION"]["avoid_duplicate_runs"] = config["FAKE_SECTION"].getboolean( - "avoid_duplicate_runs" - ) # type: ignore - return dict(config.items("FAKE_SECTION")) # type: ignore + configuration = dict(config.items("FAKE_SECTION")) + for boolean_field in ["avoid_duplicate_runs", "show_progress"]: + if isinstance(config["FAKE_SECTION"][boolean_field], str): + configuration[boolean_field] = config["FAKE_SECTION"].getboolean(boolean_field) # type: ignore + return configuration # type: ignore def get_config_as_dict() -> _Config: @@ -343,6 +348,7 @@ def get_config_as_dict() -> _Config: "avoid_duplicate_runs": avoid_duplicate_runs, "connection_n_retries": connection_n_retries, "retry_policy": retry_policy, + "show_progress": show_progress, } diff --git a/openml/datasets/functions.py b/openml/datasets/functions.py index 590955a5e..6a9f57abb 100644 --- a/openml/datasets/functions.py +++ b/openml/datasets/functions.py @@ -1262,10 +1262,9 @@ def _get_dataset_parquet( if old_file_path.is_file(): old_file_path.rename(output_file_path) - # For this release, we want to be able to force a new download even if the - # parquet file is already present when ``download_all_files`` is set. - # For now, it would be the only way for the user to fetch the additional - # files in the bucket (no function exists on an OpenMLDataset to do this). + # The call below skips files already on disk, so avoids downloading the parquet file twice. + # To force the old behavior of always downloading everything, use `force_refresh_cache` + # of `get_dataset` if download_all_files: openml._api_calls._download_minio_bucket(source=url, destination=cache_directory) diff --git a/openml/utils.py b/openml/utils.py index 80d7caaae..a03610512 100644 --- a/openml/utils.py +++ b/openml/utils.py @@ -12,6 +12,8 @@ import numpy as np import pandas as pd import xmltodict +from minio.helpers import ProgressType +from tqdm import tqdm import openml import openml._api_calls @@ -471,3 +473,39 @@ def _create_lockfiles_dir() -> Path: with contextlib.suppress(OSError): path.mkdir(exist_ok=True, parents=True) return path + + +class ProgressBar(ProgressType): + """Progressbar for MinIO function's `progress` parameter.""" + + def __init__(self) -> None: + self._object_name = "" + self._progress_bar: tqdm | None = None + + def set_meta(self, object_name: str, total_length: int) -> None: + """Initializes the progress bar. + + Parameters + ---------- + object_name: str + Not used. + + total_length: int + File size of the object in bytes. + """ + self._object_name = object_name + self._progress_bar = tqdm(total=total_length, unit_scale=True, unit="B") + + def update(self, length: int) -> None: + """Updates the progress bar. + + Parameters + ---------- + length: int + Number of bytes downloaded since last `update` call. + """ + if not self._progress_bar: + raise RuntimeError("Call `set_meta` before calling `update`.") + self._progress_bar.update(length) + if self._progress_bar.total <= self._progress_bar.n: + self._progress_bar.close() diff --git a/pyproject.toml b/pyproject.toml index b970a35b2..f401fa8a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "numpy>=1.6.2", "minio", "pyarrow", + "tqdm", # For MinIO download progress bars "packaging", ] requires-python = ">=3.8" diff --git a/tests/test_openml/test_api_calls.py b/tests/test_openml/test_api_calls.py index 8c4c03276..c6df73e0a 100644 --- a/tests/test_openml/test_api_calls.py +++ b/tests/test_openml/test_api_calls.py @@ -1,11 +1,16 @@ from __future__ import annotations import unittest.mock +from pathlib import Path +from typing import NamedTuple, Iterable, Iterator +from unittest import mock +import minio import pytest import openml import openml.testing +from openml._api_calls import _download_minio_bucket class TestConfig(openml.testing.TestBase): @@ -30,3 +35,39 @@ def test_retry_on_database_error(self, Session_class_mock, _): openml._api_calls._send_request("get", "/abc", {}) assert Session_class_mock.return_value.__enter__.return_value.get.call_count == 20 + +class FakeObject(NamedTuple): + object_name: str + +class FakeMinio: + def __init__(self, objects: Iterable[FakeObject] | None = None): + self._objects = objects or [] + + def list_objects(self, *args, **kwargs) -> Iterator[FakeObject]: + yield from self._objects + + def fget_object(self, object_name: str, file_path: str, *args, **kwargs) -> None: + if object_name in [obj.object_name for obj in self._objects]: + Path(file_path).write_text("foo") + return + raise FileNotFoundError + + +@mock.patch.object(minio, "Minio") +def test_download_all_files_observes_cache(mock_minio, tmp_path: Path) -> None: + some_prefix, some_filename = "some/prefix", "dataset.arff" + some_object_path = f"{some_prefix}/{some_filename}" + some_url = f"https://not.real.com/bucket/{some_object_path}" + mock_minio.return_value = FakeMinio( + objects=[ + FakeObject(some_object_path), + ], + ) + + _download_minio_bucket(source=some_url, destination=tmp_path) + time_created = (tmp_path / "dataset.arff").stat().st_ctime + + _download_minio_bucket(source=some_url, destination=tmp_path) + time_modified = (tmp_path / some_filename).stat().st_mtime + + assert time_created == time_modified diff --git a/tests/test_openml/test_config.py b/tests/test_openml/test_config.py index 67d2ce895..58528c5c9 100644 --- a/tests/test_openml/test_config.py +++ b/tests/test_openml/test_config.py @@ -133,3 +133,13 @@ def test_configuration_file_not_overwritten_on_load(): assert config_file_content == new_file_content assert "abcd" == read_config["apikey"] + +def test_configuration_loads_booleans(tmp_path): + config_file_content = "avoid_duplicate_runs=true\nshow_progress=false" + with (tmp_path/"config").open("w") as config_file: + config_file.write(config_file_content) + read_config = openml.config._parse_config(tmp_path) + + # Explicit test to avoid truthy/falsy modes of other types + assert True == read_config["avoid_duplicate_runs"] + assert False == read_config["show_progress"]