when I use from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss, I get a error, ModuleNotFoundError: No module named 'xentropy_cuda_lib. I don't find xentropy_cuda_lib that I use conda or pip. how do i install xentropy_cuda_lib?
Thanks