feat: implement nanamin/nanamax#30
Conversation
| if x.numel() == 0: | ||
| return torch.as_tensor(math.nan).to(x) |
There was a problem hiding this comment.
次元の処理がおかしい。dimやkeepdimをちゃんと見て要素が0個のときのshapeはsumとかと挙動を一緒にしてほしい。
There was a problem hiding this comment.
空テンソルの場合は、torch.sum で dim, keepdim に応じた正しい shape のテンソルを生成し、NaN を掛けて返すように修正しました。
| return torch.as_tensor(math.nan).to(x) | ||
|
|
||
| # 2. Build a mask for slices with at least one valid (non-NaN) element. | ||
| is_valid = (~x.isnan()).sum(dim=dim, keepdim=keepdim) > 0 |
There was a problem hiding this comment.
any に変更しました。ただし any(dim=()) は sum(dim=()) と異なり全次元を集約しないので、dim==() のときは引数なしの any() を呼ぶ分岐を追加しました。
| @@ -0,0 +1,577 @@ | |||
| import math | |||
There was a problem hiding this comment.
amax/aminなので複数の要素が最大/最小となったときの勾配に関するテストがほしいです。
There was a problem hiding this comment.
追加しました。その時、勾配が均等分配されることを、1D, 2D, NaNありの3条件で検証しました。
| """ | ||
| # 1. Handle empty tensor (amax raises RuntimeError for numel() == 0). | ||
| if x.numel() == 0: | ||
| return x.sum(dim=dim, keepdim=keepdim) * math.nan |
There was a problem hiding this comment.
torch/numpy動かしてみましたがいずれもRuntimeErrorだったのでRuntimeErrorにして良いと思います。
torch
RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.
numpy
ValueError: zero-size array to reduction operation maximum which has no identity
There was a problem hiding this comment.
x.numel() == 0 のときRuntimeErrorを出すようにしました。
エラーメッセージは torch.amax のものに合わせました。
|
|
||
| def nanamax( | ||
| x: torch.Tensor, | ||
| dim: typing.Union[int, typing.Tuple[int, ...]] = (), |
There was a problem hiding this comment.
torch.amaxを確認したらNoneを指定したときが全reduceなのでこの記述が正しそう。()はreduceしないという意味なのでちょっと違う。
| dim: typing.Union[int, typing.Tuple[int, ...]] = (), | |
| dim: typing.Union[None, int, typing.Tuple[int, ...]] = None, |
There was a problem hiding this comment.
いただいたsuggestion のとおり修正しました (nanamin.py も同様)。
type: ignore[arg-type] としたのは、dim のデフォルトを None に変更したことで amax(dim=dim) に None が渡りうるようになり、mypyで arg-type エラーが出たためです。
ignoreがよくなければ、dim is None でif文で分岐するように書く必要があると思いますが、そちらの方が良ければ修正します。
| not_nan = ~x.isnan() | ||
| is_valid = ( | ||
| not_nan.any() if dim == () else not_nan.any(dim=dim, keepdim=keepdim) | ||
| ) |
There was a problem hiding this comment.
| not_nan = ~x.isnan() | |
| is_valid = ( | |
| not_nan.any() if dim == () else not_nan.any(dim=dim, keepdim=keepdim) | |
| ) | |
| is_invalid = x.isnan().all(dim=dim, keepdim=keepdim) |
There was a problem hiding this comment.
suggestionの通り修正しました。
| ) | ||
|
|
||
| # 4. Restore NaN for all-NaN slices. | ||
| return torch.where(is_valid, y, torch.as_tensor(math.nan).to(y)) |
There was a problem hiding this comment.
| return torch.where(is_valid, y, torch.as_tensor(math.nan).to(y)) | |
| return torch.where(is_invalid, torch.as_tensor(math.nan).to(y), y) |
There was a problem hiding this comment.
suggestionの通り修正しました。
0130d11 to
eb5be04
Compare
Closes #28
Implement
nanamaxandnanamin— NaN-aware maximum/minimum functions that support multiple dimensions viatorch.amax/torch.amin.Motivation
nanmax/nanminonly accept a singleintfordimbecause they rely ontorch.max/torch.min(which return indices).nanamax/nanaminusetorch.amax/torch.amininstead, allowingdimto be anintor atupleofints at the cost of not returning indices.Changes
qfeval_functions/functions/nanamax.py— new function following thenansum.py3-line patternqfeval_functions/functions/nanamin.py— delegates tonanamaxvia sign negation (-nanamax(-x))qfeval_functions/functions/__init__.py— register both functionstests/functions/test_nanamax.py— 43 teststests/functions/test_nanamin.py— 44 testsTests
make test: 1390 passedmake format: cleanmake lint: clean (black, isort, pflake8, mypy)