add metrics, generalize code, remove old variables
This commit is contained in:
parent
a5c85e5060
commit
c6ad0969f8
59
f.py
59
f.py
@ -49,6 +49,39 @@ def __plot_conf_matr(m):
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def __conf_matr_to_binary(index, matr):
|
||||||
|
bm = np.zeros([2, 2])
|
||||||
|
|
||||||
|
for i, x in enumerate(matr):
|
||||||
|
for j, y in enumerate(x):
|
||||||
|
bm[int(index != i), int(index != j)] += y
|
||||||
|
|
||||||
|
return bm
|
||||||
|
|
||||||
|
|
||||||
|
def __calc_accuracy_for(index, bcm):
|
||||||
|
return (bcm[0,0] + bcm[1,1]) / (bcm[0,0] + bcm[0,1] + bcm[1,0] + bcm[1,1])
|
||||||
|
|
||||||
|
|
||||||
|
def __calc_precision_for(index, bcm):
|
||||||
|
return (bcm[0,0]) / (bcm[0,0] + bcm[0,1])
|
||||||
|
|
||||||
|
|
||||||
|
def __calc_recall_for(index, bcm):
|
||||||
|
return (bcm[0,0]) / (bcm[0,0] + bcm[1,0])
|
||||||
|
|
||||||
|
|
||||||
|
def __calc_specificity_for(index, bcm):
|
||||||
|
return (bcm[1,1]) / (bcm[0,1] + bcm[1,1])
|
||||||
|
|
||||||
|
|
||||||
|
def __calc_f1_score_for(index, bcm):
|
||||||
|
p = __calc_precision_for(index, bcm)
|
||||||
|
r = __calc_recall_for(index, bcm)
|
||||||
|
|
||||||
|
return 2 * p * r / (p + r)
|
||||||
|
|
||||||
|
|
||||||
def __plot_acc_rate(h):
|
def __plot_acc_rate(h):
|
||||||
plt.plot(h.history['accuracy'], label = 'train_acc')
|
plt.plot(h.history['accuracy'], label = 'train_acc')
|
||||||
plt.plot(h.history['val_accuracy'], label = 'valid_acc')
|
plt.plot(h.history['val_accuracy'], label = 'valid_acc')
|
||||||
@ -74,8 +107,32 @@ def train(m, label):
|
|||||||
m.save_weights(f"save-{label}.weights.h5")
|
m.save_weights(f"save-{label}.weights.h5")
|
||||||
|
|
||||||
__plot_acc_rate(h)
|
__plot_acc_rate(h)
|
||||||
|
|
||||||
|
|
||||||
|
def model_quality(m, label):
|
||||||
|
(_, _), (x_test, y_test) = __prep_data()
|
||||||
|
|
||||||
|
m.compile(optimizer = "adam",
|
||||||
|
loss = "categorical_crossentropy",
|
||||||
|
metrics = ["accuracy"])
|
||||||
|
|
||||||
|
m.load_weights(f"save-{label}.weights.h5")
|
||||||
|
|
||||||
__plot_conf_matr(m)
|
__plot_conf_matr(m)
|
||||||
|
|
||||||
|
cm = __prep_conf_matr(m)
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
bcm = __conf_matr_to_binary(i, cm)
|
||||||
|
|
||||||
|
acc = __calc_accuracy_for(i, bcm)
|
||||||
|
pre = __calc_precision_for(i, bcm)
|
||||||
|
rec = __calc_recall_for(i, bcm)
|
||||||
|
f1s = __calc_f1_score_for(i, bcm)
|
||||||
|
spe = __calc_specificity_for(i, bcm)
|
||||||
|
|
||||||
|
print(f"{i}: acc={acc} pre={pre} rec={rec} f1s={f1s} spe={spe}")
|
||||||
|
|
||||||
|
|
||||||
def classify(m, label, imgfn):
|
def classify(m, label, imgfn):
|
||||||
m.compile(optimizer = "adam",
|
m.compile(optimizer = "adam",
|
||||||
@ -95,8 +152,6 @@ def classify(m, label, imgfn):
|
|||||||
plt.title(np.argmax(res))
|
plt.title(np.argmax(res))
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
put_active = 0
|
|
||||||
take_active = 0
|
|
||||||
|
|
||||||
def classify_live(m, label):
|
def classify_live(m, label):
|
||||||
import lv
|
import lv
|
||||||
|
9
nn1.py
9
nn1.py
@ -11,9 +11,12 @@ m = tf.keras.models.Sequential([
|
|||||||
l.Dense(10, activation = "softmax")
|
l.Dense(10, activation = "softmax")
|
||||||
])
|
])
|
||||||
|
|
||||||
#f.train(m, "1")
|
model_label = "1"
|
||||||
|
|
||||||
|
#f.train(m, model_label)
|
||||||
|
f.model_quality(m, model_label)
|
||||||
|
|
||||||
if len(argv) == 2:
|
if len(argv) == 2:
|
||||||
f.classify(m, "1", argv[1])
|
f.classify(m, model_label, argv[1])
|
||||||
else:
|
else:
|
||||||
f.classify_live(m, "1")
|
f.classify_live(m, model_label)
|
||||||
|
9
nn2.py
9
nn2.py
@ -11,9 +11,12 @@ m = tf.keras.models.Sequential([
|
|||||||
l.Dense(10, activation = "softmax")
|
l.Dense(10, activation = "softmax")
|
||||||
])
|
])
|
||||||
|
|
||||||
#f.train(m, "2")
|
model_label = "2"
|
||||||
|
|
||||||
|
#f.train(m, model_label)
|
||||||
|
f.model_quality(m, model_label)
|
||||||
|
|
||||||
if len(argv) == 2:
|
if len(argv) == 2:
|
||||||
f.classify(m, "2", argv[1])
|
f.classify(m, model_label, argv[1])
|
||||||
else:
|
else:
|
||||||
f.classify_live(m, "2")
|
f.classify_live(m, model_label)
|
||||||
|
9
nn3.py
9
nn3.py
@ -12,9 +12,12 @@ m = tf.keras.models.Sequential([
|
|||||||
l.Dense(10, activation = "softmax")
|
l.Dense(10, activation = "softmax")
|
||||||
])
|
])
|
||||||
|
|
||||||
#f.train(m, "3")
|
model_label = "3"
|
||||||
|
|
||||||
|
#f.train(m, model_label)
|
||||||
|
f.model_quality(m, model_label)
|
||||||
|
|
||||||
if len(argv) == 2:
|
if len(argv) == 2:
|
||||||
f.classify(m, "3", argv[1])
|
f.classify(m, model_label, argv[1])
|
||||||
else:
|
else:
|
||||||
f.classify_live(m, "3")
|
f.classify_live(m, model_label)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user