14 lines
543 B
Python
14 lines
543 B
Python
import tensorflow as tf
|
|
from tensorflow import keras
|
|
|
|
def CTCLoss(y_true, y_pred):
|
|
batch_len = tf.cast(tf.shape(y_true)[0], dtype = "int64")
|
|
input_length = tf.cast(tf.shape(y_pred)[1], dtype = "int64")
|
|
label_length = tf.cast(tf.shape(y_true)[1], dtype = "int64")
|
|
|
|
input_length = input_length * tf.ones(shape = (batch_len, 1), dtype = "int64")
|
|
label_length = label_length * tf.ones(shape = (batch_len, 1), dtype = "int64")
|
|
|
|
loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
|
|
return loss
|