Skip to content
Merged

add ndr #1239

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 158 additions & 14 deletions gui/datumaro_gui/components/single/tabs/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import streamlit_antd_components as sac
from datumaro_gui.utils.dataset.data_loader import SingleDatasetHelper
from datumaro_gui.utils.dataset.info import get_category_info, get_subset_info
from datumaro_gui.utils.drawing import Dashboard, Pie, Radar
from datumaro_gui.utils.drawing import Dashboard, Gallery, Pie, Radar
from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode
from streamlit import session_state as state
from streamlit_elements import elements
Expand Down Expand Up @@ -55,15 +55,13 @@ def info(self) -> str:

@staticmethod
def _do_label_remap(data_helper, grid_table, delete_unselected):
print(f"{__class__} called")
sel_row = grid_table["selected_rows"]
mapping_dict = {item["src"]: item["dst"] for item in sel_row}
default = "delete" if delete_unselected else "keep"
data_helper.transform("remap_labels", mapping=mapping_dict, default=default)
st.toast("Remap Success!", icon="🎉")

def gui(self, data_helper: SingleDatasetHelper):
print(f"{__class__} called")
dataset = data_helper.dataset()
stats_anns = data_helper.get_ann_stats()
labels: LabelCategories = dataset.categories().get(AnnotationType.label, LabelCategories())
Expand Down Expand Up @@ -138,13 +136,10 @@ def info(self) -> str:

@staticmethod
def _do_aggregation(data_helper, selected_subsets, dst_subset_name):
print(f"selected_subsets = {selected_subsets}, dst_subset_name={dst_subset_name}")
data_helper.aggregate(from_subsets=selected_subsets, to_subset=dst_subset_name)
st.toast("Aggregation Success!", icon="🎉")

def gui(self, data_helper: SingleDatasetHelper):
print(f"{__class__} called")

subsets = list(data_helper.dataset().subsets().keys())

c1, c2 = st.columns([0.3, 0.7])
Expand All @@ -166,7 +161,7 @@ def gui(self, data_helper: SingleDatasetHelper):
selected_subsets = st.multiselect(
"Select subsets to be aggregated", subsets, default=subsets
)
dst_subset_name = st.text_input("Aggreated Subset Name:", "default")
dst_subset_name = st.text_input("Aggregated Subset Name:", "default")
st.button(
"Do aggregation",
use_container_width=True,
Expand All @@ -193,7 +188,6 @@ def _add_subset(self):
idx = 0 # default is 'train'
default_names = tuple(split.subset for split in default_splits)
for split in reversed(state["subset"]):
print(split)
if split.subset in default_names:
idx = (default_names.index(split.subset) + 1) % len(default_names)
break
Expand All @@ -214,8 +208,6 @@ def _do_split(data_helper):
st.toast("Sum of ratios is expected to be 1!", icon="🚨")

def gui(self, data_helper: SingleDatasetHelper):
print(f"{__class__} called")

c1, c2 = st.columns(2)
c1.button("Add subset", use_container_width=True, on_click=self._add_subset)
c2.button(
Expand Down Expand Up @@ -612,6 +604,155 @@ def gui(self, data_helper: SingleDatasetHelper):
st.dataframe(self._get_df(summary), use_container_width=True, hide_index=True)


class TransformNDR(TransformBase):
@property
def name(self) -> str:
return "Near Duplicate Removal"

@property
def info(self) -> str:
return "This helps to remove near-duplicated images in a subset"

@staticmethod
def _correct_dataset(data_helper, selected_task):
try:
reports_src = data_helper.validate(selected_task)
data_helper.transform("correct", reports=reports_src)
reports_dst = data_helper.validate(selected_task)
state["correct-reports"] = {"src": reports_src, "dst": reports_dst}
st.toast("Correction Success!", icon="🎉")
except Exception as e:
st.toast(f"Error: {repr(e)}", icon="🚨")

@staticmethod
def _run_ndr(
data_helper, working_subset, duplicated_subset, num_cut, over_sample, under_sample
):
try:
result = data_helper.transform(
"ndr",
working_subset=working_subset,
duplicated_subset=duplicated_subset,
num_cut=num_cut,
over_sample=over_sample,
under_sample=under_sample,
)
try:
duplicated = result.get_subset(duplicated_subset)
except KeyError:
duplicated = []
if len(duplicated) > 0:
st.toast("NDR Success!", icon="🎉")
else:
st.toast("No duplication found!", icon="⚠️")
except Exception as e:
st.toast(f"Error: {repr(e)}", icon="🚨")

@staticmethod
def _remove_duplicated(data_helper, duplicated_subset):
try:
duplicated = data_helper.dataset().get_subset(duplicated_subset)
except KeyError:
st.toast("No items to remove!", icon="⚠️")
return

ids = []
for item in duplicated:
ids.append((item.id, duplicated_subset))

try:
data_helper.transform("remove_items", ids=ids)
st.toast("Removal Success!", icon="🎉")
except Exception as e:
st.toast(f"Error: {repr(e)}", icon="🚨")

def gui(self, data_helper: SingleDatasetHelper):
subsets = list(data_helper.dataset().subsets().keys())

c1, c2 = st.columns([0.3, 0.7])
with c1:
subset_info = get_subset_info(data_helper.dataset())
with elements("single-transform-ndr-subset"):
board = Dashboard()
w = SimpleNamespace(
dashboard=board,
subset_info=Pie(
name="Subset info",
**{"board": board, "x": 0, "y": 0, "w": 4, "h": 4, "minW": 3, "minH": 3},
),
)
with st.container():
with w.dashboard(rowHeight=100):
w.subset_info(subset_info)
with c2:
working_subset = st.selectbox("Select a subset to apply NDR:", subsets)
duplicated_subset = st.text_input("Subset name for the removed data:", "duplicated")
advanced_option = st.toggle("Advanced option")
if advanced_option:
num_items = len(data_helper.dataset().get_subset(working_subset))
num_cut = st.number_input(
"Maximum output dataset size:", value=min(100, num_items), step=1
)
over_sample = st.radio(
"Oversample Policy",
["random", "similarity"],
help="Specify the strategy when num_cut > length of the result after removal.",
captions=[
"Sample from removed data randomly",
"Select from removed data with ascending order of similarity",
],
)
under_sample = st.radio(
"Undersample Policy",
["uniform", "inverse"],
help="Specify the strategy when num_cut < length of the result after removal.",
captions=[
"Sample data with uniform distribution",
"Select data with reciprocal of the number",
],
)
else:
num_cut = None
over_sample = None
under_sample = None

st.button(
"Find Near Duplicate",
use_container_width=True,
on_click=self._run_ndr,
args=(
data_helper,
working_subset,
duplicated_subset,
num_cut,
over_sample,
under_sample,
),
)
try:
duplicated = data_helper.dataset().get_subset(duplicated_subset)
except KeyError:
duplicated = []

st.button(
"Remove Near Duplicate",
use_container_width=True,
on_click=self._remove_duplicated,
args=(data_helper, duplicated_subset),
disabled=len(duplicated) == 0,
)
if len(duplicated) > 0:
with elements("single-transform-ndr-duplicated"):
board = Dashboard()
w = SimpleNamespace(
dashboard=board,
player=Gallery(board, 0, 0, 8, 3, minH=3),
)
with st.container():
with w.dashboard(rowHeight=100):
w.player(duplicated, max_number=12, title=duplicated_subset)


class TransformCategory(NamedTuple):
type: str
transforms: tuple[TransformBase]
Expand All @@ -622,19 +763,23 @@ def on_click(transform: TransformBase):


def main():
print(f"{__file__} called")
data_helper: SingleDatasetHelper = state["data_helper"]
transform_categories = (
TransformCategory("Category Management", (TransformLabelRemap,)),
TransformCategory("Subset Management", (TransformAggregation, TransformSplit)),
TransformCategory(
"Item Management",
(TransformReindexing, TransformFiltration, TransformRemove, TransformAutoCorrection),
(
TransformReindexing,
TransformFiltration,
TransformRemove,
TransformAutoCorrection,
TransformNDR,
),
),
)
if "selected_transform" not in state or state["selected_transform"] is None:
state["selected_transform"] = transform_categories[0].transforms[0]()
print(state["selected_transform"])

c1, c2 = st.columns([0.3, 0.7])
with c1:
Expand All @@ -649,7 +794,6 @@ def main():
args=(transform,),
)
with c2:
print(f"transform->c2 called : {state['selected_transform']}")
transform = state["selected_transform"]
st.subheader(transform.name)
info_str = transform.info
Expand Down
4 changes: 2 additions & 2 deletions gui/datumaro_gui/utils/drawing/gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Gallery(Dashboard.Item):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def __call__(self, dataset, max_number: int = 100):
def __call__(self, dataset, max_number: int = 100, title="Gallery"):
with mui.Paper(
key=self._key,
sx={
Expand All @@ -30,7 +30,7 @@ def __call__(self, dataset, max_number: int = 100):
):
with self.title_bar(padding="10px 15px 10px 15px", dark_switcher=False):
mui.icon.OndemandVideo()
mui.Typography("Gallery")
mui.Typography(title)

# Create a Streamlit Material-UI Box
with mui.Box(sx={"flex": 1, "minHeight": 0, "overflow": "auto"}):
Expand Down