Skip to content
Snippets Groups Projects
Commit 8e75297c authored by Benjamin Murauer's avatar Benjamin Murauer
Browse files

sped up string kernel implementation

parent 7dfdec3d
No related branches found
No related tags found
No related merge requests found
Pipeline #51128 failed
from typing import Callable
from sklearn.datasets import fetch_20newsgroups
from time import perf_counter
from tuhlbox.stringkernels import intersection_kernel, spectrum_kernel, presence_kernel, \
legacy_presence_kernel, legacy_spectrum_kernel, legacy_intersection_kernel
data = fetch_20newsgroups()['data'][:100]
def benchmark(kernel_method: Callable) -> float:
start = perf_counter()
kernel_method(1, 4)(data, data)
return perf_counter() - start
kernels = [
intersection_kernel,
legacy_intersection_kernel,
presence_kernel,
legacy_presence_kernel,
spectrum_kernel,
legacy_spectrum_kernel,
]
for kernel in kernels:
print(f'{kernel.__name__}: {benchmark(kernel)}')
......@@ -2,7 +2,8 @@
import numpy as np
from numpy.testing import assert_array_equal
from tuhlbox.stringkernels import (intersection_kernel, presence_kernel,
spectrum_kernel)
spectrum_kernel, legacy_intersection_kernel,
legacy_spectrum_kernel, legacy_presence_kernel)
docs = [
"I like this old movie. The movie is very nice.",
......@@ -26,6 +27,18 @@ def test_intersection_kernel() -> None:
assert_array_equal(expected, intersection_kernel(ngram_min, ngram_max)(docs, docs))
def test_legacy_intersection_kernel() -> None:
"""Test intersection kernel by comparing with original code."""
# obtained from:
# java ComputeStringKernel intersection 1 4 sentences.txt <outfile>
expected = np.array(
[[178, 95, 66, 49], [95, 254, 72, 72], [66, 72, 278, 112], [49, 72, 112, 334]],
dtype=int,
)
actual = legacy_intersection_kernel(ngram_min, ngram_max)(docs, docs)
assert_array_equal(expected, actual)
def test_presence_kernel() -> None:
"""Test presence kernel by comparing with original code."""
# obtained from:
......@@ -38,6 +51,19 @@ def test_presence_kernel() -> None:
assert_array_equal(expected, presence_kernel(ngram_min, ngram_max)(docs, docs))
def test_legacy_presence_kernel() -> None:
"""Test presence kernel by comparing with original code."""
# obtained from:
# java ComputeStringKernel presence 1 4 sentences.txt <outfile>
expected = np.array(
[[128, 67, 42, 29], [67, 197, 38, 42], [42, 38, 209, 64], [29, 42, 64, 235]],
dtype=int,
)
assert_array_equal(expected, legacy_presence_kernel(ngram_min, ngram_max)(docs, docs))
def test_spectrum_kernel() -> None:
"""Test spectrum kernel by comparing with original code."""
# obtained from:
......@@ -52,3 +78,19 @@ def test_spectrum_kernel() -> None:
dtype=int,
)
assert_array_equal(expected, spectrum_kernel(ngram_min, ngram_max)(docs, docs))
def test_legacy_spectrum_kernel() -> None:
"""Test spectrum kernel by comparing with original code."""
# obtained from:
# java ComputeStringKernel spectrum 1 4 sentences.txt <outfile>
expected = np.array(
[
[390, 335, 300, 313],
[335, 598, 393, 458],
[300, 393, 680, 585],
[313, 458, 585, 1006],
],
dtype=int,
)
assert_array_equal(expected, legacy_spectrum_kernel(ngram_min, ngram_max)(docs, docs))
......@@ -3,7 +3,7 @@ from __future__ import annotations
import logging
from collections import defaultdict
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Tuple, Generator
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
......@@ -11,96 +11,135 @@ from sklearn.base import BaseEstimator, TransformerMixin
logger = logging.getLogger(__name__)
def iterate_ngrams(
string: str, ngram_min: int, ngram_max: int
) -> Generator[str, None, None]:
for ngram_size in range(ngram_min, ngram_max + 1):
for index in range(len(string) - ngram_size + 1):
yield string[index : index + ngram_size].lower()
def get_all_ngram_counts(string: str, ngram_min: int, ngram_max: int) -> Dict[str, int]:
result: Dict[str, int] = defaultdict(int)
for ngram in iterate_ngrams(string, ngram_min, ngram_max):
result[ngram] += 1
return result
def presence_kernel(ngram_min: int, ngram_max: int) -> Callable:
"""Calculate the presence kernel, Ionescu & Popescu 2017."""
"""
Calculate the presence kernel, Ionescu & Popescu 2017.
the result is a matrix, where each point [x, y] is the number of different n-grams
that both documents x and y share.
"""
def internal_presence_kernel(x: np.array, y: np.array) -> np.array:
result = np.zeros((len(x), len(y)), dtype=int)
x_counts = [get_all_ngram_counts(d, ngram_min, ngram_max) for d in x]
y_counts = [get_all_ngram_counts(d, ngram_min, ngram_max) for d in y]
for i, xc in enumerate(x_counts):
xkeys = set(xc.keys())
for j, yc in enumerate(y_counts):
ykeys = set(yc.keys())
result[i, j] = len(xkeys.intersection(ykeys))
return result
return internal_presence_kernel
def legacy_presence_kernel(ngram_min: int, ngram_max: int) -> Callable:
"""Old implementation. Slower."""
def internal_legacy_presence_kernel(x: np.array, y: np.array) -> np.array:
result = np.zeros((len(x), len(y)), dtype=int)
for i, string in enumerate(x):
if type(string) == str:
string = string.lower()
ngrams1 = set()
for ngram_size in range(ngram_min, ngram_max + 1):
for index in range(len(string) - ngram_size + 1):
ngram = string[index : index + ngram_size]
if type(ngram) != str:
ngram = str(ngram)
ngrams1.add(ngram)
for ngram in iterate_ngrams(string, ngram_min, ngram_max):
ngrams1.add(ngram)
for j, counterstring in enumerate(y):
if type(counterstring) == str:
counterstring = counterstring.lower()
ngrams2 = set()
for ngram_size in range(ngram_min, ngram_max + 1):
for index in range(len(counterstring) - ngram_size + 1):
ngram = counterstring[index : index + ngram_size]
if type(ngram) != str:
ngram = str(ngram)
ngrams2.add(ngram)
for ngram in iterate_ngrams(counterstring, ngram_min, ngram_max):
ngrams2.add(ngram)
result[i, j] = len(ngrams1.intersection(ngrams2))
return result
return internal_presence_kernel
return internal_legacy_presence_kernel
def spectrum_kernel(ngram_min: int, ngram_max: int) -> Callable:
"""Calculate the spectrum kernel, Ionescu & Popescu 2017."""
def internal_spectral_kernel(x: np.array, y: np.array) -> np.array:
def internal_spectrum_kernel(x: np.array, y: np.array) -> np.array:
result = np.zeros((len(x), len(y)), dtype=int)
x_counts = [get_all_ngram_counts(d, ngram_min, ngram_max) for d in x]
y_counts = [get_all_ngram_counts(d, ngram_min, ngram_max) for d in y]
for i, xc in enumerate(x_counts):
xkeys = set(xc.keys())
for j, yc in enumerate(y_counts):
ykeys = set(yc.keys())
all_ngrams = set(xkeys).intersection(set(ykeys))
result[i, j] = sum([xc[ngram] * yc[ngram] for ngram in all_ngrams])
return result
return internal_spectrum_kernel
def legacy_spectrum_kernel(ngram_min: int, ngram_max: int) -> Callable:
"""Old implementation. Slower."""
def internal_legacy_spectrum_kernel(x: np.array, y: np.array) -> np.array:
result = np.zeros((len(x), len(y)), dtype=int)
for i, string in enumerate(x):
if type(string) == str:
string = string.lower()
ngrams: Dict[str, int] = defaultdict(int)
for ngram_size in range(ngram_min, ngram_max + 1):
for index in range(len(string) - ngram_size + 1):
ngram = string[index : index + ngram_size]
if type(ngram) != str:
ngram = str(ngram)
ngrams[ngram] += 1
for ngram in iterate_ngrams(string, ngram_min, ngram_max):
ngrams[ngram] += 1
for j, counterstring in enumerate(y):
if type(counterstring) == str:
counterstring = counterstring.lower()
for ngram_size in range(ngram_min, ngram_max + 1):
for index in range(len(counterstring) - ngram_size + 1):
ngram = counterstring[index : index + ngram_size]
if type(ngram) != str:
ngram = str(ngram)
result[i, j] += ngrams[ngram]
for ngram in iterate_ngrams(counterstring, ngram_min, ngram_max):
result[i, j] += ngrams[ngram]
return result
return internal_spectral_kernel
return internal_legacy_spectrum_kernel
def intersection_kernel(ngram_min: int, ngram_max: int) -> Callable:
"""Calculate the intersection kernel, Ionescu & Popescu 2017."""
def internal_kernel(x: np.array, y: np.array) -> np.array:
def internal_intersection_kernel(x: np.array, y: np.array) -> np.array:
result = np.zeros((len(x), len(y)), dtype=int)
x_counts = [get_all_ngram_counts(d, ngram_min, ngram_max) for d in x]
y_counts = [get_all_ngram_counts(d, ngram_min, ngram_max) for d in y]
for i, xc in enumerate(x_counts):
xkeys = set(xc.keys())
for j, yc in enumerate(y_counts):
ykeys = set(yc.keys())
common_ngrams = set(xkeys).intersection(set(ykeys))
result[i, j] = sum(
[min(xc[ngram], yc[ngram]) for ngram in common_ngrams]
)
return result
return internal_intersection_kernel
def legacy_intersection_kernel(ngram_min: int, ngram_max: int) -> Callable:
"""Old implementation. Slower."""
def internal_legacy_intersection_kernel(x: np.array, y: np.array) -> np.array:
result = np.zeros((len(x), len(y)), dtype=int)
for i, string in enumerate(x):
if type(string) == str:
string = string.lower()
ngrams: Dict[str, int] = defaultdict(int)
for ngram_size in range(ngram_min, ngram_max + 1):
for index in range(len(string) - ngram_size + 1):
ngram = string[index : index + ngram_size]
if type(ngram) != str:
ngram = str(ngram)
ngrams[ngram] += 1
for ngram in iterate_ngrams(string, ngram_min, ngram_max):
ngrams[ngram] += 1
for j, counterstring in enumerate(y):
if type(counterstring) == str:
counterstring = counterstring.lower()
ngrams2 = dict(ngrams)
for ngram_size in range(ngram_min, ngram_max + 1):
for index in range(len(counterstring) - ngram_size + 1):
ngram = counterstring[index : index + ngram_size]
if type(ngram) != str:
ngram = str(ngram)
if ngram in ngrams2 and ngrams2[ngram] > 0:
result[i, j] += 1
ngrams2[ngram] -= 1
for ngram in iterate_ngrams(counterstring, ngram_min, ngram_max):
if ngram in ngrams2 and ngrams2[ngram] > 0:
result[i, j] += 1
ngrams2[ngram] -= 1
return result
return internal_kernel
return internal_legacy_intersection_kernel
kernel_map = {
......@@ -112,6 +151,8 @@ kernel_map = {
class StringKernelTransformer(BaseEstimator, TransformerMixin):
"""
DEPRECATED: you should probably use the Scikit SVC instead.
Converts (string) documents to a similarity matrix (kernel).
Input (fit): List of m strings
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment