forked from emreaksan/deepwriting
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtf_evaluate_hw.py
More file actions
315 lines (257 loc) · 16.5 KB
/
Copy pathtf_evaluate_hw.py
File metadata and controls
315 lines (257 loc) · 16.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
import tensorflow as tf
import numpy as np
import sys
import os
import argparse
import json
from scipy.misc import imsave
from tf_dataset_hw import *
from tf_models import VRNNGMM
from tf_models_hw import HandwritingVRNNGmmModel, HandwritingVRNNModel
from utils_visualization import plot_latent_variables, plot_latent_categorical_variables, plot_matrix_and_get_image, plot_and_get_image
import visualize_hw as visualize
# Sampling options
run_gmm_eval = False # Visualize GMM latent space by using random samples and T-SNE.
run_original_sample = True # Save an image of reference samples (see reference_sample_ids).
run_reconstruction = False # Reconstruct reference samples and save reconstruction results.
run_biased_sampling = True # Use a real reference sample to infer style (see reference_sample_ids) and synthesize the given text (see conditional_texts).
run_unbiased_sampling = True # Use a random style to synthesize the given text (see conditional_texts).
run_colored_png_output = False # Save colored images (see line 47). For now we use end-of-character probabilities to assign new colors.
# Sampling hyper-parameters
eoc_threshold = 0.05
cursive_threshold = 0.005
ref_len = None # Use the whole sequence.
seq_len = 800 # Maximum number of steps.
gmm_num_samples = 500 # For run_gmm_eval only.
# Text to be written by the model.
conditional_texts = ["I am a synthetic sample", "I can write this line in so many styles."]
# Indices of reference style samples from validation split.
reference_sample_ids = [107, 226, 696]
# Concatenate reference sample with synthetic sample to make a direct comparison.
concat_ref_and_synthetic_samples = False
# Sampling output options
plot_eoc = False # Plot end-of-character probabilities.
plot_latent_vars = False # Plot a matrix of approximate posterior and prior mu values.
save_plots = True # Save plots as image.
show_plots = False # Show plots in a window.
def plot_eval_details(data_dict, sample, save_dir, save_name):
visualize.draw_stroke_svg(sample, factor=0.001, svg_filename=os.path.join(save_dir, save_name + '.svg'))
plot_data = {}
if run_colored_png_output:
synthetic_eoc = np.squeeze(data_dict['out_eoc'])
visualize.draw_stroke_svg(sample, factor=0.001, color_labels=synthetic_eoc > eoc_threshold,
svg_filename=os.path.join(save_dir, save_name + '_colored.svg'))
if plot_latent_vars and 'p_mu' in data_dict:
plot_data['p_mu'] = np.transpose(data_dict['p_mu'][0], [1, 0])
plot_data['q_mu'] = np.transpose(data_dict['q_mu'][0], [1, 0])
plot_data['q_sigma'] = np.transpose(data_dict['q_sigma'][0], [1, 0])
plot_data['p_sigma'] = np.transpose(data_dict['p_sigma'][0], [1, 0])
plot_img = plot_latent_variables(plot_data, show_plot=show_plots)
if save_plots:
imsave(os.path.join(save_dir, save_name + '_normal.png'), plot_img)
if plot_latent_vars and 'p_pi' in data_dict:
plot_data['p_pi'] = np.transpose(data_dict['p_pi'][0], [1, 0])
plot_data['q_pi'] = np.transpose(data_dict['q_pi'][0], [1, 0])
plot_img = plot_latent_categorical_variables(plot_data, show_plot=show_plots)
if save_plots:
imsave(os.path.join(save_dir, save_name + '_pi.png'), plot_img)
if plot_eoc and 'out_eoc' in data_dict:
plot_img = plot_and_get_image(np.squeeze(data_dict['out_eoc']))
imsave(os.path.join(save_dir, save_name + '_eoc.png'), plot_img)
# Same for every sample.
if 'gmm_mu' in data_dict:
gmm_mu_img = plot_matrix_and_get_image(data_dict['gmm_mu'])
gmm_sigma_img = plot_matrix_and_get_image(data_dict['gmm_sigma'])
if save_plots:
imsave(os.path.join(save_dir, 'gmm_mu.png'), gmm_mu_img)
imsave(os.path.join(save_dir, 'gmm_sigma.png'), gmm_sigma_img)
return plot_data
def do_evaluation(config, qualitative_analysis=True, quantitative_analysis=True, verbose=0):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
Model_cls = getattr(sys.modules[__name__], config['model_cls'])
Dataset_cls = getattr(sys.modules[__name__], config['dataset_cls'])
batch_size = 1
data_sequence_length = None
# Load validation dataset to fetch statistics.
if issubclass(Dataset_cls, HandWritingDatasetConditional):
validation_dataset = Dataset_cls(config['validation_data'], var_len_seq=True, use_bow_labels=config['use_bow_labels'])
elif issubclass(Dataset_cls, HandWritingDataset):
validation_dataset = Dataset_cls(config['validation_data'], var_len_seq=True)
else:
raise Exception("Unknown dataset class.")
strokes = tf.placeholder(tf.float32, shape=[batch_size, data_sequence_length, sum(validation_dataset.input_dims)])
targets = tf.placeholder(tf.float32, shape=[batch_size, data_sequence_length, sum(validation_dataset.target_dims)])
sequence_length = tf.placeholder(tf.int32, shape=[batch_size])
# Create inference graph.
with tf.name_scope("validation"):
inference_model = Model_cls(config,
reuse=False,
input_op=strokes,
target_op=targets,
input_seq_length_op=sequence_length,
input_dims=validation_dataset.input_dims,
target_dims=validation_dataset.target_dims,
batch_size=batch_size,
mode="validation",
data_processor=validation_dataset)
inference_model.build_graph()
inference_model.create_image_summary(validation_dataset.prepare_for_visualization)
# Create sampling graph.
with tf.name_scope("sampling"):
model = Model_cls(config,
reuse=True,
input_op=strokes,
target_op=None,
input_seq_length_op=sequence_length,
input_dims=validation_dataset.input_dims,
target_dims=validation_dataset.target_dims,
batch_size=batch_size,
mode="sampling",
data_processor=validation_dataset)
model.build_graph()
# Create a session object and initialize parameters.
sess = tf.Session()
# Restore computation graph.
try:
saver = tf.train.Saver()
# Restore variables.
if config['checkpoint_id'] is None:
checkpoint_path = tf.train.latest_checkpoint(config['model_dir'])
else:
checkpoint_path = os.path.join(config['model_dir'], config['checkpoint_id'])
print("Loading model " + checkpoint_path)
saver.restore(sess, checkpoint_path)
except:
raise Exception("Model is not found.")
if run_gmm_eval:
from sklearn import manifold
import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter
gmm_mus, gmm_sigmas = model.evaluate_gmm_latent_space(sess)
# We have ~70 components. Select a subset of them manually.
gmm_component_ids = [2, 3, 11, 12, 13, 14, 15, 39, 40]
gmm_legend_labels = ["1", "2", "a", "b", "c", "d", "e", "C", "D"]
num_components = len(gmm_component_ids)
size_components = gmm_mus.shape[1]
gmm_samples = np.zeros((num_components*gmm_num_samples,size_components))
gmm_sample_labels = np.zeros(num_components*gmm_num_samples)
for comp_idx in range(num_components):
epsilon = np.random.normal(0, 1, (gmm_num_samples, gmm_mus.shape[1]))
gmm_samples[comp_idx*gmm_num_samples:comp_idx*gmm_num_samples+gmm_num_samples,: ] = gmm_mus[comp_idx]+gmm_sigmas[comp_idx]*epsilon
gmm_sample_labels[comp_idx*gmm_num_samples:comp_idx*gmm_num_samples+gmm_num_samples] = np.ones(gmm_num_samples)*comp_idx
# Creating a discrete colorbar
colors = plt.cm.jet(np.linspace(0, 1, num_components))
Y = manifold.TSNE(n_components=2, init='pca', random_state=0).fit_transform(gmm_samples)
fig = plt.figure(figsize=(15, 8))
ax = fig.add_subplot(1, 1, 1)
current_plot_range = 0
previous_plot_range = 0
for i, c in enumerate(colors):
previous_plot_range += current_plot_range
current_plot_range = gmm_sample_labels[gmm_sample_labels == i].size
plt.scatter(Y[previous_plot_range:previous_plot_range+current_plot_range, 0],
Y[previous_plot_range:previous_plot_range+current_plot_range, 1],
20, lw=.25, marker='o', color=c, label=gmm_legend_labels[i], alpha=0.9, antialiased=True,
zorder=3)
ax.xaxis.set_major_formatter(NullFormatter())
ax.yaxis.set_major_formatter(NullFormatter())
plt.legend()
plt.axis('tight')
plt.show()
keyword_args = dict()
keyword_args['conditional_inputs'] = None
keyword_args['eoc_threshold'] = eoc_threshold
keyword_args['cursive_threshold'] = cursive_threshold
keyword_args['use_sample_mean'] = True
if quantitative_analysis:
pass
if qualitative_analysis:
print("Generating samples...")
for real_img_idx in reference_sample_ids:
_, stroke_model_input, _ = validation_dataset.fetch_sample(real_img_idx)
stroke_sample = stroke_model_input[:, :, 0:3]
if run_reconstruction or run_biased_sampling:
inference_results = inference_model.reconstruct_given_sample(session=sess, inputs=stroke_model_input)
if run_original_sample:
svg_path = os.path.join(config['eval_dir'], "real_image_"+str(real_img_idx)+'.svg')
visualize.draw_stroke_svg(validation_dataset.undo_normalization(validation_dataset.samples[real_img_idx], detrend_sample=False), factor=0.001, svg_filename=svg_path)
if run_reconstruction:
svg_path = os.path.join(config['eval_dir'], "reconstructed_image_" + str(real_img_idx) + '.svg')
visualize.draw_stroke_svg(validation_dataset.undo_normalization(inference_results[0]['output_sample'][0], detrend_sample=False), factor=0.001, svg_filename=svg_path)
if concat_ref_and_synthetic_samples:
reference_sample_in_img = stroke_sample
else:
reference_sample_in_img = None
# Conditional handwriting synthesis.
for text_id, text in enumerate(conditional_texts):
keyword_args['conditional_inputs'] = text
if config.get('use_real_pi_labels', False) and isinstance(model, VRNNGMM):
if run_biased_sampling:
biased_sampling_results = model.sample_biased(session=sess, seq_len=seq_len,
prev_state=inference_results[0]['state'],
prev_sample=reference_sample_in_img,
**keyword_args)
save_name = 'synthetic_biased_ref(' + str(real_img_idx) + ')_(' + str(text_id) + ')'
synthetic_sample = validation_dataset.undo_normalization(biased_sampling_results[0]['output_sample'][0], detrend_sample=False)
if save_plots:
plot_eval_details(biased_sampling_results[0], synthetic_sample, config['eval_dir'], save_name)
# Without beautification: set False
# Apply beautification: set True.
keyword_args['use_sample_mean'] = True
biased_sampling_results = model.sample_biased(session=sess, seq_len=seq_len,
prev_state=inference_results[0]['state'],
prev_sample=reference_sample_in_img,
**keyword_args)
save_name = 'synthetic_biased_sampled_ref(' + str(real_img_idx) + ')_(' + str(text_id) + ')'
synthetic_sample = validation_dataset.undo_normalization(biased_sampling_results[0]['output_sample'][0], detrend_sample=False)
if save_plots:
plot_eval_details(biased_sampling_results[0], synthetic_sample, config['eval_dir'], save_name)
if run_unbiased_sampling:
unbiased_sampling_results = model.sample_unbiased(session=sess, seq_len=seq_len, **keyword_args)
save_name = 'synthetic_unbiased_(' + str(text_id) + ')'
synthetic_sample = validation_dataset.undo_normalization(unbiased_sampling_results[0]['output_sample'][0], detrend_sample=False)
if save_plots:
plot_eval_details(unbiased_sampling_results[0], synthetic_sample, config['eval_dir'], save_name)
# Without beautification.
keyword_args['use_sample_mean'] = True
unbiased_sampling_results = model.sample_unbiased(session=sess, seq_len=seq_len, **keyword_args)
save_name = 'synthetic_unbiased_sampled(' + str(text_id) + ')'
synthetic_sample = validation_dataset.undo_normalization(unbiased_sampling_results[0]['output_sample'][0], detrend_sample=False)
if save_plots:
plot_eval_details(unbiased_sampling_results[0], synthetic_sample, config['eval_dir'],save_name)
else:
if run_biased_sampling:
biased_sampling_results = model.sample_biased(session=sess, seq_len=seq_len,
prev_state=inference_results[0]['state'],
prev_sample=reference_sample_in_img)
save_name = 'synthetic_biased_ref(' + str(real_img_idx) + ')_(' + str(text_id) + ')'
synthetic_sample = validation_dataset.undo_normalization(biased_sampling_results[0]['output_sample'][0], detrend_sample=False)
if save_plots:
plot_eval_details(biased_sampling_results[0], synthetic_sample, config['eval_dir'], save_name)
if run_unbiased_sampling:
unbiased_sampling_results = model.sample_unbiased(session=sess, seq_len=seq_len)
save_name = 'synthetic_unbiased_(' + str(text_id) + ')'
synthetic_sample = validation_dataset.undo_normalization(unbiased_sampling_results[0]['output_sample'][0], detrend_sample=False)
if save_plots:
plot_eval_details(unbiased_sampling_results[0], synthetic_sample, config['eval_dir'], save_name)
sess.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-S', '--model_save_dir', dest='model_save_dir', type=str, default='./runs/', help='path to main model save directory')
parser.add_argument('-E', '--eval_dir', type=str, default='./runs_evaluation/', help='path to main log/output directory')
parser.add_argument('-M', '--model_id', dest='model_id', type=str, help='model folder')
parser.add_argument('-C', '--checkpoint_id', type=str, default=None, help='log and output directory')
parser.add_argument('-QN', '--quantitative', dest='quantitative', action="store_true", help='Run quantitative analysis')
parser.add_argument('-QL', '--qualitative', dest='qualitative', action="store_true", help='Run qualitative analysis')
parser.add_argument('-V', '--verbose', dest='verbose', type=int, default=1, help='Verbosity')
args = parser.parse_args()
config_dict = json.load(open(os.path.abspath(os.path.join(args.model_save_dir, args.model_id, 'config.json')), 'r'))
config_dict['model_dir'] = os.path.join(args.model_save_dir, args.model_id) # in case the folder is renamed.
config_dict['checkpoint_id'] = args.checkpoint_id
if args.eval_dir is None:
config_dict['eval_dir'] = config_dict['model_dir']
else:
config_dict['eval_dir'] = os.path.join(args.eval_dir, args.model_id)
if not os.path.exists(config_dict['eval_dir']):
os.makedirs(config_dict['eval_dir'])
do_evaluation(config_dict, quantitative_analysis=args.quantitative, qualitative_analysis=args.qualitative, verbose=args.verbose)