NF2 training loop and loss scheduling

This note is about how NF2 actually optimises the neural field.

The bare PyTorch loop is ordinary: zero gradients, run the model, compute losses, backpropagate, step the optimiser. The NF2-specific part lives mainly in nf2/train/module.py (NF2Module), which owns the model, the losses, and the schedules.

The key to reading the code is to keep three mechanisms separate.

The objective

At any step the total loss is a weighted sum over named loss terms:

Each factor has a different job:

MechanismWhat it changesScopeUpdate rate
Learning-rate scheduleoptimiser step sizeglobal scalarper batch
Loss-weight schedule importance of each termglobal scalarper batch
Loss-scaling module per-point weightingspatial / per sampleevery forward pass

Mixing these up makes the training code look much stranger than it is.

Forward step

A batch is a dict of named loaders, for example boundary points, potential-field samples, and random volume points.

Per step, NF2 roughly does this:

  1. Set coords.requires_grad = True so autograd can take spatial derivatives.
  2. Apply any learnable coordinate transforms.
  3. Split points by whether they need a Jacobian.
  4. Run one model call per group, then slice the output back into datasets.
  5. Build an all pseudo-dataset so physics losses can apply across all collocation points.

Pure boundary-fit points can skip the expensive Jacobian. Physics points need it because force-free and divergence losses depend on derivatives.

Each loss produces a per-sample tensor. NF2 applies the spatial scaling , takes the mean, and then applies the global weight .

That order matters:

per-sample loss -> spatial scaling -> mean -> global loss weight

Boundary annealing

The default idea is:

boundary    : weight 1e3 -> 1   over 1e5 iterations
force_free  : weight 1e-1       fixed
divergence  : weight 1e-1       fixed

Early in training, the large boundary weight forces the network to learn the magnetogram. Later, the boundary weight decays and the force-free/divergence terms can reshape the volume into a more physical equilibrium.

This is needed because is trivially force-free and divergence-free. If the physics losses dominate too early, the network can learn a clean field that ignores the data.

For an exponential schedule from to over optimiser steps,

so after steps:

The weights are updated by hand under no_grad, but registered so they move with checkpoints and devices.

Important: iterations: 1e5 means optimiser steps, not epochs. Change the batch size and the wall-clock meaning of the schedule changes.

Loss scaling with height

There is a second problem: falls quickly with height. If the loss is averaged over the volume, strong low-altitude points dominate and the upper corona gets weak training signal.

Loss-scaling modules divide pointwise losses by an estimate of the local field magnitude. This makes high-altitude points matter more.

Common options:

  • potential-fit: use an analytic potential-field height profile; this is the principled default.
  • exponential: use a fixed profile in height.
  • b_height: use the current predicted dynamically.
  • radial: ramp loss with radius for spherical/global cases.

These scaling factors are detached. They are weights, not extra physics, and gradients should not flow through them.

Guardrails

NF2 checks for NaNs and fails loudly. This is useful because the force-free loss divides by powers of , so a collapsed or near-zero field can explode numerically.

The validation step also needs gradients enabled. That is unusual for validation, but curl and divergence still need the Jacobian at evaluation time.