Major refactoring for JAX-style classes.#29
Conversation
- Implemented `fo_integrators.py` for full orbit tracing with various methods and parameters. - Implemented `gc_integrators.py` for guiding center dynamics with adaptative and constant step sizes. - Enhanced `Tracing` class in `dynamics.py` to support multiple methods and step sizes.
…d adjust num_steps based on dt
… class for improved step size handling
…for adaptive step size
…ters for performance
…ance plots, and improve layout for better visualization
…arate gamma computation
|
Tests need fixing; |
There was a problem hiding this comment.
Pull request overview
This PR implements a major refactoring to make coils and surfaces proper JAX PyTrees, enabling automatic differentiation. It introduces a new loss wrapper system for gradient-based optimization and adds comprehensive analysis and validation code comparing ESSOS with SIMSOPT.
Key changes:
- Refactored
Coils,Curves,SurfaceRZFourier, andBiotSavartclasses as JAX PyTrees with proper tree flattening/unflattening - Added
essos/losses.pywithcustom_lossandcomposite_lossclasses for differentiable loss functions - Updated API:
Coils_from_json()→Coils.from_json(),tracing.energyproperty →tracing.energy()method - Added extensive analysis scripts for validation against SIMSOPT
Reviewed changes
Copilot reviewed 32 out of 32 changed files in this pull request and generated 39 comments.
Show a summary per file
| File | Description |
|---|---|
essos/losses.py |
New module implementing base_loss, custom_loss, and composite_loss classes for automatic differentiation |
essos/surfaces.py |
Refactored SurfaceRZFourier as PyTree with cached properties and improved initialization |
essos/coils.py |
Refactored Curves and Coils as PyTrees with cached properties, changed to classmethod constructors |
essos/fields.py |
Added MagneticField base class and registered BiotSavart as PyTree |
essos/dynamics.py |
Changed energy from cached property to method, added Particles.join() method |
essos/objective_functions.py |
Removed deprecated loss functions, added new coil separation and curvature losses |
essos/optimization.py |
Updated surface instantiation to include mpol/ntor parameters |
examples/optimize_coils_vmec_surface.py |
Major rewrite using new loss wrapper API instead of old optimization functions |
examples/trace_particles_coils_guidingcenter.py |
Updated imports and API calls (from_json, energy method) |
examples/trace_fieldlines_coils.py |
Updated to use Coils.from_json() |
examples/optimize_coils_particle_confinement_fullorbit.py |
Minor formatting and parameter updates |
examples/optimize_coils_and_surface.py |
Added mpol/ntor parameters, simplified loss calculations |
examples/input_files/*. |
Updated VMEC input file coefficients |
examples/comparisons_SIMSOPT/*.py |
Deleted old comparison scripts |
examples/compare_guidingcenter_fullorbit.py |
Updated particle initialization and energy calculation |
analysis/*.py |
New analysis scripts for validation and benchmarking |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @property | ||
| def dependencies_buffer(self): | ||
| if self._dependencies_buffer is None: | ||
| self._dependencies_buffer = tree_util.tree_map(lambda x: jnp.zeros_like(x), self.dependencies) |
There was a problem hiding this comment.
This 'lambda' is just a simple wrapper around a callable object. Use that object directly.
| self._dependencies_buffer = tree_util.tree_map(lambda x: jnp.zeros_like(x), self.dependencies) | |
| self._dependencies_buffer = tree_util.tree_map(jnp.zeros_like, self.dependencies) |
| json_file_stel = curves_stel | ||
| field_simsopt = load(json_file_stel) | ||
| coils_simsopt = field_simsopt.coils | ||
| curves_simsopt = [coil.curve for coil in coils_simsopt] |
| compile_tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=100, method='Dopri5', | ||
| stepsize='adaptive', tol_step_size=trace_tolerance_array[0], particles=particles) | ||
| block_until_ready(compile_tracing.trajectories) | ||
|
|
||
| for index, trace_tolerance_ESSOS in enumerate(trace_tolerance_array): | ||
| num_steps_essos = avg_steps_SIMSOPT_array[index] | ||
| print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}') | ||
| start_time = time() | ||
| tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=num_steps_essos, method='Dopri5', | ||
| stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles) |
There was a problem hiding this comment.
Keyword argument 'timesteps' is not a supported parameter name of Tracing.init.
Keyword argument 'tol_step_size' is not a supported parameter name of Tracing.init.
Keyword argument 'method' is not a supported parameter name of Tracing.init.
Keyword argument 'stepsize' is not a supported parameter name of Tracing.init.
| compile_tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=100, method='Dopri5', | |
| stepsize='adaptive', tol_step_size=trace_tolerance_array[0], particles=particles) | |
| block_until_ready(compile_tracing.trajectories) | |
| for index, trace_tolerance_ESSOS in enumerate(trace_tolerance_array): | |
| num_steps_essos = avg_steps_SIMSOPT_array[index] | |
| print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}') | |
| start_time = time() | |
| tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=num_steps_essos, method='Dopri5', | |
| stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles) | |
| compile_tracing = Tracing('GuidingCenter', field_essos, tmax_gc, particles=particles) | |
| block_until_ready(compile_tracing.trajectories) | |
| for index, trace_tolerance_ESSOS in enumerate(trace_tolerance_array): | |
| num_steps_essos = avg_steps_SIMSOPT_array[index] | |
| print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}') | |
| start_time = time() | |
| tracing = Tracing('GuidingCenter', field_essos, tmax_gc, particles=particles) |
| tracing_fo = Tracing(field=field, model='FullOrbit', particles=particles, maxtime=tmax_fo, | ||
| timesteps=timesteps_fo, tol_step_size=trace_tolerance) |
There was a problem hiding this comment.
Keyword argument 'timesteps' is not a supported parameter name of Tracing.init.
Keyword argument 'tol_step_size' is not a supported parameter name of Tracing.init.
| tracing_gc = Tracing(field=field, model='GuidingCenter', particles=particles, maxtime=tmax_gc, | ||
| timesteps=timesteps_gc, tol_step_size=trace_tolerance) |
There was a problem hiding this comment.
Keyword argument 'timesteps' is not a supported parameter name of Tracing.init.
Keyword argument 'tol_step_size' is not a supported parameter name of Tracing.init.
| nfp=number_of_field_periods, stellsym=True) | ||
| coils_essos = Coils(curves=curves_essos, currents=[current_on_each_coil]*number_coils_per_half_field_period) | ||
| field_essos = BiotSavart(coils_essos) | ||
| surface_essos = SurfaceRZFourier_ESSOS(vmec, ntheta=ntheta, nphi=nphi, close=False) |
There was a problem hiding this comment.
Call to SurfaceRZFourier.init with too few arguments; should be no fewer than 5.
| surface_essos = SurfaceRZFourier_ESSOS(vmec, ntheta=ntheta, nphi=nphi, close=False) | |
| surface_essos = SurfaceRZFourier_ESSOS(vmec, order_Fourier_series_coils, ntheta=ntheta, nphi=nphi, close=False) |
| EXPORT = False |
There was a problem hiding this comment.
This statement is unreachable.
| EXPORT = False | |
| EXPORT = True |
There was a problem hiding this comment.
The name of the file is not correct. It does not take the gradients of any particle confinement loss.
Refactor of coils & surfaces to be proper PyTrees;
Added a loss wrapper to differentiate with respect to the dogs (PyTree leaves);
Added analysis & validation of the code