Line https://github.com/pytorch/ignite/blob/master/examples/gan/dcgan.py#L393 fails with an error,saying that nothing in Int64Index([1, 101, 201]) is in df.columns (or something like that). To fix it, set index of the dataframe instead of passing x argument.
Replace this:
df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME), delimiter='\t')
x = np.arange(1, engine.state.iteration + 1, PRINT_FREQ)
_ = df.plot(x=x, subplots=True, figsize=(20, 20))
with this
df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME), delimiter='\t') \
.set_index(np.arange(1, engine.state.iteration + 1, PRINT_FREQ))
_ = df.plot(subplots=True, figsize=(20, 20))
My python version is 3.7, pandas is 0.24.2.
I have tested that the modified file works by running dcgan.py to completion on lfw dataset once.
@philip-bl yes, I reproduce the problem too. However, I would prefer to improve the written logs.tsv file. Currently, it looks like
errD errG D_x D_G_z1 D_G_z2
1.65096 6.79936 0.66887 0.63026 0.00156
0.58172 13.08555 0.85667 0.18992 0.00104
0.5255 6.71655 0.83787 0.16335 0.01601
0.45577 7.49637 0.8753 0.11774 0.00778
0.22892 7.91881 0.92359 0.07014 0.00368
0.36326 6.71926 0.89081 0.10043 0.0104
0.44686 5.09332 0.85201 0.13819 0.03282
0.74268 4.46701 0.78121 0.21232 0.06204
and we have no idea on which iteration we have logged data. IMO it worth to add iteration at the first column iteration and while loading the csv set index to iteration column... What do you think ?
@vfdev-5 Yes, that sounds better. Just to be clear, I don't want to do be the person responsible for implementing that change.
It's not a big deal as change :) But as you wish. Anyway thanks for reporting.
Actually scratch that. This fixes it:
1,2d0
<
<
328,329c326,327
< columns = list(engine.state.metrics.keys())
< values = [str(round(value, 5)) for value in list(engine.state.metrics.values())]
---
> columns = ["iteration"] + list(engine.state.metrics.keys())
> values = [str(engine.state.iteration)] + [str(round(value, 5)) for value in list(engine.state.metrics.values())]
391,392c389
< df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME), delimiter='\t') \
< .set_index(np.arange(1, engine.state.iteration + 1, PRINT_FREQ))
---
> df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME), delimiter='\t', index_col="iteration")
Have tested that it works by running on lfw once. Apply this patch AFTER applying the changes I suggested in the original post.
Also I wonder if there's a good reason why this script implements tsv formatting on its own instead of using csv.writer.
@philip-bl you still do not want to send a PR with a better code ? :)
Also I wonder if there's a good reason why this script implements tsv formatting on its own instead of using csv.writer.
I would say no reason. If you can improve it, go ahead !
@vfdev-5 Not really. Overhead of doing git clone, setting up remote repositories, applying patch there, pushing, creating a PR - it's too much.