Commit 90b0d3e4 authored by Luca Pasa's avatar Luca Pasa
Browse files

add sdr_compute

parent ddeede6e
......@@ -18,7 +18,7 @@ if __name__ == '__main__':
learningDecay = 1
momentum = 0.9
test_name="MULTI_GRID_100_Spectro_motion_Concat_DAE_4_speech_Test_lr-"+str(learningRate)+"_batch_size-"+str(batch_size)+"_n_hidden_encode-"+str(nHidden_encode)
data_path="/home/storage/Data/MULTI_GRID_100/rawMultiModalTfRec/"
data_path="/home/storage/Data/MULTI_GRID_100/rawAudioMotionTfRec/"
......
import tensorflow as tf
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))
from Data.Data_reader import DatabaseMultiSpeechReader as Data
from Utils.SDR_utils import bss_eval_sources as sdr
import numpy as np
generated_dataset_folder="/home/storage/Data/MULTI_GRID_100/gd/TRAIN/"
data_set_path = "/home/storage/Data/MULTI_GRID_100/multiModalTfRec/TRAIN_CTC_SENTENCES/"
from Model.DAE.Baseline_Models.Concat_DAE_4_speech import DAE_4_speech
def restore_and_get_model(graph, sess, batch_size, nIn_audio, nIn_video, nHidden, nHidden_encode, learningRate,
updating_step, learningDecay, momentum, ckpt_file): # i ckpt_file):
model = DAE_4_speech(graph=graph, sess=sess, n_in_audio=nIn_audio, n_in_video=nIn_video, n_hidden=nHidden,
n_hidden_encode=nHidden_encode, batch_size=batch_size, learning_rate=learningRate,
learning_decay=learningDecay,
momentum=momentum, updating_step=updating_step)
saver = tf.train.Saver()
saver.restore(sess, ckpt_file) # "./test_log/init_test-480.ckpt")
return model
def get_output(sess, model, input_audio, input_video, input_len):
model_out = sess.run(model.model.regression, feed_dict={model.x_audio_ph: input_audio,
model.x_video_ph: input_video,
model.x_audio_len_ph: input_len})
return model_out
if __name__ == '__main__':
num_epochs = 120
batch_size = 15
nIn_audio = 123
nIn_video = 134
nHidden = [123,500,600]
nHidden_encode = 750
learningRate = 0.001
traininglog_dir = "./"
updating_step = 2250
test_step= 10
learningDecay = 1
momentum = 0.9
ckpt_file = "../Baseline_Models/RESULT/BaseLine1/MULTI_GRID_100_Spectro_Concat_DAE_4_speech_Test_lr-0.001_batch_size-30_n_hidden_encode-750.ckpt-230"
graph = tf.Graph()
with graph.as_default():
with tf.Session(graph=graph) as sess:
model = restore_and_get_model(graph, sess, batch_size, nIn_audio, nIn_video, nHidden, nHidden_encode,
learningRate,
updating_step, learningDecay, momentum, ckpt_file)
data_set_dm = Data.dataManager(single_audio_frame_size=model.model.n_in_audio,
single_video_frame_size=model.model.n_in_video)
# read dataset
data_set = data_set_dm.get_dataset(data_set_path)
# get itarator
_, it_data = data_set_dm.get_iterator(data_set)
# init iterator
sess.run(it_data.initializer,
feed_dict={data_set_dm.batch_size_ph: batch_size, data_set_dm.n_epoch_ph: 1,
data_set_dm.buffer_size_ph: 2}) # for testing purpose we gonna run just one epoch obviously,
# moreover the batch size has to remain the same due to the
# problem with Hyper cell code
get_next = it_data.get_next()
batch_index=0
while (True):
try:
_, x_s_a, x_a_len, x_a, _, x_v, _, y = sess.run(get_next)
x_v = Data.video_batch_align(x_a, x_v)
output = get_output(sess, model, x_a, x_v, x_a_len)
for i,(pred,tar) in enumerate(zip(output,y)):
print sdr(reference_sources=tar,estimated_sources=pred)
batch_index+=1
except tf.errors.OutOfRangeError:
print("End of dataset")
break
This diff is collapsed.
This diff is collapsed.
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment