|
19 | 19 |
|
20 | 20 | from monai.config import DtypeLike, KeysCollection, PathLike |
21 | 21 | from monai.data.utils import correct_nifti_header_if_necessary, is_supported_format, orientation_ras_lps |
22 | | -from monai.utils import ensure_tuple, optional_import, require_pkg |
| 22 | +from monai.transforms.utility.array import EnsureChannelFirst |
| 23 | +from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import, require_pkg |
23 | 24 |
|
24 | 25 | if TYPE_CHECKING: |
25 | 26 | import itk |
|
38 | 39 | CuImage, _ = optional_import("cucim", name="CuImage") |
39 | 40 | TiffFile, _ = optional_import("tifffile", name="TiffFile") |
40 | 41 |
|
41 | | -__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader"] |
| 42 | +__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "WSIReader"] |
42 | 43 |
|
43 | 44 |
|
44 | 45 | class ImageReader(ABC): |
@@ -713,3 +714,265 @@ def _get_spatial_shape(self, img): |
713 | 714 | img: a PIL Image object loaded from an image file. |
714 | 715 | """ |
715 | 716 | return np.asarray((img.width, img.height)) |
| 717 | + |
| 718 | + |
| 719 | +class WSIReader(ImageReader): |
| 720 | + """ |
| 721 | + Read whole slide images and extract patches. |
| 722 | +
|
| 723 | + Args: |
| 724 | + backend: backend library to load the images, available options: "cuCIM", "OpenSlide" and "TiffFile". |
| 725 | + level: the whole slide image level at which the image is extracted. (default=0) |
| 726 | + This is overridden if the level argument is provided in `get_data`. |
| 727 | + kwargs: additional args for backend reading API in `read()`, more details in `cuCIM`, `TiffFile`, `OpenSlide`: |
| 728 | + https://github.com/rapidsai/cucim/blob/v21.12.00/cpp/include/cucim/cuimage.h#L100. |
| 729 | + https://github.com/cgohlke/tifffile. |
| 730 | + https://openslide.org/api/python/#openslide.OpenSlide. |
| 731 | +
|
| 732 | + Note: |
| 733 | + While "cuCIM" and "OpenSlide" backends both can load patches from large whole slide images |
| 734 | + without loading the entire image into memory, "TiffFile" backend needs to load the entire image into memory |
| 735 | + before extracting any patch; thus, memory consideration is needed when using "TiffFile" backend for |
| 736 | + patch extraction. |
| 737 | +
|
| 738 | + """ |
| 739 | + |
| 740 | + def __init__(self, backend: str = "OpenSlide", level: int = 0, **kwargs): |
| 741 | + super().__init__() |
| 742 | + self.backend = backend.lower() |
| 743 | + func = require_pkg(self.backend)(self._set_reader) |
| 744 | + self.wsi_reader = func(self.backend) |
| 745 | + self.level = level |
| 746 | + self.kwargs = kwargs |
| 747 | + |
| 748 | + @staticmethod |
| 749 | + def _set_reader(backend: str): |
| 750 | + if backend == "openslide": |
| 751 | + return OpenSlide |
| 752 | + if backend == "cucim": |
| 753 | + return CuImage |
| 754 | + if backend == "tifffile": |
| 755 | + return TiffFile |
| 756 | + raise ValueError("`backend` should be 'cuCIM', 'OpenSlide' or 'TiffFile'.") |
| 757 | + |
| 758 | + def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: |
| 759 | + """ |
| 760 | + Verify whether the specified file or files format is supported by WSI reader. |
| 761 | +
|
| 762 | + Args: |
| 763 | + filename: file name or a list of file names to read. |
| 764 | + if a list of files, verify all the suffixes. |
| 765 | + """ |
| 766 | + return is_supported_format(filename, ["tif", "tiff"]) |
| 767 | + |
| 768 | + def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs): |
| 769 | + """ |
| 770 | + Read image data from given file or list of files. |
| 771 | +
|
| 772 | + Args: |
| 773 | + data: file name or a list of file names to read. |
| 774 | + kwargs: additional args for backend reading API in `read()`, will override `self.kwargs` for existing keys. |
| 775 | + more details in `cuCIM`, `TiffFile`, `OpenSlide`: |
| 776 | + https://github.com/rapidsai/cucim/blob/v21.12.00/cpp/include/cucim/cuimage.h#L100. |
| 777 | + https://github.com/cgohlke/tifffile. |
| 778 | + https://openslide.org/api/python/#openslide.OpenSlide. |
| 779 | +
|
| 780 | + Returns: |
| 781 | + image object or list of image objects |
| 782 | +
|
| 783 | + """ |
| 784 | + img_: List = [] |
| 785 | + |
| 786 | + filenames: Sequence[PathLike] = ensure_tuple(data) |
| 787 | + kwargs_ = self.kwargs.copy() |
| 788 | + kwargs_.update(kwargs) |
| 789 | + for name in filenames: |
| 790 | + img = self.wsi_reader(name, **kwargs_) |
| 791 | + if self.backend == "openslide": |
| 792 | + img.shape = (img.dimensions[1], img.dimensions[0], 3) |
| 793 | + img_.append(img) |
| 794 | + |
| 795 | + return img_ if len(filenames) > 1 else img_[0] |
| 796 | + |
| 797 | + def get_data( |
| 798 | + self, |
| 799 | + img, |
| 800 | + location: Tuple[int, int] = (0, 0), |
| 801 | + size: Optional[Tuple[int, int]] = None, |
| 802 | + level: Optional[int] = None, |
| 803 | + dtype: DtypeLike = np.uint8, |
| 804 | + grid_shape: Tuple[int, int] = (1, 1), |
| 805 | + patch_size: Optional[Union[int, Tuple[int, int]]] = None, |
| 806 | + ): |
| 807 | + """ |
| 808 | + Extract regions as numpy array from WSI image and return them. |
| 809 | +
|
| 810 | + Args: |
| 811 | + img: a WSIReader image object loaded from a file, or list of CuImage objects |
| 812 | + location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame, |
| 813 | + or list of tuples (default=(0, 0)) |
| 814 | + size: (height, width) tuple giving the region size, or list of tuples (default to full image size) |
| 815 | + This is the size of image at the given level (`level`) |
| 816 | + level: the level number, or list of level numbers (default=0) |
| 817 | + dtype: the data type of output image |
| 818 | + grid_shape: (row, columns) tuple define a grid to extract patches on that |
| 819 | + patch_size: (height, width) the size of extracted patches at the given level |
| 820 | + """ |
| 821 | + # Verify inputs |
| 822 | + if level is None: |
| 823 | + level = self.level |
| 824 | + max_level = self._get_max_level(img) |
| 825 | + if level > max_level: |
| 826 | + raise ValueError(f"The maximum level of this image is {max_level} while level={level} is requested)!") |
| 827 | + |
| 828 | + # Extract a region or the entire image |
| 829 | + region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype) |
| 830 | + |
| 831 | + # Add necessary metadata |
| 832 | + metadata: Dict = {} |
| 833 | + metadata["spatial_shape"] = np.asarray(region.shape[:-1]) |
| 834 | + metadata["original_channel_dim"] = -1 |
| 835 | + |
| 836 | + # Make it channel first |
| 837 | + region = EnsureChannelFirst()(region, metadata) |
| 838 | + |
| 839 | + # Split into patches |
| 840 | + if patch_size is None: |
| 841 | + patches = region |
| 842 | + else: |
| 843 | + tuple_patch_size = ensure_tuple_rep(patch_size, 2) |
| 844 | + patches = self._extract_patches( |
| 845 | + region, patch_size=tuple_patch_size, grid_shape=grid_shape, dtype=dtype # type: ignore |
| 846 | + ) |
| 847 | + |
| 848 | + return patches, metadata |
| 849 | + |
| 850 | + def _get_max_level(self, img_obj): |
| 851 | + """ |
| 852 | + Return the maximum number of levels in the whole slide image |
| 853 | + Args: |
| 854 | + img: the whole slide image object |
| 855 | +
|
| 856 | + """ |
| 857 | + if self.backend == "openslide": |
| 858 | + return img_obj.level_count - 1 |
| 859 | + if self.backend == "cucim": |
| 860 | + return img_obj.resolutions["level_count"] - 1 |
| 861 | + if self.backend == "tifffile": |
| 862 | + return len(img_obj.pages) - 1 |
| 863 | + |
| 864 | + def _get_image_size(self, img, size, level, location): |
| 865 | + """ |
| 866 | + Calculate the maximum region size for the given level and starting location (if size is None). |
| 867 | + Note that region size in OpenSlide and cuCIM are WxH (but the final image output would be HxW) |
| 868 | + """ |
| 869 | + if size is not None: |
| 870 | + return size[::-1] |
| 871 | + |
| 872 | + max_size = [] |
| 873 | + downsampling_factor = [] |
| 874 | + if self.backend == "openslide": |
| 875 | + downsampling_factor = img.level_downsamples[level] |
| 876 | + max_size = img.level_dimensions[level] |
| 877 | + elif self.backend == "cucim": |
| 878 | + downsampling_factor = img.resolutions["level_downsamples"][level] |
| 879 | + max_size = img.resolutions["level_dimensions"][level] |
| 880 | + |
| 881 | + # subtract the top left corner of the patch (at given level) from maximum size |
| 882 | + location_at_level = (round(location[1] / downsampling_factor), round(location[0] / downsampling_factor)) |
| 883 | + size = [max_size[i] - location_at_level[i] for i in range(len(max_size))] |
| 884 | + |
| 885 | + return size |
| 886 | + |
| 887 | + def _extract_region( |
| 888 | + self, |
| 889 | + img_obj, |
| 890 | + size: Optional[Tuple[int, int]], |
| 891 | + location: Tuple[int, int] = (0, 0), |
| 892 | + level: int = 0, |
| 893 | + dtype: DtypeLike = np.uint8, |
| 894 | + ): |
| 895 | + if self.backend == "tifffile": |
| 896 | + # Read the entire image |
| 897 | + if size is not None: |
| 898 | + raise ValueError( |
| 899 | + f"TiffFile backend reads the entire image only, so size '{size}'' should not be provided!", |
| 900 | + "For more flexibility or extracting regions, please use cuCIM or OpenSlide backend.", |
| 901 | + ) |
| 902 | + if location != (0, 0): |
| 903 | + raise ValueError( |
| 904 | + f"TiffFile backend reads the entire image only, so location '{location}' should not be provided!", |
| 905 | + "For more flexibility and extracting regions, please use cuCIM or OpenSlide backend.", |
| 906 | + ) |
| 907 | + region = img_obj.asarray(level=level) |
| 908 | + else: |
| 909 | + # Get region size to be extracted |
| 910 | + region_size = self._get_image_size(img_obj, size, level, location) |
| 911 | + # reverse the order of location's dimensions to become WxH (for cuCIM and OpenSlide) |
| 912 | + region_location = location[::-1] |
| 913 | + # Extract a region (or the entire image) |
| 914 | + region = img_obj.read_region(location=region_location, size=region_size, level=level) |
| 915 | + |
| 916 | + region = self.convert_to_rgb_array(region, dtype) |
| 917 | + return region |
| 918 | + |
| 919 | + def convert_to_rgb_array(self, raw_region, dtype: DtypeLike = np.uint8): |
| 920 | + """Convert to RGB mode and numpy array""" |
| 921 | + if self.backend == "openslide": |
| 922 | + # convert to RGB |
| 923 | + raw_region = raw_region.convert("RGB") |
| 924 | + |
| 925 | + # convert to numpy (if not already in numpy) |
| 926 | + raw_region = np.asarray(raw_region, dtype=dtype) |
| 927 | + |
| 928 | + # check if the image has three dimensions (2D + color) |
| 929 | + if raw_region.ndim != 3: |
| 930 | + raise ValueError( |
| 931 | + f"The input image dimension should be 3 but {raw_region.ndim} is given. " |
| 932 | + "`WSIReader` is designed to work only with 2D colored images." |
| 933 | + ) |
| 934 | + |
| 935 | + # check if the color channel is 3 (RGB) or 4 (RGBA) |
| 936 | + if raw_region.shape[-1] not in [3, 4]: |
| 937 | + raise ValueError( |
| 938 | + f"There should be three or four color channels but {raw_region.shape[-1]} is given. " |
| 939 | + "`WSIReader` is designed to work only with 2D colored images." |
| 940 | + ) |
| 941 | + |
| 942 | + # remove alpha channel if exist (RGBA) |
| 943 | + if raw_region.shape[-1] > 3: |
| 944 | + raw_region = raw_region[..., :3] |
| 945 | + |
| 946 | + return raw_region |
| 947 | + |
| 948 | + def _extract_patches( |
| 949 | + self, |
| 950 | + region: np.ndarray, |
| 951 | + grid_shape: Tuple[int, int] = (1, 1), |
| 952 | + patch_size: Optional[Tuple[int, int]] = None, |
| 953 | + dtype: DtypeLike = np.uint8, |
| 954 | + ): |
| 955 | + if patch_size is None and grid_shape == (1, 1): |
| 956 | + return region |
| 957 | + |
| 958 | + n_patches = grid_shape[0] * grid_shape[1] |
| 959 | + region_size = region.shape[1:] |
| 960 | + |
| 961 | + if patch_size is None: |
| 962 | + patch_size = (region_size[0] // grid_shape[0], region_size[1] // grid_shape[1]) |
| 963 | + |
| 964 | + # split the region into patches on the grid and center crop them to patch size |
| 965 | + flat_patch_grid = np.zeros((n_patches, 3, patch_size[0], patch_size[1]), dtype=dtype) |
| 966 | + start_points = [ |
| 967 | + np.round(region_size[i] * (0.5 + np.arange(grid_shape[i])) / grid_shape[i] - patch_size[i] / 2).astype(int) |
| 968 | + for i in range(2) |
| 969 | + ] |
| 970 | + idx = 0 |
| 971 | + for y_start in start_points[1]: |
| 972 | + for x_start in start_points[0]: |
| 973 | + x_end = x_start + patch_size[0] |
| 974 | + y_end = y_start + patch_size[1] |
| 975 | + flat_patch_grid[idx] = region[:, x_start:x_end, y_start:y_end] |
| 976 | + idx += 1 |
| 977 | + |
| 978 | + return flat_patch_grid |
0 commit comments