Skip to content
Snippets Groups Projects
Commit 4a0eee00 authored by Benjamin Murauer's avatar Benjamin Murauer
Browse files

enable auto-detection of cuda

parent d8734538
No related branches found
No related tags found
No related merge requests found
......@@ -26,10 +26,13 @@ class TorchClassifier(ClassifierMixin, BaseEstimator):
batch_size: int = 64,
max_epochs: int = 5,
learn_rate: float = 1e-3,
device: str = "cuda",
device: str = None,
model_kwargs: Dict[str, Any] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.module = module
self.device = device
self.batch_size = batch_size
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment