"""
Contact map analysis.
"""
# Maintainer: David W.H. Swenson (dwhs@hyperblazer.net)
# Licensed under LGPL, version 2.1 or greater
import collections
import itertools
import pickle
import json
import numpy as np
import pandas as pd
import mdtraj as md
from .contact_count import ContactCount
from .py_2_3 import inspect_method_arguments
# TODO:
# * switch to something where you can define the haystack -- the trick is to
# replace the current mdtraj._compute_neighbors with something that
# build a voxel list for the haystack, and then checks the voxel for each
# query atom. Doesn't look like anything is doing that now: neighbors
# doesn't use voxels, neighborlist doesn't limit the haystack
def residue_neighborhood(residue, n=1):
"""Find n nearest neighbor residues
Parameters
----------
residue : mdtraj.Residue
this residue
n : positive int
number of neighbors to find
Returns
-------
list of int
neighbor residue numbers
"""
neighborhood = set([residue.index+i for i in range(-n, n+1)])
chain = set([res.index for res in residue.chain.residues])
# we could probably choose an faster approach here, but this is pretty
# good, and it only gets run once per residue
return [idx for idx in neighborhood if idx in chain]
def _residue_and_index(residue, topology):
res = residue
try:
res_idx = res.index
except AttributeError:
res_idx = residue
res = topology.residue(res_idx)
return (res, res_idx)
def _atom_slice(traj, indices):
"""Mock MDTraj.atom_slice without rebuilding topology"""
xyz = np.array(traj.xyz[:, indices], order='C')
topology = traj.topology.copy()
if traj._have_unitcell:
unitcell_lengths = traj._unitcell_lengths.copy()
unitcell_angles = traj._unitcell_angles.copy()
else:
unitcell_lengths = None
unitcell_angles = None
time = traj._time.copy()
# Hackish to make the smart slicing work
topology._atoms = indices
topology._numAtoms = len(indices)
return md.Trajectory(xyz=xyz, topology=topology, time=time,
unitcell_lengths=unitcell_lengths,
unitcell_angles=unitcell_angles)
def _residue_for_atom(topology, atom_list):
return set([topology.atom(a).residue for a in atom_list])
def _range_from_object_list(object_list):
"""
Objects must have .index attribute (e.g., MDTraj Residue/Atom)
"""
idxs = [obj.index for obj in object_list]
return (min(idxs), max(idxs) + 1)
class ContactsDict(object):
"""Dict-like object giving access to atom or residue contacts.
In some algorithmic situations, either the atom_contacts or the
residue_contacts might be used. Rather than use lots of if-statements,
or build an actual dictionary with the associated time cost of
generating both, this class provides an object that allows dict-like
access to either the atom or residue contacts.
Atom-based contacts (``contact.atom_contacts``) can be accessed with as
``contact_dict['atom']`` or ``contact_dict['atoms']``. Residue-based
contacts can be accessed with the keys ``'residue'``, ``'residues'``, or
``'res'``.
Parameters
----------
contacts : :class:`.ContactObject`
contact object with fundamental data
"""
def __init__(self, contacts):
self.contacts = contacts
def __getitem__(self, atom_or_res):
if atom_or_res in ["atom", "atoms"]:
contacts = self.contacts.atom_contacts
elif atom_or_res in ["residue", "residues", "res"]:
contacts = self.contacts.residue_contacts
else:
raise RuntimeError("Bad value for atom_or_res: " +
str(atom_or_res))
return contacts
class ContactObject(object):
"""
Generic object for contact map related analysis. Effectively abstract.
Much of what we need to do the contact map analysis is the same for all
analyses. It's in here.
"""
# Class default for use atom slice, None tries to be smart
_class_use_atom_slice = None
def __init__(self, topology, query, haystack, cutoff, n_neighbors_ignored):
# all inits required: no defaults for abstract class!
self._topology = topology
if query is None:
query = topology.select("not water and symbol != 'H'")
if haystack is None:
haystack = topology.select("not water and symbol != 'H'")
# make things private and accessible through read-only properties so
# they don't get accidentally changed after analysis
self._cutoff = cutoff
self._query = set(query)
self._haystack = set(haystack)
# Make tuple for efficient lookupt
all_atoms_set = set(query).union(set(haystack))
all_atoms_list = list(all_atoms_set)
all_atoms_list.sort()
self._all_atoms = tuple(all_atoms_list)
self._use_atom_slice = self._set_atom_slice()
# Set up the conversion dict to go from aton index to sliced indexes
self._idx_to_s_idx_dict = {e: i for
i, e in enumerate(self._all_atoms)}
# Get the sliced and used haystack indices
self._s_haystack = set(map(self.idx_to_s_idx, self._haystack))
self._u_haystack = self._set_used_haystack()
self._n_neighbors_ignored = n_neighbors_ignored
# Conversion dicts between the real and sliced atoms and their residues
self._r_atom_idx_to_residue_idx = {atom.index: atom.residue.index
for atom in self.topology.atoms}
self._s_atom_idx_to_residue_idx = {
i: self._r_atom_idx_to_residue_idx[e] for
i, e in enumerate(self._all_atoms)
}
self._atom_idx_to_residue_idx = self._set_atom_idx_to_residue_idx()
def _set_atom_slice(self):
""" Set atom slice logic """
if (self._class_use_atom_slice is None and
not len(self._all_atoms) < self._topology.n_atoms):
# Don't use if there are no atoms to be sliced
return False
elif self._class_use_atom_slice is None:
# Use if there are atms to be sliced
return True
else:
# Use class default
return self._class_use_atom_slice
def _set_used_haystack(self):
"""set which haystack to use in contact map"""
if self._use_atom_slice:
return self._s_haystack
else:
return self._haystack
def _set_atom_idx_to_residue_idx(self):
"""set which atom index to residue index is used"""
if self._use_atom_slice:
return self._s_atom_idx_to_residue_idx
else:
return self._r_atom_idx_to_residue_idx
def s_idx_to_idx(self, idx):
"""function to convert a sliced atom index back to real index"""
if self._use_atom_slice:
return self._all_atoms[idx]
else:
return idx
def idx_to_s_idx(self, idx):
"""function to convert a real atom index to a sliced one"""
if self._use_atom_slice:
return self._idx_to_s_idx_dict[idx]
else:
return idx
@property
def contacts(self):
""":class:`.ContactsDict` : contact dict for these contacts"""
return ContactsDict(self)
def __hash__(self):
return hash((self.cutoff, self.n_neighbors_ignored,
frozenset(self._query), frozenset(self._haystack),
self.topology))
def __eq__(self, other):
is_equal = (self.cutoff == other.cutoff
and self.n_neighbors_ignored == other.n_neighbors_ignored
and self.query == other.query
and self.haystack == other.haystack
and self.topology == other.topology)
return is_equal
def to_dict(self):
"""Convert object to a dict.
Keys should be strings; values should be (JSON-) serializable.
See also
--------
from_dict
"""
# need to explicitly convert possible np.int64 to int in several
dct = {
'topology': self._serialize_topology(self.topology),
'cutoff': self._cutoff,
'query': list([int(val) for val in self._query]),
'haystack': list([int(val) for val in self._haystack]),
'all_atoms': tuple(
[int(val) for val in self._all_atoms]),
'n_neighbors_ignored': self._n_neighbors_ignored,
'atom_idx_to_residue_idx': self._atom_idx_to_residue_idx,
'atom_contacts': \
self._serialize_contact_counter(self._atom_contacts),
'residue_contacts': \
self._serialize_contact_counter(self._residue_contacts),
'use_atom_slice': self._use_atom_slice}
return dct
@classmethod
def from_dict(cls, dct):
"""Create object from dict.
Parameters
----------
dct : dict
dict-formatted serialization (see to_dict for details)
See also
--------
to_dict
"""
deserialize_set = set
deserialize_atom_to_residue_dct = lambda d: {int(k): d[k] for k in d}
deserialization_helpers = {
'topology': cls._deserialize_topology,
'atom_contacts': cls._deserialize_contact_counter,
'residue_contacts': cls._deserialize_contact_counter,
'query': deserialize_set,
'haystack': deserialize_set,
'all_atoms': deserialize_set,
'atom_idx_to_residue_idx': deserialize_atom_to_residue_dct
}
for key in deserialization_helpers:
if key in dct:
dct[key] = deserialization_helpers[key](dct[key])
kwarg_keys = inspect_method_arguments(cls.__init__)
set_keys = set(dct.keys())
missing = set(kwarg_keys) - set_keys
dct.update({k: None for k in missing})
instance = cls.__new__(cls)
for k in dct:
setattr(instance, "_" + k, dct[k])
return instance
@staticmethod
def _deserialize_topology(topology_json):
"""Create MDTraj topology from JSON-serialized version"""
table, bonds = json.loads(topology_json)
topology_df = pd.read_json(table)
topology = md.Topology.from_dataframe(topology_df,
np.array(bonds))
return topology
@staticmethod
def _serialize_topology(topology):
"""Serialize MDTraj topology (to JSON)"""
table, bonds = topology.to_dataframe()
json_tuples = (table.to_json(), bonds.tolist())
return json.dumps(json_tuples)
# TODO: adding a separate object for these frozenset counters will be
# useful for many things, and this serialization should be moved there
@staticmethod
def _serialize_contact_counter(counter):
"""JSON string from contact counter"""
# have to explicitly convert to int because json doesn't know how to
# serialize np.int64 objects, which we get in Python 3
serializable = {json.dumps([int(val) for val in key]): counter[key]
for key in counter}
return json.dumps(serializable)
@staticmethod
def _deserialize_contact_counter(json_string):
"""Contact counted from JSON string"""
dct = json.loads(json_string)
counter = collections.Counter({
frozenset(json.loads(key)): dct[key] for key in dct
})
return counter
def to_json(self):
"""JSON-serialized version of this object.
See also
--------
from_json
"""
dct = self.to_dict()
return json.dumps(dct)
@classmethod
def from_json(cls, json_string):
"""Create object from JSON string
Parameters
----------
json_string : str
JSON-serialized version of the object
See also
--------
to_json
"""
dct = json.loads(json_string)
return cls.from_dict(dct)
def _check_compatibility(self, other, err=AssertionError):
compatibility_attrs = ['cutoff', 'topology', 'query', 'haystack',
'n_neighbors_ignored']
failed_attr = {}
for attr in compatibility_attrs:
self_val = getattr(self, attr)
other_val = getattr(other, attr)
if self_val != other_val:
failed_attr.update({attr: (self_val, other_val)})
msg = "Incompatible ContactObjects:\n"
for (attr, vals) in failed_attr.items():
msg += " {attr}: {self} != {other}\n".format(
attr=attr, self=str(vals[0]), other=str(vals[1])
)
if failed_attr:
raise err(msg)
return True
def save_to_file(self, filename, mode="w"):
"""Save this object to the given file.
Parameters
----------
filename : string
the file to write to
mode : 'w' or 'a'
file writing mode. Use 'w' to overwrite, 'a' to append. Note
that writing by bytes ('b' flag) is automatically added.
See also
--------
from_file : load from generated file
"""
with open(filename, mode+"b") as f:
pickle.dump(self, f)
@classmethod
def from_file(cls, filename):
"""Load this object from a given file
Parameters
----------
filename : string
the file to read from
Returns
-------
:class:`.ContactObject`:
the reloaded object
See also
--------
save_to_file : save to a file
"""
with open(filename, "rb") as f:
reloaded = pickle.load(f)
return reloaded
def __sub__(self, other):
return ContactDifference(positive=self, negative=other)
@property
def cutoff(self):
"""float : cutoff distance for contacts, in nanometers"""
return self._cutoff
@property
def n_neighbors_ignored(self):
"""int : number of neighbor residues (in same chain) to ignore"""
return self._n_neighbors_ignored
@property
def query(self):
"""list of int : indices of atoms to include as query"""
return list(self._query)
@property
def haystack(self):
"""list of int : indices of atoms to include as haystack"""
return list(self._haystack)
@property
def all_atoms(self):
"""list of int: all atom indices used in the contact map"""
return list(self._all_atoms)
@property
def topology(self):
"""
:class:`mdtraj.Topology` :
topology object for this system
The topology includes information about the atoms, how they are
grouped into residues, and how the residues are grouped into
chains.
"""
return self._topology
@property
def use_atom_slice(self):
"""bool : Indicates if `mdtraj.atom_slice()` is used before calculating
the contact map"""
return self._use_atom_slice
@property
def residue_query_atom_idxs(self):
"""dict : maps query residue index to atom indices in query"""
result = collections.defaultdict(list)
for atom_idx in self._query:
residue_idx = self._r_atom_idx_to_residue_idx[atom_idx]
if self.use_atom_slice:
atom_idx = self._idx_to_s_idx_dict[atom_idx]
result[residue_idx].append(atom_idx)
return result
@property
def residue_ignore_atom_idxs(self):
"""dict : maps query residue index to atom indices to ignore"""
all_atoms_set = set(self._all_atoms)
result = {}
for residue_idx in self.residue_query_atom_idxs.keys():
residue = self.topology.residue(residue_idx)
# Several steps to go residue indices -> atom indices
ignore_residue_idxs = residue_neighborhood(
residue,
self._n_neighbors_ignored
)
ignore_residues = [self.topology.residue(idx)
for idx in ignore_residue_idxs]
ignore_atoms = sum([list(res.atoms)
for res in ignore_residues], [])
ignore_atom_idxs = self._ignore_atom_idx(ignore_atoms,
all_atoms_set)
result[residue_idx] = ignore_atom_idxs
return result
def _ignore_atom_idx(self, atoms, all_atoms_set):
result = set([atom.index for atom in atoms])
if self._use_atom_slice:
result &= all_atoms_set
result = set(map(self.idx_to_s_idx, result))
return result
@property
def haystack_residues(self):
"""list : residues for atoms in the haystack"""
return _residue_for_atom(self.topology, self.haystack)
@property
def query_residues(self):
"""list : residues for atoms in the query"""
return _residue_for_atom(self.topology, self.query)
@property
def haystack_residue_range(self):
"""(int, int): min and (max + 1) of haystack residue indices"""
return _range_from_object_list(self.haystack_residues)
@property
def query_residue_range(self):
"""(int, int): min and (max + 1) of query residue indices"""
return _range_from_object_list(self.query_residues)
def most_common_atoms_for_residue(self, residue):
"""
Most common atom contact pairs for contacts with the given residue
Parameters
----------
residue : Residue or int
the Residue object or index representing the residue for which
the most common atom contact pairs will be calculated
Returns
-------
list :
Atom contact pairs involving given residue, order of frequency.
Referring to the list as ``l``, each element of the list
``l[e]`` consists of two parts: ``l[e][0]`` is a list containing
the two MDTraj Atom objects that make up the contact, and
``l[e][1]`` is the measure of how often the contact occurs.
"""
residue = _residue_and_index(residue, self.topology)[0]
residue_atoms = set(atom.index for atom in residue.atoms)
results = []
for atoms, number in self.atom_contacts.most_common_idx():
atoms_in_residue = atoms & residue_atoms
if atoms_in_residue:
as_atoms = [self.topology.atom(a) for a in atoms]
results += [(as_atoms, number)]
return results
def most_common_atoms_for_contact(self, contact_pair):
"""
Most common atom contacts for a given residue contact pair
Parameters
----------
contact_pair : length 2 list of Residue or int
the residue contact pair for which the most common atom contact
pairs will be calculated
Returns
-------
list :
Atom contact pairs for the residue contact pair, in order of
frequency. Referring to the list as ``l``, each element of the
list ``l[e]`` consists of two parts: ``l[e][0]`` is a list
containing the two MDTraj Atom objects that make up the contact,
and ``l[e][1]`` is the measure of how often the contact occurs.
"""
contact_pair = list(contact_pair)
res_1 = _residue_and_index(contact_pair[0], self.topology)[0]
res_2 = _residue_and_index(contact_pair[1], self.topology)[0]
atom_idxs_1 = set(atom.index for atom in res_1.atoms)
atom_idxs_2 = set(atom.index for atom in res_2.atoms)
all_atom_pairs = [
frozenset(pair)
for pair in itertools.product(atom_idxs_1, atom_idxs_2)
]
result = [([self.topology.atom(idx) for idx in contact[0]], contact[1])
for contact in self.atom_contacts.most_common_idx()
if frozenset(contact[0]) in all_atom_pairs]
return result
def slice_trajectory(self, trajectory):
# Prevent (memory) expensive atom slicing if not needed.
# This check is also needed here because ContactFrequency slices the
# whole trajectory before calling this function.
if self.use_atom_slice and (len(self._all_atoms) <
trajectory.topology.n_atoms):
sliced_trajectory = _atom_slice(trajectory, self._all_atoms)
else:
sliced_trajectory = trajectory
return sliced_trajectory
def contact_map(self, trajectory, frame_number, residue_query_atom_idxs,
residue_ignore_atom_idxs):
"""
Returns atom and residue contact maps for the given frame.
Parameters
----------
frame : mdtraj.Trajectory
the desired frame (uses the first frame in this trajectory)
residue_query_atom_idxs : dict
residue_ignore_atom_idxs : dict
Returns
-------
atom_contacts : collections.Counter
residue_contact : collections.Counter
"""
used_trajectory = self.slice_trajectory(trajectory)
neighborlist = md.compute_neighborlist(used_trajectory, self.cutoff,
frame_number)
contact_pairs = set([])
residue_pairs = set([])
for residue_idx in residue_query_atom_idxs:
ignore_atom_idxs = set(residue_ignore_atom_idxs[residue_idx])
query_idxs = residue_query_atom_idxs[residue_idx]
for atom_idx in query_idxs:
# sets should make this fast, esp since neighbor_idxs
# should be small and s-t is avg cost len(s)
neighbor_idxs = set(neighborlist[atom_idx])
contact_neighbors = neighbor_idxs - ignore_atom_idxs
contact_neighbors = contact_neighbors & self._u_haystack
# frozenset is unique key independent of order
# local_pairs = set(frozenset((atom_idx, neighb))
# for neighb in contact_neighbors)
local_pairs = set(map(
frozenset,
itertools.product([atom_idx], contact_neighbors)
))
contact_pairs |= local_pairs
# contact_pairs |= set(frozenset((atom_idx, neighb))
# for neighb in contact_neighbors)
local_residue_partners = set(self._atom_idx_to_residue_idx[a]
for a in contact_neighbors)
local_res_pairs = set(map(
frozenset,
itertools.product([residue_idx], local_residue_partners)
))
residue_pairs |= local_res_pairs
atom_contacts = collections.Counter(contact_pairs)
# residue_pairs = set(
# frozenset(self._atom_idx_to_residue_idx[aa] for aa in pair)
# for pair in contact_pairs
# )
residue_contacts = collections.Counter(residue_pairs)
return (atom_contacts, residue_contacts)
def convert_atom_contacts(self, atom_contacts):
if self._use_atom_slice:
result = {frozenset(tuple(map(self.s_idx_to_idx, key))): value
for key, value in atom_contacts.items()}
return collections.Counter(result)
else:
return atom_contacts
@property
def atom_contacts(self):
n_atoms = self.topology.n_atoms
return ContactCount(self._atom_contacts, self.topology.atom,
n_atoms, n_atoms)
@property
def residue_contacts(self):
n_res = self.topology.n_residues
return ContactCount(self._residue_contacts, self.topology.residue,
n_res, n_res)