from m import * from matplotlib import pyplot as plt def __prep_conf_matr(data): output_matrix = np.zeros([2, 2]) for p, r in zip(*data): print(p, r) output_matrix[np.argmax(p)][np.argmax(r)] += 1 return output_matrix def __plot_conf_matr(data): matr = __prep_conf_matr(data) _, ax = plt.subplots() ax.matshow(matr, cmap = plt.cm.Blues) for i, x in enumerate(matr): for j, y in enumerate(x): ax.text(i, j, str(round(y)), va = "center", ha = "center") plt.show() ds = tf.keras.preprocessing.image_dataset_from_directory( '../dataset-orig-aug-1-mini-1/', labels = 'inferred', label_mode = 'categorical', color_mode = 'rgb', image_size = (300, 300), batch_size = 32, verbose = True ) #ds_short = ds.take(1) p = mod.predict(ds) r = np.concatenate([y for x, y in ds]) __plot_conf_matr([p, r])