You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
See sae-viewer to see the visualizer code, hosted publicly here.
See model.py for details on the autoencoder model architecture.
See train.py for autoencoder training code.
See paths.py for more details on the available autoencoders.
Example usage
importtorchimportblobfileasbfimporttransformer_lensimportsparse_autoencoder# Extract neuron activations with transformer_lensmodel=transformer_lens.HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)
device=next(model.parameters()).deviceprompt="This is an example of a prompt that"tokens=model.to_tokens(prompt) # (1, n_tokens)withtorch.no_grad():
logits, activation_cache=model.run_with_cache(tokens, remove_batch_dim=True)
layer_index=6location="resid_post_mlp"transformer_lens_loc= {
"mlp_post_act": f"blocks.{layer_index}.mlp.hook_post",
"resid_delta_attn": f"blocks.{layer_index}.hook_attn_out",
"resid_post_attn": f"blocks.{layer_index}.hook_resid_mid",
"resid_delta_mlp": f"blocks.{layer_index}.hook_mlp_out",
"resid_post_mlp": f"blocks.{layer_index}.hook_resid_post",
}[location]
withbf.BlobFile(sparse_autoencoder.paths.v5_32k(location, layer_index), mode="rb") asf:
state_dict=torch.load(f)
autoencoder=sparse_autoencoder.Autoencoder.from_state_dict(state_dict)
autoencoder.to(device)
input_tensor=activation_cache[transformer_lens_loc]
input_tensor_ln=input_tensorwithtorch.no_grad():
latent_activations, info=autoencoder.encode(input_tensor_ln)
reconstructed_activations=autoencoder.decode(latent_activations, info)
normalized_mse= (reconstructed_activations-input_tensor).pow(2).sum(dim=1) / (input_tensor).pow(2).sum(dim=1)
print(location, normalized_mse)