48 lines
978 B
Python
48 lines
978 B
Python
from m import *
|
|
|
|
from matplotlib import pyplot as plt
|
|
|
|
def __prep_conf_matr(data):
|
|
output_matrix = np.zeros([2, 2])
|
|
|
|
for p, r in zip(*data):
|
|
print(p, r)
|
|
output_matrix[np.argmax(p)][np.argmax(r)] += 1
|
|
|
|
return output_matrix
|
|
|
|
|
|
def __plot_conf_matr(data):
|
|
matr = __prep_conf_matr(data)
|
|
|
|
_, ax = plt.subplots()
|
|
|
|
ax.matshow(matr, cmap = plt.cm.Blues)
|
|
|
|
for i, x in enumerate(matr):
|
|
for j, y in enumerate(x):
|
|
ax.text(i,
|
|
j,
|
|
str(round(y)),
|
|
va = "center",
|
|
ha = "center")
|
|
|
|
plt.show()
|
|
|
|
ds = tf.keras.preprocessing.image_dataset_from_directory(
|
|
'../dataset-orig-aug-1-mini-1/',
|
|
labels = 'inferred',
|
|
label_mode = 'categorical',
|
|
color_mode = 'rgb',
|
|
image_size = (300, 300),
|
|
batch_size = 32,
|
|
verbose = True
|
|
)
|
|
|
|
#ds_short = ds.take(1)
|
|
|
|
p = mod.predict(ds)
|
|
r = np.concatenate([y for x, y in ds])
|
|
|
|
__plot_conf_matr([p, r])
|