"""
Implementation of ContactFrequency parallelization using dask.distributed
"""
from . import frequency_task
from .contact_map import ContactFrequency, ContactObject
from .contact_trajectory import ContactTrajectory
import mdtraj as md
def dask_run(trajectory, client, run_info):
"""
Runs dask version of ContactFrequency. Note that this API on this will
definitely change before the release.
Parameters
----------
trajectory : mdtraj.trajectory
client : dask.distributed.Client
path to dask scheduler file
run_info : dict
keys are 'trajectory_file' (trajectory filename), 'load_kwargs'
(additional kwargs passed to md.load), and 'parameters' (dict of
kwargs for the ContactFrequency object)
Returns
-------
:class:`.ContactFrequency` :
total contact frequency for the trajectory
"""
subtrajs = dask_load_and_slice(trajectory, client, run_info)
maps = client.map(frequency_task.map_task, subtrajs,
parameters=run_info['parameters'])
freq = client.submit(frequency_task.reduce_all_results, maps)
return freq.result()
def dask_load_and_slice(trajectory, client, run_info):
"""
Run dask to load a trajectory and make a list of subtraj futures that
can be used for the other analysis
Parameters
----------
trajectory : mdtraj.Trajectory
client : dask.distributed.Client
run_info : dict
Keys are 'trajectory_file' (trajectory filename) and 'load_kwargs'
(additional kwargs passed to md.load)
Returns
-------
list of Futures :
A list of futures for loaded subtrajectory
"""
slices = frequency_task.default_slices(n_total=len(trajectory),
n_workers=len(client.ncores()))
subtrajs = client.map(frequency_task.load_trajectory_task, slices,
file_name=run_info['trajectory_file'],
**run_info['load_kwargs'])
return subtrajs
class DaskContactTrajectory(ContactTrajectory):
"""Dask-based parallelization of contact trajectory.
The contact trajectory tracks all contacts of a trajectory, frame-by-frame.
See :class:`.ContactTrajectory` for details. This implementation
parallelizes the contact map calculations using
``dask.distributed``, which must be installed separately to use this
object.
Notes
-----
The interface for this object closely mimics that of the
:class:`.ContactTrajectory` object, with the addition requiring the
``dask.distributed.Client`` as input. However, there is one important
difference. Whereas :class:`.ContactTrajectory` takes an
``mdtraj.Trajectory`` object as input, :class:`.DaskContactTrajectory`
takes a file name, plus any extra kwargs that MDTraj needs to load the
file.
Parameters
----------
client : dask.distributed.Client
Client object connected to the dask network.
filename : str
Name of the file where the trajectory is located. File must be
accessible by all workers in the dask network.
query : list of int
Indices of the atoms to be included as query. Default ``None``
means all atoms.
haystack : list of int
Indices of the atoms to be included as haystack. Default ``None``
means all atoms.
cutoff : float
Cutoff distance for contacts, in nanometers. Default 0.45.
n_neighbors_ignored : int
Number of neighboring residues (in the same chain) to ignore.
Default 2.
"""
def __init__(self, client, filename, query=None, haystack=None,
cutoff=0.45, n_neighbors_ignored=2, **kwargs):
self.client = client
self.filename = filename
self.kwargs = kwargs
self.trajectory = md.load(filename, **kwargs)
self.contact_object = ContactObject(
self.trajectory.topology,
query=query,
haystack=haystack,
cutoff=cutoff,
n_neighbors_ignored=n_neighbors_ignored,
)
super(DaskContactTrajectory, self).__init__(
self.trajectory, query, haystack, cutoff, n_neighbors_ignored,
)
def _build_contacts(self, trajectory):
subtrajs = dask_load_and_slice(self.trajectory, self.client,
self.run_info)
contact_lists = self.client.map(frequency_task.contacts_per_frame_task,
subtrajs,
contact_object=self.contact_object)
# Return a generator for this to work out
gen = ((atom_contacts, residue_contacts)
for contacts in contact_lists
for atom_contacts, residue_contacts in contacts.result())
return gen
@property
def parameters(self):
return {'query': self.query,
'haystack': self.haystack,
'cutoff': self.cutoff,
'n_neighbors_ignored': self.n_neighbors_ignored}
@property
def run_info(self):
return {'parameters': self.parameters,
'trajectory_file': self.filename,
'load_kwargs': self.kwargs}