Skip to content
Merged
9 changes: 9 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -2447,6 +2447,15 @@
"doc"
]
},
{
"login": "Cyril-Meyer",
"name": "Cyril Meyer",
"avatar_url": "https://avatars.githubusercontent.com/u/69190238?v=4",
"profile": "https://github.com/Cyril-Meyer",
"contributions": [
"test"
]
},
{
"login": "Moonzyyy",
"name": "Daniele Carli",
Expand Down
72 changes: 72 additions & 0 deletions aeon/networks/tests/test_all_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,75 @@ def test_all_networks_functionality(network):
)
else:
pytest.skip(f"{network.__name__} not to be tested since its a base class.")


@pytest.mark.parametrize("network", _networks)
def test_all_networks_params(network):
"""Test the functionality of all networks."""
input_shape = (100, 2)

if network.__name__ in ["BaseDeepLearningNetwork", "EncoderNetwork"]:
pytest.skip(f"{network.__name__} not to be tested since its a base class.")

if network._config["structure"] == "auto-encoder":
pytest.skip(
f"{network.__name__} not to be tested (AE networks have their own tests)."
)

if not (
_check_soft_dependencies(
network._config["python_dependencies"], severity="none"
)
and _check_python_version(network._config["python_version"], severity="none")
):
pytest.skip(
f"{network.__name__} dependencies not satisfied or invalid \
Python version."
)

# check with default parameters
my_network = network()
my_network.build_network(input_shape=input_shape)

# check with list parameters
params = dict()
for attrname in [
"kernel_size",
"n_filters",
"avg_pool_size",
"activation",
"padding",
"strides",
"dilation_rate",
"use_bias",
]:

# Exceptions to fix
if (
attrname in ["kernel_size", "padding"]
and network.__name__ == "TapNetNetwork"
):
continue
# LITENetwork does not seem to work with list args
if network.__name__ == "LITENetwork":
continue

# Here we use 'None' string as default to differentiate with None values
attr = getattr(my_network, attrname, "None")
if attr != "None":
if attr is None:
attr = 3
elif isinstance(attr, list):
attr = attr[0]
else:
if network.__name__ in ["ResNetNetwork"]:
attr = [attr] * my_network.n_conv_per_residual_block
elif network.__name__ in ["InceptionNetwork"]:
attr = [attr] * my_network.depth
else:
attr = [attr] * my_network.n_layers
params[attrname] = attr

if params:
my_network = network(**params)
my_network.build_network(input_shape=input_shape)