Skip to content

ndwang/latent_beam_dynamics

Repository files navigation

Latent Beam Dynamics

A latent-space causal transformer that predicts beam distribution evolution through accelerator elements. The beam state is represented as a VAE latent vector, and the model predicts how it transforms through a variable-length sequence of elements (drifts, quadrupoles, RF cavities, etc.).

Architecture

ElementEncoder

graph TD
    subgraph Input ["Input: Element i"]
        RawVec["<b>Union Parameter Vector</b><br/>(L, K1, K2, Angle, V_rf, f_rf, phi)"]
        Length["<b>Element Length</b><br/>L_i (meters)"]
    end

    subgraph Normalization ["Physics Normalization"]
        UnitScale["<b>Unit Scaling</b><br/>Each parameter divided by<br/>its characteristic scale"]
    end

    subgraph Position ["Positional Encoding"]
        CumSum["<b>Start Position</b><br/>s_i = cumulative sum of<br/>preceding lengths"]
        Fourier["<b>Fourier Features</b><br/>Sin/Cos on geometrically<br/>spaced frequency basis<br/>(1 cm to 1 km)"]
    end

    subgraph Projection ["Projection & Fusion"]
        MLP["<b>3-Layer MLP</b><br/>Maps normalized physics<br/>to model dimension"]
        Concat["<b>Concatenate</b>"]
        Fuse["<b>Linear Fusion</b><br/>Mixes element features<br/>with position features"]
    end

    RawVec --> UnitScale
    UnitScale --> MLP
    Length --> CumSum
    CumSum --> Fourier
    MLP --> Concat
    Fourier --> Concat
    Concat --> Fuse
    Fuse --> Output["<b>Element Embedding</b><br/>Position-aware, ready<br/>for transformer"]
Loading

TrackingTransformer

graph LR
    subgraph Inputs
        PrevBeam(("<b>Beam History</b><br/>States z_0 … z_{t-1}"))
        ElemEmb["<b>Element Embedding</b><br/>(from ElementEncoder)"]
    end

    subgraph Transformer [Causal Transformer]
        Fuse["<b>Token Fusion</b><br/>add / concat / bilinear"]
        subgraph Layer ["Transformer Layer (×N)"]
            Attn["<b>Causal Self-Attention</b><br/>Sees only past elements 0…t"]
            FFN["<b>Feed-Forward Block</b>"]
        end
    end

    subgraph DeltaDynamics [Delta-Dynamics Head]
        Proj["<b>Latent Projection</b><br/>Back to beam state dimension"]
        Delta["<b>Predicted Correction</b><br/>Δz_t"]
    end

    subgraph NextState [State Update]
        Add(("<b>+</b>"))
        NewBeam(("<b>Updated Beam</b><br/>State z_t"))
    end

    PrevBeam --> Fuse
    ElemEmb --> Fuse

    Fuse --> Attn
    Attn --> FFN
    FFN --> Proj

    Proj --> Delta

    Delta --> Add
    PrevBeam -->|"Skip connection"| Add
    Add --> NewBeam

    NewBeam -.->|Feeds into next step| PrevBeam
Loading

ElementEncoder takes the raw parameter vector for each element [L, K1, K2, Angle, V_rf, f_rf, phi_rf], applies physics-informed normalization, projects through a 3-layer MLP, and mixes in Fourier positional features encoding the element's longitudinal position. A quadrupole with K=0 naturally produces the same embedding as a drift -- no discrete type IDs needed.

TrackingTransformer is a GPT-style autoregressive model. Each token fuses the projected previous beam state with the element embedding via one of three fusion modes ("add", "concat", "bilinear"). A causal transformer predicts the delta: z_t = z_{t-1} + Δz_t.

LatticeTransformer

A parallel (non-autoregressive) alternative. The initial beam state z₀ conditions all transformer layers via Adaptive Layer Norm (AdaLN). Per-element Δz predictions are accumulated with a cumulative sum: z_t = z₀ + Σ Δz. No scheduled sampling needed — the model sees the same inputs at training and inference.

Forward Modes (TrackingTransformer)

Mode z_gt sampling_prob Execution
Teacher forcing provided 0.0 Parallel
Scheduled sampling provided (0, 1) Sequential
Autoregressive None ignored Sequential

Teacher forcing feeds ground-truth states as input and runs in parallel. Scheduled sampling randomly substitutes the model's own (detached) predictions at each step with the given probability, bridging the train/test gap. Autoregressive uses only the model's predictions, for inference.

The LatticeTransformer has a single forward(z0, x_raw) path for both training and inference.

Configuration

Hyperparameters are composed from YAML files in configs/ with CLI dot-notation overrides:

model:
  name: tracking       # tracking | lattice (required)
  latent_dim: 64       # VAE latent dimension
  d_model: 256         # transformer hidden dimension
  n_layers: 6          # transformer layers
  n_heads: 8           # attention heads
  n_freq: 32           # Fourier frequency pairs
  element_dim: 7       # raw element parameter dimension
  lambda_min: 0.01     # min positional wavelength (m)
  lambda_max: 1000.0   # max positional wavelength (m)
  dropout: 0.1
  mlp_ratio: 4         # feed-forward expansion ratio
  fusion: concat       # add | concat | bilinear (TrackingTransformer only)

training:
  epochs: 200
  batch_size: 32
  lr: 3.0e-4
  weight_decay: 1.0e-2
  grad_clip: 1.0
  ss_warmup: 10        # scheduled sampling warmup (TrackingTransformer only)
  ss_k: 0.05           # scheduled sampling ramp rate (TrackingTransformer only)

Usage

from src.models import ModelConfig, TrackingTransformer, LatticeTransformer

config = ModelConfig(latent_dim=64, d_model=256)

# TrackingTransformer (autoregressive)
model = TrackingTransformer(config)
z_pred = model(z0, x_raw, z_gt=z_gt, sampling_prob=0.0)  # teacher forcing
z_pred = model(z0, x_raw, z_gt=z_gt, sampling_prob=0.3)  # scheduled sampling
z_pred = model(z0, x_raw)                                 # autoregressive inference

# LatticeTransformer (parallel)
model = LatticeTransformer(config)
z_pred = model(z0, x_raw)  # same call for training and inference

Setup

Runs on NERSC Perlmutter. Three conda environments cover different pipeline stages:

Environment Stage Key packages
lbd_datagen Lattice generation + Tao tracking NumPy, distgen, pmd_beamphysics, Bmad/Tao
vae VAE data prep (frequency maps) beam_vae, PyTorch
lbd Transformer training PyTorch + CUDA 12.4
ml load conda

# Stage 1: Generate lattice + beam inputs
conda activate lbd_datagen
python scripts/generate_inputs.py --mode sectioned --n-samples 10000 --seq-len 32 --output-dir data/sectioned_10k

# Stage 2: Track particles through Bmad/Tao
conda activate lbd_datagen
find data/sectioned_10k -mindepth 1 -maxdepth 1 -type d | sort | \
    parallel -j$(nproc) bash scripts/track_one.sh {}

# Stage 3: Convert tracked beams to VAE training data
conda activate vae
python scripts/prepare_vae_data.py --data-dir data/sectioned_10k \
    --output data/vae_training/sectioned_10k --workers 128

# Stage 4: Train transformer
conda activate lbd
python scripts/train.py model.name=tracking   # or model.name=lattice

Run the model sanity check:

conda activate lbd
python scripts/check_models.py

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors