44 lines
1.5 KiB
Python
44 lines
1.5 KiB
Python
|
|
import tensorflow as tf
|
||
|
|
from tensorflow import keras
|
||
|
|
from jiwer import wer
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
from preprocessing import *
|
||
|
|
|
||
|
|
def decode_batch_predictions(pred):
|
||
|
|
input_len = np.ones(pred.shape[0]) * pred.shape[1]
|
||
|
|
results = keras.backend.ctc_decode(pred,
|
||
|
|
input_length = input_len,
|
||
|
|
greedy = True)[0][0]
|
||
|
|
output_text = []
|
||
|
|
for result in results:
|
||
|
|
result = tf.strings.reduce_join(num_to_char(result)).numpy().decode("utf-8")
|
||
|
|
output_text.append(result)
|
||
|
|
return output_text
|
||
|
|
|
||
|
|
class ce(keras.callbacks.Callback):
|
||
|
|
def __init__(self, dataset, model):
|
||
|
|
super().__init__()
|
||
|
|
self.dataset = dataset
|
||
|
|
self.____model = model
|
||
|
|
|
||
|
|
def on_epoch_end(self, epoch, logs = None):
|
||
|
|
predictions = []
|
||
|
|
targets = []
|
||
|
|
for batch in self.dataset:
|
||
|
|
X, y = batch
|
||
|
|
batch_predictions = self.____model.predict(X, verbose = 0)
|
||
|
|
batch_predictions = decode_batch_predictions(batch_predictions)
|
||
|
|
predictions.extend(batch_predictions)
|
||
|
|
for label in y:
|
||
|
|
label = (
|
||
|
|
tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
|
||
|
|
)
|
||
|
|
targets.append(label)
|
||
|
|
wer_score = wer(targets, predictions)
|
||
|
|
print(f"Word Error Rate: {wer_score:.4f}")
|
||
|
|
for i in np.random.randint(0, len(predictions), 10):
|
||
|
|
print(f"Target : {targets[i]}")
|
||
|
|
print(f"Prediction: {predictions[i]}")
|