Skip to content

add more nnx and sl tutorials#304

Open
da-niao-dan wants to merge 11 commits into
mainfrom
peicheng/nnx_split_learning_tutorial
Open

add more nnx and sl tutorials#304
da-niao-dan wants to merge 11 commits into
mainfrom
peicheng/nnx_split_learning_tutorial

Conversation

@da-niao-dan

Copy link
Copy Markdown
Member

add more nnx and sl tutorials

@da-niao-dan da-niao-dan requested a review from a team as a code owner December 18, 2025 09:01
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @da-niao-dan, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request expands the tutorial collection by adding two new examples that demonstrate advanced usage patterns within MPLang v1. The first tutorial provides a comprehensive guide on effectively managing Flax NNX model states and optimizers by leveraging pure Python dictionaries for seamless transfer across distributed function calls. The second tutorial builds upon these concepts to illustrate a practical implementation of vertical split learning, demonstrating how multiple parties can collaboratively train a machine learning model on vertically partitioned data while maintaining data privacy.

Highlights

  • NNX Model State Management Tutorial: Added a new tutorial (08_initialize_nnx_on_device.py) demonstrating best practices for managing Flax NNX model states and optimizers within MPLang v1. It focuses on converting model and optimizer states to pure Python dictionaries for seamless transfer across @mp.function boundaries, enabling stateful training and checkpointing.
  • Vertical Split Learning Tutorial: Introduced a comprehensive tutorial (09_split_learning_vertical.py) showcasing vertical split learning in MPLang v1. This example illustrates how two parties (Alice and Bob) can collaboratively train a machine learning model on vertically partitioned data, utilizing a three-model architecture and surrogate loss to maintain data privacy.
  • Memory-Efficient GraphDef Creation: Both new tutorials emphasize the use of nnx.eval_shape() for creating abstract model GraphDefs, which is a memory-efficient approach as it avoids allocating actual arrays during model reconstruction.
  • Optax Integration for Optimizer State: The tutorials demonstrate the integration of optax for flexible optimizer state management, showing how optimizer states can also be represented and passed as pure Python dictionaries, crucial for iterative training in a distributed environment.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two valuable tutorials demonstrating how to use Flax NNX with MPLang v1 for model initialization and for implementing a vertical split learning scenario. The code is well-structured and the tutorials are clear. My feedback focuses on the split learning tutorial (09_split_learning_vertical.py), with suggestions to improve maintainability by removing redundant code and to enhance efficiency by using more optimal JAX APIs.

Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
@da-niao-dan da-niao-dan force-pushed the peicheng/nnx_split_learning_tutorial branch from c036ab3 to e648e05 Compare December 18, 2025 09:17

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds two comprehensive NNX (Flax Neural Network) tutorials for MPLang v1, demonstrating proper patterns for distributed machine learning with state management and split learning architectures.

  • Tutorial 08: Introduces the foundational pattern for using Flax NNX with MPLang by showing how to convert model states to pure Python dicts for cross-device transfer
  • Tutorial 09: Builds on Tutorial 08 to implement split learning with vertical data partitioning, showcasing privacy-preserving collaborative training

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.

File Description
tutorials/v1/device/08_initialize_nnx_on_device.py New tutorial demonstrating NNX model initialization on devices, state dict conversion pattern, and stateful training with optimizer state management
tutorials/v1/device/09_split_learning_vertical.py New tutorial implementing split learning with vertical data partitioning using three models (Alice base, Bob base, Alice aggregate) and surrogate loss for privacy-preserving training

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
@da-niao-dan da-niao-dan requested a review from Copilot December 18, 2025 09:22
@da-niao-dan

Copy link
Copy Markdown
Member Author

@gemini review

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 11 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tutorials/v1/device/08_initialize_nnx_on_device.py Outdated
Comment thread tutorials/v1/device/08_initialize_nnx_on_device.py
Comment thread tutorials/v1/device/08_initialize_nnx_on_device.py Outdated
Comment thread tutorials/v1/device/09_split_learning_vertical.py
Comment thread tutorials/v1/device/09_split_learning_vertical.py
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
Comment thread tutorials/v1/device/08_initialize_nnx_on_device.py Outdated
Comment thread tutorials/v1/device/09_split_learning_vertical.py
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 11 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tutorials/v1/device/08_initialize_nnx_on_device.py Outdated
Comment thread tutorials/v1/device/08_initialize_nnx_on_device.py
Comment thread tutorials/v1/device/09_split_learning_vertical.py
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
Comment thread tutorials/v1/device/08_initialize_nnx_on_device.py Outdated
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
@da-niao-dan

Copy link
Copy Markdown
Member Author

@copilot open a new pull request to apply changes based on the comments in this thread

Copilot AI commented Dec 19, 2025

Copy link
Copy Markdown
Contributor

@da-niao-dan I've opened a new pull request, #306, to work on those changes. Once the pull request is ready, I'll request review from you.

Copilot AI and others added 2 commits December 19, 2025 10:33
…cy (#306)

* Initial plan

* Address review comments: improve documentation and clarity

Co-authored-by: da-niao-dan <9532472+da-niao-dan@users.noreply.github.com>

* Fix type consistency for final_loss initialization

Co-authored-by: da-niao-dan <9532472+da-niao-dan@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: da-niao-dan <9532472+da-niao-dan@users.noreply.github.com>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
Comment thread tutorials/v1/device/09_split_learning_vertical.py Outdated
Comment thread tutorials/v1/device/08_initialize_nnx_on_device.py Outdated
Comment thread tutorials/v1/device/08_initialize_nnx_on_device.py Outdated
@da-niao-dan

Copy link
Copy Markdown
Member Author

@copilot open a new pull request to apply changes based on the comments in this thread

Copilot AI commented Dec 19, 2025

Copy link
Copy Markdown
Contributor

@da-niao-dan I've opened a new pull request, #307, to work on those changes. Once the pull request is ready, I'll request review from you.

* Initial plan

* Replace _graphdef with _ for unused variables

Co-authored-by: da-niao-dan <9532472+da-niao-dan@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: da-niao-dan <9532472+da-niao-dan@users.noreply.github.com>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +227 to +229
# Create temporary model to get GraphDef structure
temp_model = model_class(*model_args, **model_kwargs)
graphdef, temp_state = nnx.split(temp_model)

Copilot AI Dec 19, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tutorial 09 creates a temporary full model instance to get the GraphDef structure (line 228), while Tutorial 08 uses nnx.eval_shape() which is more memory-efficient for large models. For consistency and best practices, consider using nnx.eval_shape() here as well, especially since Tutorial 08 specifically mentions this as a best practice (line 522-524).

Suggested change
# Create temporary model to get GraphDef structure
temp_model = model_class(*model_args, **model_kwargs)
graphdef, temp_state = nnx.split(temp_model)
# Use nnx.eval_shape to get GraphDef and abstract state structure
graphdef, temp_state = nnx.eval_shape(
lambda: nnx.split(model_class(*model_args, **model_kwargs))
)

Copilot uses AI. Check for mistakes.
def _init_alice_agg():
model = AliceAggregateModel(
input_dim=h1 + h2,
hidden_dim=16,

Copilot AI Dec 19, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded value 16 for the hidden dimension of Alice's aggregate model should be documented or defined as a constant. This magic number appears in both the model definition (line 664) and the loss function (line 328), making it harder to maintain if this architecture parameter needs to change.

Copilot uses AI. Check for mistakes.
Comment on lines +573 to +583
# NOTE: mp.ops.basic.table_to_tensor requires that all columns in a table share
# the same dtype. Since Alice's feature columns are FLOAT64, we also declare
# the label column as FLOAT64 so that we can load features + labels in a single
# table_to_tensor call. In this tutorial, the labels in the CSV are integer-valued
# class IDs (e.g. 0/1), so converting the label column from FLOAT64 back to
# jnp.int32 after tensorization is an exact, lossless cast.
#
# If your labels are not integer-valued, or if they must be stored with a
# different dtype than the features, prefer a more robust pattern such as
# reading features and labels via separate schemas / read calls so you do not
# rely on a FLOAT64 → int cast.

Copilot AI Dec 19, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states "all columns must have same dtype for table_to_tensor" but this is a limitation that should be documented more prominently. The workaround of casting labels from FLOAT64 to int32 after loading could result in precision loss for non-integer label values. Consider warning users that this pattern is only safe when labels are truly integer-valued, and suggest using separate read operations if labels need different dtypes.

Suggested change
# NOTE: mp.ops.basic.table_to_tensor requires that all columns in a table share
# the same dtype. Since Alice's feature columns are FLOAT64, we also declare
# the label column as FLOAT64 so that we can load features + labels in a single
# table_to_tensor call. In this tutorial, the labels in the CSV are integer-valued
# class IDs (e.g. 0/1), so converting the label column from FLOAT64 back to
# jnp.int32 after tensorization is an exact, lossless cast.
#
# If your labels are not integer-valued, or if they must be stored with a
# different dtype than the features, prefer a more robust pattern such as
# reading features and labels via separate schemas / read calls so you do not
# rely on a FLOAT64 → int cast.
# NOTE: mp.ops.basic.table_to_tensor currently requires that all columns in a
# table share the same dtype. Since Alice's feature columns are FLOAT64, we
# also declare the label column as FLOAT64 so that we can load features +
# labels in a single table_to_tensor call.
#
# In this tutorial, the labels in the CSV are integer-valued class IDs
# (e.g. 0/1). Under that assumption, converting the label column from
# FLOAT64 back to jnp.int32 after tensorization is an exact, lossless cast,
# because every label value is representable as an int32 without rounding.
#
# WARNING: If your labels are not truly integer-valued, or if they must be
# stored with a different dtype than the features (e.g. floating-point
# regression targets), copying this pattern can silently introduce precision
# loss or semantic errors due to the FLOAT64 → int cast.
#
# In those cases, prefer a more robust pattern: define separate schemas and
# read operations for features and labels (each with their appropriate
# dtypes), and avoid casting labels from FLOAT64 to int entirely.

Copilot uses AI. Check for mistakes.
bob_base_dict = mp.device("P1", fe_type="nnx")(_init_bob_base)()

# === Step 2: Training loop (all iterations run inside this @mp.function) ===

Copilot AI Dec 19, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable name final_loss is initialized to 0.0 before the training loop, but if n_steps is 0 (empty range), it would return the initial value rather than None or raise an error. This could mask bugs where no training actually occurs. Consider either validating that n_steps > 0 at the function start, or initializing final_loss to None and setting it properly within the loop.

Suggested change
if n_steps <= 0:
raise ValueError(f"n_steps must be a positive integer, got {n_steps}")

Copilot uses AI. Check for mistakes.
Comment on lines +198 to +203
Note: graphdef is NOT included - it can be reconstructed from the model class
definition, but requires the same model initialization parameters (input_dim,
hidden_dim, output_dim, seed) to be provided when reconstructing via
reconstruct_model_from_dict. Only the state dict (parameters), optimizer state,
and step counter are stored for transfer.
"""

Copilot AI Dec 19, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment "Note: graphdef is NOT included - it can be reconstructed from the model class definition" is accurate, but it's worth noting that this approach requires the exact same model initialization parameters (input_dim, hidden_dim, etc.) to be available when reconstructing. If these parameters change between serialization and reconstruction, the state dict won't match the GraphDef structure. Consider adding a note about parameter version compatibility.

Copilot uses AI. Check for mistakes.

# Handle final_loss if it's wrapped
if isinstance(final_loss, (list, tuple)):
final_loss = final_loss[0] if len(final_loss) > 0 else 0.0

Copilot AI Dec 19, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fallback value of 0.0 when the list is empty could mask errors. If final_loss is unexpectedly wrapped as an empty list or tuple, using 0.0 as a default silently hides the problem. Consider using a more obvious sentinel value or raising an error if the structure is unexpected, especially in tutorial code where correctness is important for learning.

Suggested change
final_loss = final_loss[0] if len(final_loss) > 0 else 0.0
if not final_loss:
raise ValueError(
"Unexpected empty sequence for final_loss; expected a scalar loss "
"or a single-element sequence."
)
if len(final_loss) == 1:
final_loss = final_loss[0]
else:
raise ValueError(
f"Unexpected sequence length for final_loss: expected 1 element, "
f"got {len(final_loss)}."
)

Copilot uses AI. Check for mistakes.
Comment on lines +440 to +443
if isinstance(output, (list, tuple)) and len(output) > 0:
output = output[0]
if not isinstance(output, jnp.ndarray):
output = jnp.array(output)

Copilot AI Dec 19, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The handling of wrapped output could be more robust. The code checks if output is a list/tuple and has length > 0, then takes the first element. However, if the output is an empty list/tuple, it falls through without converting to ndarray, which could cause issues in the subsequent print statements. Consider handling the empty case explicitly or ensuring the conversion happens in all branches.

Copilot uses AI. Check for mistakes.
Comment on lines +488 to +491
if isinstance(loss_history, (list, tuple)):
loss_history = loss_history[0] if len(loss_history) > 0 else loss_history
if not isinstance(loss_history, jnp.ndarray):
loss_history = jnp.array(loss_history)

Copilot AI Dec 19, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fallback for an empty list/tuple returns loss_history unchanged (which would be an empty list), but then the code attempts to convert it to a jnp.array and later access loss_history[0] and loss_history[-1], which would raise IndexError. Consider either validating that loss_history is non-empty before accessing indices, or providing a more meaningful default value.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants