Source code for finol.evaluation_layer.distiller_selector

from sklearn.linear_model import *
from finol.utils import load_config

distiller_dict = {
    ###############################
    # Classical linear regressors #
    ###############################
    "LinearRegression": LinearRegression,  # Ordinary least squares Linear Regression
    "Ridge": Ridge,  # Linear least squares with l2 regularization
    "RidgeCV": RidgeCV,  # Ridge regression with built-in cross-validation
    "SGDRegressor": SGDRegressor,  # Linear model fitted by minimizing a regularized empirical loss with SGD

    ######################################
    # Regressors with variable selection #
    ######################################
    "ElasticNet": ElasticNet,  # Linear regression with combined L1 and L2 priors as regularizer
    "ElasticNetCV": ElasticNetCV,  # Elastic Net model with iterative fitting along a regularization path
    "Lars": Lars,  # Least Angle Regression model a.k.a
    "LarsCV": LarsCV,  # Cross-validated Least Angle Regression model
    "Lasso": Lasso,  # Linear Model trained with L1 prior as regularizer (aka the Lasso)
    "LassoCV": LassoCV,  # Lasso linear model with iterative fitting along a regularization path
    "LassoLars": LassoLars,  # Lasso model fit with Least Angle Regression a.k.a
    "LassoLarsCV": LassoLarsCV,  # Cross-validated Lasso, using the LARS algorithm
    "LassoLarsIC": LassoLarsIC,  # Lasso model fit with Lars using BIC or AIC for model selection
    "OrthogonalMatchingPursuit": OrthogonalMatchingPursuit,  # Orthogonal Matching Pursuit model (OMP)
    "OrthogonalMatchingPursuitCV": OrthogonalMatchingPursuitCV,  # Cross-validated Orthogonal Matching Pursuit model (OMP)

    #######################
    # Bayesian regressors #
    #######################
    "ARDRegression": ARDRegression,  # Bayesian ARD regression
    "BayesianRidge": BayesianRidge,  # Bayesian ridge regression

    #############################
    # Outlier-robust regressors #
    #############################
    "HuberRegressor": HuberRegressor,  # L2-regularized linear regression model that is robust to outliers
    "QuantileRegressor": QuantileRegressor,  # Linear regression model that predicts conditional quantiles
    "RANSACRegressor": RANSACRegressor,  # RANSAC (RANdom SAmple Consensus) algorithm
    "TheilSenRegressor": TheilSenRegressor,  # Theil-Sen Estimator: robust multivariate regression model

    ##################################################
    # Generalized linear models (GLM) for regression #
    ##################################################
    "PoissonRegressor": PoissonRegressor,  # Generalized Linear Model with a Poisson distribution
    "TweedieRegressor": TweedieRegressor,  # Generalized Linear Model with a Tweedie distribution
    "GammaRegressor": GammaRegressor,  # Generalized Linear Model with a Gamma distribution

    #################
    # Miscellaneous #
    #################
    "PassiveAggressiveRegressor": PassiveAggressiveRegressor,  # Passive Aggressive Regressor
}


[docs]class DistillerSelector: """ Class to select distiller for model economic distillation. """ def __init__(self) -> None: self.config = load_config()
[docs] def select_distiller(self): """ Select a distiller based on the configuration provided in the class. """ distiller_cls = distiller_dict.get(self.config["INTERPRETABLE_ANALYSIS_CONFIG"]["DISTILLER_NAME"], None) if distiller_cls is None: raise ValueError(f"Invalid distiller: {self.config['INTERPRETABLE_ANALYSIS_CONFIG']['DISTILLER_NAME']}. Supported distillers are: {distiller_dict}") return distiller_cls()