incremental update
This commit is contained in:
parent
e0ea8c5387
commit
47a5a196e2
28
detect.py
Normal file
28
detect.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from sys import argv, exit
|
||||||
|
|
||||||
|
if len(argv) != 2:
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
from model import *
|
||||||
|
from preprocessing import *
|
||||||
|
from cc import decode_batch_predictions
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from spellchecker import SpellChecker
|
||||||
|
|
||||||
|
sc = SpellChecker()
|
||||||
|
|
||||||
|
m = model(input_dim = fft_length // 2 + 1,
|
||||||
|
output_dim = char_to_num.vocabulary_size())
|
||||||
|
|
||||||
|
m.load_weights('model41-best.keras')
|
||||||
|
|
||||||
|
sg, _ = encode_single_sample_selectable_dir(argv[1], "")
|
||||||
|
|
||||||
|
seq = m.predict(np.array([sg]))
|
||||||
|
|
||||||
|
dc = decode_batch_predictions(seq)[0]
|
||||||
|
print(f"Decode : {dc}")
|
||||||
|
|
||||||
|
cdc = ' '.join([sc.correction(i) if sc.correction(i) else i for i in dc.split()])
|
||||||
|
print(f"Correct: {cdc}")
|
||||||
@ -17,7 +17,11 @@ fft_length = 384
|
|||||||
wavs = '/mnt/tmpfs1/LJSpeech-1.1/wavs/'
|
wavs = '/mnt/tmpfs1/LJSpeech-1.1/wavs/'
|
||||||
|
|
||||||
def encode_single_sample(wav, label):
|
def encode_single_sample(wav, label):
|
||||||
file = tf.io.read_file(wavs + wav + ".wav")
|
# for backward compatibility
|
||||||
|
encode_single_sample_selectable_dir(wavs + wav + ".wav", label)
|
||||||
|
|
||||||
|
def encode_single_sample_selectable_dir(wav, label):
|
||||||
|
file = tf.io.read_file(wav)
|
||||||
audio, _ = tf.audio.decode_wav(file)
|
audio, _ = tf.audio.decode_wav(file)
|
||||||
audio = tf.squeeze(audio, axis = -1)
|
audio = tf.squeeze(audio, axis = -1)
|
||||||
audio = tf.cast(audio, tf.float32)
|
audio = tf.cast(audio, tf.float32)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user