Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions openml/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def _encode_if_category(column):
)
elif array_format == "dataframe":
if scipy.sparse.issparse(data):
return pd.DataFrame.sparse.from_spmatrix(data, columns=attribute_names)
data = pd.DataFrame.sparse.from_spmatrix(data, columns=attribute_names)
else:
data_type = "sparse-data" if scipy.sparse.issparse(data) else "non-sparse data"
logger.warning(
Expand Down Expand Up @@ -732,6 +732,7 @@ def get_data(
else:
target = [target]
targets = np.array([True if column in target else False for column in attribute_names])
target_names = np.array([column for column in attribute_names if column in target])
if np.sum(targets) > 1:
raise NotImplementedError(
"Number of requested targets %d is not implemented." % np.sum(targets)
Expand All @@ -752,11 +753,17 @@ def get_data(
attribute_names = [att for att, k in zip(attribute_names, targets) if not k]

x = self._convert_array_format(x, dataset_format, attribute_names)
if scipy.sparse.issparse(y):
y = np.asarray(y.todense()).astype(target_dtype).flatten()
y = y.squeeze()
y = self._convert_array_format(y, dataset_format, attribute_names)
if dataset_format == "array" and scipy.sparse.issparse(y):
# scikit-learn requires dense representation of targets
y = np.asarray(y.todense()).astype(target_dtype)
# dense representation of single column sparse arrays become a 2-d array
# need to flatten it to a 1-d array for _convert_array_format()
y = y.squeeze()
y = self._convert_array_format(y, dataset_format, target_names)
y = y.astype(target_dtype) if dataset_format == "array" else y
if len(y.shape) > 1 and y.shape[1] == 1:
# single column targets should be 1-d for both `array` and `dataframe` formats
y = y.squeeze()
data, targets = x, y

return data, targets, categorical, attribute_names
Expand Down
21 changes: 18 additions & 3 deletions tests/test_datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def setUp(self):

self.sparse_dataset = openml.datasets.get_dataset(4136, download_data=False)

def test_get_sparse_dataset_with_target(self):
def test_get_sparse_dataset_array_with_target(self):
X, y, _, attribute_names = self.sparse_dataset.get_data(
dataset_format="array", target="class"
)
Expand All @@ -303,7 +303,22 @@ def test_get_sparse_dataset_with_target(self):
self.assertEqual(len(attribute_names), 20000)
self.assertNotIn("class", attribute_names)

def test_get_sparse_dataset(self):
def test_get_sparse_dataset_dataframe_with_target(self):
X, y, _, attribute_names = self.sparse_dataset.get_data(
dataset_format="dataframe", target="class"
)
self.assertIsInstance(X, pd.DataFrame)
self.assertIsInstance(X.dtypes[0], pd.SparseDtype)
self.assertEqual(X.shape, (600, 20000))

self.assertIsInstance(y, pd.Series)
self.assertIsInstance(y.dtypes, pd.SparseDtype)
self.assertEqual(y.shape, (600,))

self.assertEqual(len(attribute_names), 20000)
self.assertNotIn("class", attribute_names)

def test_get_sparse_dataset_array(self):
rval, _, categorical, attribute_names = self.sparse_dataset.get_data(dataset_format="array")
self.assertTrue(sparse.issparse(rval))
self.assertEqual(rval.dtype, np.float32)
Expand All @@ -315,7 +330,7 @@ def test_get_sparse_dataset(self):
self.assertEqual(len(attribute_names), 20001)
self.assertTrue(all([isinstance(att, str) for att in attribute_names]))

def test_get_sparse_dataframe(self):
def test_get_sparse_dataset_dataframe(self):
rval, *_ = self.sparse_dataset.get_data()
self.assertIsInstance(rval, pd.DataFrame)
np.testing.assert_array_equal(
Expand Down