Commit 678de338 authored by Luca Pasa's avatar Luca Pasa
Browse files

Merge branch 'master' of gitlab.iit.it:lpasa/AV_ASR

parents c7641d11 81a69637
......@@ -17,7 +17,7 @@ if __name__ == '__main__':
test_step= 10
learningDecay = 1
momentum = 0.9
test_name="GRID_100_Concat_DAE_4_speech_Test_lr-"+str(learningRate)+"_batch_size-"+str(batch_size)+"_n_hidden_encode-"+str(nHidden_encode)
test_name="CPU_GRID_100_Concat_DAE_4_speech_Test_lr-"+str(learningRate)+"_batch_size-"+str(batch_size)+"_n_hidden_encode-"+str(nHidden_encode)
#Code for running on CPU
# config = tf.ConfigProto(device_count={'GPU': 0})
......@@ -28,7 +28,7 @@ if __name__ == '__main__':
n_hidden_encode=nHidden_encode, batch_size=batch_size, learning_rate=learningRate,
learning_decay=learningDecay, momentum=momentum, updating_step=updating_step)
model.restore_model("./RESULT/BaseLine1/GRID_100_Concat_DAE_4_speech_Test_lr-0.001_batch_size-15_n_hidden_encode-750.ckpt-80")
#model.restore_model("./RESULT/BaseLine1/GRID_100_Concat_DAE_4_speech_Test_lr-0.001_batch_size-15_n_hidden_encode-750.ckpt-80")
model.training_model(training_set_path="/home/storage/Data/MULTI_GRID_100/multiModalTfRec/TRAIN_CTC_SENTENCES/",
test_set_path="/home/storage/Data/MULTI_GRID_100/multiModalTfRec/TEST_CTC_SENTENCES/",
validation_set_path="/home/storage/Data/MULTI_GRID_100/multiModalTfRec/VAL_CTC_SENTENCES/",
......
......@@ -8,7 +8,7 @@ from Utils.SDR_utils import bss_eval_sources as sdr
import numpy as np
generated_dataset_folder="/home/storage/Data/MULTI_GRID_/gd/TRAIN/"
generated_dataset_folder="/home/storage/Data/MULTI_GRID_100/gd/TRAIN/"
data_set_path = "/home/storage/Data/MULTI_GRID_100/rawMultiModalTfRec/TEST_CTC_SENTENCES/"
......@@ -39,11 +39,12 @@ def get_output(sess, model, input_audio, input_video, input_len):
return model_out
if __name__ == '__main__':
num_epochs = 120
batch_size = 15
nIn_audio = 123
num_epochs = 500
batch_size = 30
nIn_audio = 257
nIn_video = 134
nHidden = [123,500,600]
nHidden = [500,600]
nHidden_encode = 750
learningRate = 0.001
traininglog_dir = "./"
......@@ -90,7 +91,9 @@ if __name__ == '__main__':
output = get_output(sess, model, x_a, x_v, x_a_len)
for i,(pred,tar) in enumerate(zip(output,y)):
for i,(pred,tar) in enumerate(zip(output,x_s_a)):
print pred.shape
print tar.shape
print sdr(reference_sources=tar,estimated_sources=pred)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment