ProblematicStructures
#
- class mbgdml.analysis.problematic.ProblematicStructures(models, predict_model, use_ray=False, n_workers=1, ray_address='auto', wkr_chunk_size=100)[source]#
Find problematic structures for models in datasets.
Clusters all structures in a dataset using agglomerative and k-means algorithms using a structural descriptor and energies.
- Parameters:
models (
list
ofmbgdml.models.Model
) – Machine learning model objects that contain all information to make predictions usingpredict_model
.predict_model (
callable
) – A function that takesZ
,R
,entity_ids
,nbody_gen
,model
and computes energies and forces. This will be turned into a ray remote function ifuse_ray = True
. This can return total properties or all individual \(n\)-body energies and forces.use_ray (
bool
, default:False
) – Use ray to parallelize computations.n_workers (
int
, default:1
) – Total number of workers available for ray. This is ignored ifuse_ray
isFalse
.ray_address (
str
, default:"auto"
) – Ray cluster address to connect to.wkr_chunk_size (
int
, default:100
) – Number of \(n\)-body structures to assign to each spawned worker with ray.
- course_n_cl_e#
Number of clusters used in the course stage for energies. After clustering structures by geometric descriptor (using
course_n_cl_r
), then each cluster is further refined by energies.There will be a total of
course_n_cl_r
\(\\times\)course_n_cl_e
clusters.- Type:
int
, default:5
- course_n_cl_r#
Number of clusters used in the course stage for geometries.
There will be a total of
course_n_cl_r
clusters.- Type:
int
, default:10
- find(dset, n_find, dset_is_train=True, train_idxs=None, write_json=True, save_cl_plot=True, image_format='png', image_dpi=600, save_dir='.')[source]#
Find problematic structures in a dataset.
Uses agglomerative and k-means clustering on a dataset. First, the dataset is split into
10
clusters based on atomic pairwise distances. Then each cluster is further split into5
clusters based on energies.Energies and forces are predicted, and then problematic structures are taken from clusters with higher-than-average losses. Here, the force MSE is used as the loss function.
Finally,
n_find
structures are sampled from the 100 clusters based on a weighted cluster error distribution.- Parameters:
dset (
mbgdml.data.DataSet
) – Dataset to cluster and analyze errors.n_find (
int
) – Number of dataset indices to find.dset_is_train (
bool
, default:True
) – Ifdset
is the training dataset. Training indices will be dropped from the analyses.train_idxs (
numpy.ndarray
, ndim:1
, default:None
) – Training indices that will be dropped ifdset_is_train
isTrue
. These do not need to be provided for GDML models (as they are already stored in the model).write_json (
bool
, default:True
) – Write JSON file detailing clustering and prediction errors.save_cl_plot (
bool
, default:True
) – Plot cluster losses and histogram.image_format (
str
, default:png
) – Format to save the image in.image_dpi (
int
, default:600
) – Dots per inch to save the image.save_dir (
str
, default:'.'
) – Directory to save any files.
- get_pd(R)[source]#
Computes pairwise distances from atomic positions.
- Parameters:
R (
numpy.ndarray
, shape:(n_samples, n_atoms, 3)
) – Atomic positions.- Returns:
Pairwise distances of atoms in each structure with shape
(n_samples, n_atoms*(n_atoms-1)/2)
.- Return type:
- kwargs_subplot#
pyplot.subplot
keyword arguments.Default:
{'figsize': (5.5, 3), 'constrained_layout': True}
- Type:
- loss_func#
Loss function used to determine problematic structures.
- Type:
callable
, default:mbgdml.losses.loss_f_mse
- loss_func_kwargs#
Any keyword arguments beyond
results
for the loss function.- Type:
dict
, default:{}
- n_cl_samples(n_sample, cl_weights, cl_pop, cl_losses)[source]#
Number of dataset indices to sample from each cluster.
- Parameters:
n_sample (
int
) – Total number of dataset indices to sample from all clusters.cl_weights (
numpy.ndarray
) – Normalized cluster weights.cl_pop (
numpy.ndarray
) – Cluster populations.
- Returns:
Number of dataset indices to sample from each cluster.
- Return type:
- plot_annotate_cl_idx#
Add the cluster index above the cluster loss value.
- Type:
bool
, default:False
- plot_cl_losses(cl_pop, cl_losses)[source]#
Plot cluster losses and population histogram using matplotlib.
- Parameters:
cl_pop (
numpy.ndarray
) – Cluster populations (unsorted).cl_losses (
numpy.ndarray
) – Cluster losses (unsorted).
- Returns:
A matplotlib figure object.
- Return type:
object
- prob_cl_indices(cl_idxs, cl_losses)[source]#
Identify problematic dataset indices.
- Parameters:
cl_idxs (
list
ofnumpy.ndarray
) – Clustered dataset indices.cl_losses (
numpy.ndarray
) – Losses for each cluster incl_idxs
.
- Returns:
Dataset indices from clusters with higher-than-average losses.
- Return type:
- refine_min_r_ratio#
Minimum ratio of structures to number of clusters in the refine stage. Will reduce the minimum loss set point for refinement until
refine_n_cl
\(\\times\)refine_min_r_ratio
structures are available.- Type:
int
, default:2.0
- select_prob_indices(n_select, cl_idxs, idx_loss_cl)[source]#
Select
n
problematic dataset indices based on weighted cluster losses and distribution.- Parameters:
n_select (
int
) – Number of problematic dataset indices to select.cl_idxs (
list
ofnumpy.ndarray
) – Clustered dataset indices.idx_loss_cl (
list
ofnumpy.ndarray
) – Clustered individual structure losses. Same shape ascl_idxs
.
- Returns:
:obj:`numpy.ndarray``, shape – Problematic dataset indices.
- Return type:
(n_select,)