Skip to content

[v2.1] Model Zoo: Community architectures with unified CREDIT training interface#312

Draft
jsschreck wants to merge 32 commits intomainfrom
model-zoo
Draft

[v2.1] Model Zoo: Community architectures with unified CREDIT training interface#312
jsschreck wants to merge 32 commits intomainfrom
model-zoo

Conversation

@jsschreck
Copy link
Copy Markdown
Collaborator

Summary

  • Add 13 community weather AI architectures (Stormer, ClimaX, FourCastNet v1/v3, SFNO, SwinRNN, FengWu, GraphCast, HEALPix, Aurora, Pangu-Weather, AIFS) under a unified load_model interface — same config file, same credit train / credit_rollout entry points as WXFormer
  • Add pretrained_weights support to load_model for all zoo models; key remapping handles DDP prefixes and architecture differences automatically (strict=False)
  • Handle (B, C, T, H, W) 5D inputs in all zoo model forwards to match CREDIT trainer expectations
  • Fix init failures across models: HEALPix padding, GraphCast OOM kNN graph, activation checkpoint policy discovery
  • Filter unknown kwargs before calling model constructors so zoo configs don't error on CREDIT-specific keys
  • Add per-model README.md files and YAML config snippets for all 13 architectures
  • Add MODELS.md reference table covering distributed training support and torch.compile compatibility (spectral norm models require use_spectral_norm: false to compile)
  • Add convergence training configs and submit scripts for Derecho

Test plan

  • Smoke test: all 13 zoo models initialize and produce correct output shape (12/12 pass in tests/test_model_zoo.py)
  • pretrained_weights loads correctly for Stormer (194/295 keys), ClimaX (105/130), Aurora (~74%) with expected partial match
  • WXFormer 5D forward works with and without pretrained weights
  • torch.compile with use_spectral_norm: false on WXFormer
  • Convergence configs launch on Derecho without error

🤖 Generated with Claude Code

jsschreck and others added 26 commits March 27, 2026 15:36
credit/models/aurora/ — MIT-licensed Aurora architecture (Price et al., 2024)
  Perceiver3DEncoder → Swin3DTransformerBackbone → Perceiver3DDecoder.
  CREDITAurora wraps the native Batch interface to CREDIT's (B, C, H, W) tensors.
  All HuggingFace dependencies removed; no pretrained weights required.

credit/models/pangu/ — Pangu-Weather 3D Earth Transformer (Bi et al., 2023)
  PyTorch implementation from scratch from the official pseudocode.
  EarthSpecificBias, 3-D shifted window attention, U-shaped encoder-decoder.
  CREDITPangu provides the CREDIT flat-tensor interface.

credit/models/aifs/ — AIFS-inspired lat/lon Transformer processor (Lang et al., 2024)
  Simplified: linear embed + sinusoidal pos enc + N-layer Transformer processor.
  GNN encoder/decoder replaced by linear projections; operates on regular lat/lon grid.

All three: registered in credit/models/__init__.py, smoke-tested (forward + backward)
on GPU (casper29 V100), junk-data __main__ blocks included.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…raphCast, HEALPix

All models trainable from scratch with CREDIT's (B, C, H, W) flat tensor pipeline.
No external package dependencies (no HuggingFace, no anemoi, no earth2studio).

Stormer   — plain ViT (patch embed + positional encoding + transformer blocks)
ClimaX    — per-variable tokenization ViT (separate embed per channel, cross-var aggregation)
FourCastNet — AFNO: Adaptive Fourier Neural Operators replacing self-attention
SFNO      — Spherical FNO; uses torch-harmonics SHT when available, falls back to rfft2
SwinRNN   — 3-stage Swin encoder-decoder with skip connections
FengWu    — hierarchical ViT with per-variable-group encoders + cross-group fuser
GraphCast — kNN GNN encoder-processor-decoder on flat lat/lon grid
HEALPix   — DLWP-HEALPix U-Net with transparent lat/lon <-> HEALPix reprojection;
            uses healpy when installed, approximate grid otherwise

Also: wxformer-sdl canonical name + crossformer-style legacy alias kept
wxformer-v2-sdl: guarded import (requires wxformer_v2.py)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Multi-scale U-Net with SNO blocks at each stage. SHT via torch-harmonics
when available; falls back to rfft2. GroupNorm for rotation equivariance.
Registered as 'fourcastnet3' in model_types.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
AFNO: duplicate 'b' subscript in bhwnb caused RuntimeError. Renamed
block_size dim to 's' and hidden dim to 'e' throughout.

tests/test_model_zoo.py: smoke test all 12 zoo models on GPU; correct
kwarg names for graphcast (latent_dim/processor_depth), healpix (embed_dim),
aurora/pangu/aifs (variable-list interface). All pass on A100-80GB.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Documents: paper + original code source, architecture summary,
what CREDIT's implementation does/doesn't do, validation status,
and known caveats. Honest about what is smoke-tested vs trained.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…2/12 train

healpix.py rewrite:
- HEALPixPadding: full face-adjacency with rotated borders (pn/pe/ps)
  for all 3 polar/equatorial face types (ported from PhysicsNeMo Zephyr,
  authored by Noah Brenowitz / nbren12 at NVIDIA, Apache 2.0)
- healpy now installed; pix2xyf/pix2ang used for correct pixel geometry
- _build_index_buffers: hp_to_ll (12,nside,nside) + ll_to_hp (H*W,)
  via ang2pix + pix2xyf for exact nearest-neighbour reprojection
- HEALPixConv: skips face-aware pad for kernel_size=1 (no neighbors needed)
- HEALPixUNet: n_stages-deep encoder-decoder, all face-aware convs
- Previous impl used F.pad circular on a flat array -- not face-aware at all

tests/test_model_zoo_train.py: 50-step overfit test, 12/12 converging on A100
All models show meaningful loss drops: sfno 83%, stormer 71%, pangu 99.7%

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Each model directory README now includes a ready-to-use config block,
practical notes (memory warnings, optional deps, caveats), and accurate
parameter names matching the actual constructors.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
healpix: rewrite README to credit Noah Brenowitz (nbren12) / NVlabs/earth2grid
(Apache 2.0) for face-aware HEALPixPadding — the core innovation we ported.
Update implementation description to match actual rewrite (not the old
flat-array approximation). Note 41% loss drop and healpy confirmed working.

fengwu: credit Shanghai AI Lab / SJTU / Microsoft Research Asia as original
authors (no public repo; PhysicsNeMo is the architecture reference).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
crossformer.py is based on lucidrains' CrossFormer PyTorch implementation.
Original GitHub repo taken down; now at gitlab.com/lucidrains.
Note NCAR/MILES adaptations: encoder-decoder head, SDL noise layer, CREDIT
channel conventions.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ins' CrossFormer

Phil Wang (lucidrains) implemented Zhang et al.'s CrossFormer in PyTorch.
John Schreck used that implementation to build WXFormer: encoder-decoder head,
SDL noise injection, CREDIT channel conventions.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
upsample_with_ps (sub-pixel conv decoder) was built by Will Chapman and
partially resolves the grid artefact problem from transposed-conv upsampling.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
metrics.py: two ACC bugs fixed:
1. Loop was over self.acc_vars but anomaly tensors were indexed in
   ordered_acc_vars order — variable labels could be assigned to the
   wrong anomaly slice.
2. Lines 203-204 subtracted unweighted spatial mean from climatology
   anomalies (double-demeaning). Standard WB2/WMO ACC is the
   latitude-weighted correlation of (f - clim) vs (obs - clim) with
   no additional mean removal.

test_model_zoo.py / test_model_zoo_train.py: add make_wxformer() factory
so WXFormer is included in both the smoke test and the 50-step train test.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
CrossFormer expects time as an explicit dimension, not folded into channels.
train_one() now handles optional 4th return value (frames) from factories.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
cast_tuple(val, n) only broadcasts scalars — tuples are returned unchanged.
All CrossFormer hyperparams (dim, depth, global_window_size, etc.) must be
length-4 tuples or the assert len(dim)==4 fires an AssertionError.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…t) pattern

UpBlock skip connections require dim[i] = last_dim // 2^(3-i).
Default prod config (64,128,256,512) follows this; test factory now uses
(4,8,16,32). Also fixes GroupNorm: num_groups=dim[0]=4 divides all out_ch.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Two issues:
1. dim_head=32 default but dim[0]=4 → heads=0 → zero-element weights
   → spectral_norm reshape error. Fix: pass dim_head=4.
2. H=W=16 with 4 strides of 2 → spatial at deepest stage is 1×1,
   but local_window_size=4. Fix: use HW=64 (deepest stage → 4×4).

Also add scripts/casper_model_zoo.sh for GPU submission.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ILES

Weyn and Karlbauer are University of Washington / NVIDIA, not affiliated
with NCAR/MILES. Face-aware padding credit remains with Noah Brenowitz.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…table

The conda env has credit installed from main (no model-zoo models).
Prepend repo to PYTHONPATH so local credit/models/ takes precedence.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
9 configs in config/model_zoo/: stormer, climax, fourcastnet, sfno, swinrnn,
fengwu, graphcast, healpix, fourcastnet3. All use ERA5 1-deg data,
in_channels=80, out_channels=84, 1 node × 4 GPUs, 12h walltime.
Outputs to /glade/derecho/scratch/schreck/tmp/model_zoo_train/<model>.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Parser injects 'levels' into conf['model'] for coordinate checks, but
flat-channel zoo models (Stormer, ClimaX, etc.) don't accept 'levels'
in their __init__. Use inspect.signature to pass only accepted kwargs,
falling back to all kwargs when **kwargs is present in the signature.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Stormer/SFNO/SwinRNN/FourCastNet/FengWu: pad input to next multiple of
  patch_size in CREDIT wrapper forward; inner model built with padded size.
  Fixes AssertionError for 181×360 grid (181 is prime, not divisible by 4).
- GraphCast: replace O(N²) haversine matrix with scipy cKDTree on 3-D XYZ
  coordinates. Avoids ~34 GB allocation for 65k-node global grid.
- All zoo configs: add activation_checkpoint: False to trainer section.
  Parser defaults this to True, which triggered load_fsdp_or_checkpoint_policy
  and raised OSError for models not in the hard-coded list.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
When conf['model']['pretrained_weights'] is set and load_weights=False,
the model is constructed fresh then initialized from the external checkpoint
with strict=False (allows partial weight transfer when input channels differ).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The CREDIT trainer feeds (B, C, T, H, W) tensors; zoo models expect
(B, C, H, W). Flatten history dim T into channels at forward entry:
  x = x.reshape(B, C*T, H*W) for 5-D inputs.

Affects: Stormer, SFNO, SwinRNN, FourCastNet, FengWu, ClimaX,
         GraphCast, HEALPix, FourCastNetV3.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ctures.md

Documents which models compile cleanly, which require config changes (wxformer
upsamplePS: true), and which are blocked by spectral norm (camulator, fuxi, swin,
crossformer, graph). Explains the spectral norm constraint and provides guidance
for sfno/fourcastnet3 SHT path and graphcast dynamic shapes.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…fy preset partial-match reason

- torch.compile table: spectral-norm models show ✗ by default; fix says
  use_spectral_norm: false (the actual config knob), not upsamplePS
- Add warning that disabling spectral norm risks rollout divergence
- Presets table: explain partial key match for Stormer/ClimaX/Aurora is due to
  CREDIT data input structure — these presets are transfer learning starting
  points, not drop-in inference checkpoints

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 20.54997% with 2687 lines in your changes missing coverage. Please review.
✅ Project coverage is 16.78%. Comparing base (fd24d56) to head (d9d20e8).

Files with missing lines Patch % Lines
credit/models/pangu/pangu.py 14.21% 325 Missing and 1 partial ⚠️
credit/models/aurora/swin3d.py 16.34% 261 Missing ⚠️
credit/models/healpix/healpix.py 18.77% 211 Missing and 1 partial ⚠️
credit/models/swinrnn/swinrnn.py 14.66% 191 Missing and 1 partial ⚠️
credit/models/fengwu/fengwu.py 16.75% 153 Missing and 1 partial ⚠️
credit/models/aifs/aifs.py 18.64% 143 Missing and 1 partial ⚠️
credit/models/fourcastnet3/fcn3.py 19.48% 123 Missing and 1 partial ⚠️
credit/models/aurora/encoder.py 13.13% 119 Missing ⚠️
credit/models/stormer/stormer.py 20.00% 115 Missing and 1 partial ⚠️
credit/models/sfno/sfno.py 18.43% 114 Missing and 1 partial ⚠️
... and 19 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #312      +/-   ##
==========================================
+ Coverage   16.13%   16.78%   +0.64%     
==========================================
  Files         122      162      +40     
  Lines       19604    22982    +3378     
  Branches     3308     3559     +251     
==========================================
+ Hits         3164     3858     +694     
- Misses      16148    18817    +2669     
- Partials      292      307      +15     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

jsschreck and others added 3 commits March 30, 2026 07:41
…her, MambaVision, CorrDiff)

Each follows the standard CREDIT zoo pattern: credit/models/<name>/ directory with wrapper
class, README, __init__.py, config YAML, and smoke-test factory in tests/test_model_zoo.py.

- iTransformer: inverted attention across variable tokens (H*W→d_model per channel)
- FuXi-ENS: ViT backbone + VAE bottleneck for ensemble perturbations
- ArchesWeather: alternating window attention + column attention blocks
- MambaVision: 4-stage hybrid Mamba+attention U-Net (pure-PyTorch fallback if mamba_ssm absent)
- CorrDiff: EDM-preconditioned score-based diffusion with CondEncoder + SongUNet

Also wrap legacy model imports (crossformer family) in try/except to handle numba/NumPy>=2.3
environments cleanly — zoo models have no such transitive dependency.

Fix CorrDiff decoder in_c calculation: use prev level's output channels for h, not current ch.

17/18 smoke tests pass (wxformer excluded: pre-existing numba/NumPy 2.4 env conflict).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
bridgescaler.distributed eagerly imports numba, which hard-errors when
NumPy >= 2.3 is installed. Wrapping the two eager import sites in
transforms_quantile.py and deprecated/_transforms.py lets PostBlock
(and therefore wxformer) import cleanly in modern NumPy environments.
All 18 smoke tests now pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ils)

Fills in iTransformer, FuXi-ENS, ArchesWeather, MambaVision, CorrDiff in:
- Pretrained Weight Compatibility table
- Credits & Licenses table
- Model Details / Model Zoo section (architecture blurbs)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@jsschreck jsschreck changed the title [v2.1] Model Zoo: 13 community architectures with unified CREDIT training interface [v2.1] Model Zoo: Community architectures with unified CREDIT training interface Mar 31, 2026
@dkimpara
Copy link
Copy Markdown
Collaborator

dkimpara commented Apr 1, 2026

Can you make the default dry-run submit to derecho-develop?

jsschreck and others added 3 commits April 3, 2026 09:31
… wxf_v2 branch

- credit/models/__init__.py: remove wxformer_v2_ensemble import and model_types entry
- credit/models/wxformer/wxformer_v2_ensemble.py: deleted (belongs only in wxf_v2)
- credit/models/wxformer/README.md: remove wxformer_v2_ensemble row

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ub-package

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants