Skip to content

Commit 209db9e

Browse files
authored
Bug fix and improvement in WSI (#4216)
* Make all transforms optional Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update wsireader tests Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Remove optional from PersistentDataset and its derivatives Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add unittests for cache without transform Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add default replace_rate Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add default value Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Set default replace_rate to 0.1 Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update metadata to include path Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Adds SmartCachePatchWSIDataset Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add unittests for SmartCachePatchWSIDataset Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update references Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update docs Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Remove smart cache Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Remove unused imports Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add path metadata for OpenSlide Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update metadata to be unified across different backends Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update wsi metadata for multi wsi objects Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add unittests for wsi metadata Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>
1 parent 5f00bf1 commit 209db9e

File tree

3 files changed

+85
-78
lines changed

3 files changed

+85
-78
lines changed

monai/data/wsi_datasets.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# limitations under the License.
1111

1212
import inspect
13-
from typing import Callable, Dict, List, Optional, Tuple, Union
13+
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
1414

1515
import numpy as np
1616

@@ -32,10 +32,12 @@ class PatchWSIDataset(Dataset):
3232
size: the size of patch to be extracted from the whole slide image.
3333
level: the level at which the patches to be extracted (default to 0).
3434
transform: transforms to be executed on input data.
35-
reader: the module to be used for loading whole slide imaging,
36-
- if `reader` is a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM.
37-
- if `reader` is a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader.
38-
- if `reader` is an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader.
35+
reader: the module to be used for loading whole slide imaging. If `reader` is
36+
37+
- a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM.
38+
- a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader.
39+
- an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader.
40+
3941
kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class
4042
4143
Note:
@@ -45,14 +47,14 @@ class PatchWSIDataset(Dataset):
4547
4648
[
4749
{"image": "path/to/image1.tiff", "location": [200, 500], "label": 0},
48-
{"image": "path/to/image2.tiff", "location": [100, 700], "label": 1}
50+
{"image": "path/to/image2.tiff", "location": [100, 700], "size": [20, 20], "level": 2, "label": 1}
4951
]
5052
5153
"""
5254

5355
def __init__(
5456
self,
55-
data: List,
57+
data: Sequence,
5658
size: Optional[Union[int, Tuple[int, int]]] = None,
5759
level: Optional[int] = None,
5860
transform: Optional[Callable] = None,

monai/data/wsi_reader.py

Lines changed: 54 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# limitations under the License.
1111

1212
from abc import abstractmethod
13+
from os.path import abspath
1314
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
1415

1516
import numpy as np
@@ -53,6 +54,7 @@ class BaseWSIReader(ImageReader):
5354
"""
5455

5556
supported_suffixes: List[str] = []
57+
backend = ""
5658

5759
def __init__(self, level: int, **kwargs):
5860
super().__init__()
@@ -63,7 +65,7 @@ def __init__(self, level: int, **kwargs):
6365
@abstractmethod
6466
def get_size(self, wsi, level: int) -> Tuple[int, int]:
6567
"""
66-
Returns the size of the whole slide image at a given level.
68+
Returns the size (height, width) of the whole slide image at a given level.
6769
6870
Args:
6971
wsi: a whole slide image object loaded from a file
@@ -83,6 +85,11 @@ def get_level_count(self, wsi) -> int:
8385
"""
8486
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
8587

88+
@abstractmethod
89+
def get_file_path(self, wsi) -> str:
90+
"""Return the file path for the WSI object"""
91+
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
92+
8693
@abstractmethod
8794
def get_patch(
8895
self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str
@@ -102,20 +109,29 @@ def get_patch(
102109
"""
103110
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
104111

105-
@abstractmethod
106-
def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict:
112+
def get_metadata(
113+
self, wsi, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int
114+
) -> Dict:
107115
"""
108116
Returns metadata of the extracted patch from the whole slide image.
109117
110118
Args:
119+
wsi: the whole slide image object, from which the patch is loaded
111120
patch: extracted patch from whole slide image
112121
location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).
113122
size: (height, width) tuple giving the patch size at the given level (`level`).
114123
If None, it is set to the full image size at the given level.
115124
level: the level number. Defaults to 0
116125
117126
"""
118-
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
127+
metadata: Dict = {
128+
"backend": self.backend,
129+
"original_channel_dim": 0,
130+
"spatial_shape": np.asarray(patch.shape[1:]),
131+
"wsi": {"path": self.get_file_path(wsi)},
132+
"patch": {"location": location, "size": size, "level": level},
133+
}
134+
return metadata
119135

120136
def get_data(
121137
self,
@@ -194,8 +210,26 @@ def get_data(
194210
patch_list.append(patch)
195211

196212
# Set patch-related metadata
197-
each_meta = self.get_metadata(patch=patch, location=location, size=size, level=level)
198-
metadata.update(each_meta)
213+
each_meta = self.get_metadata(wsi=each_wsi, patch=patch, location=location, size=size, level=level)
214+
215+
if len(wsi) == 1:
216+
metadata = each_meta
217+
else:
218+
if not metadata:
219+
metadata = {
220+
"backend": each_meta["backend"],
221+
"original_channel_dim": each_meta["original_channel_dim"],
222+
"spatial_shape": each_meta["spatial_shape"],
223+
"wsi": [each_meta["wsi"]],
224+
"patch": [each_meta["patch"]],
225+
}
226+
else:
227+
if metadata["original_channel_dim"] != each_meta["original_channel_dim"]:
228+
raise ValueError("original_channel_dim is not consistent across wsi objects.")
229+
if any(metadata["spatial_shape"] != each_meta["spatial_shape"]):
230+
raise ValueError("spatial_shape is not consistent across wsi objects.")
231+
metadata["wsi"].append(each_meta["wsi"])
232+
metadata["patch"].append(each_meta["patch"])
199233

200234
return _stack_images(patch_list, metadata), metadata
201235

@@ -247,7 +281,7 @@ def get_level_count(self, wsi) -> int:
247281

248282
def get_size(self, wsi, level: int) -> Tuple[int, int]:
249283
"""
250-
Returns the size of the whole slide image at a given level.
284+
Returns the size (height, width) of the whole slide image at a given level.
251285
252286
Args:
253287
wsi: a whole slide image object loaded from a file
@@ -256,19 +290,9 @@ def get_size(self, wsi, level: int) -> Tuple[int, int]:
256290
"""
257291
return self.reader.get_size(wsi, level)
258292

259-
def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict:
260-
"""
261-
Returns metadata of the extracted patch from the whole slide image.
262-
263-
Args:
264-
patch: extracted patch from whole slide image
265-
location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).
266-
size: (height, width) tuple giving the patch size at the given level (`level`).
267-
If None, it is set to the full image size at the given level.
268-
level: the level number. Defaults to 0
269-
270-
"""
271-
return self.reader.get_metadata(patch=patch, size=size, location=location, level=level)
293+
def get_file_path(self, wsi) -> str:
294+
"""Return the file path for the WSI object"""
295+
return self.reader.get_file_path(wsi)
272296

273297
def get_patch(
274298
self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str
@@ -317,6 +341,7 @@ class CuCIMWSIReader(BaseWSIReader):
317341
"""
318342

319343
supported_suffixes = ["tif", "tiff", "svs"]
344+
backend = "cucim"
320345

321346
def __init__(self, level: int = 0, **kwargs):
322347
super().__init__(level, **kwargs)
@@ -335,7 +360,7 @@ def get_level_count(wsi) -> int:
335360
@staticmethod
336361
def get_size(wsi, level: int) -> Tuple[int, int]:
337362
"""
338-
Returns the size of the whole slide image at a given level.
363+
Returns the size (height, width) of the whole slide image at a given level.
339364
340365
Args:
341366
wsi: a whole slide image object loaded from a file
@@ -344,27 +369,9 @@ def get_size(wsi, level: int) -> Tuple[int, int]:
344369
"""
345370
return (wsi.resolutions["level_dimensions"][level][1], wsi.resolutions["level_dimensions"][level][0])
346371

347-
def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict:
348-
"""
349-
Returns metadata of the extracted patch from the whole slide image.
350-
351-
Args:
352-
patch: extracted patch from whole slide image
353-
location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).
354-
size: (height, width) tuple giving the patch size at the given level (`level`).
355-
If None, it is set to the full image size at the given level.
356-
level: the level number. Defaults to 0
357-
358-
"""
359-
metadata: Dict = {
360-
"backend": "cucim",
361-
"spatial_shape": np.asarray(patch.shape[1:]),
362-
"original_channel_dim": 0,
363-
"location": location,
364-
"size": size,
365-
"level": level,
366-
}
367-
return metadata
372+
def get_file_path(self, wsi) -> str:
373+
"""Return the file path for the WSI object"""
374+
return str(abspath(wsi.path))
368375

369376
def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs):
370377
"""
@@ -440,6 +447,7 @@ class OpenSlideWSIReader(BaseWSIReader):
440447
"""
441448

442449
supported_suffixes = ["tif", "tiff", "svs"]
450+
backend = "openslide"
443451

444452
def __init__(self, level: int = 0, **kwargs):
445453
super().__init__(level, **kwargs)
@@ -458,7 +466,7 @@ def get_level_count(wsi) -> int:
458466
@staticmethod
459467
def get_size(wsi, level: int) -> Tuple[int, int]:
460468
"""
461-
Returns the size of the whole slide image at a given level.
469+
Returns the size (height, width) of the whole slide image at a given level.
462470
463471
Args:
464472
wsi: a whole slide image object loaded from a file
@@ -467,27 +475,9 @@ def get_size(wsi, level: int) -> Tuple[int, int]:
467475
"""
468476
return (wsi.level_dimensions[level][1], wsi.level_dimensions[level][0])
469477

470-
def get_metadata(self, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int) -> Dict:
471-
"""
472-
Returns metadata of the extracted patch from the whole slide image.
473-
474-
Args:
475-
patch: extracted patch from whole slide image
476-
location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).
477-
size: (height, width) tuple giving the patch size at the given level (`level`).
478-
If None, it is set to the full image size at the given level.
479-
level: the level number. Defaults to 0
480-
481-
"""
482-
metadata: Dict = {
483-
"backend": "openslide",
484-
"spatial_shape": np.asarray(patch.shape[1:]),
485-
"original_channel_dim": 0,
486-
"location": location,
487-
"size": size,
488-
"level": level,
489-
}
490-
return metadata
478+
def get_file_path(self, wsi) -> str:
479+
"""Return the file path for the WSI object"""
480+
return str(abspath(wsi._filename))
491481

492482
def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs):
493483
"""

tests/test_wsireader_new.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,13 @@ class Tests(unittest.TestCase):
125125
def test_read_whole_image(self, file_path, level, expected_shape):
126126
reader = WSIReader(self.backend, level=level)
127127
with reader.read(file_path) as img_obj:
128-
img = reader.get_data(img_obj)[0]
128+
img, meta = reader.get_data(img_obj)
129129
self.assertTupleEqual(img.shape, expected_shape)
130+
self.assertEqual(meta["backend"], self.backend)
131+
self.assertEqual(meta["wsi"]["path"], str(os.path.abspath(file_path)))
132+
self.assertEqual(meta["patch"]["level"], level)
133+
self.assertTupleEqual(meta["patch"]["size"], expected_shape[1:])
134+
self.assertTupleEqual(meta["patch"]["location"], (0, 0))
130135

131136
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
132137
def test_read_region(self, file_path, patch_info, expected_img):
@@ -138,29 +143,39 @@ def test_read_region(self, file_path, patch_info, expected_img):
138143
reader.get_data(img_obj, **patch_info)[0]
139144
else:
140145
# Read twice to check multiple calls
141-
img = reader.get_data(img_obj, **patch_info)[0]
146+
img, meta = reader.get_data(img_obj, **patch_info)
142147
img2 = reader.get_data(img_obj, **patch_info)[0]
143148
self.assertTupleEqual(img.shape, img2.shape)
144149
self.assertIsNone(assert_array_equal(img, img2))
145150
self.assertTupleEqual(img.shape, expected_img.shape)
146151
self.assertIsNone(assert_array_equal(img, expected_img))
152+
self.assertEqual(meta["backend"], self.backend)
153+
self.assertEqual(meta["wsi"]["path"], str(os.path.abspath(file_path)))
154+
self.assertEqual(meta["patch"]["level"], patch_info["level"])
155+
self.assertTupleEqual(meta["patch"]["size"], expected_img.shape[1:])
156+
self.assertTupleEqual(meta["patch"]["location"], patch_info["location"])
147157

148158
@parameterized.expand([TEST_CASE_3])
149-
def test_read_region_multi_wsi(self, file_path, patch_info, expected_img):
159+
def test_read_region_multi_wsi(self, file_path_list, patch_info, expected_img):
150160
kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {}
151161
reader = WSIReader(self.backend, **kwargs)
152-
img_obj = reader.read(file_path, **kwargs)
162+
img_obj_list = reader.read(file_path_list, **kwargs)
153163
if self.backend == "tifffile":
154164
with self.assertRaises(ValueError):
155-
reader.get_data(img_obj, **patch_info)[0]
165+
reader.get_data(img_obj_list, **patch_info)[0]
156166
else:
157167
# Read twice to check multiple calls
158-
img = reader.get_data(img_obj, **patch_info)[0]
159-
img2 = reader.get_data(img_obj, **patch_info)[0]
168+
img, meta = reader.get_data(img_obj_list, **patch_info)
169+
img2 = reader.get_data(img_obj_list, **patch_info)[0]
160170
self.assertTupleEqual(img.shape, img2.shape)
161171
self.assertIsNone(assert_array_equal(img, img2))
162172
self.assertTupleEqual(img.shape, expected_img.shape)
163173
self.assertIsNone(assert_array_equal(img, expected_img))
174+
self.assertEqual(meta["backend"], self.backend)
175+
self.assertEqual(meta["wsi"][0]["path"], str(os.path.abspath(file_path_list[0])))
176+
self.assertEqual(meta["patch"][0]["level"], patch_info["level"])
177+
self.assertTupleEqual(meta["patch"][0]["size"], expected_img.shape[1:])
178+
self.assertTupleEqual(meta["patch"][0]["location"], patch_info["location"])
164179

165180
@parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1])
166181
@skipUnless(has_tiff, "Requires tifffile.")

0 commit comments

Comments
 (0)