import tensorflow as tf from tensorflow import keras from jiwer import wer import numpy as np from preprocessing import * def decode_batch_predictions(pred): input_len = np.ones(pred.shape[0]) * pred.shape[1] results = keras.backend.ctc_decode(pred, input_length = input_len, greedy = True)[0][0] output_text = [] for result in results: result = tf.strings.reduce_join(num_to_char(result)).numpy().decode("utf-8") output_text.append(result) return output_text class ce(keras.callbacks.Callback): def __init__(self, dataset, model): super().__init__() self.dataset = dataset self.____model = model def on_epoch_end(self, epoch, logs = None): predictions = [] targets = [] for batch in self.dataset: X, y = batch batch_predictions = self.____model.predict(X, verbose = 0) batch_predictions = decode_batch_predictions(batch_predictions) predictions.extend(batch_predictions) for label in y: label = ( tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8") ) targets.append(label) wer_score = wer(targets, predictions) print(f"Word Error Rate: {wer_score:.4f}") for i in np.random.randint(0, len(predictions), 10): print(f"Target : {targets[i]}") print(f"Prediction: {predictions[i]}")