-
Notifications
You must be signed in to change notification settings - Fork 15
multi-outcome tabnet-fit and predict
#118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
switch target `y` from vector to array
manage loss cases for multi_output lint and refactor code
encode output_dim for multi-outcome improve multi-outcome classification loss split predict based on `is_multi_outcome`
switch to hardhat v1.3.0
fix tests vqlues add mixed-outcome and multi-outcome with valid test
move multi-output tests is a dedicated file
fix multi-outcome classification
tabnet-fit and predicttabnet-fit and predict
increase the GPU timeout
get rid of `library()` in tests remove typo in vignette
dfalbel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work @cregouby ! Looks great to me!
R/model.R
Outdated
| # if target is_multi_outcome, loss has to be applied to each label-group | ||
| if (max(batch$output_dim$shape) > 1) { | ||
| # TODO maybe torch_stack here would help loss$backward and better to shift right torch_sum at the end ? | ||
| outcome_nlevels <- as.numeric(batch$output_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that batch$output_dim is placed on the model device, thus as.numeric() doesn't work if the model is on CUDA.
| outcome_nlevels <- as.numeric(batch$output_dim) | |
| outcome_nlevels <- as.numeric(batch$output_dim$to(device="cpu")) |
|
GPU tests are recovered now that they run in a larger instance. I think we should move the batch_dim to cpu before |
|
You might need to re-add the env check in the testthat.R file when submitting to CRAN, so tests doesn't run on CRAN machines where torch is usually not installed. |
Codecov Report
@@ Coverage Diff @@
## main #118 +/- ##
==========================================
+ Coverage 87.50% 88.02% +0.52%
==========================================
Files 10 10
Lines 1072 1127 +55
==========================================
+ Hits 938 992 +54
- Misses 134 135 +1
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
tabnet_fitandpredict, relying on the fresh hardhat::spruce_*_multiple