A latent-space causal transformer that predicts beam distribution evolution through accelerator elements. The beam state is represented as a VAE latent vector, and the model predicts how it transforms through a variable-length sequence of elements (drifts, quadrupoles, RF cavities, etc.).
graph TD
subgraph Input ["Input: Element i"]
RawVec["<b>Union Parameter Vector</b><br/>(L, K1, K2, Angle, V_rf, f_rf, phi)"]
Length["<b>Element Length</b><br/>L_i (meters)"]
end
subgraph Normalization ["Physics Normalization"]
UnitScale["<b>Unit Scaling</b><br/>Each parameter divided by<br/>its characteristic scale"]
end
subgraph Position ["Positional Encoding"]
CumSum["<b>Start Position</b><br/>s_i = cumulative sum of<br/>preceding lengths"]
Fourier["<b>Fourier Features</b><br/>Sin/Cos on geometrically<br/>spaced frequency basis<br/>(1 cm to 1 km)"]
end
subgraph Projection ["Projection & Fusion"]
MLP["<b>3-Layer MLP</b><br/>Maps normalized physics<br/>to model dimension"]
Concat["<b>Concatenate</b>"]
Fuse["<b>Linear Fusion</b><br/>Mixes element features<br/>with position features"]
end
RawVec --> UnitScale
UnitScale --> MLP
Length --> CumSum
CumSum --> Fourier
MLP --> Concat
Fourier --> Concat
Concat --> Fuse
Fuse --> Output["<b>Element Embedding</b><br/>Position-aware, ready<br/>for transformer"]
graph LR
subgraph Inputs
PrevBeam(("<b>Beam History</b><br/>States z_0 … z_{t-1}"))
ElemEmb["<b>Element Embedding</b><br/>(from ElementEncoder)"]
end
subgraph Transformer [Causal Transformer]
Fuse["<b>Token Fusion</b><br/>add / concat / bilinear"]
subgraph Layer ["Transformer Layer (×N)"]
Attn["<b>Causal Self-Attention</b><br/>Sees only past elements 0…t"]
FFN["<b>Feed-Forward Block</b>"]
end
end
subgraph DeltaDynamics [Delta-Dynamics Head]
Proj["<b>Latent Projection</b><br/>Back to beam state dimension"]
Delta["<b>Predicted Correction</b><br/>Δz_t"]
end
subgraph NextState [State Update]
Add(("<b>+</b>"))
NewBeam(("<b>Updated Beam</b><br/>State z_t"))
end
PrevBeam --> Fuse
ElemEmb --> Fuse
Fuse --> Attn
Attn --> FFN
FFN --> Proj
Proj --> Delta
Delta --> Add
PrevBeam -->|"Skip connection"| Add
Add --> NewBeam
NewBeam -.->|Feeds into next step| PrevBeam
ElementEncoder takes the raw parameter vector for each element [L, K1, K2, Angle, V_rf, f_rf, phi_rf], applies physics-informed normalization, projects through a 3-layer MLP, and mixes in Fourier positional features encoding the element's longitudinal position. A quadrupole with K=0 naturally produces the same embedding as a drift -- no discrete type IDs needed.
TrackingTransformer is a GPT-style autoregressive model. Each token fuses the projected previous beam state with the element embedding via one of three fusion modes ("add", "concat", "bilinear"). A causal transformer predicts the delta: z_t = z_{t-1} + Δz_t.
A parallel (non-autoregressive) alternative. The initial beam state z₀ conditions all transformer layers via Adaptive Layer Norm (AdaLN). Per-element Δz predictions are accumulated with a cumulative sum: z_t = z₀ + Σ Δz. No scheduled sampling needed — the model sees the same inputs at training and inference.
| Mode | z_gt |
sampling_prob |
Execution |
|---|---|---|---|
| Teacher forcing | provided | 0.0 | Parallel |
| Scheduled sampling | provided | (0, 1) | Sequential |
| Autoregressive | None | ignored | Sequential |
Teacher forcing feeds ground-truth states as input and runs in parallel. Scheduled sampling randomly substitutes the model's own (detached) predictions at each step with the given probability, bridging the train/test gap. Autoregressive uses only the model's predictions, for inference.
The LatticeTransformer has a single forward(z0, x_raw) path for both training and inference.
Hyperparameters are composed from YAML files in configs/ with CLI dot-notation overrides:
model:
name: tracking # tracking | lattice (required)
latent_dim: 64 # VAE latent dimension
d_model: 256 # transformer hidden dimension
n_layers: 6 # transformer layers
n_heads: 8 # attention heads
n_freq: 32 # Fourier frequency pairs
element_dim: 7 # raw element parameter dimension
lambda_min: 0.01 # min positional wavelength (m)
lambda_max: 1000.0 # max positional wavelength (m)
dropout: 0.1
mlp_ratio: 4 # feed-forward expansion ratio
fusion: concat # add | concat | bilinear (TrackingTransformer only)
training:
epochs: 200
batch_size: 32
lr: 3.0e-4
weight_decay: 1.0e-2
grad_clip: 1.0
ss_warmup: 10 # scheduled sampling warmup (TrackingTransformer only)
ss_k: 0.05 # scheduled sampling ramp rate (TrackingTransformer only)from src.models import ModelConfig, TrackingTransformer, LatticeTransformer
config = ModelConfig(latent_dim=64, d_model=256)
# TrackingTransformer (autoregressive)
model = TrackingTransformer(config)
z_pred = model(z0, x_raw, z_gt=z_gt, sampling_prob=0.0) # teacher forcing
z_pred = model(z0, x_raw, z_gt=z_gt, sampling_prob=0.3) # scheduled sampling
z_pred = model(z0, x_raw) # autoregressive inference
# LatticeTransformer (parallel)
model = LatticeTransformer(config)
z_pred = model(z0, x_raw) # same call for training and inferenceRuns on NERSC Perlmutter. Three conda environments cover different pipeline stages:
| Environment | Stage | Key packages |
|---|---|---|
lbd_datagen |
Lattice generation + Tao tracking | NumPy, distgen, pmd_beamphysics, Bmad/Tao |
vae |
VAE data prep (frequency maps) | beam_vae, PyTorch |
lbd |
Transformer training | PyTorch + CUDA 12.4 |
ml load conda
# Stage 1: Generate lattice + beam inputs
conda activate lbd_datagen
python scripts/generate_inputs.py --mode sectioned --n-samples 10000 --seq-len 32 --output-dir data/sectioned_10k
# Stage 2: Track particles through Bmad/Tao
conda activate lbd_datagen
find data/sectioned_10k -mindepth 1 -maxdepth 1 -type d | sort | \
parallel -j$(nproc) bash scripts/track_one.sh {}
# Stage 3: Convert tracked beams to VAE training data
conda activate vae
python scripts/prepare_vae_data.py --data-dir data/sectioned_10k \
--output data/vae_training/sectioned_10k --workers 128
# Stage 4: Train transformer
conda activate lbd
python scripts/train.py model.name=tracking # or model.name=latticeRun the model sanity check:
conda activate lbd
python scripts/check_models.py