Skip to content

Introducing dask-based parallel backend for the AnalysisBase.run() #4158

@marinegor

Description

@marinegor

Motivation

This project focuses on improving MDAnalysis speed and scalability. I am planning to implement a parallel backend for the MDAnalysis library. The backend is supposed to be implemented using dask library, allowing users to seamlessly run their analysis either on powerful local machines or various clusters, such as SLURM. There was a proof-of-concept fork of the MDAnalysis library, pmda, which implemented this idea for a subset of analysis methods implemented in the MDAnalysis.

Basically, proposal idea is to refactor pmda methods into the current version of MDAnalysis.

Proposed solution

A key component of the MDAnalysis library is the AnalysisBase class, from which all objects that allow user to run an analysis of a trajectory are inherited. Namely, it implements a run method, that looks somewhat like that:

def run(self, start=None, stop=None, step=None, frames=None, ...):	
	self._setup_frames(self._trajectory, start=start, stop=stop, step=step, frames=frames)
	
	self._prepare()
	for i, ts in enumerate(self._sliced_trajectory, ...):
		...
		self._single_frame()
	self._conclude()

and consists of three steps:

setting up frames for reading – may include options to analyze only a part of the trajectory in time coordinate (change time-step, start/stop, etc)
preparing for the analysis: may include preparation of some arrays storing intermediate data, etc
running analysis of a single frame
and concluding the results – e.g. average some quantity calculated for each frame separately.

For a setup with multiple worker processes, this protocol will require an additional step of first separating a trajectory into blocks. Each block will be processed with a single separate process, and also results from different blocks will potentially be concluded separately:

def run(self, start=None, stop=None, step=None, frames=None, scheduler: Optional[str]=None):
	if scheduler is None:
	# fallback to the old behavior
		self._setup_frames(self._trajectory, start=start, stop=stop, step=step, frames=frames)
		
		self._prepare()
		for i, ts in enumerate(self._sliced_trajectory, ...):
			...
			self._single_frame()
		self._conclude()
		
	else:
		self._configure_scheduler(scheduler=scheduler)
		self._setup_blocks(start=start, stop=stop, step=step, frames=frames) 
		# split trajectory into blocks according to scheduler settings

		tasks = []
		for block in self._blocks:
			# create separate tasks 
			# that would fall back to the old behavior 
			# and schedule them as dask tasks
			subrun = self.__class__(start=block.start, stop=block.stop, step=block.step, frames=block.frames, scheduler=None)
			dask_task = dask.delayed(subrun.run)
			tasks.append(dask_task)
		
		# perform dask computation
		tasks = dask.delayed(tasks)
		res = tasks.compute(**self._scheduler_params)
		self._parallel_conclude()

Which requires introducing following methods for the AnalysisBase class:

class AnalysisBase(object):
	def _configure_scheduler(self, scheduler=scheduler, **params):
	    ...

	@property
	def _blocks(self):
		...
	
	def _setup_blocks(start=start, stop=stop, step=step, frames=frames):
		# will also update `self._blocks` accordingly
		...

	def _parallel_conclude(...):
		...

Which is similar to the protocol implemented in pmda. Such a modular design has following advantages:

  • it reuses the previously existing code for prepare and conclude methods for the subclasses, and hence will successfully run on sub-trajectories
  • it requires developers to introduce only a proper non-default parallel_conclude method for sophisticated results combination
  • it allows to raise an exception for subclasses that for some reason don’t allow such parallelization (e.g. rely on results from previous frames when computing current one), via re-implementing _configure_scheduler and raising an exception in not-None cases.

Alternatives

If you're looking into speeding up of your MD trajectory analysis, I'd recommend looking into pmda. It is an older fork of MDAnalysis (before 2.0) that explores the same idea that I want to implement here, but for limited amount of AnalysisBase subclasses.

Additional context

This issue is a part of GSOC project, and most of the code-related communication is supposed to happen here.

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions