NeurIPS 2020

Refactoring Policy for Compositional Generalizability using Self-Supervised Object Proposals

Review 1

Summary and Contributions: This work focuses on compositional generalizability through learning an object-centric graph of the dynamics. They propose a self-supervised approach to learn object detectors. However, they argue that an end-to-end approach is naive and prone to errors, and so introduce a two-stage framework that first learns a teacher policy, then trains an unsupervised object detector to output proposals, with GNN to imitate the teacher policy. The authors call this teacher-student approach a "refactorization," and argue that it demonstrates better compositional generalization and interpretability.

Strengths: The improvement in performance over the teacher policy in Pacman and BigFish are impressive. The use of self-supervised object detection methods that don't rely on dynamics or the MDP seems novel, and empirically appears to work well in comparison to end-to-end graph network methods ([36]). I appreciate the thoughtful analysis in the conclusion, and commend the authors for pointing out the limitations of the method.

Weaknesses: Only one environment from ProcGen and Atari are showcased. One of the benefits of these benchmarks is the ability to test the robustness of your method on many different but related tasks. It would significantly strengthen your paper to include additional environments, even if performance is poor, if the performance can be explained. There should be ablations on the refactorization process, how important is that compared to the object proposal method? It seems that the other baselines should use a similar distillation process to be fair?

Correctness: Yes, they seem correct.

Clarity: There are some grammatical mistakes and misspellings, and there are sections and subsections with no text but just jump directly into a new section. There should be introductory comments to roadmap each section. What is DP in table 1? update: I now understand they are data parameters, but this should be clarified in the text. In line 159, do you mean CIFAR-task, not CIFAR-Recon? I assume CIFAR-Recon and ImageNet-Recon are features learned with an auto encoder? This needs to be defined explicitly in the text. Why are CIFAR-Task and ImageNet-Task using two different ways of computing the features, if they're just two different datasets for the same task/method? Why are there background features in ImageNet but not in CIFAR? The legends are too small in Fig 1 and should be made larger to see the colors better.

Relation to Prior Work: A similar work that also learns a GN with self-supervised object detectors is [1]. Given the comparison to relation networks, going more in depth into the algorithmic differences to account for the performance differences would be helpful. Is the drop in performance from relation networks to your method because of poor choice in nodes because the end-to-end method is worse at finding objects? 1. Contrastive Learning of Structured World Models - T. Kipf, E. van der Pol, M. Welling, ICLR 2020.

Reproducibility: Yes

Additional Feedback: I don't understand L165: "Second, we investigate how data parameters help generalization." It is unclear how this paragraph shows generalization? Were data parameters used for BigFish? Some of my concerns were answered satisfactorily in the rebuttal -- mainly an ablation that shows how much improvement is gained from refactorization. I am improving my score to a 7.

Review 2

Summary and Contributions: This paper proposes a two-stage framework to refactorize an overfitting CNN-based policy trained from image input into: (1) object proposals generation, (2) a generalizable GNN-based policy which takes objects of the image as input. In the second stage, GNN-based policy is cloned from CNN-based policy. The experimental results yield policies of better generalizability in unseen environments.

Strengths: One contribution is to down-weight incomplete proposals in GNN. Also the authors leverage the advantages of GNN, such as generalizability and interpretability. All three experimental results show GNN-based two-stage framework has better generalizability, which significantly reducing the overfitted CNN-based policy. The interesting one is that GNN can get better results with sparse/lost information, compared to CNN policy.

Weaknesses: From the idea level, this paper uses the existed methods, so its novelty is incremental. Why do we propose a two-stage framework, not an unified end-to-end policy learning based GNN? One reason (I guess) is that if we propose an end-to-end GNN policy model, we need to generate objects online because environment (or image) will change (depends what action is taken). This will be time-consuming.

Correctness: Yes

Clarity: Yes

Relation to Prior Work: Yes

Reproducibility: Yes

Additional Feedback: The result is good, especially when GNN policy clones the CNN policy with different domain inputs. The two-stage framework has high dependency on GNN and object proposals, which limit its application.

Review 3

Summary and Contributions: The paper proposes to improve the generalization of model-free RL agent by re-factorizing the previously trained policy network into an object-centric one with simple behavioral cloning. More specifically, the authors use an existing method (called SPACE) for unsupervised object discovery algorithms to obtain a pretrained object detector and extract object proposals from the scene. With these object proposals as inputs, a new policy network is trained to minimize the L2 prediction discrepancy between it and the original policy network. Further, the loss is re-weighted with a simple method such that images with poor object detection results contribute less to the total loss. The authors verify the effectiveness of their proposed approach on some hand-crafted environments and show that the learned feature representations are more task-centric, compared to those from purely reconstruction-based methods. In addition, the authors test their method on rather simple and customized pacman environment and show that the re-factored policy network can generalize better than the original one.

Strengths: 1. The paper demonstrates that unsupervisedly discovered object-centric representations are beneficial to out-of-distribution generalization. And the proposed two-staged distillation-like algorithm yields better model with no significant cost. 2. A data-dependent loss function is introduced to alleviate the negative effects of image samples with poor object detections results. 3. The paper shows that task-relevant representations are more natural and interpretable than reconstruction-based representations.

Weaknesses: 1. Some existing papers (e.g. [1]) have shown that object-centric representation generalizes well on novel scenes in Atari games. Given the environments in this paper are relatively simple and SPACE works quite well in such envs, the novelty of this paper is somewhat undermined. [1] Davidson, G. and Lake, B. M. (2020). Investigating simple object representations in model-free deep reinforcement learning. In Proceedings of the 42nd Annual Conference of the Cognitive Science Society.

Correctness: The claims and methods are correct.

Clarity: The rough idea of this paper is easy to grasp, though some parts are confusing. For instance, the abbreviation "DP" in Table 1 is unclear, if not completely unexplained.

Relation to Prior Work: Yes.

Reproducibility: Yes

Additional Feedback: 1. As mentioned before, what is "DP" in Table 1? The re-weighted loss? If it is, the authors need to explain it clearly in the results section. 2. How much does the reweighting of individual losses improve? Besides, a visualization of the most and least weighted data samples will be good.

Review 4

Summary and Contributions: The main focus of the paper is better generalization of RL policies for tasks which require graph reasoning. Their proposal is to first train a CNN based policy on the task, then to "refactor" that policy by using it as a teacher to train a GNN, in conjunction with a pre-trained object detector which feeds the GNN a graph. Edges are computed in a task-specific way. The idea is that for graph tasks the GNN has better generalization properties than the original (teacher) CNN but requires the structured input provided by the object detector in order to work well. The idea is explored on several visual tasks (e.g. summing mnist digits with varied backgrounds).

Strengths: -the refactorization is an interesting idea and seems to help -the results are positive and the approach is practical -by representing the state as a graph it enables better interpretation of the object features

Weaknesses: -it seems somewhat unlikely that the baselines chosen could really do as well as this approach since they are not given access to the object proposals. I wonder if there isn't some more reasonable baseline that could make use of that information to compare to the GNN? Could the object proposals be supplied as a kernel to the CNN of the relation network? -it wasn't totally clear to me why you needed to use the teacher policy. Does it not converge if you train the GNN with the (pre-trained) object detector from scratch? -how well does this work if you just use AIR with no background image. Being able to handle complicated background is certainly a positive thing but it is hard to tease out the difficulty of handling that "noise" robustly from the difficulty of generalizing to more objects. In fact, looking at table 1, it seems that the difference between your approach and the baselines decreases with the (presumbably) easier CIFAR backgrounds. Similarly with the less noisy BigFish backgrounds. -There is another line of work that has similar aims and uses a task that is similar to pac man that is relevant (but not cited): typos and suggestions: -line 62: I would use a different symbol than $\pi_i$ for this since $\pi$ is typically used for policies and this represents the output not the policy itself -line 68: necessary to have compositional generaliziabilty. generalizaiability -> generalizability - line128: on which some MNIST digtis digtis -> digits

Correctness: The method seems correct. One minor point, for table 1a, your results are shown in bold but are statistically the same as the CNN+DP.

Clarity: The paper is reasonably well written and I understood it (I think) without too much trouble.

Relation to Prior Work: Prior work is ok (see suggestion above for Garnelo paper)

Reproducibility: Yes

Additional Feedback: I have read the rebuttal and it largely addressed my concerns.