Source code for mbgdml.analysis.problematic

# MIT License
#
# Copyright (c) 2020 monopsony
# Copyright (c) 2020-2023, Alex M. Maldonado
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""Identifies problematic (high error) structures for model to train on next.
Some code within this module is modified from https://github.com/fonsecag/MLFF.
"""

import os
import numpy as np
from scipy.spatial.distance import pdist

from . import clustering
from ..mbe import mbePredict
from ..utils import save_json
from ..losses import loss_f_mse
from ..logger import GDMLLogger

log = GDMLLogger(__name__)


[docs]class ProblematicStructures: r"""Find problematic structures for models in datasets. Clusters all structures in a dataset using agglomerative and k-means algorithms using a structural descriptor and energies. """ def __init__( self, models, predict_model, use_ray=False, n_workers=1, ray_address="auto", wkr_chunk_size=100, ): """ Parameters ---------- models : :obj:`list` of :obj:`mbgdml.models.Model` Machine learning model objects that contain all information to make predictions using ``predict_model``. predict_model : ``callable`` A function that takes ``Z``, ``R``, ``entity_ids``, ``nbody_gen``, ``model`` and computes energies and forces. This will be turned into a ray remote function if ``use_ray = True``. This can return total properties or all individual :math:`n`-body energies and forces. use_ray : :obj:`bool`, default: ``False`` Use `ray <https://docs.ray.io/en/latest/>`__ to parallelize computations. n_workers : :obj:`int`, default: ``1`` Total number of workers available for ray. This is ignored if ``use_ray`` is ``False``. ray_address : :obj:`str`, default: ``"auto"`` Ray cluster address to connect to. wkr_chunk_size : :obj:`int`, default: ``100`` Number of :math:`n`-body structures to assign to each spawned worker with ray. """ self.models = models self.mbe_pred = mbePredict( models, predict_model, use_ray=use_ray, n_workers=n_workers, ray_address=ray_address, wkr_chunk_size=wkr_chunk_size, ) self.course_n_cl_r = 10 r"""Number of clusters used in the course stage for geometries. There will be a total of ``course_n_cl_r`` clusters. :type: :obj:`int`, default: ``10`` """ self.course_n_cl_e = 5 r"""Number of clusters used in the course stage for energies. After clustering structures by geometric descriptor (using ``course_n_cl_r``), then each cluster is further refined by energies. There will be a total of ``course_n_cl_r`` :math:`\\times` ``course_n_cl_e`` clusters. :type: :obj:`int`, default: ``5`` """ self.refine_n_cl = 100 r"""Number of clusters used in the refine stage. :type: :obj:`int`, default: ``100`` """ self.refine_min_r_ratio = 2.0 r"""Minimum ratio of structures to number of clusters in the refine stage. Will reduce the minimum loss set point for refinement until ``refine_n_cl`` :math:`\\times` ``refine_min_r_ratio`` structures are available. :type: :obj:`int`, default: ``2.0`` """ self.loss_func = loss_f_mse r"""Loss function used to determine problematic structures. :type: ``callable``, default: :obj:`mbgdml.losses.loss_f_mse` """ self.loss_func_kwargs = {} r"""Any keyword arguments beyond ``results`` for the loss function. :type: :obj:`dict`, default: ``{}`` """ self.kwargs_subplot = {"figsize": (5.5, 3), "constrained_layout": True} r"""``pyplot.subplot`` keyword arguments. **Default:** .. code-block:: python {'figsize': (5.5, 3), 'constrained_layout': True} :type: :obj:`dict` """ self.plot_lolli_color = "#223F4B" r"""Lollipop color. :type: :obj:`str`, default: ``'#223F4B'`` """ self.plot_annotate_cl_idx = False r"""Add the cluster index above the cluster loss value. :type: :obj:`bool`, default: ``False`` """
[docs] def get_pd(self, R): r"""Computes pairwise distances from atomic positions. Parameters ---------- R : :obj:`numpy.ndarray`, shape: ``(n_samples, n_atoms, 3)`` Atomic positions. Returns ------- :obj:`numpy.ndarray` Pairwise distances of atoms in each structure with shape ``(n_samples, n_atoms*(n_atoms-1)/2)``. """ assert R.ndim == 3 n_samples, n_atoms, _ = R.shape n_pd = int(n_atoms * ((n_atoms - 1) / 2)) R_pd = np.zeros((n_samples, n_pd)) for i, r in enumerate(R): R_pd[i] = pdist(r) return R_pd
[docs] def prob_cl_indices(self, cl_idxs, cl_losses): r"""Identify problematic dataset indices. Parameters ---------- cl_idxs : :obj:`list` of :obj:`numpy.ndarray` Clustered dataset indices. cl_losses : :obj:`numpy.ndarray` Losses for each cluster in ``cl_idxs``. Returns ------- :obj:`numpy.ndarray` Dataset indices from clusters with higher-than-average losses. """ log.info("Finding problematic structures") loss_bound = np.mean(cl_losses) # Initial minimum loss loss_step = loss_bound / 500 loss_bound += loss_step idxs = [] while len(idxs) < 1.5 * self.refine_n_cl: log.info("Minimum cluster loss : %.4f", loss_bound) cl_idxs_prob = np.concatenate(np.argwhere(cl_losses >= loss_bound)) clusters = np.array(cl_idxs, dtype=object)[cl_idxs_prob] idxs = np.concatenate(clusters) loss_bound -= loss_step log.info("N structures included : %d\n", len(idxs)) return idxs
[docs] def n_cl_samples(self, n_sample, cl_weights, cl_pop, cl_losses): r"""Number of dataset indices to sample from each cluster. Parameters ---------- n_sample : :obj:`int` Total number of dataset indices to sample from all clusters. cl_weights : :obj:`numpy.ndarray` Normalized cluster weights. cl_pop : :obj:`numpy.ndarray` Cluster populations. Returns ------- :obj:`numpy.ndarray` Number of dataset indices to sample from each cluster. """ samples = np.array(cl_weights * n_sample) samples = np.floor(samples) # Check that selections do not sample more than the population for i, pop in enumerate(cl_pop): if samples[i] > pop: samples[i] = pop # Try to have at least one sample from each cluster # (in order of max loss) arg_max = (-cl_losses).argsort() for i in arg_max: if np.sum(samples) == n_sample: return samples if samples[i] == 0: samples[i] = 1 # If there are still not enough samples, we start adding additional # samples in order of highest cluster losses. for i in arg_max: if np.sum(samples) == n_sample: return samples if samples[i] < cl_pop[i]: samples[i] += 1 return samples.astype(int)
[docs] def select_prob_indices(self, n_select, cl_idxs, idx_loss_cl): r"""Select ``n`` problematic dataset indices based on weighted cluster losses and distribution. Parameters ---------- n_select : :obj:`int` Number of problematic dataset indices to select. cl_idxs : :obj:`list` of :obj:`numpy.ndarray` Clustered dataset indices. idx_loss_cl : :obj:`list` of :obj:`numpy.ndarray` Clustered individual structure losses. Same shape as ``cl_idxs``. Returns ------- :obj:`numpy.ndarray``, shape: ``(n_select,)`` Problematic dataset indices. """ log.info("\nSelecting problematic structures") cl_losses = np.array([np.mean(losses) for losses in idx_loss_cl]) cl_pop = np.array([len(_) for _ in cl_idxs]) # Cluster population log.info("Computing cluster loss weights") cl_weights = (cl_losses / np.sum(cl_losses)) * (cl_pop / np.sum(cl_pop)) cl_weights_norm = np.array(cl_weights) / np.sum(cl_weights) # pylint: disable-next=invalid-name Ns = self.n_cl_samples(n_select, cl_weights_norm, cl_pop, cl_losses) log.info("Sampling structures") n_cl = len(cl_losses) prob_idxs = [] for i in range(n_cl): losses = idx_loss_cl[i] idxs = cl_idxs[i] ni = int(Ns[i]) # pylint: disable=invalid-name argmax = np.argsort(-losses)[:ni] prob_idxs.extend(idxs[argmax]) prob_idxs = np.array(prob_idxs) log.debug("Selected dataset indices:") log.log_array(prob_idxs, level=10) return prob_idxs
# pylint: disable-next=too-many-branches, too-many-statements
[docs] def find( self, dset, n_find, dset_is_train=True, train_idxs=None, write_json=True, save_cl_plot=True, image_format="png", image_dpi=600, save_dir=".", ): r"""Find problematic structures in a dataset. Uses agglomerative and k-means clustering on a dataset. First, the dataset is split into ``10`` clusters based on atomic pairwise distances. Then each cluster is further split into ``5`` clusters based on energies. Energies and forces are predicted, and then problematic structures are taken from clusters with higher-than-average losses. Here, the force MSE is used as the loss function. Finally, ``n_find`` structures are sampled from the 100 clusters based on a weighted cluster error distribution. Parameters ---------- dset : :obj:`mbgdml.data.DataSet` Dataset to cluster and analyze errors. n_find : :obj:`int` Number of dataset indices to find. dset_is_train : :obj:`bool`, default: ``True`` If ``dset`` is the training dataset. Training indices will be dropped from the analyses. train_idxs : :obj:`numpy.ndarray`, ndim: ``1``, default: :obj:`None` Training indices that will be dropped if ``dset_is_train`` is ``True``. These do not need to be provided for GDML models (as they are already stored in the model). write_json : :obj:`bool`, default: ``True`` Write JSON file detailing clustering and prediction errors. save_cl_plot : :obj:`bool`, default: ``True`` Plot cluster losses and histogram. image_format : :obj:`str`, default: ``png`` Format to save the image in. image_dpi : :obj:`int`, default: ``600`` Dots per inch to save the image. save_dir : :obj:`str`, default: ``'.'`` Directory to save any files. """ log.info( "---------------------------\n" "| Finding Problematic |\n" "| Structures |\n" "---------------------------\n" ) if write_json: self.json_dict = {} log.info("Loading dataset\n") Z, R, E, F = dset.Z, dset.R, dset.E, dset.F entity_ids, comp_ids = dset.entity_ids, dset.comp_ids # Removing training indices. R_idxs_orig = np.array(list(range(len(R)))) # pylint: disable=invalid-name if dset_is_train: log.info("Dropping indices already in training set") if len(self.models) != 1: log.warning("Cannot drop training indices if there are multiple models") log.warning("Not dropping any indices") assert len(self.models) == 1 if train_idxs is None: try: train_idxs = self.models[0].model_dict["idxs_train"] except Exception as e: raise AttributeError("Training indices were not provided") from e else: assert isinstance(train_idxs, np.ndarray) log.debug("Training indices") log.log_array(train_idxs, level=10) n_Ri = len(R_idxs_orig) # pylint: disable=invalid-name log.info("There are a total of %d structures", n_Ri) R_idxs = np.setdiff1d(R_idxs_orig, train_idxs) n_Rf = len(R_idxs) # pylint: disable=invalid-name log.info("Removed %d structures", n_Ri - n_Rf) else: R_idxs = R_idxs_orig # Note: Indices from this point on do not directly map to their index # in the dataset. We have to convert back to their original indices # when necessary. We refer to R_idxs as no-training indices. # Perform clustering based on pairwise distances and energies R, E, F = R[R_idxs], E[R_idxs], F[R_idxs] R_pd = self.get_pd(R) cl_data = (R_pd, E.reshape(-1, 1)) cl_algos = (clustering.agglomerative, clustering.kmeans) cl_kwargs = ({"n_clusters": 10}, {"n_clusters": 5}) cl_idxs = clustering.cluster_structures(cl_data, cl_algos, cl_kwargs) cl_pop = [len(i) for i in cl_idxs] if write_json: # Convert back to dataset indices just to write. # The no-train indices is still needed to compute errors and # problematic clustering. cl_idxs_write = [np.array(R_idxs[idxs]) for idxs in cl_idxs] self.json_dict["clustering"] = {} self.json_dict["clustering"]["indices"] = cl_idxs_write self.json_dict["clustering"]["population"] = cl_pop log.info("\nPredicting structures") t_prediction = log.t_start() # pylint: disable-next=unbalanced-tuple-unpacking E_pred, F_pred = self.mbe_pred.predict(Z, R, entity_ids, comp_ids) log.t_stop(t_prediction, message="Took {time} s") log.info("Computing prediction errors") E_errors = E_pred - E F_errors = F_pred - F log.debug("Energy errors") log.log_array(E_errors, level=10) log.debug("Force errors") log.log_array(F_errors, level=10) log.info("\nAggregating errors") # pylint: disable-next=invalid-name E_errors_cl = clustering.get_clustered_data(cl_idxs, E_errors) # pylint: disable-next=invalid-name F_errors_cl = clustering.get_clustered_data(cl_idxs, F_errors) # Computing cluster losses loss_kwargs = {"energy": E_errors_cl, "force": F_errors_cl} cl_losses = clustering.get_cluster_losses(self.loss_func, loss_kwargs) if write_json: self.json_dict["clustering"]["loss_function"] = self.loss_func.__name__ self.json_dict["clustering"]["losses"] = cl_losses prob_indices = self.prob_cl_indices(cl_idxs, cl_losses) # Problematic clustering log.info("Refine clustering of problematic structures") # Switching to problematic idxs for clustering. R_pd_prob = R_pd[prob_indices] # pylint: disable=invalid-name cl_data_prob = (R_pd_prob,) cl_algos_prob = (clustering.agglomerative,) cl_kwargs_prob = ({"n_clusters": self.refine_n_cl},) cl_idxs_prob = clustering.cluster_structures( cl_data_prob, cl_algos_prob, cl_kwargs_prob ) # switching back to no-training idxs cl_idxs_prob = [np.array(prob_indices[idxs]) for idxs in cl_idxs_prob] cl_pop_prob = [len(i) for i in cl_idxs_prob] if write_json: # Convert back to dataset indices just to write. cl_idxs_prob_write = [np.array(R_idxs[idxs]) for idxs in cl_idxs_prob] self.json_dict["problematic_clustering"] = {} self.json_dict["problematic_clustering"]["indices"] = cl_idxs_prob_write self.json_dict["problematic_clustering"]["population"] = cl_pop_prob log.info("Aggregating errors for problematic structures") # pylint: disable-next=invalid-name # E_errors_cluster_prob = clustering.get_clustered_data(cl_idxs_prob, E_errors) # pylint: disable-next=invalid-name # F_errors_cluster_prob = clustering.get_clustered_data(cl_idxs_prob, F_errors) # idx_loss_kwargs = {"energy": E_errors, "force": F_errors} structure_loss = np.empty(E_errors.shape) for i in range(len(structure_loss)): # pylint: disable=consider-using-enumerate structure_loss[i] = self.loss_func( {"energy": E_errors[i], "force": F_errors[i]} ) structure_loss_cl = clustering.get_clustered_data(cl_idxs_prob, structure_loss) if write_json: self.json_dict["problematic_clustering"][ "loss_function" ] = self.loss_func.__name__ self.json_dict["problematic_clustering"]["losses"] = structure_loss_cl next_idxs = self.select_prob_indices(n_find, cl_idxs_prob, structure_loss_cl) # Convert back to dataset indices. next_idxs = R_idxs[next_idxs] if write_json: self.json_dict["add_training_indices"] = next_idxs save_json( os.path.join(save_dir, "find_problematic_indices.json"), self.json_dict ) if save_cl_plot: fig = self.plot_cl_losses(cl_pop, cl_losses) fig.savefig( os.path.join(save_dir, f"cl_losses.{image_format}"), dpi=image_dpi ) return next_idxs
[docs] def plot_cl_losses(self, cl_pop, cl_losses): r"""Plot cluster losses and population histogram using matplotlib. Parameters ---------- cl_pop : :obj:`numpy.ndarray` Cluster populations (unsorted). cl_losses : :obj:`numpy.ndarray` Cluster losses (unsorted). Returns ------- ``object`` A matplotlib figure object. """ import matplotlib.pyplot as plt # pylint: disable=import-outside-toplevel cl_width = 1 cl_losses = np.array(cl_losses) cl_pop = np.array(cl_pop) loss_sort = np.argsort(cl_losses) cl_pop = cl_pop[loss_sort] cl_losses = cl_losses[loss_sort] n_cl = len(cl_pop) cl_plot_x = np.array(range(n_cl)) * cl_width fig, ax_pop = plt.subplots(nrows=1, ncols=1, **self.kwargs_subplot) ax_loss = ax_pop.twinx() ax_loss.yaxis.set_ticks_position("left") ax_loss.yaxis.set_label_position("left") ax_pop.yaxis.set_ticks_position("right") ax_pop.yaxis.set_label_position("right") # Cluster losses ax_loss.set_ylabel(self.loss_func.__name__) ax_loss.vlines( x=cl_plot_x, ymin=0, ymax=cl_losses, linewidth=0.8, color=self.plot_lolli_color, ) ax_loss.scatter(cl_plot_x, cl_losses, s=2, color=self.plot_lolli_color) # Losses mean ax_loss.axhline( np.mean(cl_losses), color=self.plot_lolli_color, alpha=1, linewidth=1.0, linestyle=":", ) ax_loss.text(0.5, np.mean(cl_losses), "Mean", fontsize=8) # population histogram (bar chart) ax_pop.set_xlabel("Cluster") ax_pop.set_ylabel("Size") edge_shift = cl_width / 2 edges = [i - edge_shift for i in cl_plot_x] + [cl_plot_x[-1] + edge_shift] ax_pop.stairs( values=cl_pop, edges=edges, fill=False, baseline=0.0, zorder=-1.0, edgecolor="lightgrey", alpha=1.0, ) # Annotate with cluster index if self.plot_annotate_cl_idx: for i, cl_idx in enumerate(loss_sort): cl_x = cl_plot_x[i] if cl_idx < 10: x_disp = -1.5 else: x_disp = -2.7 ax_loss.annotate( str(cl_idx), (cl_x, cl_losses[i]), xytext=(x_disp, 3), xycoords="data", fontsize=4, fontweight="bold", textcoords="offset points", color=self.plot_lolli_color, ) # Handle axes label ax_pop.set_xticks([]) ax_loss.set_xlim(left=edges[0], right=edges[-1]) ax_loss.set_ylim(bottom=0) ax_pop.set_ylim(bottom=0) return fig