"""Plugin for automated WE hyperparameter optimization."""
import msm_we.msm_we
import westpa
from westpa.core import extloader
from westpa.cli.core import w_run
from msm_we import optimization
import numpy as np
import pickle
from rich.progress import Progress
import ray
from synd.westpa.propagator import get_segment_parent_index
from westpa.core.data_manager import create_dataset_from_dsopts
@ray.remote
class GlobalModelActor:
"""
Ray-parallel Actor that loads a model and holds it in memory. Used by the PcoordCalculator.
"""
def __init__(self, model, processCoordinates, synd_model, original_pcoord_ndim):
msm_we.msm_we.modelWE.processCoordinates = processCoordinates
self.model = model
assert hasattr(self.model, "processCoordinates")
self.backmap = synd_model.backmap
self.original_pcoord_ndim = original_pcoord_ndim
def get_model(self):
return self.model
def get_original_pcoord(self, state_index):
return self.backmap(state_index)[: self.original_pcoord_ndim]
@ray.remote
class PcoordCalculator:
"""
Ray-parallel Actor that computes the extended progress coordinate
(original progress coordinate + dimensionality-reduce MSM features) for a structure.
"""
def __init__(self, model_actor, processCoordinates):
msm_we.msm_we.modelWE.processCoordinates = processCoordinates
self.model_actor = model_actor
self.model = ray.get(self.model_actor.get_model.remote())
def compute_new_structure_pcoord(self, structure, state_index):
reduceCoordinates = self.model.reduceCoordinates
# Take the zero index, because we're just passing a single structure
reduced_coords = reduceCoordinates(structure)[0]
original_pcoord = ray.get(
self.model_actor.get_original_pcoord.remote(state_index)
)
new_pcoord = np.concatenate([original_pcoord, reduced_coords])
return new_pcoord, state_index
[docs]class OptimizationDriver:
"""
WESTPA plugin to automatically handle performing optimization.
Using an haMSM, updates binning and allocation according to user-specified optimization algorithms.
An OptimizedBinMapper is constructed from the optimized binning and allocation, and WE is continued with the new
mapper.
Can be used by including the following entries in your west.cfg::
west:
plugins:
- plugin: msm_we.westpa_plugins.optimization_driver.OptimizationDriver
full_coord_map: A pickled dictionary mapping discrete states to full-coordinate structures
max_iters: Number of total iterations. WE will run for west.system.max_iters, perform optimization, and
continue for another west.system.max_iters, up to this value.
# The following parameters are optional, and provided as an example.
binning_strategy: An arbitrary python function defining a bin optimization scheme.
Takes in an msm_we.msm_we.modelWE and returns an array-like of length n_msm_states, where each
element is the index of the WE bin that MSM state will be assigned to by the OptimizedMapper.
allocation_strategy: An arbitrary python function defining an allocation optimization scheme.
Takes in an msm_we.msm_we.modelWE and returns an array of integer walker allocations for the WE bins.
"""
def __init__(self, sim_manager, plugin_config):
westpa.rc.pstatus("Initializing optimization plugin")
if not sim_manager.work_manager.is_master:
westpa.rc.pstatus("Not running on the master process, skipping")
return
self.data_manager = sim_manager.data_manager
self.sim_manager = sim_manager
self.we_driver = westpa.rc.get_we_driver()
self.propagator = westpa.rc.get_propagator()
self.synd_model = westpa.rc.get_propagator().synd_model
self.plugin_config = plugin_config
coord_map_path = plugin_config.get("full_coord_map")
with open(coord_map_path, "rb") as infile:
self.coord_map = pickle.load(infile)
# Big number is low priority -- this should run before anything else
self.priority = plugin_config.get("priority", 3)
self.cluster_on_pcoord = plugin_config.get("cluster_on_pcoord", False)
self.original_pcoord_dim = westpa.rc.config.get(
["west", "system", "system_options", "pcoord_ndim"]
)
sim_manager.register_callback(
sim_manager.finalize_run, self.do_optimization, self.priority
)
[docs] def do_optimization(self):
"""
Update WESTPA with an optimized bin mapper, bin allocation, and extend the progress coordinate. Then, continue
the WE for more iterations.
"""
# 1. Discrepancy calculation
westpa.rc.pstatus("Updating bin mapper")
we_bin_mapper = self.compute_optimized_bins()
self.we_driver.bin_mapper = we_bin_mapper
# 2. Update allocation
# "This is where I'd put my updated allocation... IF I HAD ONE"
westpa.rc.pstatus("Updating allocation")
we_allocation = self.compute_optimized_allocation()
self.we_driver.bin_target_counts = we_allocation
# 3. Update pcoord
# TODO: This is SynMD specific -- how can I make extending the progress coordinate generic?
# Maybe I could wrap the progress coordinate calculation as the original progress coordinate calculation,
# whatever it may be, and then additionally the result of `model.reduceCoordinates` on the full-coord
# structure (not sure the best way to get that, it is eventually stored in auxdata)
westpa.rc.pstatus("Updating pcoord map")
new_pcoord_map = self.compute_new_pcoord_map()
self.update_westpa_pcoord(new_pcoord_map)
# 4. Continue WE, with optimized parameters
# No need to re-initialize/restart, just extend max iterations and continue
remaining_iters = (
self.plugin_config.get("max_iters") - self.sim_manager.max_total_iterations
)
if remaining_iters > 0:
new_iters = min(
remaining_iters,
westpa.rc.config.get(["west", "propagation", "max_total_iterations"]),
)
self.sim_manager.max_total_iterations += new_iters
westpa.rc.pstatus(
f"\n\n=== Applying optimization and continuing for {new_iters} more iterations ===\n"
)
w_run.run_simulation()
else:
westpa.rc.pstatus("No more iterations for optimization, completing.")
[docs] @staticmethod
def default_allocation_optimizer(model):
"""A (trivial) example allocation optimization function, which returns an array with the target number of
walkers in each bin."""
westpa.rc.pstatus("\tNot updating allocation")
return westpa.rc.we_driver.bin_target_counts
[docs] def compute_optimized_allocation(self):
"""
Compute the optimal allocation.
If `plugin.allocation_strategy` is None or not provided, the allocation is not updated.
Otherwise, the constructed haMSM is passed to an arbitrary function that returns an array-like describing the
new walker allocation over the WE bins.
"""
allocation_strategy = self.plugin_config.get("allocation_strategy", None)
if allocation_strategy is None:
westpa.rc.pstatus("\tNot updating allocation")
allocation_optimizer = self.default_allocation_optimizer
else:
westpa.rc.pstatus(f"\tUsing {allocation_strategy} to update allocation")
allocation_optimizer = extloader.get_object(allocation_strategy)
new_target_counts = allocation_optimizer(self.data_manager.hamsm_model)
return new_target_counts
[docs] @staticmethod
def default_bin_optimizer(model):
"""Example bin optimization function, which assigns microstates to WE bins."""
n_active_bins = np.count_nonzero(westpa.rc.we_driver.bin_target_counts)
westpa.rc.pstatus(
"\tUsing default k-means MFPT optimization (optimization.get_clustered_mfpt_bins) "
"for bin optimization"
)
discrepancy, variance = optimization.solve_discrepancy(
tmatrix=model.Tmatrix, pi=model.pSS, B=model.indTargets
)
microstate_assignments = optimization.get_clustered_mfpt_bins(
variance, discrepancy, model.pSS, n_active_bins
)
return microstate_assignments
[docs] def compute_optimized_bins(self):
"""
Computes discrepancy and variance, and returns the resulting optimized bin mapper.
If `plugin.binning_strategy` is None or not provided, :code:`optimization.get_clustered_mfpt_bins()` is used.
Otherwise, the constructed haMSM is passed to an arbitrary function that returns an array-like with the WE bin
index of all MSM microbins excluding the basis/target (model.indBasis and model.indTargets).
Returns
-------
An OptimizedBinMapper
"""
model = self.data_manager.hamsm_model
binning_strategy = self.plugin_config.get("binning_strategy", None)
n_active_bins = np.count_nonzero(self.we_driver.bin_target_counts)
if binning_strategy is None:
bin_optimizer = self.default_bin_optimizer
else:
westpa.rc.pstatus(f"\tUsing {binning_strategy} for bin optimization")
bin_optimizer = extloader.get_object(binning_strategy)
microstate_assignments = bin_optimizer(model)
microstate_assignments = np.concatenate(
[microstate_assignments, [n_active_bins - 2, n_active_bins - 1]]
)
westpa.rc.pstatus(f"\tMicrostate assignments are {microstate_assignments}")
# 3. Update binning
base_mapper = model.clusters.bin_mapper
n_pcoord_dims = self.original_pcoord_dim
we_bin_mapper = optimization.OptimizedBinMapper(
n_active_bins,
# In case the pcoord is extended, this is the original pcoord dimensionality
n_pcoord_dims,
# If the pcoord was extended, pcoord boundaries are in the original pcoord space
model.basis_pcoord_bounds,
model.target_pcoord_bounds,
# The original, non-Optimized BinMapper that WESTPA was run with.
# Used for stratified clustering
base_mapper,
microstate_assignments,
model.clusters,
cluster_on_pcoord = self.cluster_on_pcoord
)
return we_bin_mapper
[docs] def compute_new_pcoord_map(self):
"""
SynD specific: Compute a new progress coordinate mapping.
Returns
-------
A dictionary of {state indices : extended progress coordinates}
"""
model = self.data_manager.hamsm_model
processCoordinates = self.data_manager.processCoordinates
new_pcoord_map = {}
# TODO: is this robust if you don't already have a ray cluster started?
n_actors = int(ray.available_resources().get("CPU", 1))
model_actor = GlobalModelActor.remote(
model, processCoordinates, self.synd_model, self.original_pcoord_dim
)
pcoord_calculators = [
PcoordCalculator.remote(model_actor, processCoordinates)
for i in range(n_actors)
]
msm_we.msm_we.modelWE.check_connect_ray()
ids = []
with Progress() as progress:
submit_task = progress.add_task(
"Submitting structures for pcoord calculation",
total=len(self.coord_map),
)
retrieve_task = progress.add_task(
f"Retrieving structure pcoords from {n_actors} workers",
total=len(self.coord_map),
)
for state_index, structure in self.coord_map.items():
_id = pcoord_calculators[
state_index % n_actors
].compute_new_structure_pcoord.remote(structure, state_index)
ids.append(_id)
progress.advance(submit_task)
while ids:
finished, ids = ray.wait(ids, num_returns=min(50, len(ids)), timeout=5)
results = ray.get(finished)
for pcoord, state_index in results:
new_pcoord_map[state_index] = pcoord
progress.advance(retrieve_task)
return new_pcoord_map
[docs] def update_westpa_pcoord(self, new_pcoord_map):
"""
Changing a progress coordinate during a WE run requires a number of changes in WESTPA's internal state.
This handles making those, so you can call w_run and continue with the new, changed pcoord
Parameters
----------
new_pcoord_map: A dictionary mapping discrete states to the new, extended pcoord
"""
# TODO: Replace this with propagator.get_pcoord
self.propagator.synd_model._backmappers["default"] = new_pcoord_map.get
new_pcoord_dim = new_pcoord_map.get(0).shape[0]
westpa.rc.pstatus(f"New pcoord dimensionality is {new_pcoord_dim}")
system = westpa.rc.get_system_driver()
data_manager = westpa.rc.get_data_manager()
sim_manager = westpa.rc.get_sim_manager()
# This function causes problems when extending pcoords, and isn't actually used for anything, so just skip it
data_manager.get_new_weight_data = lambda x: None
# # Update system driver
system.pcoord_ndim = new_pcoord_dim
# # Update the pcoord dataset in west.h5
data_manager.open_backing()
iter_group = data_manager.get_iter_group(sim_manager.n_iter)
n_iter = sim_manager.n_iter
segments = data_manager.get_segments(
n_iter=sim_manager.n_iter, load_pcoords=True
)
del data_manager.we_h5file[
"/iterations/iter_{:0{prec}d}/pcoord".format(
int(n_iter), prec=data_manager.iter_prec
)
]
data_manager.flush_backing()
pcoord_opts = data_manager.dataset_options.get(
"pcoord",
{"name": "pcoord", "h5path": "pcoord", "compression": False},
)
# # Update the currently held segments
westpa.rc.pstatus(f"Attempting to fetch segments for iter {sim_manager.n_iter}")
for segment in segments:
# TODO: This is SynD specific, but should be easy to port over to something generic.
# Use propagator.get_pcoord directly
parent_state_index = get_segment_parent_index(segment)
segment.pcoord = np.concatenate(
[
[new_pcoord_map[parent_state_index]],
np.zeros(shape=(system.pcoord_len - 1, system.pcoord_ndim)),
]
)
create_dataset_from_dsopts(
iter_group,
pcoord_opts,
data=np.array([segment.pcoord for segment in segments]),
shape=(
len(sim_manager.segments),
system.pcoord_len,
system.pcoord_ndim,
),
dtype=system.pcoord_dtype,
)
data_manager.update_segments(sim_manager.n_iter, segments)
# The initial states that were computed for the next iteration are using the old-style pcoord,
# so we can't bin them correctly with the OptimizedBinMapper unless we update them.
# Although the old istates won't be used if you have gen_istate enabled, we still have to
# recreate the dataset. Otherwise, it'll be the wrong shape for the new istates.
initial_states = data_manager.get_initial_states(sim_manager.n_iter)
n_initial_states = len(initial_states)
del data_manager.we_h5file[
"/iterations/iter_{:0{prec}d}/ibstates/istate_pcoord".format(
int(n_iter), prec=data_manager.iter_prec
)
]
del data_manager.we_h5file[
"/iterations/iter_{:0{prec}d}/ibstates/istate_index".format(
int(n_iter), prec=data_manager.iter_prec
)
]
new_istates = data_manager.create_initial_states(
n_states=n_initial_states, n_iter=n_iter
)
for old_istate, new_istate in zip(initial_states, new_istates):
new_istate = old_istate
bstate_id = old_istate.basis_state_id
parent_state_index = int(sim_manager.next_iter_bstates[bstate_id].auxref)
new_istate.pcoord = new_pcoord_map[parent_state_index]
data_manager.flush_backing()
for segment in sim_manager.segments.values():
parent_state = get_segment_parent_index(segment)
segment.pcoord = new_pcoord_map[parent_state]
data_manager.flush_backing()
data_manager.close_backing()