Skip to content
Snippets Groups Projects
Commit cec20282 authored by Benjamin Murauer's avatar Benjamin Murauer
Browse files

added counter contributor, removed cli (never worked)

parent 978fa7bd
No related branches found
No related tags found
No related merge requests found
......@@ -20,11 +20,6 @@ from tqdm import tqdm # type: ignore
from transformers import MarianMTModel, MarianTokenizer # type: ignore
import stanza # type: ignore
from tuhlbox.contributors import (
StanzaContributor,
LanguageDetectionContributor,
LinguisticStatsContributor,
)
logger = logging.getLogger(__name__)
......@@ -408,20 +403,3 @@ def translate(
df2 = pd.DataFrame.from_records(new_rows)
df3 = pd.concat([df, df2])
df3.to_csv(os.path.join(input_dir, "dataset.csv"), index=False)
@click.command(help="run a contributor on a dataset")
@click.argument("csv")
@click.argument("contributor")
def run_contributor(csv, contributor):
available_contributors = dict(
stanza=StanzaContributor,
detect_language=LanguageDetectionContributor,
linguistic_stats=LinguisticStatsContributor,
)
if contributor not in available_contributors.keys():
raise Error(
f'contributor "{contributor}" not found.'
f"available contributors: {available_contributors.keys()}"
)
available_contributors[contributor].contribute(csv)
import json
import os
import pickle
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import List, Any, Optional, Dict
from collections import Counter, defaultdict
from typing import Any, Dict, List
import langdetect
import pandas as pd
import stanza
import pickle
from sklearn.pipeline import make_pipeline
from tqdm import tqdm
from treegrams.transformers import TreeGramExtractor
import stanza
from tuhlbox import logger
from tuhlbox.stanza import StanzaToNltkTreesTransformer
class Contributor(ABC):
......@@ -54,7 +57,7 @@ class Contributor(ABC):
return text_fh.read()
if mode == "pickle":
with open(full_path, "rb") as pickle_fh:
pickle.load(pickle_fh)
return pickle.load(pickle_fh)
def write_subdir_file(self, path: str, content: Any, mode: str = "text") -> None:
full_path = os.path.join(self.base_dir, path)
......@@ -166,6 +169,40 @@ class LanguageDetectionContributor(RowWiseContributor):
return row
class CounterContributor(RowWiseContributor):
def __init__(
self,
column_name: str = "dtgram_counter",
directory_name: str = "dtgram_counters",
**kwargs: Any,
) -> None:
super().__init__(column_name=column_name, required_columns=["stanza"], **kwargs)
self.column_name = column_name
self.directory_name = directory_name
self.dtgram_pipeline = make_pipeline(
StanzaToNltkTreesTransformer(),
TreeGramExtractor(),
)
def calculate_row(self, row: pd.Series) -> pd.Series:
stanza_document = self.read_subdir_file(row["stanza"], mode="pickle")
counter_dir = os.path.join(self.base_dir, self.directory_name)
if not os.path.isdir(counter_dir):
os.makedirs(counter_dir)
out_filename = os.path.splitext(os.path.basename(row["stanza"]))[0] + ".pckl"
out_filepath = os.path.join(self.directory_name, out_filename)
full_out_filepath = os.path.join(self.base_dir, out_filepath)
if not os.path.isfile(full_out_filepath) or self.overwrite:
dt_grams = self.dtgram_pipeline.transform([stanza_document])[0]
counter: Counter = Counter(dt_grams)
self.write_subdir_file(out_filepath, counter, mode="pickle")
row[self.column_name] = out_filepath
return row
class ConstantContributor(Contributor):
def __init__(self, column_name: str, value: Any) -> None:
super().__init__(column_name=column_name, required_columns=[])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment