fromsklearn.baseimportBaseEstimator,RegressorMixin,ClassifierMixinfromsklearn.utils.validationimportcheck_is_fitted,check_array,check_X_yclassSMWrapper(BaseEstimator):""" A universal sklearn-style wrapper for statsmodels regressors """def__init__(self,estimator,fit_intercept=True,**init_params):self.estimator=estimatorself.fit_intercept=fit_interceptself.init_params=init_paramsdeffit(self,X,y,**fit_params):ifself.fit_intercept:X=sm.add_constant(X)# Check that X and y have correct shapeX,y=check_X_y(X,y)self.estimator_=self.estimator(exog=X,endog=y,**self.fit_params)self.results_=self.estimator_.fit()returnselfdefpredict(self,X):# Check is fit had been calledcheck_is_fitted(self,'estimator_')# Input validationX=check_array(X)ifself.fit_intercept:X=sm.add_constant(X)returnself.results_.predict(X)defsummary(self,**summary_params):returnself.results_.summary(**summary_params)classSMRegressor(RegressorMixin,SMWrapper):def__init__(self,estimator,fit_intercept=True,**init_params):super().__init__(estimator,fit_intercept,**init_params)classSMClassifier(ClassifierMixin,SMWrapper):def__init__(self,estimator,fit_intercept=True,**init_params):super().__init__(estimator,fit_intercept,**init_params)
# Model Definitionmodel=SMRegressor(sm.OLS)# model = SMRegressor(sm.GLS, sigma = sigma)# Trainingmodel.fit(X,y)alpha=0.05sig=alpha/X.shape[1]# bonferroni-correctionprint(model.summary(alpha=sig))# Inferencemodel.predict(X)