add metrics, generalize code, remove old variables
This commit is contained in:
		
							parent
							
								
									a5c85e5060
								
							
						
					
					
						commit
						c6ad0969f8
					
				
							
								
								
									
										59
									
								
								f.py
									
									
									
									
									
								
							
							
						
						
									
										59
									
								
								f.py
									
									
									
									
									
								
							@ -49,6 +49,39 @@ def __plot_conf_matr(m):
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def __conf_matr_to_binary(index, matr):
 | 
			
		||||
    bm = np.zeros([2, 2])
 | 
			
		||||
 | 
			
		||||
    for i, x in enumerate(matr):
 | 
			
		||||
        for j, y in enumerate(x):
 | 
			
		||||
            bm[int(index != i), int(index != j)] += y
 | 
			
		||||
 | 
			
		||||
    return bm
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def __calc_accuracy_for(index, bcm):
 | 
			
		||||
    return (bcm[0,0] + bcm[1,1]) / (bcm[0,0] + bcm[0,1] + bcm[1,0] + bcm[1,1])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def __calc_precision_for(index, bcm):
 | 
			
		||||
    return (bcm[0,0]) / (bcm[0,0] + bcm[0,1])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def __calc_recall_for(index, bcm):
 | 
			
		||||
    return (bcm[0,0]) / (bcm[0,0] + bcm[1,0])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def __calc_specificity_for(index, bcm):
 | 
			
		||||
    return (bcm[1,1]) / (bcm[0,1] + bcm[1,1])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def __calc_f1_score_for(index, bcm):
 | 
			
		||||
    p = __calc_precision_for(index, bcm)
 | 
			
		||||
    r = __calc_recall_for(index, bcm)
 | 
			
		||||
 | 
			
		||||
    return 2 * p * r / (p + r)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def __plot_acc_rate(h):
 | 
			
		||||
    plt.plot(h.history['accuracy'], label = 'train_acc')
 | 
			
		||||
    plt.plot(h.history['val_accuracy'], label = 'valid_acc')
 | 
			
		||||
@ -74,8 +107,32 @@ def train(m, label):
 | 
			
		||||
    m.save_weights(f"save-{label}.weights.h5")
 | 
			
		||||
 | 
			
		||||
    __plot_acc_rate(h)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def model_quality(m, label):
 | 
			
		||||
    (_, _), (x_test, y_test) = __prep_data()
 | 
			
		||||
 | 
			
		||||
    m.compile(optimizer = "adam",
 | 
			
		||||
              loss = "categorical_crossentropy",
 | 
			
		||||
              metrics = ["accuracy"])
 | 
			
		||||
 | 
			
		||||
    m.load_weights(f"save-{label}.weights.h5")
 | 
			
		||||
 | 
			
		||||
    __plot_conf_matr(m)
 | 
			
		||||
 | 
			
		||||
    cm = __prep_conf_matr(m)
 | 
			
		||||
 | 
			
		||||
    for i in range(10):
 | 
			
		||||
        bcm = __conf_matr_to_binary(i, cm)
 | 
			
		||||
 | 
			
		||||
        acc = __calc_accuracy_for(i, bcm)
 | 
			
		||||
        pre = __calc_precision_for(i, bcm)
 | 
			
		||||
        rec = __calc_recall_for(i, bcm)
 | 
			
		||||
        f1s = __calc_f1_score_for(i, bcm)
 | 
			
		||||
        spe = __calc_specificity_for(i, bcm)
 | 
			
		||||
 | 
			
		||||
        print(f"{i}: acc={acc} pre={pre} rec={rec} f1s={f1s} spe={spe}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def classify(m, label, imgfn):
 | 
			
		||||
    m.compile(optimizer = "adam",
 | 
			
		||||
@ -95,8 +152,6 @@ def classify(m, label, imgfn):
 | 
			
		||||
    plt.title(np.argmax(res))
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
put_active = 0
 | 
			
		||||
take_active = 0
 | 
			
		||||
 | 
			
		||||
def classify_live(m, label):
 | 
			
		||||
    import lv
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										9
									
								
								nn1.py
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								nn1.py
									
									
									
									
									
								
							@ -11,9 +11,12 @@ m = tf.keras.models.Sequential([
 | 
			
		||||
    l.Dense(10, activation = "softmax")
 | 
			
		||||
])
 | 
			
		||||
 | 
			
		||||
#f.train(m, "1")
 | 
			
		||||
model_label = "1"
 | 
			
		||||
 | 
			
		||||
#f.train(m, model_label)
 | 
			
		||||
f.model_quality(m, model_label)
 | 
			
		||||
 | 
			
		||||
if len(argv) == 2:
 | 
			
		||||
    f.classify(m, "1", argv[1])
 | 
			
		||||
    f.classify(m, model_label, argv[1])
 | 
			
		||||
else:
 | 
			
		||||
    f.classify_live(m, "1")
 | 
			
		||||
    f.classify_live(m, model_label)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										9
									
								
								nn2.py
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								nn2.py
									
									
									
									
									
								
							@ -11,9 +11,12 @@ m = tf.keras.models.Sequential([
 | 
			
		||||
    l.Dense(10, activation = "softmax")
 | 
			
		||||
])
 | 
			
		||||
 | 
			
		||||
#f.train(m, "2")
 | 
			
		||||
model_label = "2"
 | 
			
		||||
 | 
			
		||||
#f.train(m, model_label)
 | 
			
		||||
f.model_quality(m, model_label)
 | 
			
		||||
 | 
			
		||||
if len(argv) == 2:
 | 
			
		||||
    f.classify(m, "2", argv[1])
 | 
			
		||||
    f.classify(m, model_label, argv[1])
 | 
			
		||||
else:
 | 
			
		||||
    f.classify_live(m, "2")
 | 
			
		||||
    f.classify_live(m, model_label)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										9
									
								
								nn3.py
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								nn3.py
									
									
									
									
									
								
							@ -12,9 +12,12 @@ m = tf.keras.models.Sequential([
 | 
			
		||||
    l.Dense(10, activation = "softmax")
 | 
			
		||||
])
 | 
			
		||||
 | 
			
		||||
#f.train(m, "3")
 | 
			
		||||
model_label = "3"
 | 
			
		||||
 | 
			
		||||
#f.train(m, model_label)
 | 
			
		||||
f.model_quality(m, model_label)
 | 
			
		||||
 | 
			
		||||
if len(argv) == 2:
 | 
			
		||||
    f.classify(m, "3", argv[1])
 | 
			
		||||
    f.classify(m, model_label, argv[1])
 | 
			
		||||
else:
 | 
			
		||||
    f.classify_live(m, "3")
 | 
			
		||||
    f.classify_live(m, model_label)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user