Source code for oximachinerunner

# -*- coding: utf-8 -*-
# pylint:disable=wrong-import-position
"""Implements methods to use oximachine as part of a Python package"""
import os
import sys
import warnings

warnings.filterwarnings("ignore", category=DeprecationWarning)
from collections import OrderedDict
from typing import Dict, List, Tuple, Union

import joblib
import numpy as np
from ase import Atoms
from oximachine_featurizer import featurize
from oximachine_featurizer.featurize import get_feature_names
from pymatgen import Structure
from import AseAtomsAdaptor

import oximachinerunner.learnmofox as learnmofox

from ._version import get_versions
from .errors import FeaturizationError, ParsingError, PredictionError
from .utils import download_model, has_metal_sites, model_exists

__version__ = get_versions()["version"]
del get_versions
sys.modules["learnmofox"] = learnmofox

THIS_DIR = os.path.dirname(os.path.realpath(__file__))

__all__ = ["OximachineRunner"]

        ("metal_indices", []),
        ("metal_symbols", []),
        ("prediction", []),
        ("max_probas", []),
        ("base_predictions", []),

def _load_file(
    path, md5: str, url: str, automatic_download: bool = True
) -> object:  # pylint:disable=inconsistent-return-statements

        path (str, PathLike): Path of the model file
        md5 (str): md5 hash of the model file
        url (str): url to download the file from
        automatic_download (bool): If true, it automatically downloads
            the file if it is not available

        FileNotFoundError: If automatic_download is not enabled and the file
            is not on the disk

        model, typically a sklearn estimator object
    if model_exists(path, md5):  # pylint:disable=no-else-return

        model = joblib.load(path)
        return model
        if not automatic_download:
            raise FileNotFoundError(
                "The model does not exist and you didn't allow automatic download.\
                Probably you did not download it yet. You can either enable automatic downloads\
                (automatic_download=True) or use the download functions from the utils module\
                to download the files"
        download_model(url, path, md5)
        return _load_file(path, md5, url, automatic_download)

def load_model(modelname: str, automatic_download: bool = True) -> Tuple:
    """Orchestrates the loading of the model and the scaler

        modelname (str): name of the model
        automatic_download (bool): if true,
            it will attempt to automatically download the model

        ValueError: if the modelname is not defined in the configuration

        model, scaler, featurenames
    with warnings.catch_warnings():
        # Check if one default model was selected
        if modelname == "all":
            modelname = MODEL_DEFAULT_MAPPING["all"]
        if modelname == "mof":
            modelname = MODEL_DEFAULT_MAPPING["mof"]

        if modelname not in MODEL_CONFIG.keys():
            raise ValueError(
                "A model with name {} does not exist in the configuration.".format(

        modelpath = MODEL_CONFIG[modelname]["classifier"]["path"]
        modelmd5 = MODEL_CONFIG[modelname]["classifier"]["md5"]
        modelurl = MODEL_CONFIG[modelname]["classifier"]["url"]

        scalerpath = MODEL_CONFIG[modelname]["scaler"]["path"]
        scalermd5 = MODEL_CONFIG[modelname]["scaler"]["md5"]
        scalerurl = MODEL_CONFIG[modelname]["scaler"]["url"]

        model = _load_file(modelpath, modelmd5, modelurl, automatic_download)
        scaler = _load_file(scalerpath, scalermd5, scalerurl, automatic_download)
        featureset = MODEL_CONFIG[modelname]["features"]

    return model, scaler, featureset

[docs]class OximachineRunner: """Loads a model and then runs the prediction"""
[docs] def __init__(self, modelname: str = "all", automatic_download: bool = True): """ Args: modelname (str, optional): [description]. Defaults to 'all'. Use it to specifiy a model. You can view all available models with the .available_models property automatic_download (bool, optional): [description]. Defaults to True. """ self.modelname = modelname self._automatic_download = automatic_download self.md5 = MODEL_CONFIG[MODEL_DEFAULT_MAPPING[modelname]]["classifier"]["md5"] self._model_dict = {}
@property def model(self): """Return the model object with `.predict` method""" return self._get("model") @property def scaler(self): """Return the scaler object with `.transform` method""" return self._get("scaler") @property def featureset(self) -> List[str]: """Return the list of feature names""" return self._get("featureset") def _get(self, key: str): try: return self._model_dict[key] except KeyError: self._load_model() return self._get(key) def _load_model(self): ( self._model_dict["model"], self._model_dict["scaler"], self._model_dict["featureset"], ) = load_model(self.modelname, self._automatic_download)
[docs] def load_model(self): """Load the model and populate the namespace with the model objects.""" self._load_model()
@property def available_models(self) -> List[str]: """List all the available models.""" return sorted(list(MODEL_CONFIG.keys()) + list(MODEL_DEFAULT_MAPPING.keys())) @property def feature_names(self) -> List[str]: """Get a list of feature names""" return get_feature_names(self.featureset) @property def default_mapping(self) -> Dict[str, str]: """Return the default mapping between model name and filename""" return MODEL_DEFAULT_MAPPING
[docs] def __repr__(self): return "OximachineRunner (version: {}) with model {} (md5: {})".format( __version__, self.modelname, self.md5 )
def _make_predictions( # pylint:disable=invalid-name self, feature_matrix: np.ndarray ) -> Tuple[list, list, list]: """Makes predictions for a set of metal sites. Applies the scaler to the feature matrix. Args: feature_matrix (np.ndarray): feature matrix (two dimensional, metal sites in rows and features in columns) Raises: PredictionError: For possible exceptions when running the model. Returns: Tuple[list, list, list]: predictions (this is the vote of the four base estimators), maximum probabilities of the base estimators, the prediction of each base estimator """ try: feature_matrix_scaled = self.scaler.transform(feature_matrix) prediction = self.model.predict(feature_matrix_scaled) max_probas = np.max(self.model.predict_proba(feature_matrix_scaled), axis=1) _base_predictions = self.model._predict( # pylint:disable=protected-access feature_matrix_scaled ) base_predictions = [] for pred in _base_predictions: base_predictions.append( [self.model.classes[prediction_index] for prediction_index in pred] ) return list(prediction), list(max_probas), list(base_predictions) except Exception as exception: raise PredictionError( "Could not make predictions for structure." ) from exception def _featurize_single(self, structure: Structure) -> Tuple[np.array, list, list]: """Finds metals in the structure, featurizes the metal sites and collects the features Args: structure (pymatgen.Structure): Structure to featurize Raises: FeaturizationError: Raised for all kinds of errors that might occurr when featurizing the structure Returns: Tuple[np.array, list, list]: Feature array, metal indices, metal symbols """ try: feature_matrix, metal_indices, metals = featurize( structure, self.featureset ) return feature_matrix, metal_indices, metals except Exception as exception: raise FeaturizationError("Could not featurize structure.") from exception
[docs] def run_oximachine( self, structure: Union[str, os.PathLike, Structure, Atoms] ) -> OrderedDict: """Runs oximachine after attempting to guess what structure is Args: structure (Union[str, os.PathLike, Structure, Atoms]): can be a `pymatgen.Structure`, `ase.Atoms` or a filepath as `str` or `os.PathLike`, which we then attempt to parse with pymatgen. Raises: ParsingError: In case the format of structure is not implemented or in case we cannot convert the input into a pymatgen Structure object. FeaturizationError: In case the featurization fails. PredictionError: In case the prediction fails. Returns: OrderedDict: with the keys metal_indices, metal_symbols, prediction, max_probas, base_predictions """ if isinstance(structure, Structure): # pylint:disable=no-else-return return self._run_oximachine(structure) elif isinstance(structure, Atoms): try: pymatgen_structure = AseAtomsAdaptor.get_structure(structure) except Exception as exception: raise ParsingError( "Could not convert structure into a pymatgen Structure object." ) from exception return self._run_oximachine(pymatgen_structure) elif isinstance(structure, str): try: pymatgen_structure = Structure.from_file(structure) except Exception as exception: raise ParsingError( "Could not convert structure into a pymatgen Structure object." ) from exception return self._run_oximachine(pymatgen_structure) elif isinstance(structure, os.PathLike): try: pymatgen_structure = Structure.from_file(structure) except Exception as exception: raise ParsingError( "Could not convert structure into a pymatgen Structure object." ) from exception return self._run_oximachine(pymatgen_structure) else: raise ParsingError( "Could not recognize structure! I can read Pymatgen structure objects,\ ASE atom objects and a filepath in a file format that can be read by ase" )
def _run_oximachine(self, structure: Structure) -> OrderedDict: """Run the oximachine on one structure Args: structure (Structure): pymatgen Structure object Returns: OrderedDict: with the keys metal_indices, metal_symbols, prediction, max_probas, base_predictions """ if not has_metal_sites(structure): warnings.warn( "Oximachine can only predict oxidation states of metals. \ This structure contains no metals." ) return EMPTY_PREDICTION with warnings.catch_warnings(): warnings.simplefilter("ignore") ( feature_matrix, metal_indices, metal_symbols, ) = self._featurize_single( # pylint:disable=protected-access structure ) ( prediction, max_probas, base_predictions, ) = self._make_predictions( # pylint:disable=protected-access feature_matrix ) return OrderedDict( [ ("metal_indices", metal_indices), ("metal_symbols", metal_symbols), ("prediction", prediction), ("max_probas", max_probas), ("base_predictions", base_predictions), ] )