Skip to content
Snippets Groups Projects
Commit cd9a076c authored by Benjamin Murauer's avatar Benjamin Murauer
Browse files

switched to CPU for torch tests

parent a5276629
No related branches found
No related tags found
No related merge requests found
Pipeline #52290 passed
......@@ -2,16 +2,11 @@ image: python:3.7
tests:
before_script:
- date
- curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python
- source $HOME/.poetry/env
- poetry config virtualenvs.create false
- date
- poetry install
- date
- python -m nltk.downloader punkt
- date
- python -c 'import stanza; stanza.download("en")'
- date
script:
- pytest tests
......@@ -10,7 +10,9 @@ from tuhlbox.torch_classifier import TorchClassifier
from tuhlbox.torch_cnn import CharCNN
from tuhlbox.torch_lstm import RNNClassifier
x, y = fetch_20newsgroups(return_X_y=True)
x, y = fetch_20newsgroups(
return_X_y=True, categories=["alt.atheism", "talk.religion.misc"]
)
x_train, x_test, y_train, y_test = train_test_split(x, y)
......@@ -25,7 +27,7 @@ def test_cnn() -> None:
Padder2d(pad_value=VOCAB_SIZE, max_len=MAX_SEQ_LEN, dtype=int),
TorchClassifier(
module=CharCNN,
device="cuda",
device="cpu", # the gitlab CI does not have cuda
batch_size=54,
max_epochs=5,
learn_rate=0.01,
......@@ -39,8 +41,7 @@ def test_cnn() -> None:
)
pipe.fit(x_train, y_train)
predictions = pipe.predict(x_test)
print(accuracy_score(predictions, y_test))
pipe.predict(x_test)
def test_lstm() -> None:
......@@ -49,8 +50,8 @@ def test_lstm() -> None:
Padder2d(pad_value=VOCAB_SIZE, max_len=MAX_SEQ_LEN, dtype=int),
TorchClassifier(
module=RNNClassifier,
device="cuda",
batch_size=54,
device="cpu", # the gitlab CI does not have cuda
batch_size=4,
max_epochs=5,
learn_rate=0.01,
optimizer=torch.optim.Adam,
......@@ -58,5 +59,4 @@ def test_lstm() -> None:
)
pipe.fit(x_train, y_train)
predictions = pipe.predict(x_test)
print(accuracy_score(predictions, y_test))
pipe.predict(x_test)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment