Skip to content
Snippets Groups Projects
model_implementation.py 760 B
Newer Older
Jan Klimaschewski's avatar
Jan Klimaschewski committed
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
import torch

class DummyRegressor(BaseEstimator):

    def __init__(self,
                 sample_count = 20):
        self.sample_count = sample_count
        
    def fit(self, X, y):        
        X, y = check_X_y(X, y)

        X = torch.tensor(X, dtype = torch.float)
        y = torch.tensor(y, dtype = torch.float)

        self.estimate_ = y[:self.sample_count].mean()
            
        print('Finished Training')
        return self

    def predict(self, X):
        check_is_fitted(self)
        X = check_array(X)
        with torch.no_grad():
            y_pred = self.estimate_ * torch.ones(X.shape[0])
        return y_pred