In this work, we study and compare multiple capsule routing algorithms for text classification including dynamic routing, Heinsen routing, and capsule-routing inspired attention-based sentence encoding techniques like dynamic self-attention. Further, similar to some works in computer vision, we do an ablation test of the capsule network where we remove the routing algorithm itself. We analyze the theoretical connection between attention and capsule routing, and contrast the two ways of normalizing the routing weights. Finally, we present a new way to do capsule routing, or rather iterative refinement, using a richer attention function to measure agreement among output and input capsules and with highway connections in between iterations.
Code from the final project of the course CS521: Stastitical Natural Language Processing. The code was created around 2020 Quarter 1.
- Download google word2vec 300 dimension embeddings
- Put the downloaded embedding file in
embeddings/word2vec
- Go to
process/after downloading embeddings. - Run both the python files in `process/'.
python train_AAP.py --model=[INSERT MODEL NAME]
python train_Reuters.py --model=[INSERT MODEL NAME]
python test_AAP.py --model=[INSERT MODEL NAME]
python test_Reuters.py--model=[INSERT MODEL NAME]
- CNN
- CNN_att
- CNN_capsule
- CNN_heinsen_capsule
- CNN_DSA
- CNN_DSA_global
- CNN_PCaps
- CNN_custom
- CNN_custom_alpha_ablation
- CNN_custom_global
- CNN_custom2
- CNN: Convolutional Neural Network (similar to Kim et al. version)
- CNN_att (Convolutional Neural Network + attention)
- CNN_capsule (CNN + Dynamic Capsule Routing)
- CNN_heinsen_capsule (CNN + heinsen Capsule Routing)
- CNN_DSA (CNN + Dynamic Self Attention)
- CNN_DSA_global (CNN + Dynamic Self Attention for Sentence encoding)
- CNN_PCaps (CNN + “non-routing” mechanism inspired from PCapsNet)
- CNN_custom (CNN + “new routing” + “reverse normalization”)
- CNN_custom_alpha_ablation (CNN + “new routing” + “reverse normalization”- Highway connection)
- CNN_custom_global (CNN + “new routing” + “reverse normalization” for sentence encoding)
- CNN_custom2 (CNN + “new routing”)
- Hyperparameters used for each model in AAPD are in configs/AAPD_args.py
- Hyperparameters used for each model in Reuters are in configs/Reuters_args.py
- process/ directory have preprocessing codes for AAPD and Reuters. (but the data is already
- preprocessed within processed_data/. So not need to process further).
- data/ directory have the data in use
- Preprocessing requires word2vec embeddings in embeddings/word2vec directory (download google word2vec 300 dimension embeddings put it in the directory and run bin2txt.py) (word2vec downloadable from here) models/modules have the routing codes
- models/ have all the model codes
- utils/ have some evaluation and other utilities code.
- saved_params/ will have the saved model parameters after training.
- heinsen routing is based on this
- The data were downloaded from: https://github.com/castorini/hedwig (word2vec bin file, and bin2txt.py too are available through a link in their repository).
- Hedwig library was also referenced for initial CNN implementations.