diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fc1319d79..9612e4f80 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,6 +19,12 @@ repos: additional_dependencies: - types-requests - types-python-dateutil + - id: mypy + name: mypy type-def-check + files: openml/datasets/.* + additional_dependencies: + - types-requests + - types-python-dateutil - id: mypy name: mypy top-level-functions files: openml/_api_calls.py @@ -26,7 +32,8 @@ repos: - types-requests - types-python-dateutil args: [ --disallow-untyped-defs, --disallow-any-generics, - --disallow-any-explicit, --implicit-optional ] + --disallow-any-explicit, --implicit-optional, --allow-redefinition] + - repo: https://github.com/pycqa/flake8 rev: 6.0.0 hooks: diff --git a/openml/base.py b/openml/base.py index 35a9ce58f..a53ba5e3a 100644 --- a/openml/base.py +++ b/openml/base.py @@ -46,7 +46,7 @@ def _entity_letter(cls) -> str: return cls.__name__.lower()[len("OpenML") :][0] @abstractmethod - def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str]]]]: + def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str], None]]]: """Collect all information to display in the __repr__ body. Returns diff --git a/openml/datasets/data_feature.py b/openml/datasets/data_feature.py index b4550b5d7..a4d2e5e19 100644 --- a/openml/datasets/data_feature.py +++ b/openml/datasets/data_feature.py @@ -1,6 +1,6 @@ # License: BSD 3-Clause - from typing import List +from prettyprinter import PrettyPrinter class OpenMLDataFeature(object): @@ -59,11 +59,11 @@ def __init__( self.nominal_values = nominal_values self.number_missing_values = number_missing_values - def __repr__(self): + def __repr__(self) -> str: return "[%d - %s (%s)]" % (self.index, self.name, self.data_type) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, OpenMLDataFeature) and self.__dict__ == other.__dict__ - def _repr_pretty_(self, pp, cycle): + def _repr_pretty_(self, pp: PrettyPrinter, cycle: bool) -> None: pp.text(str(self)) diff --git a/openml/datasets/dataset.py b/openml/datasets/dataset.py index dcdef162d..a50e38809 100644 --- a/openml/datasets/dataset.py +++ b/openml/datasets/dataset.py @@ -1,4 +1,5 @@ # License: BSD 3-Clause +from __future__ import annotations from collections import OrderedDict import re @@ -6,14 +7,24 @@ import logging import os import pickle -from typing import List, Optional, Union, Tuple, Iterable, Dict -import warnings +from io import TextIOWrapper +from typing import ( + List, + Optional, + Union, + Tuple, + Iterable, + Dict, + cast, +) import arff import numpy as np import pandas as pd import scipy.sparse +import typing import xmltodict +import warnings from openml.base import OpenMLBase from .data_feature import OpenMLDataFeature @@ -104,47 +115,47 @@ class OpenMLDataset(OpenMLBase): def __init__( self, - name, - description, - data_format="arff", - cache_format="pickle", - dataset_id=None, - version=None, - creator=None, - contributor=None, - collection_date=None, - upload_date=None, - language=None, - licence=None, - url=None, - default_target_attribute=None, - row_id_attribute=None, - ignore_attribute=None, - version_label=None, - citation=None, - tag=None, - visibility=None, - original_data_url=None, - paper_url=None, - update_comment=None, - md5_checksum=None, - data_file=None, + name: str, + description: str, + data_format: str = "arff", + cache_format: str = "pickle", + dataset_id: Optional[str] = None, + version: Optional[str] = None, + creator: Optional[str] = None, + contributor: Optional[str] = None, + collection_date: Optional[str] = None, + upload_date: Optional[str] = None, + language: Optional[str] = None, + licence: Optional[str] = None, + url: Optional[str] = None, + default_target_attribute: Optional[str] = None, + row_id_attribute: Optional[str] = None, + ignore_attribute: Optional[Union[List[str], str]] = None, + version_label: Optional[str] = None, + citation: Optional[str] = None, + tag: Optional[str] = None, + visibility: Optional[str] = None, + original_data_url: Optional[str] = None, + paper_url: Optional[str] = None, + update_comment: Optional[str] = None, + md5_checksum: Optional[str] = None, + data_file: Optional[str] = None, features_file: Optional[str] = None, qualities_file: Optional[str] = None, - dataset=None, + dataset: Optional[str] = None, minio_url: Optional[str] = None, parquet_file: Optional[str] = None, ): - def find_invalid_characters(string, pattern): - invalid_chars = set() + def find_invalid_characters(string: str, pattern: str) -> str: + invalid_chars_set = set() regex = re.compile(pattern) for char in string: if not regex.match(char): - invalid_chars.add(char) + invalid_chars_set.add(char) invalid_chars = ",".join( [ "'{}'".format(char) if char != "'" else '"{}"'.format(char) - for char in invalid_chars + for char in invalid_chars_set ] ) return invalid_chars @@ -207,7 +218,6 @@ def find_invalid_characters(string, pattern): self.paper_url = paper_url self.update_comment = update_comment self.md5_checksum = md5_checksum - self.data_file = data_file self.parquet_file = parquet_file self._dataset = dataset self._minio_url = minio_url @@ -233,6 +243,7 @@ def find_invalid_characters(string, pattern): self._qualities = _read_qualities(qualities_file) if data_file is not None: + self.data_file = data_file rval = self._compressed_cache_file_paths(data_file) self.data_pickle_file = rval[0] if os.path.exists(rval[0]) else None self.data_feather_file = rval[1] if os.path.exists(rval[1]) else None @@ -243,14 +254,14 @@ def find_invalid_characters(string, pattern): self.feather_attribute_file = None @property - def features(self): + def features(self) -> Optional[Dict[int, OpenMLDataFeature]]: if self._features is None: self._load_features() return self._features @property - def qualities(self): + def qualities(self) -> Optional[Dict[str, float]]: # We have to check `_no_qualities_found` as there might not be qualities for a dataset if self._qualities is None and (not self._no_qualities_found): self._load_qualities() @@ -261,7 +272,7 @@ def qualities(self): def id(self) -> Optional[int]: return self.dataset_id - def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str]]]]: + def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str], None]]]: """Collect all information to display in the __repr__ body.""" # Obtain number of features in accordance with lazy loading. @@ -303,7 +314,7 @@ def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str]]]]: ] return [(key, fields[key]) for key in order if key in fields] - def __eq__(self, other): + def __eq__(self, other) -> bool: # type: ignore if not isinstance(other, OpenMLDataset): return False @@ -332,61 +343,7 @@ def _download_data(self) -> None: if self._minio_url is not None: self.parquet_file = _get_dataset_parquet(self) - def _get_arff(self, format: str) -> Dict: - """Read ARFF file and return decoded arff. - - Reads the file referenced in self.data_file. - - Parameters - ---------- - format : str - Format of the ARFF file. - Must be one of 'arff' or 'sparse_arff' or a string that will be either of those - when converted to lower case. - - - - Returns - ------- - dict - Decoded arff. - - """ - - # TODO: add a partial read method which only returns the attribute - # headers of the corresponding .arff file! - import struct - - filename = self.data_file - bits = 8 * struct.calcsize("P") - # Files can be considered too large on a 32-bit system, - # if it exceeds 120mb (slightly more than covtype dataset size) - # This number is somewhat arbitrary. - if bits != 64 and os.path.getsize(filename) > 120000000: - raise NotImplementedError( - "File {} too big for {}-bit system ({} bytes).".format( - filename, os.path.getsize(filename), bits - ) - ) - - if format.lower() == "arff": - return_type = arff.DENSE - elif format.lower() == "sparse_arff": - return_type = arff.COO - else: - raise ValueError("Unknown data format {}".format(format)) - - def decode_arff(fh): - decoder = arff.ArffDecoder() - return decoder.decode(fh, encode_nominal=True, return_type=return_type) - - if filename[-3:] == ".gz": - with gzip.open(filename) as zipfile: - return decode_arff(zipfile) - else: - with open(filename, encoding="utf8") as fh: - return decode_arff(fh) - + @typing.no_type_check def _parse_data_from_arff( self, arff_file_path: str ) -> Tuple[Union[pd.DataFrame, scipy.sparse.csr_matrix], List[bool], List[str]]: @@ -545,7 +502,9 @@ def _cache_compressed_file_from_file( return data, categorical, attribute_names - def _load_data(self): + def _load_data( + self, + ) -> Tuple[Union[np.ndarray, pd.DataFrame, scipy.sparse.csr_matrix], List[bool], List[str]]: """Load data from compressed format or arff. Download data if not present on disk.""" need_to_create_pickle = self.cache_format == "pickle" and self.data_pickle_file is None need_to_create_feather = self.cache_format == "feather" and self.data_feather_file is None @@ -561,12 +520,12 @@ def _load_data(self): fpath = self.data_feather_file if self.cache_format == "feather" else self.data_pickle_file logger.info(f"{self.cache_format} load data {self.name}") try: - if self.cache_format == "feather": + if self.cache_format == "feather" and self.feather_attribute_file: data = pd.read_feather(self.data_feather_file) fpath = self.feather_attribute_file with open(self.feather_attribute_file, "rb") as fh: categorical, attribute_names = pickle.load(fh) - else: + elif self.data_pickle_file: with open(self.data_pickle_file, "rb") as fh: data, categorical, attribute_names = pickle.load(fh) except FileNotFoundError: @@ -607,12 +566,20 @@ def _load_data(self): data_up_to_date = isinstance(data, pd.DataFrame) or scipy.sparse.issparse(data) if self.cache_format == "pickle" and not data_up_to_date: logger.info("Updating outdated pickle file.") - file_to_load = self.data_file if self.parquet_file is None else self.parquet_file + file_to_load = ( + cast(str, self.data_file) + if self.parquet_file is None + else cast(str, self.parquet_file) + ) return self._cache_compressed_file_from_file(file_to_load) return data, categorical, attribute_names @staticmethod - def _convert_array_format(data, array_format, attribute_names): + def _convert_array_format( + data: Union[np.ndarray, pd.DataFrame, scipy.sparse.csr_matrix], + array_format: str, + attribute_names: List[str], + ) -> Union[np.ndarray, pd.SparseDtype]: """Convert a dataset to a given array format. Converts to numpy array if data is non-sparse. @@ -639,7 +606,7 @@ def _convert_array_format(data, array_format, attribute_names): if array_format == "array" and not scipy.sparse.issparse(data): # We encode the categories such that they are integer to be able # to make a conversion to numeric for backward compatibility - def _encode_if_category(column): + def _encode_if_category(column: pd.Series) -> pd.Series: if column.dtype.name == "category": column = column.cat.codes.astype(np.float32) mask_nan = column == -1 @@ -673,9 +640,9 @@ def _encode_if_category(column): return data @staticmethod - def _unpack_categories(series, categories): + def _unpack_categories(series: pd.Series, categories: List[str]) -> pd.Series: # nan-likes can not be explicitly specified as a category - def valid_category(cat): + def valid_category(cat: Union[str, None]) -> bool: return isinstance(cat, str) or (cat is not None and not np.isnan(cat)) filtered_categories = [c for c in categories if valid_category(c)] @@ -697,7 +664,7 @@ def get_data( include_ignore_attribute: bool = False, dataset_format: str = "dataframe", ) -> Tuple[ - Union[np.ndarray, pd.DataFrame, scipy.sparse.csr_matrix], + Union[np.ndarray, pd.DataFrame, pd.SparseDtype], Optional[Union[np.ndarray, pd.DataFrame]], List[bool], List[str], @@ -815,7 +782,7 @@ def get_data( return data, targets, categorical, attribute_names - def _load_features(self): + def _load_features(self) -> None: """Load the features metadata from the server and store it in the dataset object.""" # Delayed Import to avoid circular imports or having to import all of dataset.functions to # import OpenMLDataset. @@ -830,7 +797,7 @@ def _load_features(self): features_file = _get_dataset_features_file(None, self.dataset_id) self._features = _read_features(features_file) - def _load_qualities(self): + def _load_qualities(self) -> None: """Load qualities information from the server and store it in the dataset object.""" # same reason as above for _load_features from openml.datasets.functions import _get_dataset_qualities_file @@ -865,14 +832,19 @@ def retrieve_class_labels(self, target_name: str = "class") -> Union[None, List[ ------- list """ - for feature in self.features.values(): - if (feature.name == target_name) and (feature.data_type == "nominal"): - return feature.nominal_values + if self.features: + for feature in self.features.values(): + if (feature.name == target_name) and (feature.data_type == "nominal"): + return feature.nominal_values return None def get_features_by_type( - self, data_type, exclude=None, exclude_ignore_attribute=True, exclude_row_id_attribute=True - ): + self, + data_type: str, + exclude: Optional[List[str]] = None, + exclude_ignore_attribute: bool = True, + exclude_row_id_attribute: bool = True, + ) -> List[int]: """ Return indices of features of a given type, e.g. all nominal features. Optional parameters to exclude various features by index or ontology. @@ -921,16 +893,21 @@ def get_features_by_type( offset = 0 # this function assumes that everything in to_exclude will # be 'excluded' from the dataset (hence the offset) - for idx in self.features: - name = self.features[idx].name - if name in to_exclude: - offset += 1 - else: - if self.features[idx].data_type == data_type: - result.append(idx - offset) + if self.features: + for idx in self.features: + name = self.features[idx].name + if name in to_exclude: + offset += 1 + else: + if self.features[idx].data_type == data_type: + result.append(idx - offset) + else: + raise ValueError( + "get_features_by_type can only be called if feature information is available." + ) return result - def _get_file_elements(self) -> Dict: + def _get_file_elements(self) -> Dict[str, str]: """Adds the 'dataset' to file elements.""" file_elements = {} path = None if self.data_file is None else os.path.abspath(self.data_file) @@ -939,9 +916,9 @@ def _get_file_elements(self) -> Dict: file_elements["dataset"] = self._dataset elif path is not None and os.path.exists(path): with open(path, "rb") as fp: - file_elements["dataset"] = fp.read() + file_elements["dataset"] = cast(str, fp.read()) try: - dataset_utf8 = str(file_elements["dataset"], "utf8") + dataset_utf8 = str(cast(bytes, file_elements["dataset"]), "utf8") arff.ArffDecoder().decode(dataset_utf8, encode_nominal=True) except arff.ArffException: raise ValueError("The file you have provided is not a valid arff file.") @@ -949,11 +926,11 @@ def _get_file_elements(self) -> Dict: raise ValueError("No valid url/path to the data file was given.") return file_elements - def _parse_publish_response(self, xml_response: Dict): + def _parse_publish_response(self, xml_response: Dict[str, Dict[str, str]]) -> None: """Parse the id from the xml_response and assign it to self.""" self.dataset_id = int(xml_response["oml:upload_data_set"]["oml:id"]) - def _to_dict(self) -> "OrderedDict[str, OrderedDict]": + def _to_dict(self) -> "OrderedDict[str, OrderedDict[str, str]]": """Creates a dictionary representation of self.""" props = [ "id", @@ -981,7 +958,7 @@ def _to_dict(self) -> "OrderedDict[str, OrderedDict]": "md5_checksum", ] - data_container = OrderedDict() # type: 'OrderedDict[str, OrderedDict]' + data_container = OrderedDict() data_dict = OrderedDict([("@xmlns:oml", "http://openml.org/openml")]) data_container["oml:data_set_description"] = data_dict @@ -992,6 +969,80 @@ def _to_dict(self) -> "OrderedDict[str, OrderedDict]": return data_container + def _get_arff( + self, format: str + ) -> Dict[ + str, + Union[ + str, + List[Union[List[str], Tuple[str, List[str]]]], + List[List[Union[int, float, str]]], + Tuple[List[float], List[int], List[int]], + ], + ]: + """Read ARFF file and return decoded arff. + + Reads the file referenced in self.data_file. + + Parameters + ---------- + format : str + Format of the ARFF file. + Must be one of 'arff' or 'sparse_arff' or a string that will be either of those + when converted to lower case. + + + + Returns + ------- + dict + Decoded arff. + + """ + + # TODO: add a partial read method which only returns the attribute + # headers of the corresponding .arff file! + import struct + + filename = cast(str, self.data_file) + bits = 8 * struct.calcsize("P") + # Files can be considered too large on a 32-bit system, + # if it exceeds 120mb (slightly more than covtype dataset size) + # This number is somewhat arbitrary. + if bits != 64 and os.path.getsize(filename) > 120000000: + raise NotImplementedError( + "File {} too big for {}-bit system ({} bytes).".format( + filename, os.path.getsize(filename), bits + ) + ) + + if format.lower() == "arff": + return_type = arff.DENSE + elif format.lower() == "sparse_arff": + return_type = arff.COO + else: + raise ValueError("Unknown data format {}".format(format)) + + def decode_arff( + fh: Union[gzip.GzipFile, TextIOWrapper] + ) -> Dict[ + str, + Union[ + str, + List[Union[List[str], Tuple[str, List[str]]]], + List[List[Union[int, float, str]]], + Tuple[List[float], List[int], List[int]], + ], + ]: + decoder = arff.ArffDecoder() + return decoder.decode(fh, encode_nominal=True, return_type=return_type) + + if filename[-3:] == ".gz": + with gzip.open(filename) as zipfile: + return decode_arff(zipfile) + with open(filename, encoding="utf8") as fh: + return decode_arff(fh) + def _read_features(features_file: str) -> Dict[int, OpenMLDataFeature]: features_pickle_file = _get_features_pickle_file(features_file) @@ -1001,35 +1052,30 @@ def _read_features(features_file: str) -> Dict[int, OpenMLDataFeature]: except: # noqa E722 with open(features_file, encoding="utf8") as fh: features_xml_string = fh.read() - - features = _parse_features_xml(features_xml_string) + xml_dict = xmltodict.parse( + features_xml_string, force_list=("oml:feature", "oml:nominal_value") + ) + features_xml = xml_dict["oml:data_features"] + + features = {} + for idx, xmlfeature in enumerate(features_xml["oml:feature"]): + nr_missing = xmlfeature.get("oml:number_of_missing_values", 0) + feature = OpenMLDataFeature( + int(xmlfeature["oml:index"]), + xmlfeature["oml:name"], + xmlfeature["oml:data_type"], + xmlfeature.get("oml:nominal_value"), + int(nr_missing), + ) + if idx != feature.index: + raise ValueError("Data features not provided in right order") + features[feature.index] = feature with open(features_pickle_file, "wb") as fh_binary: pickle.dump(features, fh_binary) return features -def _parse_features_xml(features_xml_string): - xml_dict = xmltodict.parse(features_xml_string, force_list=("oml:feature", "oml:nominal_value")) - features_xml = xml_dict["oml:data_features"] - - features = {} - for idx, xmlfeature in enumerate(features_xml["oml:feature"]): - nr_missing = xmlfeature.get("oml:number_of_missing_values", 0) - feature = OpenMLDataFeature( - int(xmlfeature["oml:index"]), - xmlfeature["oml:name"], - xmlfeature["oml:data_type"], - xmlfeature.get("oml:nominal_value"), - int(nr_missing), - ) - if idx != feature.index: - raise ValueError("Data features not provided in right order") - features[feature.index] = feature - - return features - - def _get_features_pickle_file(features_file: str) -> str: """This function only exists so it can be mocked during unit testing""" return features_file + ".pkl" @@ -1043,12 +1089,19 @@ def _read_qualities(qualities_file: str) -> Dict[str, float]: except: # noqa E722 with open(qualities_file, encoding="utf8") as fh: qualities_xml = fh.read() - qualities = _parse_qualities_xml(qualities_xml) + xml_as_dict = xmltodict.parse(qualities_xml, force_list=("oml:quality",)) + qualities = xml_as_dict["oml:data_qualities"]["oml:quality"] + qualities = _check_qualities(qualities) with open(qualities_pickle_file, "wb") as fh_binary: pickle.dump(qualities, fh_binary) return qualities +def _get_qualities_pickle_file(qualities_file: str) -> str: + """This function only exists so it can be mocked during unit testing""" + return qualities_file + ".pkl" + + def _check_qualities(qualities: List[Dict[str, str]]) -> Dict[str, float]: qualities_ = {} for xmlquality in qualities: @@ -1061,14 +1114,3 @@ def _check_qualities(qualities: List[Dict[str, str]]) -> Dict[str, float]: value = float(xmlquality["oml:value"]) qualities_[name] = value return qualities_ - - -def _parse_qualities_xml(qualities_xml): - xml_as_dict = xmltodict.parse(qualities_xml, force_list=("oml:quality",)) - qualities = xml_as_dict["oml:data_qualities"]["oml:quality"] - return _check_qualities(qualities) - - -def _get_qualities_pickle_file(qualities_file: str) -> str: - """This function only exists so it can be mocked during unit testing""" - return qualities_file + ".pkl" diff --git a/openml/datasets/functions.py b/openml/datasets/functions.py index d04ad8812..57fde91e6 100644 --- a/openml/datasets/functions.py +++ b/openml/datasets/functions.py @@ -4,12 +4,13 @@ import logging import os from pyexpat import ExpatError -from typing import List, Dict, Optional, Union, cast +from typing import List, Dict, Union, Optional, cast, Tuple import warnings import numpy as np import arff import pandas as pd +import scipy import urllib3 import xmltodict @@ -28,6 +29,7 @@ from ..utils import _remove_cache_dir_for_id, _create_cache_directory_for_id, _get_cache_dir_for_id DATASETS_CACHE_DIR_NAME = "datasets" +ARFF_ATTRIBUTE_TYPE = List[Union[Tuple[str, str], Tuple[str, List[str]]]] logger = logging.getLogger(__name__) @@ -69,8 +71,8 @@ def list_datasets( status: Optional[str] = None, tag: Optional[str] = None, output_format: str = "dict", - **kwargs, -) -> Union[Dict, pd.DataFrame]: + **kwargs: Optional[Union[Dict[str, Union[str, int]], Union[str, int]]], +) -> Union[Dict[int, Dict[str, Union[str, int]]], pd.DataFrame]: """ Return a list of all dataset which are on OpenML. Supports large amount of results. @@ -149,7 +151,11 @@ def list_datasets( ) -def _list_datasets(data_id: Optional[List] = None, output_format="dict", **kwargs): +def _list_datasets( + data_id: Optional[List[str]] = None, + output_format: str = "dict", + **kwargs: Optional[Dict[str, Union[str, int]]], +) -> Union[Dict[str, Dict[str, Union[str, int]]], pd.DataFrame]: """ Perform api call to return a list of all datasets. @@ -186,7 +192,7 @@ def _list_datasets(data_id: Optional[List] = None, output_format="dict", **kwarg return __list_datasets(api_call=api_call, output_format=output_format) -def __list_datasets(api_call, output_format="dict"): +def __list_datasets(api_call: str, output_format: str = "dict") -> pd.DataFrame: xml_string = openml._api_calls._perform_api_call(api_call, "get") datasets_dict = xmltodict.parse(xml_string, force_list=("oml:dataset",)) @@ -219,7 +225,7 @@ def __list_datasets(api_call, output_format="dict"): return datasets -def _expand_parameter(parameter: Union[str, List[str]]) -> List[str]: +def _expand_parameter(parameter: Union[str, List[str], None]) -> List[str]: expanded_parameter = [] if isinstance(parameter, str): expanded_parameter = [x.strip() for x in parameter.split(",")] @@ -229,7 +235,9 @@ def _expand_parameter(parameter: Union[str, List[str]]) -> List[str]: def _validated_data_attributes( - attributes: List[str], data_attributes: List[str], parameter_name: str + attributes: List[str], + data_attributes: ARFF_ATTRIBUTE_TYPE, + parameter_name: str, ) -> None: for attribute_ in attributes: is_attribute_a_data_attribute = any([attr[0] == attribute_ for attr in data_attributes]) @@ -267,11 +275,11 @@ def check_datasets_active( A dictionary with items {did: bool} """ datasets = list_datasets(status="all", data_id=dataset_ids, output_format="dataframe") - missing = set(dataset_ids) - set(datasets.get("did", [])) + missing = set(dataset_ids) - set(datasets.get("did", [])) # type: ignore if raise_error_if_not_exist and missing: missing_str = ", ".join(str(did) for did in missing) raise ValueError(f"Could not find dataset(s) {missing_str} in OpenML dataset list.") - return dict(datasets["status"] == "active") + return dict(datasets["status"] == "active") # type: ignore def _name_to_id( @@ -303,6 +311,7 @@ def _name_to_id( The id of the dataset. """ status = None if version is not None else "active" + candidates = list_datasets(status=status, data_name=dataset_name, data_version=version) candidates = cast( pd.DataFrame, list_datasets( @@ -512,14 +521,16 @@ def get_dataset( finally: if remove_dataset_cache: _remove_cache_dir_for_id(DATASETS_CACHE_DIR_NAME, did_cache_dir) - - dataset = _create_dataset_from_description( - description, features_file, qualities_file, arff_file, parquet_file, cache_format - ) + if qualities_file: + dataset = _create_dataset_from_description( + description, features_file, qualities_file, arff_file, parquet_file, cache_format + ) return dataset -def attributes_arff_from_df(df): +def attributes_arff_from_df( + df: pd.DataFrame, +) -> List[Union[Tuple[str, str], Tuple[str, List[str]]]]: """Describe attributes of the dataframe according to ARFF specification. Parameters @@ -575,24 +586,24 @@ def attributes_arff_from_df(df): def create_dataset( - name, - description, - creator, - contributor, - collection_date, - language, - licence, - attributes, - data, - default_target_attribute, - ignore_attribute, - citation, - row_id_attribute=None, - original_data_url=None, - paper_url=None, - update_comment=None, - version_label=None, -): + name: str, + description: str, + creator: str, + contributor: str, + collection_date: str, + language: str, + licence: str, + attributes: Union[List[Union[Tuple[str, str], Tuple[str, List[str]]]], Dict[str, str]], + data: Union[np.ndarray, pd.DataFrame, scipy.sparse.csr_matrix], + default_target_attribute: str, + ignore_attribute: Union[str, List[str]], + citation: str, + row_id_attribute: Optional[str] = None, + original_data_url: Optional[str] = None, + paper_url: Optional[str] = None, + update_comment: Optional[str] = None, + version_label: Optional[str] = None, +) -> OpenMLDataset: """Create a dataset. This function creates an OpenMLDataset object. @@ -668,7 +679,6 @@ def create_dataset( # We need to reset the index such that it is part of the data. if data.index.name is not None: data = data.reset_index() - if attributes == "auto" or isinstance(attributes, dict): if not hasattr(data, "columns"): raise ValueError( @@ -685,9 +695,11 @@ def create_dataset( attributes_[attr_idx] = (attr_name, attributes[attr_name]) else: attributes_ = attributes + + # attributes: Union[List[Tuple[str, str]], List[Tuple[str, List[str]]], str] + # attributes_: List[Union[Tuple[str, str], Tuple[str, List[str]]]] ignore_attributes = _expand_parameter(ignore_attribute) _validated_data_attributes(ignore_attributes, attributes_, "ignore_attribute") - default_target_attributes = _expand_parameter(default_target_attribute) _validated_data_attributes(default_target_attributes, attributes_, "default_target_attribute") @@ -775,7 +787,7 @@ def create_dataset( ) -def status_update(data_id, status): +def status_update(data_id: int, status: str) -> None: """ Updates the status of a dataset to either 'active' or 'deactivated'. Please see the OpenML API documentation for a description of the status @@ -792,7 +804,7 @@ def status_update(data_id, status): legal_status = {"active", "deactivated"} if status not in legal_status: raise ValueError("Illegal status value. " "Legal values: %s" % legal_status) - data = {"data_id": data_id, "status": status} + data: Dict[str, Union[str, int]] = {"data_id": data_id, "status": status} result_xml = openml._api_calls._perform_api_call("data/status/update", "post", data=data) result = xmltodict.parse(result_xml) server_data_id = result["oml:data_status_update"]["oml:id"] @@ -803,18 +815,18 @@ def status_update(data_id, status): def edit_dataset( - data_id, - description=None, - creator=None, - contributor=None, - collection_date=None, - language=None, - default_target_attribute=None, - ignore_attribute=None, - citation=None, - row_id_attribute=None, - original_data_url=None, - paper_url=None, + data_id: int, + description: Optional[str] = None, + creator: Optional[str] = None, + contributor: Optional[str] = None, + collection_date: Optional[str] = None, + language: Optional[str] = None, + default_target_attribute: Optional[str] = None, + ignore_attribute: Optional[Union[str, List[str]]] = None, + citation: Optional[str] = None, + row_id_attribute: Optional[str] = None, + original_data_url: Optional[str] = None, + paper_url: Optional[str] = None, ) -> int: """Edits an OpenMLDataset. @@ -877,8 +889,9 @@ def edit_dataset( raise TypeError("`data_id` must be of type `int`, not {}.".format(type(data_id))) # compose data edit parameters as xml - form_data = {"data_id": data_id} # type: openml._api_calls.DATA_TYPE - xml = OrderedDict() # type: 'OrderedDict[str, OrderedDict]' + form_data: Optional[Dict[str, Union[str, int]]] = {"data_id": data_id} + xml: OrderedDict[str, OrderedDict[str, Union[str, List[str], None]]] + xml = OrderedDict() xml["oml:data_edit_parameters"] = OrderedDict() xml["oml:data_edit_parameters"]["@xmlns:oml"] = "http://openml.org/openml" xml["oml:data_edit_parameters"]["oml:description"] = description @@ -948,7 +961,7 @@ def fork_dataset(data_id: int) -> int: return int(data_id) -def _topic_add_dataset(data_id: int, topic: str): +def _topic_add_dataset(data_id: int, topic: str) -> int: """ Adds a topic for a dataset. This API is not available for all OpenML users and is accessible only by admins. @@ -968,7 +981,7 @@ def _topic_add_dataset(data_id: int, topic: str): return int(data_id) -def _topic_delete_dataset(data_id: int, topic: str): +def _topic_delete_dataset(data_id: int, topic: str) -> int: """ Removes a topic from a dataset. This API is not available for all OpenML users and is accessible only by admins. @@ -989,7 +1002,7 @@ def _topic_delete_dataset(data_id: int, topic: str): return int(data_id) -def _get_dataset_description(did_cache_dir, dataset_id): +def _get_dataset_description(did_cache_dir: str, dataset_id: int) -> Dict[str, str]: """Get the dataset description as xml dictionary. This function is NOT thread/multiprocessing safe. @@ -1033,7 +1046,7 @@ def _get_dataset_description(did_cache_dir, dataset_id): def _get_dataset_parquet( - description: Union[Dict, OpenMLDataset], + description: Union[Dict[str, str], OpenMLDataset], cache_directory: Optional[str] = None, download_all_files: bool = False, ) -> Optional[str]: @@ -1064,12 +1077,12 @@ def _get_dataset_parquet( output_filename : string, optional Location of the Parquet file if successfully downloaded, None otherwise. """ - if isinstance(description, dict): - url = cast(str, description.get("oml:minio_url")) - did = description.get("oml:id") - elif isinstance(description, OpenMLDataset): + if isinstance(description, OpenMLDataset): url = cast(str, description._minio_url) did = description.dataset_id + elif isinstance(description, dict): + url = cast(str, description.get("oml:minio_url")) + did = int(description.get("oml:id", "")) else: raise TypeError("`description` should be either OpenMLDataset or Dict.") @@ -1102,7 +1115,8 @@ def _get_dataset_parquet( def _get_dataset_arff( - description: Union[Dict, OpenMLDataset], cache_directory: Optional[str] = None + description: Union[Dict[str, str], OpenMLDataset], + cache_directory: Optional[str] = None, ) -> str: """Return the path to the local arff file of the dataset. If is not cached, it is downloaded. @@ -1126,14 +1140,14 @@ def _get_dataset_arff( output_filename : string Location of ARFF file. """ - if isinstance(description, dict): - md5_checksum_fixture = description.get("oml:md5_checksum") - url = description["oml:url"] - did = description.get("oml:id") - elif isinstance(description, OpenMLDataset): + if isinstance(description, OpenMLDataset): md5_checksum_fixture = description.md5_checksum - url = description.url + url = cast(str, description.url) did = description.dataset_id + elif isinstance(description, dict): + md5_checksum_fixture = description.get("oml:md5_checksum") + url = cast(str, description["oml:url"]) + did = int(description.get("oml:id", "")) else: raise TypeError("`description` should be either OpenMLDataset or Dict.") @@ -1153,7 +1167,7 @@ def _get_dataset_arff( return output_file_path -def _get_features_xml(dataset_id): +def _get_features_xml(dataset_id: int) -> str: url_extension = f"data/features/{dataset_id}" return openml._api_calls._perform_api_call(url_extension, "get") @@ -1197,7 +1211,7 @@ def _get_dataset_features_file(did_cache_dir: Union[str, None], dataset_id: int) return features_file -def _get_qualities_xml(dataset_id): +def _get_qualities_xml(dataset_id: int) -> str: url_extension = f"data/qualities/{dataset_id}" return openml._api_calls._perform_api_call(url_extension, "get") @@ -1267,9 +1281,9 @@ def _create_dataset_from_description( ---------- description : dict Description of a dataset in xml dict. - featuresfile : str + features_file : str Path of the dataset features as xml file. - qualities : list + qualities_file : str Path of the dataset qualities as xml file. arff_file : string, optional Path of dataset ARFF file. @@ -1284,8 +1298,8 @@ def _create_dataset_from_description( Dataset object from dict and ARFF. """ return OpenMLDataset( - description["oml:name"], - description.get("oml:description"), + name=description["oml:name"], + description=description.get("oml:description", ""), data_format=description["oml:format"], dataset_id=description["oml:id"], version=description["oml:version"], @@ -1316,7 +1330,7 @@ def _create_dataset_from_description( ) -def _get_online_dataset_arff(dataset_id): +def _get_online_dataset_arff(dataset_id: int) -> Optional[str]: """Download the ARFF file for a given dataset id from the OpenML website. @@ -1338,7 +1352,7 @@ def _get_online_dataset_arff(dataset_id): ) -def _get_online_dataset_format(dataset_id): +def _get_online_dataset_format(dataset_id: int) -> str: """Get the dataset format for a given dataset id from the OpenML website. diff --git a/openml/flows/flow.py b/openml/flows/flow.py index b9752e77c..bd8d55509 100644 --- a/openml/flows/flow.py +++ b/openml/flows/flow.py @@ -173,7 +173,7 @@ def extension(self): "No extension could be found for flow {}: {}".format(self.flow_id, self.name) ) - def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str]]]]: + def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str], None]]]: """Collect all information to display in the __repr__ body.""" fields = { "Flow Name": self.name, diff --git a/openml/runs/run.py b/openml/runs/run.py index 5528c8a67..822ee03e5 100644 --- a/openml/runs/run.py +++ b/openml/runs/run.py @@ -189,7 +189,7 @@ def _evaluation_summary(self, metric: str) -> str: return "{:.4f} +- {:.4f}".format(np.mean(rep_means), np.mean(rep_stds)) - def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str]]]]: + def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str], None]]]: """Collect all information to display in the __repr__ body.""" # Set up fields fields = { diff --git a/openml/study/study.py b/openml/study/study.py index cfc4cab3b..31b1294da 100644 --- a/openml/study/study.py +++ b/openml/study/study.py @@ -97,7 +97,7 @@ def _entity_letter(cls) -> str: def id(self) -> Optional[int]: return self.study_id - def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str]]]]: + def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str], None]]]: """Collect all information to display in the __repr__ body.""" fields: Dict[str, Any] = { "Name": self.name, diff --git a/openml/tasks/task.py b/openml/tasks/task.py index 36e0ada1c..52554f0e1 100644 --- a/openml/tasks/task.py +++ b/openml/tasks/task.py @@ -80,7 +80,7 @@ def _entity_letter(cls) -> str: def id(self) -> Optional[int]: return self.task_id - def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str]]]]: + def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str], None]]]: """Collect all information to display in the __repr__ body.""" fields: Dict[str, Any] = { "Task Type Description": "{}/tt/{}".format(