Hi,
I have been experimenting with the Attend Infer Repeat model. I downloaded the code from examples and run the code with the args provided in the tutorial (http://pyro.ai/examples/air.html), which is:
python main.py -n 200000 -blr 0.1 --z-pres-prior 0.01 --scale-prior-sd 0.2 --predict-net 200 --bl-predict-net 200 --decoder-output-use-sigmoid --decoder-output-bias -2 --seed 287710
But the highest count accuracy I get is around 76%. Could someone let me know if the args are not correct? I did not change any part of the code.
Also I am using pyro version 0.2.1 and pytorch version 0.4.0
Many thanks
Hi, did you let inference run to completion or are you stopping it after a few epochs? Are you talking about count accuracy on the training images (as in the tutorial) or some unseen test images?
cc @null-a
Could someone let me know if the args are not correct?
The command line args in the tutorial are the ones I used, and it looks like you're using the same.
As well as answers to eb8680's questions, I'd be interested in hearing what other values for count accuracy you obtained, if that's possible. Thanks.
@eb8680 @null-a
I let the inference run till the end (I believe it is specified by -n 200000). I am using the count accuracy function provided in 'main.py'. I just specified the args '--eval-every' to 1000. I have attached a plot of the count accuracy obtained. This plot looks very different from the plot given in the tutorial.

@thematrixduo How many times have you run inference? What other values for final count accuracy have you obtained? Setting the random seed doesn't make this deterministic, so some variance is expected. If you're seeing consistently poor results, I'll try running it myself and see if I can spot anything odd going on. Thanks.
@null-a I have tried at least 5 times but for none of them did the count accuracy go above 80%. I have attached another plot.

@thematrixduo Great, thanks for the info. That sounds worse than I would expect, so I'll take a look. To begin, I guess I'll confirm I'm getting similar results to you, and then I'll go back and try running against the Pyro commit recorded in the tutorial.
@thematrixduo It looks like performance may have degraded after c99ea673. If you apply the following patch (to v0.2.1 say) and then re-try I think you will see performance similar to that reported in the tutorial.
diff --git a/examples/air/main.py b/examples/air/main.py
index df89df2..09bdf25 100644
--- a/examples/air/main.py
+++ b/examples/air/main.py
@@ -195,7 +195,7 @@ def main(**kwargs):
vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z))))
def per_param_optim_args(module_name, param_name):
- lr = args.baseline_learning_rate if 'bl_' in param_name else args.learning_rate
+ lr = args.baseline_learning_rate if 'bl_' in param_name or 'bl_' in module_name else args.learning_rate
return {'lr': lr}
svi = SVI(air.model, air.guide,
If we confirm this fixes the problem, then I'll open a PR to apply the patch to dev. (BTW, I've not been using a fixed random seed when I run this locally, and I seem to achieve OK results consistently.)
@null-a good catch!
@null-a Thanks for the help!
@null-a Thanks for the help!
@thematrixduo No problem, thanks for opening the issue.
Most helpful comment
@thematrixduo It looks like performance may have degraded after c99ea673. If you apply the following patch (to v0.2.1 say) and then re-try I think you will see performance similar to that reported in the tutorial.
If we confirm this fixes the problem, then I'll open a PR to apply the patch to dev. (BTW, I've not been using a fixed random seed when I run this locally, and I seem to achieve OK results consistently.)