initial commit (pre ctc)
This commit is contained in:
commit
d998a04d47
43
cc.py
Normal file
43
cc.py
Normal file
@ -0,0 +1,43 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from jiwer import wer
|
||||
|
||||
import numpy as np
|
||||
|
||||
from preprocessing import *
|
||||
|
||||
def decode_batch_predictions(pred):
|
||||
input_len = np.ones(pred.shape[0]) * pred.shape[1]
|
||||
results = keras.backend.ctc_decode(pred,
|
||||
input_length = input_len,
|
||||
greedy = True)[0][0]
|
||||
output_text = []
|
||||
for result in results:
|
||||
result = tf.strings.reduce_join(num_to_char(result)).numpy().decode("utf-8")
|
||||
output_text.append(result)
|
||||
return output_text
|
||||
|
||||
class ce(keras.callbacks.Callback):
|
||||
def __init__(self, dataset, model):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
self.____model = model
|
||||
|
||||
def on_epoch_end(self, epoch, logs = None):
|
||||
predictions = []
|
||||
targets = []
|
||||
for batch in self.dataset:
|
||||
X, y = batch
|
||||
batch_predictions = self.____model.predict(X, verbose = 0)
|
||||
batch_predictions = decode_batch_predictions(batch_predictions)
|
||||
predictions.extend(batch_predictions)
|
||||
for label in y:
|
||||
label = (
|
||||
tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
|
||||
)
|
||||
targets.append(label)
|
||||
wer_score = wer(targets, predictions)
|
||||
print(f"Word Error Rate: {wer_score:.4f}")
|
||||
for i in np.random.randint(0, len(predictions), 10):
|
||||
print(f"Target : {targets[i]}")
|
||||
print(f"Prediction: {predictions[i]}")
|
||||
13
loss.py
Normal file
13
loss.py
Normal file
@ -0,0 +1,13 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
|
||||
def CTCLoss(y_true, y_pred):
|
||||
batch_len = tf.cast(tf.shape(y_true)[0], dtype = "int64")
|
||||
input_length = tf.cast(tf.shape(y_pred)[1], dtype = "int64")
|
||||
label_length = tf.cast(tf.shape(y_true)[1], dtype = "int64")
|
||||
|
||||
input_length = input_length * tf.ones(shape = (batch_len, 1), dtype = "int64")
|
||||
label_length = label_length * tf.ones(shape = (batch_len, 1), dtype = "int64")
|
||||
|
||||
loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
|
||||
return loss
|
||||
58
model.py
Normal file
58
model.py
Normal file
@ -0,0 +1,58 @@
|
||||
from tensorflow.keras import layers as kl
|
||||
from tensorflow.keras import models as km
|
||||
from tensorflow.keras import losses as ks
|
||||
from tensorflow.keras import optimizers as ko
|
||||
from tensorflow.keras import callbacks as kc
|
||||
|
||||
from tensorflow import keras
|
||||
|
||||
from loss import CTCLoss
|
||||
|
||||
|
||||
def model(input_dim, output_dim, rnn_layers = 3, rnn_units = 72):
|
||||
li = kl.Input((None, input_dim))
|
||||
l1 = kl.Reshape((-1, input_dim, 1))(li)
|
||||
|
||||
l2 = kl.Conv2D(32,
|
||||
kernel_size = [11, 41],
|
||||
strides = [2, 2],
|
||||
padding = 'same',
|
||||
use_bias = False)(l1)
|
||||
l3 = kl.BatchNormalization()(l2)
|
||||
l4 = kl.ReLU()(l3)
|
||||
|
||||
l5 = kl.Conv2D(32,
|
||||
kernel_size = [11, 21],
|
||||
strides = [1, 2],
|
||||
padding = 'same',
|
||||
use_bias = False)(l4)
|
||||
l6 = kl.BatchNormalization()(l5)
|
||||
l7 = kl.ReLU()(l6)
|
||||
|
||||
lb = kl.Reshape((-1, l7.shape[-2] * l7.shape[-1]))(l7)
|
||||
|
||||
for i in range(rnn_layers):
|
||||
i += 1
|
||||
|
||||
r = kl.GRU(rnn_units,
|
||||
activation = 'tanh',
|
||||
recurrent_activation = 'sigmoid',
|
||||
use_bias = True,
|
||||
return_sequences = True,
|
||||
reset_after = True)
|
||||
|
||||
lb = kl.Bidirectional(r,
|
||||
merge_mode = 'concat')(lb)
|
||||
|
||||
if i < rnn_layers:
|
||||
lb = kl.Dropout(rate=0.5)(lb)
|
||||
|
||||
lc1 = kl.Dense(rnn_units * 2, activation = 'relu')(lb)
|
||||
lc2 = kl.Dropout(0.5)(lc1)
|
||||
lo = kl.Dense(output_dim + 1, activation = 'softmax')(lc2)
|
||||
|
||||
m = keras.Model(li, lo)
|
||||
m.compile(optimizer = ko.Lion(0.0004),
|
||||
loss = CTCLoss)
|
||||
|
||||
return m
|
||||
45
preprocessing.py
Normal file
45
preprocessing.py
Normal file
@ -0,0 +1,45 @@
|
||||
from tensorflow import keras
|
||||
import tensorflow as tf
|
||||
|
||||
characters = list("abcdefghijklmnopqrstuvwxyz'?! ")
|
||||
|
||||
char_to_num = keras.layers.StringLookup(vocabulary = characters,
|
||||
oov_token = "")
|
||||
|
||||
num_to_char = keras.layers.StringLookup(vocabulary = char_to_num.get_vocabulary(),
|
||||
oov_token = "",
|
||||
invert = True)
|
||||
|
||||
frame_length = 256
|
||||
frame_step = 160
|
||||
fft_length = 384
|
||||
|
||||
wavs = '/mnt/tmpfs1/LJSpeech-1.1/wavs/'
|
||||
|
||||
def encode_single_sample(wav, label):
|
||||
file = tf.io.read_file(wavs + wav + ".wav")
|
||||
audio, _ = tf.audio.decode_wav(file)
|
||||
audio = tf.squeeze(audio, axis = -1)
|
||||
audio = tf.cast(audio, tf.float32)
|
||||
spectrogram = tf.signal.stft(audio,
|
||||
frame_length = frame_length,
|
||||
frame_step = frame_step,
|
||||
fft_length = fft_length)
|
||||
spectrogram = tf.abs(spectrogram)
|
||||
spectrogram = tf.math.pow(spectrogram, 0.5)
|
||||
means = tf.math.reduce_mean(spectrogram, 1, keepdims=True)
|
||||
stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True)
|
||||
spectrogram = (spectrogram - means) / (stddevs + 1e-10)
|
||||
label = tf.strings.lower(label)
|
||||
label = tf.strings.unicode_split(label, input_encoding="UTF-8")
|
||||
label = char_to_num(label)
|
||||
return spectrogram, label
|
||||
|
||||
def to_dataset(df, batch_size = 32):
|
||||
ds = tf.data.Dataset.from_tensor_slices((list(df["file_name"]),
|
||||
list(df["normalized_transcription"])))
|
||||
ds = ds.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE) \
|
||||
.padded_batch(batch_size) \
|
||||
.prefetch(buffer_size=tf.data.AUTOTUNE)
|
||||
|
||||
return ds
|
||||
41
train.py
Normal file
41
train.py
Normal file
@ -0,0 +1,41 @@
|
||||
from model import *
|
||||
from preprocessing import *
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from cc import ce
|
||||
|
||||
bs = 32
|
||||
|
||||
data = pd.read_csv("/mnt/tmpfs1/LJSpeech-1.1/metadata.csv",
|
||||
sep = '|',
|
||||
header = None,
|
||||
quoting = 3,
|
||||
names = ['file_name', 'i', 'normalized_transcription'])
|
||||
s = int(len(data) // 2 * 0.90)
|
||||
train_data = data[:s]
|
||||
valid_data = data[s:len(data) // 2]
|
||||
|
||||
train_ds = to_dataset(train_data, batch_size = bs)
|
||||
valid_ds = to_dataset(valid_data, batch_size = bs)
|
||||
|
||||
m = model(input_dim = fft_length // 2 + 1,
|
||||
output_dim = char_to_num.vocabulary_size())
|
||||
|
||||
m.load_weights('model20-latest.keras')
|
||||
ckpt1 = kc.ModelCheckpoint('model21-latest.keras',
|
||||
monitor = 'val_loss',
|
||||
save_best_only = False,
|
||||
verbose = 1)
|
||||
|
||||
ckpt2 = kc.ModelCheckpoint('model21-best.keras',
|
||||
monitor = 'val_loss',
|
||||
save_best_only = True,
|
||||
verbose = 1)
|
||||
|
||||
ce1 = ce(valid_ds, m)
|
||||
|
||||
m.fit(train_ds,
|
||||
epochs = 8,
|
||||
validation_data = valid_ds,
|
||||
callbacks = [ckpt1, ckpt2, ce1])
|
||||
Loading…
x
Reference in New Issue
Block a user