Newer
Older
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
import torch
class DummyRegressor(BaseEstimator):
def __init__(self,
sample_count = 20):
self.sample_count = sample_count
def fit(self, X, y):
X, y = check_X_y(X, y)
X = torch.tensor(X, dtype = torch.float)
y = torch.tensor(y, dtype = torch.float)
self.estimate_ = y[:self.sample_count].mean()
print('Finished Training')
return self
def predict(self, X):
check_is_fitted(self)
X = check_array(X)
with torch.no_grad():
y_pred = self.estimate_ * torch.ones(X.shape[0])
return y_pred