add metrics, generalize code, remove old variables

This commit is contained in:
ІО-23 Шмуляр Олег 2025-10-11 22:52:07 +03:00
parent a5c85e5060
commit c6ad0969f8
4 changed files with 75 additions and 11 deletions

59
f.py
View File

@ -49,6 +49,39 @@ def __plot_conf_matr(m):
plt.show() plt.show()
def __conf_matr_to_binary(index, matr):
bm = np.zeros([2, 2])
for i, x in enumerate(matr):
for j, y in enumerate(x):
bm[int(index != i), int(index != j)] += y
return bm
def __calc_accuracy_for(index, bcm):
return (bcm[0,0] + bcm[1,1]) / (bcm[0,0] + bcm[0,1] + bcm[1,0] + bcm[1,1])
def __calc_precision_for(index, bcm):
return (bcm[0,0]) / (bcm[0,0] + bcm[0,1])
def __calc_recall_for(index, bcm):
return (bcm[0,0]) / (bcm[0,0] + bcm[1,0])
def __calc_specificity_for(index, bcm):
return (bcm[1,1]) / (bcm[0,1] + bcm[1,1])
def __calc_f1_score_for(index, bcm):
p = __calc_precision_for(index, bcm)
r = __calc_recall_for(index, bcm)
return 2 * p * r / (p + r)
def __plot_acc_rate(h): def __plot_acc_rate(h):
plt.plot(h.history['accuracy'], label = 'train_acc') plt.plot(h.history['accuracy'], label = 'train_acc')
plt.plot(h.history['val_accuracy'], label = 'valid_acc') plt.plot(h.history['val_accuracy'], label = 'valid_acc')
@ -74,8 +107,32 @@ def train(m, label):
m.save_weights(f"save-{label}.weights.h5") m.save_weights(f"save-{label}.weights.h5")
__plot_acc_rate(h) __plot_acc_rate(h)
def model_quality(m, label):
(_, _), (x_test, y_test) = __prep_data()
m.compile(optimizer = "adam",
loss = "categorical_crossentropy",
metrics = ["accuracy"])
m.load_weights(f"save-{label}.weights.h5")
__plot_conf_matr(m) __plot_conf_matr(m)
cm = __prep_conf_matr(m)
for i in range(10):
bcm = __conf_matr_to_binary(i, cm)
acc = __calc_accuracy_for(i, bcm)
pre = __calc_precision_for(i, bcm)
rec = __calc_recall_for(i, bcm)
f1s = __calc_f1_score_for(i, bcm)
spe = __calc_specificity_for(i, bcm)
print(f"{i}: acc={acc} pre={pre} rec={rec} f1s={f1s} spe={spe}")
def classify(m, label, imgfn): def classify(m, label, imgfn):
m.compile(optimizer = "adam", m.compile(optimizer = "adam",
@ -95,8 +152,6 @@ def classify(m, label, imgfn):
plt.title(np.argmax(res)) plt.title(np.argmax(res))
plt.show() plt.show()
put_active = 0
take_active = 0
def classify_live(m, label): def classify_live(m, label):
import lv import lv

9
nn1.py
View File

@ -11,9 +11,12 @@ m = tf.keras.models.Sequential([
l.Dense(10, activation = "softmax") l.Dense(10, activation = "softmax")
]) ])
#f.train(m, "1") model_label = "1"
#f.train(m, model_label)
f.model_quality(m, model_label)
if len(argv) == 2: if len(argv) == 2:
f.classify(m, "1", argv[1]) f.classify(m, model_label, argv[1])
else: else:
f.classify_live(m, "1") f.classify_live(m, model_label)

9
nn2.py
View File

@ -11,9 +11,12 @@ m = tf.keras.models.Sequential([
l.Dense(10, activation = "softmax") l.Dense(10, activation = "softmax")
]) ])
#f.train(m, "2") model_label = "2"
#f.train(m, model_label)
f.model_quality(m, model_label)
if len(argv) == 2: if len(argv) == 2:
f.classify(m, "2", argv[1]) f.classify(m, model_label, argv[1])
else: else:
f.classify_live(m, "2") f.classify_live(m, model_label)

9
nn3.py
View File

@ -12,9 +12,12 @@ m = tf.keras.models.Sequential([
l.Dense(10, activation = "softmax") l.Dense(10, activation = "softmax")
]) ])
#f.train(m, "3") model_label = "3"
#f.train(m, model_label)
f.model_quality(m, model_label)
if len(argv) == 2: if len(argv) == 2:
f.classify(m, "3", argv[1]) f.classify(m, model_label, argv[1])
else: else:
f.classify_live(m, "3") f.classify_live(m, model_label)