Skip to content
Merged

cleanup #1688

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions pytraj/actions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,31 @@ def __new__(cls, value):
from ..core.c_core import CpptrajState, Command


def add_reference_dataset(dslist, name, frame, topology=None):
"""Helper function to add reference dataset consistently

Parameters
----------
dslist : CpptrajDatasetList
Dataset list to add reference to
name : str
Name for the reference dataset
frame : Frame
Reference frame
topology : Topology, optional
Topology for the reference, defaults to frame.top

Returns
-------
dataset : Dataset
The created reference dataset
"""
dataset = dslist.add('reference', name)
dataset.top = topology or getattr(frame, 'top', None)
dataset.add_frame(frame)
return dataset


class CommandBuilder:

def __init__(self):
Expand Down Expand Up @@ -147,6 +172,13 @@ def add_dataset(self, dataset_type, dataset_name, data, aspect=None):
else:
dataset.data = np.asarray(data).astype('f8')

def add_reference(self, name, frame, topology=None):
"""Convenience method to add reference dataset"""
dataset = self.datasets.add(DatasetType.REFERENCE, name)
dataset.top = topology or getattr(frame, 'top', None)
dataset.add_frame(frame)
return dataset

def run_analysis(self, command):
self.analysis(command, dslist=self.datasets)
return self.datasets
Expand Down
12 changes: 2 additions & 10 deletions pytraj/actions/correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,8 @@ def atomiccorr(traj,
if byres:
command += " byres"

c_dslist = CpptrajDatasetList()
action = c_action.Action_AtomicCorr()
action.read_input(command, top=traj.top, dslist=c_dslist)
action.setup(traj.top)

for frame in traj:
action.compute(frame)

action.post_process()
return get_data_from_dtype(c_dslist, dtype=dtype)
action_datasets, _ = do_action(traj, command, c_action.Action_AtomicCorr)
return get_data_from_dtype(action_datasets, dtype=dtype)


def timecorr(vec0,
Expand Down
8 changes: 3 additions & 5 deletions pytraj/actions/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Geometric analysis functions: distances, angles, dihedrals
"""
from .base import *
from .base import add_reference_dataset

__all__ = [
'distance', 'pairwise_distance', 'angle', 'dihedral', 'mindist',
Expand Down Expand Up @@ -616,17 +617,14 @@ def dihedral_rms(traj=None,

if ref is not None:
ref_frame = get_reference(traj, ref)
ref_dataset = action_datasets.add(DatasetType.REFERENCE_FRAME,
name=ref_name)
ref_dataset.top = ref_frame.top or traj.top
ref_dataset.add_frame(ref_frame)
add_reference_dataset(action_datasets, ref_name, ref_frame, ref_frame.top or traj.top)

action_datasets, _ = do_action(traj,
command,
c_action.Action_DihedralRMS,
dslist=action_datasets)

if ref is not None:
action_datasets._pop(0)
action_datasets.remove_at(0)

return get_data_from_dtype(action_datasets, dtype=dtype)
2 changes: 1 addition & 1 deletion pytraj/actions/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def projection(traj,
command = f"evecs {mode_name} {mask} beg 1 end {n_vectors}"
projection_action(command, traj, top=top, dslist=action_datasets)

action_datasets._pop(0)
action_datasets.remove_at(0)

return get_data_from_dtype(action_datasets, dtype=dtype)

Expand Down
24 changes: 4 additions & 20 deletions pytraj/actions/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,8 @@ def surf(traj=None, mask="", dtype='ndarray', frame_indices=None, top=None):
>>> traj = pt.datafiles.load_tz2_ortho()
>>> data = pt.surf(traj, '@CA')
"""
action = c_action.Action_Surf()
c_dslist = CpptrajDatasetList()
action.read_input(mask, top=traj.top, dslist=c_dslist)
action.setup(traj.top)

for frame in traj:
action.compute(frame)

action.post_process()
return get_data_from_dtype(c_dslist, dtype=dtype)
action_datasets, _ = do_action(traj, mask, c_action.Action_Surf)
return get_data_from_dtype(action_datasets, dtype=dtype)


@super_dispatch()
Expand Down Expand Up @@ -217,16 +209,8 @@ def volume(traj=None, mask="", top=None, dtype='ndarray', frame_indices=None):
>>> traj = pt.datafiles.load_tz2_ortho()
>>> vol = pt.volume(traj, '@CA')
"""
action = c_action.Action_Volume()
c_dslist = CpptrajDatasetList()
action.read_input(mask, top=traj.top, dslist=c_dslist)
action.setup(traj.top)

for frame in traj:
action.compute(frame)

action.post_process()
return get_data_from_dtype(c_dslist, dtype=dtype)
action_datasets, _ = do_action(traj, mask, c_action.Action_Volume)
return get_data_from_dtype(action_datasets, dtype=dtype)


@super_dispatch()
Expand Down
28 changes: 9 additions & 19 deletions pytraj/actions/topology_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Topology manipulation functions: centering, alignment, imaging, etc.
"""
from .base import *
from .base import add_reference_dataset
# Ensure _assert_mutable is available
from .base import _assert_mutable, _ensure_mutable
from ..builder.build import make_structure
Expand Down Expand Up @@ -177,9 +178,7 @@ def align(traj,
reference_topology = traj.top

action_datasets = CpptrajDatasetList()
action_datasets.add(DatasetType.REFERENCE, name=reference_name)
action_datasets[0].top = reference_topology
action_datasets[0].add_frame(ref)
add_reference_dataset(action_datasets, reference_name, ref, reference_topology)

align_action = c_action.Action_Align()
align_action.read_input(command, top=top, dslist=action_datasets)
Expand All @@ -190,7 +189,7 @@ def align(traj,
align_action.post_process()

# remove ref
action_datasets._pop(0)
action_datasets.remove_at(0)

return traj

Expand Down Expand Up @@ -322,13 +321,8 @@ def closest(traj=None,

return traj_mut
else:
c_dslist = CpptrajDatasetList()
action = c_action.Action_Closest()
action.read_input(command, top=traj.top, dslist=c_dslist)
action.setup(traj.top)

_closest_iter(c_action, traj)
return get_data_from_dtype(c_dslist, dtype=dtype)
action_datasets, _ = do_action(traj, command, c_action.Action_Closest)
return get_data_from_dtype(action_datasets, dtype=dtype)


@super_dispatch()
Expand Down Expand Up @@ -774,25 +768,21 @@ def atom_map(traj, ref, rmsfit=False):
command = ' '.join(('my_target my_ref', options))
dataset_list = CpptrajDatasetList()

target = dataset_list.add('reference', name='my_target')
target.top = traj.top
target.append(traj[0])
add_reference_dataset(dataset_list, 'my_target', traj[0], traj.top)

refset = dataset_list.add('reference', name='my_ref')
refset.top = ref.top if ref.top is not None else traj.top
if not isinstance(ref, Frame):
ref_frame = ref[0]
else:
ref_frame = ref
refset.append(ref_frame)
add_reference_dataset(dataset_list, 'my_ref', ref_frame, ref.top if ref.top is not None else traj.top)

with capture_stdout() as (out, err):
act(command, traj, top=traj.top, dslist=dataset_list)
act.post_process()

# free memory of two reference
dataset_list._pop(0)
dataset_list._pop(0)
dataset_list.remove_at(0)
dataset_list.remove_at(0)

return (out.read(), get_data_from_dtype(dataset_list, dtype='ndarray'))

Expand Down
65 changes: 16 additions & 49 deletions pytraj/actions/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Utility and miscellaneous functions
"""
from .base import *
from .base import _assert_mutable
from .base import _assert_mutable, add_reference_dataset
from ..trajectory.trajectory_iterator import TrajectoryIterator

__all__ = [
Expand Down Expand Up @@ -444,11 +444,9 @@ def native_contacts(traj=None,
condition=include_solvent).add("byresidue", condition=byres).add(
options, condition=bool(options)).build())

action_datasets.add(DatasetType.REFERENCE_FRAME, 'myframe')
action_datasets[0].top = top
action_datasets[0].add_frame(ref)
add_reference_dataset(action_datasets, 'myframe', ref, top)
native_contacts_action(command, traj, top=top, dslist=action_datasets)
action_datasets._pop(0)
action_datasets.remove_at(0)

return get_data_from_dtype(action_datasets, dtype=dtype)

Expand All @@ -470,16 +468,8 @@ def grid(traj=None, command="", top=None, dtype='dataset'):
-------
out : DatasetList
"""
c_dslist = CpptrajDatasetList()
action = c_action.Action_Grid()
action.read_input(command, top=traj.top, dslist=c_dslist)
action.setup(traj.top)

for frame in traj:
action.compute(frame)

action.post_process()
return get_data_from_dtype(c_dslist, dtype=dtype)
action_datasets, _ = do_action(traj, command, c_action.Action_Grid, top=top)
return get_data_from_dtype(action_datasets, dtype=dtype)


def transform(traj, by, frame_indices=None):
Expand Down Expand Up @@ -588,17 +578,8 @@ def lipidscd(traj, mask='', options='', dtype='dict', top=None):
out : dict or DatasetList
"""
command = mask + " " + options

c_dslist = CpptrajDatasetList()
action = c_action.Action_LipidOrder()
action.read_input(command, top=traj.top, dslist=c_dslist)
action.setup(traj.top)

for frame in traj:
action.compute(frame)

action.post_process()
return get_data_from_dtype(c_dslist, dtype=dtype)
action_datasets, _ = do_action(traj, command, c_action.Action_LipidOrder, top=top)
return get_data_from_dtype(action_datasets, dtype=dtype)


@super_dispatch()
Expand Down Expand Up @@ -628,22 +609,13 @@ def xtalsymm(traj, mask='', options='', ref=None, **kwargs):

if ref is not None:
ref_frame = get_reference(traj, ref)
ref_dataset = c_dslist.add('reference', 'ref')
ref_dataset.top = ref_frame.top or traj.top
ref_dataset.add_frame(ref_frame)
add_reference_dataset(c_dslist, 'ref', ref_frame, ref_frame.top or traj.top)
command += " ref ref"

action = c_action.Action_XtalSymm()
action.read_input(command, top=traj.top, dslist=c_dslist)
action.setup(traj.top)

for frame in traj:
action.compute(frame)

action.post_process()
do_action(traj, command, c_action.Action_XtalSymm, dslist=c_dslist)

if ref is not None:
c_dslist._pop(0) # remove reference
c_dslist.remove_at(0) # remove reference

return c_dslist

Expand All @@ -667,7 +639,7 @@ def analyze_modes(mode_type,
command = ' '.join((mode_type, 'name {}'.format(my_modes), options))
runner.run_analysis(command)

runner.datasets._pop(0)
runner.datasets.remove_at(0)
return get_data_from_dtype(runner.datasets, dtype=dtype)


Expand Down Expand Up @@ -712,14 +684,9 @@ def hausdorff(matrix, options='', dtype='ndarray'):
"""
runner = AnalysisRunner(c_analysis.Analysis_Hausdorff)
runner.add_dataset(DatasetType.MATRIX_DBL, "my_matrix", matrix)

command = f"my_matrix {options}"
runner.run_analysis(command)

runner.datasets._pop(0)

data = get_data_from_dtype(runner.datasets, dtype)
return data
runner.run_analysis(f"my_matrix {options}")
runner.datasets.remove_at(0) # Remove input matrix
return get_data_from_dtype(runner.datasets, dtype)


def permute_dihedrals(traj, filename, options=''):
Expand Down Expand Up @@ -753,8 +720,8 @@ def permute_dihedrals(traj, filename, options=''):
with Command() as executor:
executor.dispatch(state, command)

state.data._pop(0)
state.data._pop(0)
state.data.remove_at(0)
state.data.remove_at(0)


def check_structure(traj,
Expand Down
14 changes: 3 additions & 11 deletions pytraj/actions/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def volmap(traj,
if volume_ds.key.endswith("[totalvol]"):
index = i
if index is not None:
action_datasets._pop(index)
action_datasets.remove_at(index)
return get_data_from_dtype(action_datasets, dtype)


Expand Down Expand Up @@ -339,16 +339,8 @@ def gist(traj,
if solvent_mask:
command = f"solvent {solvent_mask} " + command

c_dslist = CpptrajDatasetList()
action = c_action.Action_Gist()
action.read_input(command, top=traj.top, dslist=c_dslist)
action.setup(traj.top)

for frame in traj:
action.compute(frame)

action.post_process()
return get_data_from_dtype(c_dslist, dtype=dtype)
action_datasets, _ = do_action(traj, command, c_action.Action_Density)
return get_data_from_dtype(action_datasets, dtype=dtype)


def _grid(traj,
Expand Down
Loading
Loading