diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 70e6e15c16c26da5f6d89de982ee8ec89f548df3..403fff14b4069f035ebfc4b9bbb6da64a8095a90 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -20,11 +20,13 @@ before_script: - virtualenv venv - source venv/bin/activate - pip install setuptools wheel + - pip install torch torchvision torchaudio test: stage: test script: - pip install . + - pip install soundfile - python -m unittest discover tests deploy to package registry: diff --git a/setup.py b/setup.py index 4fa1c8b4e061154927fb145d8c511e51acb97686..4032ee93be67afabc282a98e60998e92d46cdbbd 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,7 @@ setup( packages=find_packages(), install_requires=[ 'levenshtein', + 'phonemizer', 'transformers' ], python_requires='>=3.8' diff --git a/src/lps/lps.py b/src/lps/lps.py index 6ffbcd60d66eb559d7bab673c07bd21c90591a58..192bf0f2df27e6f67981fda5ae360e1972bffa42 100644 --- a/src/lps/lps.py +++ b/src/lps/lps.py @@ -22,8 +22,12 @@ class PhonemePredictor(Module): return self.processor.batch_decode(predicted_ids) -def lps(self, sample: np.ndarray, reference=None, sampling_frequency=16000) -> float: - sample_phonems = self.phoneme_predictor.forward(sample)[0].replace(" ", "") - ref_phonems = self.phoneme_predictor.forward(reference)[0].replace(" ", "") - lev_distance = distance(sample_phonems, ref_phonems) - return 1 - lev_distance / len(ref_phonems) +class LevenshteinPhonemeSimilarity: + def __init__(self): + self.phoneme_predictor = PhonemePredictor() + + def __call__(self, sample: np.ndarray, reference: np.ndarray) -> float: + sample_phonems = self.phoneme_predictor.forward(sample)[0].replace(" ", "") + ref_phonems = self.phoneme_predictor.forward(reference)[0].replace(" ", "") + lev_distance = distance(sample_phonems, ref_phonems) + return 1 - lev_distance / len(ref_phonems) diff --git a/tests/resources/speech.wav b/tests/resources/speech.wav new file mode 100644 index 0000000000000000000000000000000000000000..0fa4e9e7f978e56a655e754d068b69d625dc9a3f Binary files /dev/null and b/tests/resources/speech.wav differ diff --git a/tests/resources/speech_bab_0dB.wav b/tests/resources/speech_bab_0dB.wav new file mode 100644 index 0000000000000000000000000000000000000000..1bed1071a5d07dfc667b5e2d06ccbec0775fd6a9 Binary files /dev/null and b/tests/resources/speech_bab_0dB.wav differ diff --git a/tests/test_levenshtein.py b/tests/test_levenshtein.py index 8ce7a613a51ada64f7fcfc911cfafe2cbaa290b4..a5c2b19dbf8d9fd953e8e88796a160116a6164b0 100644 --- a/tests/test_levenshtein.py +++ b/tests/test_levenshtein.py @@ -1,7 +1,18 @@ from unittest import TestCase from Levenshtein import distance +from lps.lps import LevenshteinPhonemeSimilarity +import torchaudio +import soundfile as sf class TestLevenshtein(TestCase): def test_levenshtein(self): self.assertEqual(distance("foo", "nooo"), 2) + + def test_levenshtein_sim(self): + (ref, _), (sample, _) = torchaudio.load("resources/speech.wav"), torchaudio.load("resources/speech_bab_0dB.wav") + lps = LevenshteinPhonemeSimilarity() + ref = ref.numpy() + sample = sample.numpy() + self.assertEqual(lps(ref, ref), 1.0) + self.assertLess(lps(ref, sample), 1.0)