Ignite: dcgan.py: `df.plot(x=np.array([1, 101, 201]))` doesn't work

Created on 1 May 2019  路  8Comments  路  Source: pytorch/ignite

Line https://github.com/pytorch/ignite/blob/master/examples/gan/dcgan.py#L393 fails with an error,saying that nothing in Int64Index([1, 101, 201]) is in df.columns (or something like that). To fix it, set index of the dataframe instead of passing x argument.

Replace this:

df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME), delimiter='\t')
x = np.arange(1, engine.state.iteration + 1, PRINT_FREQ)
_ = df.plot(x=x, subplots=True, figsize=(20, 20))

with this

df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME), delimiter='\t') \
    .set_index(np.arange(1, engine.state.iteration + 1, PRINT_FREQ))
_ = df.plot(subplots=True, figsize=(20, 20))

My python version is 3.7, pandas is 0.24.2.

bug

All 8 comments

I have tested that the modified file works by running dcgan.py to completion on lfw dataset once.

@philip-bl yes, I reproduce the problem too. However, I would prefer to improve the written logs.tsv file. Currently, it looks like

errD    errG    D_x     D_G_z1  D_G_z2
1.65096 6.79936 0.66887 0.63026 0.00156
0.58172 13.08555        0.85667 0.18992 0.00104
0.5255  6.71655 0.83787 0.16335 0.01601
0.45577 7.49637 0.8753  0.11774 0.00778
0.22892 7.91881 0.92359 0.07014 0.00368
0.36326 6.71926 0.89081 0.10043 0.0104
0.44686 5.09332 0.85201 0.13819 0.03282
0.74268 4.46701 0.78121 0.21232 0.06204

and we have no idea on which iteration we have logged data. IMO it worth to add iteration at the first column iteration and while loading the csv set index to iteration column... What do you think ?

@vfdev-5 Yes, that sounds better. Just to be clear, I don't want to do be the person responsible for implementing that change.

It's not a big deal as change :) But as you wish. Anyway thanks for reporting.

Actually scratch that. This fixes it:

1,2d0
<
<
328,329c326,327
<             columns = list(engine.state.metrics.keys())
<             values = [str(round(value, 5)) for value in list(engine.state.metrics.values())]
---
>             columns = ["iteration"] + list(engine.state.metrics.keys())
>             values = [str(engine.state.iteration)] + [str(round(value, 5)) for value in list(engine.state.metrics.values())]
391,392c389
<             df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME), delimiter='\t') \
<                 .set_index(np.arange(1, engine.state.iteration + 1, PRINT_FREQ))
---
>             df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME), delimiter='\t', index_col="iteration")

Have tested that it works by running on lfw once. Apply this patch AFTER applying the changes I suggested in the original post.

Also I wonder if there's a good reason why this script implements tsv formatting on its own instead of using csv.writer.

@philip-bl you still do not want to send a PR with a better code ? :)

Also I wonder if there's a good reason why this script implements tsv formatting on its own instead of using csv.writer.

I would say no reason. If you can improve it, go ahead !

@vfdev-5 Not really. Overhead of doing git clone, setting up remote repositories, applying patch there, pushing, creating a PR - it's too much.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

vfdev-5 picture vfdev-5  路  3Comments

karfly picture karfly  路  4Comments

Sudy picture Sudy  路  4Comments

UjwalKandi picture UjwalKandi  路  3Comments

czotti picture czotti  路  3Comments