This is the code for the PopulationTransformer.
Quick Start | Prerequisites | Data | Pretraining | Fine-tuning | Citation
With the requirements and data installed, run these scripts for their desired outcomes:
| Script | Outcome |
|---|---|
0_write_pretraining_data.sh |
Write pretraining data |
1_create_pretraining_manifest.sh |
Create the pretraining manifest file |
2_run_pretraining.sh |
Run pretraining |
3_write_finetuning_data.sh |
Write fine-tuning data |
4_create_finetuning_manifest.sh |
Create fine-tuning manifest file |
5_run_finetuning.sh |
Run fine-tuning |
More details about each script in pretraining and fine-tuning sections below.
Requirements:
- pytorch >= 1.13.1
pip install -r requirements.txtIf using the data from the Brain Treebank, the data can be downloaded from braintreebank.dev.
The following commands expects the Brain Treebank data to have the following structure:
/braintreebank_data
|_electrode_labels
|_subject_metadata
|_localization
|_all_subject_data
|_sub_*_trial*.h5
Details
The below commands assume that you will be using the PopulationTransformer in conjunction with BrainBERT. If you are, you will need to download the BrainBERT weights.
First, we write the BrainBERT features for pre-training. This command takes a list of brain recordings (see below) and creates a training dataset from their BrainBERT representations:
REPO_DIR="/path/to/PopulationTransformer"
BRAINTREEBANK_DIR="/path/to/braintreebank_data"
python3 -m data.write_nsp_pretraining_data \
+preprocessor=multi_elec_spec_pretrained \
++preprocessor.upstream_ckpt=${REPO_DIR}/pretrained_weights/stft_large_pretrained.pth \
+data_prep=pretrain_multi_subj_multi_chan_template \
++data_prep.task_name=nsp_pretraining \
++data_prep.brain_runs=${REPO_DIR}/trial_selections/pretrain_split_trials.json \
++data_prep.electrodes=${REPO_DIR}/electrode_selections/clean_laplacian.json \
++data_prep.output_directory=${REPO_DIR}/saved_examples/cr_pretrain_examples \
+data=pretraining_subject_data_template \
++data.cached_transcript_aligns=${REPO_DIR}/semantics/saved_aligns \
++data.cached_data_array=${REPO_DIR}/cached_data_arrays/ \
++data.raw_brain_data_dir=${BRAINTREEBANK_DIR} Salient arguments:
- Input:
preprocessor.upstream_ckptis the path to the BrainBERT weightspreprocessor.brain_runsis the path to a json file of the following format:{<sub_name>: [trial_name]}. This specifies the brain recording files that will be used.data_prep.electrodesis the path to a json file of the following format{<sub_name>: [electrode_name]}. Similar to the above, this specifies which channels will be used.data.raw_brain_data_diris the path to the root of the Brain Treebank data (see the Data section above)
- Output:
data_prep.output_directoryis the path where the output will be writtendata.cached_data_arrayis the path to an (optional) cache where intermediate outputs can be written for faster processing
Next, we need to create a manifest for all the training examples we've just created.
REPO_DIR="/path/to/PopulationTransformer"
python3 -m data.make_pretrain_replace_manifest +data_prep=combine_nsp_datasets \
++data_prep.source_dir=${REPO_DIR}/saved_examples/cr_pretrain_examples \
++data_prep.output_dir=${REPO_DIR}/saved_examples/nsp_replace_task-0_5s \
++data_prep.task="nsp_negative_any"Salient arguments:
- Input:
data_prep.source_dirshould matchdata_prep.output_dirfrom above.
- Output:
data_prep.output_diris the path where the output will be written.
Now, we can run the pretraining
REPO_DIR="/path/to/PopulationTransformer"
python3 run_train.py \
+exp=multi_elec_pretrain \
++exp.runner.device=cuda \
+data=nsp_replace_only_pretrain \
++data.data_path=${REPO_DIR}/saved_examples/nsp_replace_task-0_5s \
++data.saved_data_split=${REPO_DIR}/saved_data_splits/pretrain_split \
++data.test_data_cfg.name=nsp_replace_only_deterministic \
++data.test_data_cfg.data_path=${REPO_DIR}/saved_examples/nsp_replace_task-0_5s \
+model=pt_custom_model \
+task=nsp_replace_only_pretrain \
+criterion=nsp_replace_only_pretrain \
+preprocessor=empty_preprocessorSalient arguments:
- Input:
data.data_pathshould match thedata_prep.output_dirfrom the manifest creation step above.
- Output:
- The final weights will be saved in an automatically created directory under
outputs.
- The final weights will be saved in an automatically created directory under
Details
Now, let's write the BrainBERT features for a finetuning task. For this example, let's decode volume (rms) from one electrode over the course of one trial.
REPO_DIR="/path/to/PopulationTransformer"
BRAINTREEBANK_DIR="/path/to/braintreebank_data"
python3 -m data.write_multi_subject_multi_channel \
+data_prep=pretrain_multi_subj_multi_chan_template \
++data_prep.task_name=rms \
++data_prep.brain_runs=${REPO_DIR}/trial_selections/test_trials.json \
++data_prep.electrodes=${REPO_DIR}/electrode_selections/clean_laplacian.json \
++data_prep.output_directory=${REPO_DIR}/saved_examples/all_test_rms \
+preprocessor=multi_elec_spec_pretrained \
++preprocessor.upstream_ckpt=${REPO_DIR}/pretrained_weights/stft_large_pretrained.pth \
+data=subject_data_template \
++data.cached_transcript_aligns=${REPO_DIR}/semantics/saved_aligns \
++data.cached_data_array=${REPO_DIR}/cached_data_arrays/ \
++data.raw_brain_data_dir=${BRAINTREEBANK_DIR}/ \
++data.movie_transcripts_dir=${BRAINTREEBANK_DIR}/transcripts- Inputs:
data_prep.electrodesanddata_prep.brain_runsas in Pretraining, these files specify the trials and channels that will be used to create the dataset.
- Outputs:
data_prep.output_directoryis the path to where the BrainBERT embeddings will be written.
Let's write the manifest for this decoding task.
REPO_DIR="/path/to/PopulationTransformer"
SUBJECT=sub_1; TASK=rms; python3 -m data.make_subject_specific_manifest \
+data_prep=subject_specific_manifest \
++data_prep.data_path=${REPO_DIR}/saved_examples/all_test_${TASK} \
++data_prep.subj=${SUBJECT} \
++data_prep.out_path=${REPO_DIR}/saved_examples/${SUBJECT}_${TASK}_cr- Inputs:
data_prep.data_pathshould match theoutput_directorygiven above
Now, we are ready to run the finetuning. You an either fine-tune a model that you have pre-trained yourself, or use a model from our huggingface repo.
REPO_DIR="/path/to/PopulationTransformer"
SUBJECT=sub_1; TASK=rms; N=1; NAME=popt_brainbert_stft; WEIGHTS=pretrained_popt_brainbert_stft;
python3 run_train.py \
+exp=multi_elec_feature_extract \
++exp.runner.results_dir=${REPO_DIR}/outputs/${SUBJECT}_${TASK}_top${N}_${NAME} \
++exp.runner.save_checkpoints=False \
++model.frozen_upstream=False \
+task=pt_feature_extract_coords \
+criterion=pt_feature_extract_coords_criterion \
+preprocessor=empty_preprocessor \
+data=pt_supervised_task_coords \
++data.data_path=${REPO_DIR}/saved_examples/${SUBJECT}_${TASK}_cr \
++data.saved_data_split=${REPO_DIR}/saved_data_splits/${SUBJECT}_${TASK}_fine_tuning \
++data.sub_sample_electrodes=${REPO_DIR}/electrode_selections/debug_electrodes.json \
+model=pt_downstream_model \
++model.upstream_path=${REPO_DIR}/pretrained_weights/${WEIGHTS}.pth - Inputs:
data.data_pathshould match theout_pathof the manifest creation step above.model.upstream_pathshould be a path to the weights from pretraining --- either from the steps above, or from the huggingface repo.data.saved_data_splitis a path to where the indices for train/val/test splits will be written. You can use this to ensure that splits are consistent between runs.
- Outputs:
exp.runner.results_dirwill contain performance metrics (f1, ROC-AUC) on the test set.
@inproceedings{
chau_wang_2025_population,
title={Population Transformer: Learning Population-level Representations of Neural Activity},
url={https://openreview.net/forum?id=FVuqJt3c4L},
booktitle={The Thirteenth International Conference on Learning Representations},
author={Chau, Geeling and Wang, Christopher and Talukder, Sabera J and Subramaniam, Vighnesh and Soedarmadji, Saraswati and Yue, Yisong and Katz, Boris and Barbu, Andrei},
year={2025}
}