diff --git a/syft/frameworks/torch/nn/conv.py b/syft/frameworks/torch/nn/conv.py index 8baf96abd8f..6d0fad14a3a 100644 --- a/syft/frameworks/torch/nn/conv.py +++ b/syft/frameworks/torch/nn/conv.py @@ -87,13 +87,9 @@ def forward(self, data): for i in range(0, rows - self.kernel_size + 1): for j in range(0, cols - self.kernel_size + 1): kernel_out = ( - ( - expanded_data[:, :, :, i : i + self.kernel_size, j : j + self.kernel_size] - * expanded_model - ) - .sum(3) - .sum(3) - ) + expanded_data[:, :, :, i : i + self.kernel_size, j : j + self.kernel_size] + * expanded_model + ).sum((3, 4)) kernel_results.append(kernel_out) pred = th.cat(kernel_results, axis=2).view( diff --git a/syft/frameworks/torch/nn/pool.py b/syft/frameworks/torch/nn/pool.py index 009b2cda9cb..1021e3ffd4f 100644 --- a/syft/frameworks/torch/nn/pool.py +++ b/syft/frameworks/torch/nn/pool.py @@ -62,7 +62,7 @@ def forward(self, data): for i in range(0, rows - self.kernel_size + 1, self.stride): for j in range(0, cols - self.kernel_size + 1, self.stride): kernel_out = ( - data[:, :, i : i + self.kernel_size, j : j + self.kernel_size].sum(2).sum(2) + data[:, :, i : i + self.kernel_size, j : j + self.kernel_size].sum((2, 3)) * self._one_over_kernel_size ) kernel_results.append(kernel_out.unsqueeze(2)) diff --git a/test/torch/nn/test_pool.py b/test/torch/nn/test_pool.py index a4b67e7edf3..edfaa3a7062 100644 --- a/test/torch/nn/test_pool.py +++ b/test/torch/nn/test_pool.py @@ -30,4 +30,4 @@ def test_pool2d(): out = pool(model_out) out2 = pool2(model_out) - assert th.isclose(out, out2, atol=1e-6).all() + assert th.eq(out, out2).all()