Skip to content

Is the state gradient not implemented yet for the CUDA kernel? (hence bptt_truncated_learning still forced to be True?) #102

@shouldsee

Description

@shouldsee

Thanks for sharing this great project!

        # Add warning that bptt_truncated_learning is forced to be true
        # due to incomplete implementation of CUDA kernel for bptt_learning
        #
        # @TODO : remove this warning once the CUDA kernel, with state gradient, is implemented
        if self.bptt_truncated_learning == False:
            print("====================================================================")
            print("[WARNING]: bptt_truncated_learning is set as true (was configured as false), due to incomplete implementation of CUDA kernel for bptt_learning")
            print("====================================================================")
            self.bptt_truncated_learning = True

https://github.com/RWKV/RWKV-infctx-trainer/blob/70d02c4997578a027d110e3acb03a523d3986448/RWKV-v6/src/model.py#L291C1-L300C1

Just to confirm, when doing tbptt, this is essentially similar to the gradient estimator used in TransformerXL right?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions