Commit 81a69637 authored by Luca Pasa's avatar Luca Pasa
Browse files

debug

parent 6f472281
......@@ -8,7 +8,7 @@ from Utils.SDR_utils import bss_eval_sources as sdr
import numpy as np
generated_dataset_folder="/home/storage/Data/MULTI_GRID_/gd/TRAIN/"
generated_dataset_folder="/home/storage/Data/MULTI_GRID_100/gd/TRAIN/"
data_set_path = "/home/storage/Data/MULTI_GRID_100/rawMultiModalTfRec/TEST_CTC_SENTENCES/"
......@@ -39,11 +39,12 @@ def get_output(sess, model, input_audio, input_video, input_len):
return model_out
if __name__ == '__main__':
num_epochs = 120
batch_size = 15
nIn_audio = 123
num_epochs = 500
batch_size = 30
nIn_audio = 257
nIn_video = 134
nHidden = [123,500,600]
nHidden = [500,600]
nHidden_encode = 750
learningRate = 0.001
traininglog_dir = "./"
......@@ -90,7 +91,9 @@ if __name__ == '__main__':
output = get_output(sess, model, x_a, x_v, x_a_len)
for i,(pred,tar) in enumerate(zip(output,y)):
for i,(pred,tar) in enumerate(zip(output,x_s_a)):
print pred.shape
print tar.shape
print sdr(reference_sources=tar,estimated_sources=pred)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment