Thanks for sharing this great project!
# Add warning that bptt_truncated_learning is forced to be true
# due to incomplete implementation of CUDA kernel for bptt_learning
#
# @TODO : remove this warning once the CUDA kernel, with state gradient, is implemented
if self.bptt_truncated_learning == False:
print("====================================================================")
print("[WARNING]: bptt_truncated_learning is set as true (was configured as false), due to incomplete implementation of CUDA kernel for bptt_learning")
print("====================================================================")
self.bptt_truncated_learning = True
https://github.com/RWKV/RWKV-infctx-trainer/blob/70d02c4997578a027d110e3acb03a523d3986448/RWKV-v6/src/model.py#L291C1-L300C1
Just to confirm, when doing tbptt, this is essentially similar to the gradient estimator used in TransformerXL right?