Skip to content
Snippets Groups Projects
Commit cec20282 authored by Benjamin Murauer's avatar Benjamin Murauer
Browse files

added counter contributor, removed cli (never worked)

parent 978fa7bd
No related branches found
No related tags found
No related merge requests found
...@@ -20,11 +20,6 @@ from tqdm import tqdm # type: ignore ...@@ -20,11 +20,6 @@ from tqdm import tqdm # type: ignore
from transformers import MarianMTModel, MarianTokenizer # type: ignore from transformers import MarianMTModel, MarianTokenizer # type: ignore
import stanza # type: ignore import stanza # type: ignore
from tuhlbox.contributors import (
StanzaContributor,
LanguageDetectionContributor,
LinguisticStatsContributor,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -408,20 +403,3 @@ def translate( ...@@ -408,20 +403,3 @@ def translate(
df2 = pd.DataFrame.from_records(new_rows) df2 = pd.DataFrame.from_records(new_rows)
df3 = pd.concat([df, df2]) df3 = pd.concat([df, df2])
df3.to_csv(os.path.join(input_dir, "dataset.csv"), index=False) df3.to_csv(os.path.join(input_dir, "dataset.csv"), index=False)
@click.command(help="run a contributor on a dataset")
@click.argument("csv")
@click.argument("contributor")
def run_contributor(csv, contributor):
available_contributors = dict(
stanza=StanzaContributor,
detect_language=LanguageDetectionContributor,
linguistic_stats=LinguisticStatsContributor,
)
if contributor not in available_contributors.keys():
raise Error(
f'contributor "{contributor}" not found.'
f"available contributors: {available_contributors.keys()}"
)
available_contributors[contributor].contribute(csv)
import json import json
import os import os
import pickle
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import Counter, defaultdict
from typing import List, Any, Optional, Dict from typing import Any, Dict, List
import langdetect import langdetect
import pandas as pd import pandas as pd
import stanza from sklearn.pipeline import make_pipeline
import pickle
from tqdm import tqdm from tqdm import tqdm
from treegrams.transformers import TreeGramExtractor
import stanza
from tuhlbox import logger from tuhlbox import logger
from tuhlbox.stanza import StanzaToNltkTreesTransformer
class Contributor(ABC): class Contributor(ABC):
...@@ -54,7 +57,7 @@ class Contributor(ABC): ...@@ -54,7 +57,7 @@ class Contributor(ABC):
return text_fh.read() return text_fh.read()
if mode == "pickle": if mode == "pickle":
with open(full_path, "rb") as pickle_fh: with open(full_path, "rb") as pickle_fh:
pickle.load(pickle_fh) return pickle.load(pickle_fh)
def write_subdir_file(self, path: str, content: Any, mode: str = "text") -> None: def write_subdir_file(self, path: str, content: Any, mode: str = "text") -> None:
full_path = os.path.join(self.base_dir, path) full_path = os.path.join(self.base_dir, path)
...@@ -166,6 +169,40 @@ class LanguageDetectionContributor(RowWiseContributor): ...@@ -166,6 +169,40 @@ class LanguageDetectionContributor(RowWiseContributor):
return row return row
class CounterContributor(RowWiseContributor):
def __init__(
self,
column_name: str = "dtgram_counter",
directory_name: str = "dtgram_counters",
**kwargs: Any,
) -> None:
super().__init__(column_name=column_name, required_columns=["stanza"], **kwargs)
self.column_name = column_name
self.directory_name = directory_name
self.dtgram_pipeline = make_pipeline(
StanzaToNltkTreesTransformer(),
TreeGramExtractor(),
)
def calculate_row(self, row: pd.Series) -> pd.Series:
stanza_document = self.read_subdir_file(row["stanza"], mode="pickle")
counter_dir = os.path.join(self.base_dir, self.directory_name)
if not os.path.isdir(counter_dir):
os.makedirs(counter_dir)
out_filename = os.path.splitext(os.path.basename(row["stanza"]))[0] + ".pckl"
out_filepath = os.path.join(self.directory_name, out_filename)
full_out_filepath = os.path.join(self.base_dir, out_filepath)
if not os.path.isfile(full_out_filepath) or self.overwrite:
dt_grams = self.dtgram_pipeline.transform([stanza_document])[0]
counter: Counter = Counter(dt_grams)
self.write_subdir_file(out_filepath, counter, mode="pickle")
row[self.column_name] = out_filepath
return row
class ConstantContributor(Contributor): class ConstantContributor(Contributor):
def __init__(self, column_name: str, value: Any) -> None: def __init__(self, column_name: str, value: Any) -> None:
super().__init__(column_name=column_name, required_columns=[]) super().__init__(column_name=column_name, required_columns=[])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment