ProblematicStructures#
- class mbgdml.analysis.problematic.ProblematicStructures(models, predict_model, use_ray=False, n_workers=1, ray_address='auto', wkr_chunk_size=100)[source]#
Find problematic structures for models in datasets.
Clusters all structures in a dataset using agglomerative and k-means algorithms using a structural descriptor and energies.
- Parameters:
models (
listofmbgdml.models.Model) – Machine learning model objects that contain all information to make predictions usingpredict_model.predict_model (
callable) – A function that takesZ,R,entity_ids,nbody_gen,modeland computes energies and forces. This will be turned into a ray remote function ifuse_ray = True. This can return total properties or all individual \(n\)-body energies and forces.use_ray (
bool, default:False) – Use ray to parallelize computations.n_workers (
int, default:1) – Total number of workers available for ray. This is ignored ifuse_rayisFalse.ray_address (
str, default:"auto") – Ray cluster address to connect to.wkr_chunk_size (
int, default:100) – Number of \(n\)-body structures to assign to each spawned worker with ray.
- course_n_cl_e#
Number of clusters used in the course stage for energies. After clustering structures by geometric descriptor (using
course_n_cl_r), then each cluster is further refined by energies.There will be a total of
course_n_cl_r\(\\times\)course_n_cl_eclusters.- Type:
int, default:5
- course_n_cl_r#
Number of clusters used in the course stage for geometries.
There will be a total of
course_n_cl_rclusters.- Type:
int, default:10
- find(dset, n_find, dset_is_train=True, train_idxs=None, write_json=True, save_cl_plot=True, image_format='png', image_dpi=600, save_dir='.')[source]#
Find problematic structures in a dataset.
Uses agglomerative and k-means clustering on a dataset. First, the dataset is split into
10clusters based on atomic pairwise distances. Then each cluster is further split into5clusters based on energies.Energies and forces are predicted, and then problematic structures are taken from clusters with higher-than-average losses. Here, the force MSE is used as the loss function.
Finally,
n_findstructures are sampled from the 100 clusters based on a weighted cluster error distribution.- Parameters:
dset (
mbgdml.data.DataSet) – Dataset to cluster and analyze errors.n_find (
int) – Number of dataset indices to find.dset_is_train (
bool, default:True) – Ifdsetis the training dataset. Training indices will be dropped from the analyses.train_idxs (
numpy.ndarray, ndim:1, default:None) – Training indices that will be dropped ifdset_is_trainisTrue. These do not need to be provided for GDML models (as they are already stored in the model).write_json (
bool, default:True) – Write JSON file detailing clustering and prediction errors.save_cl_plot (
bool, default:True) – Plot cluster losses and histogram.image_format (
str, default:png) – Format to save the image in.image_dpi (
int, default:600) – Dots per inch to save the image.save_dir (
str, default:'.') – Directory to save any files.
- get_pd(R)[source]#
Computes pairwise distances from atomic positions.
- Parameters:
R (
numpy.ndarray, shape:(n_samples, n_atoms, 3)) – Atomic positions.- Returns:
Pairwise distances of atoms in each structure with shape
(n_samples, n_atoms*(n_atoms-1)/2).- Return type:
- kwargs_subplot#
pyplot.subplotkeyword arguments.Default:
{'figsize': (5.5, 3), 'constrained_layout': True}
- Type:
- loss_func#
Loss function used to determine problematic structures.
- Type:
callable, default:mbgdml.losses.loss_f_mse
- loss_func_kwargs#
Any keyword arguments beyond
resultsfor the loss function.- Type:
dict, default:{}
- n_cl_samples(n_sample, cl_weights, cl_pop, cl_losses)[source]#
Number of dataset indices to sample from each cluster.
- Parameters:
n_sample (
int) – Total number of dataset indices to sample from all clusters.cl_weights (
numpy.ndarray) – Normalized cluster weights.cl_pop (
numpy.ndarray) – Cluster populations.
- Returns:
Number of dataset indices to sample from each cluster.
- Return type:
- plot_annotate_cl_idx#
Add the cluster index above the cluster loss value.
- Type:
bool, default:False
- plot_cl_losses(cl_pop, cl_losses)[source]#
Plot cluster losses and population histogram using matplotlib.
- Parameters:
cl_pop (
numpy.ndarray) – Cluster populations (unsorted).cl_losses (
numpy.ndarray) – Cluster losses (unsorted).
- Returns:
A matplotlib figure object.
- Return type:
object
- prob_cl_indices(cl_idxs, cl_losses)[source]#
Identify problematic dataset indices.
- Parameters:
cl_idxs (
listofnumpy.ndarray) – Clustered dataset indices.cl_losses (
numpy.ndarray) – Losses for each cluster incl_idxs.
- Returns:
Dataset indices from clusters with higher-than-average losses.
- Return type:
- refine_min_r_ratio#
Minimum ratio of structures to number of clusters in the refine stage. Will reduce the minimum loss set point for refinement until
refine_n_cl\(\\times\)refine_min_r_ratiostructures are available.- Type:
int, default:2.0
- select_prob_indices(n_select, cl_idxs, idx_loss_cl)[source]#
Select
nproblematic dataset indices based on weighted cluster losses and distribution.- Parameters:
n_select (
int) – Number of problematic dataset indices to select.cl_idxs (
listofnumpy.ndarray) – Clustered dataset indices.idx_loss_cl (
listofnumpy.ndarray) – Clustered individual structure losses. Same shape ascl_idxs.
- Returns:
:obj:`numpy.ndarray``, shape – Problematic dataset indices.
- Return type:
(n_select,)