Use KDTree instead of mdtraj.compute_contacts in _filter_unphysical_traj_masks to mitigate MemoryError and improve performance#158
Conversation
|
@microsoft-github-policy-service agree |
|
Thank you for your contribution @ahmedselim2017 ! We'll review this asap. |
|
I am guessing that for small systems, computing all pairs is more efficient and this approach wins beyond a certain system size. Can we test for a couple of systems of different sizes, and implement a switchover if this assumption is correct? |
|
Thanks! For a simple benchmark you can just run repeats of "GGS", which is extremely flexible, and will also produce a lot of clashes for longer chains. Maybe 50 and 100 residues will be informative - looking at these plots I'm getting the impression that 100 residues may be a good switchover point. Which GPU are you using? |
|
No problem! For the other proteins, I have used A6000 but as it has some other queued jobs, I am using RTX4080 for 50 and 100 residue test runs. |
|
The MD equilibrium phase of the GGS repeated sequences did not progressed 1 step in >~10 hours probably because of the unnatural sequence. I will restart the runs with real proteins but it will probably take a similar time. |
|
Oh, that's weird. Any idea why that may happen @josejimenezluna ? I think the code shouldn't just get stuck when encountering a weird sequence. In case if this is a problem of not finding an alignment at least an error message should be returned. Can we create a separate issue for this? |
|
The sidechain generation was crashed because of some issues with the |
|
Hi again, I have ran the same benchmark with 50 and 100 residue proteins and as you have predicted for 100 residues or less KDTree was slower so we can use 100 residues as a cutoff.
By the way I have came across the same issue in other runs that uses |
|
Sounds good! Leaving this to reviewed by @josejimenezluna after his return to office. |
|
Hi, hope you are doing well! Are there any updates about the review? If you like, I can run some more tests but 100 residues looks like a good lower bound cutoff as |
|
Thank you for the great work @ahmedselim2017. I wanted to add a test so I have made a branch from yours, where I added a test checking that the kdtree and mdtraj versions give the same answers for some junk coordinates. The branch is here: https://github.com/microsoft/bioemu/tree/sarahlewis/kdtree-tests. Could you please check if you're happy with those changes, if so, incorporate them in your branch, and I'll approve and merge your PR? I don't want to just make a PR directly from my branch because it feels like taking credit for your work. One small thing I noticed (because of the hacky way I constructed the test trajectory) is that if a trajectory's 'time' attribute is too short, 'enumerate(traj)' does not iterate through all the frames. So I have changed your kdtree implementation to use 'range(len(traj))' instead. |
|
Of course! Thanks for the feedback, added your changes. |
sarahnlewis
left a comment
There was a problem hiding this comment.
Thanks for the nice work



Hi, first of all thanks for the project!
I was running bioemu for a structure with 994 residues for 10,000 samples. The model and side chain generation ran successfully but I got
MemoryError: Unable to allocate 445. GiB for an array with shape (10000, 11954402) and data type float32in thesave_pdb_and_xtcfunction.After some digging, I found that the error was caused by the
mdtraj.compute_contactsfunction that is used to find clashes in the_filter_unphysical_traj_masksfunction. As themdtraj.compute_contactsfunction works by first calculating the distances between atoms for each frame of the trajectory (ignoring i+1th and i+2th residues for residue i), it tries to create a huge array with a shape(10000, 11954402). To avoid creating this huge array, I processed each frame one by one.Also, since we are not directly interested in the distances between all atom pairs but whether if there are any clashes, I have used KDTree to check if there are any clashes without directly computing the distance matrices, which improved the performance of finding clashes dramatically as you can see from the figure below where I found the clashes for the first N frames of the original trajectory. The
mdtraj.compute_contactscan only find clashes up to 4000 frames as it gaves memory errors after that.