Nirvana is a Specialized Genearlist Model with task-aware memory mechanism, linear time complexity, and test-time task information extraction.
Nirvana/
βββ nirvana_backbone/ # Core model architecture and training
β βββ train/ # Training scripts and configurations
β βββ eval/ # Evaluation scripts and benchmarks
β βββ modeling_transformer_rnn.py # Nirvana model
β βββ nirvana_1_3B.json # Nirvana 1.3B configuration
β βββ configuration_transformer_rnn.py # Nirvana configuration
β βββ task_aware_delta_net.py # Specialized Memory Updater
β βββ ttt_cross_layer.py # Task-Aware Trigger (with cross-layer online gradient descent)
βββ specialized_ability/ # Domain-specific capabilities
β βββ MRI_reconstruction/ # MRI image reconstruction and analysis model
β β βββ model/ # Custom MRI reconstruction and analysis model
β β βββ dataset/ # MRI dataset handling
β β βββ train/ # MRI-specific training
βββ requirements.txt # Python dependencies
- Python 3.10+
- CUDA 11.8+ (for GPU acceleration)
- Conda or Miniconda
- 8+ GPUs recommended for training
-
Clone the repository
cd Nirvana -
Create and activate conda environment
conda create -n nirvana python=3.10 conda activate nirvana
-
Install dependencies
pip install -r requirements.txt
-
Install Flash Attention (optional, for enhanced performance in SWA)
pip install flash-attn==2.7.0.post2 --no-build-isolation
- PyTorch: 2.5.0
- Transformers: 4.52.4
- Accelerate: 1.1.1
- Flash Attention: 2.7.0.post2
- Flash Linear Attention: 0.2.2
- FastMRI: 0.3.0 (for MRI datasets)
- WandB: 0.21.1 (for experiment tracking)
cd nirvana_backbone/train
bash train.shTraining Configuration:
- Model: 1.3B parameters
- Data: FineWebEdu dataset
- Precision: BF16
- Distributed training with 8 GPUs
- Checkpointing every 1910 steps
- WandB integration for experiment tracking
cd specialized_ability/MRI_reconstruction
bash ./train/run_two_stage_training.shcd nirvana_backbone/eval
# In-context learning evaluation
bash eval_nirvana_1.3B-icl.sh
# Long sequence evaluation
bash eval_nirvana_1.3B-longbench.sh
# Commonsense reasoning evaluation
bash eval_nirvana_1.3B-commonsense.sh
# NIAH evaluation
bash eval_nirvana_1.3B-niah.shSupported Benchmarks:
- S-NIAH
- LongBench
- Commonsense reasoning tasks
- FastMRI
The model configuration is defined in nirvana_1_3B.json:
{
"hidden_size": 2048,
"num_heads": 16,
"num_hidden_layers": 22,
"max_position_embeddings": 32768,
"vocab_size": 32000,
"concept_dim": 64,
"logit_dim": 32,
"window_size": 2048
}- Create a new directory under
specialized_ability/ - Implement your custom models in the
model/subdirectory - Add dataset handling in the
dataset/subdirectory - Create training scripts in the
train/subdirectory - Update the main
__init__.pyfiles to register your models
- Task-Aware Delta Network: Implement custom delta functions in
task_aware_delta_net.py - Cross-Layer Connections: Modify
ttt_cross_layer.pyfor custom layer interactions - Transformer Variants: Extend
modeling_transformer_rnn.pyfor new architectures
- Parameters: 1.3B
- Training Context Length: 4096 tokens
- Training Precision: BF16
- Acceleration: Flash Linear Attention
- Parallelism: Data, tensor, and sequence parallelism support
- Selective Recompute: Configurable gradient checkpointing
- Mixed Precision: BF16 training with automatic mixed precision
- Distributed Training: Multi-GPU and multi-node support
- Memory Optimization: Efficient memory management with FSDP
- MRI Reconstruction: Fast and accurate MRI image reconstruction from undersampled k-space data
- Report Generation: Automated medical report generation from MRI
- Multi-modal Learning: Integration of k-space, imaging, and textual data
- Language Understanding: Strong performance on specialized and general language tasks
- Task Adaptation: Efficient adaptation for specialized applications