Commit 60e0d32b authored by Luca Pasa's avatar Luca Pasa
Browse files

add test for spectrogram GRID_100

parent 4015a91e
......@@ -2,3 +2,4 @@
*.pyc
datasets/*
Model/DAE/Restore/gd/*
*.npy
......@@ -97,12 +97,12 @@ def video_batch_align(audio_batch, video_batch):
if __name__ == '__main__':
path = '/home/storage/Data/MULTI_GRID/multiModalTfRec/TRAIN_CTC_SENTENCES/'
path = '/home/storage/Data/MULTI_GRID/rawMultiModalTfRec/TRAIN_CTC_SENTENCES/'
n_batch = 4
n_epoch = 5
buffer_size = 2
dm = dataManager()
dm = dataManager(single_audio_frame_size=257)
ds = dm.get_dataset(path)
......
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))
from Concat_DAE_4_speech import DAE_4_speech
import tensorflow as tf
if __name__ == '__main__':
num_epochs = 500
batch_size = 9
nIn_audio = 257
nIn_video = 134
nHidden = [500,600]
nHidden_encode = 750
learningRate = 0.001
traininglog_dir = "./"
updating_step = 2250
test_step= 10
learningDecay = 1
momentum = 0.9
test_name="GRID_100_Spectro_Concat_DAE_4_speech_Test_lr-"+str(learningRate)+"_batch_size-"+str(batch_size)+"_n_hidden_encode-"+str(nHidden_encode)
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
model = DAE_4_speech(sess=sess, graph=graph, n_in_audio=nIn_audio, n_in_video=nIn_video, n_hidden=nHidden,
n_hidden_encode=nHidden_encode, batch_size=batch_size, learning_rate=learningRate,
learning_decay=learningDecay, momentum=momentum, updating_step=updating_step)
model.training_model(training_set_path="/home/storage/Data/MULTI_GRID/rawMultiModalTfRec/TRAIN_CTC_SENTENCES/",
test_set_path="/home/storage/Data/MULTI_GRID/rawMultiModalTfRec/TEST_CTC_SENTENCES/",
validation_set_path="/home/storage/Data/MULTI_GRID/rawMultiModalTfRec/VAL_CTC_SENTENCES/",
n_epoch=num_epochs, test_step=test_step, test_name=test_name,
log_dir="./test_log/")
......@@ -97,7 +97,8 @@ class DAE_4_speech:
@define_scope("audio_acc")
def audio_acc(self):
return tf.contrib.metrics.streaming_pearson_correlation(self.model.regression,self.y_ph)
return tf.contrib.metrics.streaming_pearson_correlation(self.model.regression,self.y_ph), \
tf.norm(self.model.regression-self.y_ph)
# return np.linalg.norm(x-y)
def eval_training(self, data_set_path):
......@@ -121,6 +122,9 @@ class DAE_4_speech:
# problem with Hyper cell code
get_next = it_data.get_next()
pearson_cor_list=[]
ecl_dist_list=[]
while (True):
try:
......@@ -128,16 +132,21 @@ class DAE_4_speech:
x_len, x_a_val,_,x_ma_val,_, x_v_val, _, _ = self.sess.run(get_next)
x_v_val = Data.video_batch_align(x_a_val, x_v_val)
# compute model output
model_out, acc = self.sess.run([self.model.regression,self.audio_acc], feed_dict={self.x_audio_ph: x_ma_val,
model_out, (acc, ecl_dist) = self.sess.run([self.model.regression,self.audio_acc], feed_dict={self.x_audio_ph: x_ma_val,
self.x_video_ph: x_v_val,
self.x_audio_len_ph: x_len,
self.y_ph:x_a_val})
pearson_cor_list.append(acc[0])
ecl_dist_list.append(ecl_dist)
except tf.errors.OutOfRangeError:
print("End of test")
break
return acc[0]
return np.mean(np.asarray(pearson_cor_list)) , np.mean(np.asarray(ecl_dist_list))
def prepare_log_files(self, test_name, log_dir):
......@@ -148,7 +157,7 @@ class DAE_4_speech:
for f in (train_log, test_log, valid_log):
f.write("test_name: %s \n" % test_name)
f.write(str(datetime.datetime.now()) + '\n')
f.write("#epoch \t distance \t time \t epoch_cost \n")
f.write("#epoch \t pearson correlation \t euclidean distance \t time \t epoch_cost \n")
return train_log, test_log, valid_log
......@@ -218,17 +227,17 @@ class DAE_4_speech:
epoch_start_time = time.time()
if epoch_counter % test_step == 0 or n_processed_batch == 1:
training_ler = self.eval_training(data_set_path=training_set_path)
training_pc, training_ed = self.eval_training(data_set_path=training_set_path)
train_log.write(
"{:d}\t{:.8f}\t{:.8f}\t{:.8f}\n".format(epoch_counter, training_ler,epoch_time,
"{:d}\t{:.8f}\t{:.8f}\t{:.8f}\n".format(epoch_counter, training_pc, training_ed,epoch_time,
epoch_cost))
test_ler = self.eval_training(data_set_path=test_set_path)
test_pc,test_ed = self.eval_training(data_set_path=test_set_path)
test_log.write(
"{:d}\t{:.8f}\t{:.8f}\t{:.8f}\n".format(epoch_counter, test_ler, epoch_time,
"{:d}\t{:.8f}\t{:.8f}\t{:.8f}\n".format(epoch_counter, test_pc, test_ed, epoch_time,
epoch_cost))
valid_ler = self.eval_training(data_set_path=validation_set_path)
valid_pc, valid_ed = self.eval_training(data_set_path=validation_set_path)
valid_log.write(
"{:d}\t{:.8f}\t{:.8f}\t{:.8f}\n".format(epoch_counter, valid_ler, epoch_time,
"{:d}\t{:.8f}\t{:.8f}\t{:.8f}\n".format(epoch_counter, valid_pc, valid_ed, epoch_time,
epoch_cost))
saver.save(sess=self.sess, save_path=os.path.join(log_dir, test_name+".ckpt"),
......@@ -266,12 +275,12 @@ if __name__ == '__main__':
_n_epoch=10
graph = tf.Graph()
model = DAE_4_speech(graph, n_in_audio=nIn_audio, n_in_video=nIn_video, n_hidden=nHidden, n_hidden_encode=nHidden_encode,
batch_size=_batch_size, learning_rate=learningRate, learning_decay=learningDecay,
momentum=momentum, updating_step=updatingStep)
model.training_model(training_set_path="/home/storage/Data/MULTI_GRID/multiModalTfRec/TRAIN_CTC_SENTENCES/",
test_set_path="/home/storage/Data/MULTI_GRID/multiModalTfRec/TEST_CTC_SENTENCES/",
validation_set_path="/home/storage/Data/MULTI_GRID/multiModalTfRec/VAL_CTC_SENTENCES/",
n_epoch=_n_epoch, test_step=1, test_name="init_test",
log_dir="./")
with tf.Session(graph=graph) as _sess:
model = DAE_4_speech(_sess, graph, n_in_audio=nIn_audio, n_in_video=nIn_video, n_hidden=nHidden, n_hidden_encode=nHidden_encode,
batch_size=_batch_size, learning_rate=learningRate, learning_decay=learningDecay,
momentum=momentum, updating_step=updatingStep)
model.training_model(training_set_path="/home/storage/Data/MULTI_GRID/multiModalTfRec/TRAIN_CTC_SENTENCES/",
test_set_path="/home/storage/Data/MULTI_GRID/multiModalTfRec/TEST_CTC_SENTENCES/",
validation_set_path="/home/storage/Data/MULTI_GRID/multiModalTfRec/VAL_CTC_SENTENCES/",
n_epoch=_n_epoch, test_step=1, test_name="init_test",
log_dir="./")
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment