Skip to content

Conversation

@cregouby
Copy link
Collaborator

@cregouby cregouby commented Apr 30, 2023

  • add multi-outcome capabilty to tabnet_fit and predict, relying on the fresh hardhat::spruce_*_multiple

switch target `y` from vector to array
manage loss cases for multi_output
lint and refactor code
encode output_dim for multi-outcome
improve multi-outcome classification loss
split predict based on `is_multi_outcome`
fix tests vqlues
add mixed-outcome and multi-outcome with valid test
move multi-output tests is a dedicated file
@cregouby cregouby marked this pull request as draft April 30, 2023 19:30
fix multi-outcome classification
@cregouby cregouby changed the title multi-output tabnet-fit and predict multi-outcome tabnet-fit and predict May 2, 2023
@cregouby cregouby marked this pull request as ready for review May 7, 2023 22:06
@cregouby cregouby requested a review from dfalbel May 7, 2023 22:06
@cregouby cregouby closed this May 7, 2023
@cregouby cregouby deleted the feature/multilabel branch May 7, 2023 22:10
@cregouby cregouby restored the feature/multilabel branch May 7, 2023 22:11
@cregouby cregouby reopened this May 7, 2023
Copy link
Member

@dfalbel dfalbel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work @cregouby ! Looks great to me!

R/model.R Outdated
# if target is_multi_outcome, loss has to be applied to each label-group
if (max(batch$output_dim$shape) > 1) {
# TODO maybe torch_stack here would help loss$backward and better to shift right torch_sum at the end ?
outcome_nlevels <- as.numeric(batch$output_dim)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that batch$output_dim is placed on the model device, thus as.numeric() doesn't work if the model is on CUDA.

Suggested change
outcome_nlevels <- as.numeric(batch$output_dim)
outcome_nlevels <- as.numeric(batch$output_dim$to(device="cpu"))

@dfalbel
Copy link
Member

dfalbel commented May 9, 2023

GPU tests are recovered now that they run in a larger instance. I think we should move the batch_dim to cpu before as.numeric() and that should probably fix that!

@dfalbel
Copy link
Member

dfalbel commented May 9, 2023

You might need to re-add the env check in the testthat.R file when submitting to CRAN, so tests doesn't run on CRAN machines where torch is usually not installed.

if (Sys.getenv("TORCH_TEST", unset = 0) == 1)
   test_check("tabnet")

d90d749

@codecov
Copy link

codecov bot commented May 9, 2023

Codecov Report

Merging #118 (0bdd416) into main (162134c) will increase coverage by 0.52%.
The diff coverage is 97.22%.

❗ Current head 0bdd416 differs from pull request most recent head 3af6434. Consider uploading reports for the commit 3af6434 to get more accurate results

@@            Coverage Diff             @@
##             main     #118      +/-   ##
==========================================
+ Coverage   87.50%   88.02%   +0.52%     
==========================================
  Files          10       10              
  Lines        1072     1127      +55     
==========================================
+ Hits          938      992      +54     
- Misses        134      135       +1     
Impacted Files Coverage Δ
R/plot.R 100.00% <ø> (ø)
R/tab-network.R 100.00% <ø> (ø)
R/hardhat.R 88.78% <96.42%> (+0.34%) ⬆️
R/model.R 94.75% <97.40%> (+0.69%) ⬆️
R/pretraining.R 95.45% <100.00%> (ø)

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@cregouby cregouby merged commit e5c5306 into main May 9, 2023
@cregouby cregouby deleted the feature/multilabel branch May 9, 2023 18:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants