From 5d4b836c2d3c384d4ebb8b756df52686465ec232 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=89=BE=E5=84=92?= Date: Mon, 10 Nov 2025 11:26:17 +0800 Subject: [PATCH 1/3] feat(jax): Add JAX export kernel --- mplang/kernels/context.py | 2 + mplang/kernels/jax_xla.py | 96 ++++++ mplang/ops/jax_cc.py | 65 +++- pyproject.toml | 11 +- tests/kernels/test_jax_xla.py | 572 ++++++++++++++++++++++++++++++++++ tests/ops/test_jax_cc.py | 120 +++++++ uv.lock | 4 + 7 files changed, 861 insertions(+), 9 deletions(-) create mode 100644 mplang/kernels/jax_xla.py create mode 100644 tests/kernels/test_jax_xla.py diff --git a/mplang/kernels/context.py b/mplang/kernels/context.py index 159da895..631715ad 100644 --- a/mplang/kernels/context.py +++ b/mplang/kernels/context.py @@ -38,6 +38,7 @@ def _ensure_impl_imported() -> None: from mplang.kernels import basic as _impl_basic # noqa: F401 from mplang.kernels import crypto as _impl_crypto # noqa: F401 from mplang.kernels import fhe as _impl_fhe # noqa: F401 + from mplang.kernels import jax_xla as _impl_jax_xla # noqa: F401 from mplang.kernels import mock_tee as _impl_tee # noqa: F401 from mplang.kernels import phe as _impl_phe # noqa: F401 from mplang.kernels import spu as _impl_spu # noqa: F401 @@ -99,6 +100,7 @@ def _ensure_impl_imported() -> None: "spu.run_pphlo": "spu.run_pphlo", # stablehlo "mlir.stablehlo": "mlir.stablehlo", + "jax.exec": "jax.exec", # sql # generic SQL op; backend-specific kernel id for duckdb "sql.run": "duckdb.run_sql", diff --git a/mplang/kernels/jax_xla.py b/mplang/kernels/jax_xla.py new file mode 100644 index 00000000..fa0eb962 --- /dev/null +++ b/mplang/kernels/jax_xla.py @@ -0,0 +1,96 @@ +# Copyright 2025 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import base64 +from typing import Any + +import jax.export as jax_export +import jax.numpy as jnp +import numpy as np + +from mplang.core import PFunction +from mplang.kernels.base import kernel_def +from mplang.kernels.value import TensorValue + + +@kernel_def("jax.exec") +def _jax_exec(pfunc: PFunction, *args: Any) -> Any: + """Execute a JAX exported function. + + Args: + pfunc: PFunction containing serialized JAX export data + *args: Input arguments for the function execution + + Returns: + The result of executing the JAX function with the provided arguments + """ + if pfunc.fn_type != "jax.exec": + raise ValueError(f"jax exec kernel received wrong fn_type: {pfunc.fn_type}") + + export_text = pfunc.fn_text + if export_text is None: + raise ValueError("jax exec kernel missing fn_text") + + try: + export_bytes = base64.b64decode(export_text) + except Exception as e: + raise ValueError(f"Failed to decode base64 export data: {e}") from e + + try: + exported = jax_export.deserialize(bytearray(export_bytes)) + except Exception as e: + raise ValueError(f"Failed to deserialize JAX export: {e}") from e + + # Convert TensorValue arguments to JAX arrays + jax_args = [] + for i, arg in enumerate(args): + if isinstance(arg, TensorValue): + # Convert TensorValue to JAX array + jax_array = jnp.array(arg.to_numpy()) + jax_args.append(jax_array) + elif isinstance(arg, (jnp.ndarray, np.ndarray)): + # Already a JAX/NumPy array + jax_args.append(jnp.array(arg)) + else: + # Try to convert to JAX array + try: + jax_args.append(jnp.array(arg)) + except Exception as e: + raise ValueError( + f"Cannot convert argument {i} of type {type(arg)} to JAX array: {e}" + ) from e + + # Execute the exported function + # The normalized function expects a single list argument containing all variables + try: + result = exported.call(jax_args) + except Exception as e: + raise RuntimeError(f"Failed to execute JAX exported function: {e}") from e + + # Convert result back to TensorValue if it's a JAX array + if isinstance(result, (jnp.ndarray, np.ndarray)): + return TensorValue(np.array(result)) + elif isinstance(result, (tuple, list)): + # Handle multiple outputs + converted_result = [] + for item in result: + if isinstance(item, (jnp.ndarray, np.ndarray)): + converted_result.append(TensorValue(np.array(item))) + else: + converted_result.append(item) + return type(result)(converted_result) + else: + return result diff --git a/mplang/ops/jax_cc.py b/mplang/ops/jax_cc.py index 59ba5013..b6dd389b 100644 --- a/mplang/ops/jax_cc.py +++ b/mplang/ops/jax_cc.py @@ -14,11 +14,13 @@ from __future__ import annotations +import base64 from collections.abc import Callable from typing import Any import jax import jax.numpy as jnp +from jax import export from jax.tree_util import PyTreeDef, tree_flatten from mplang.core import MPObject, PFunction, TensorType, get_fn_name @@ -28,6 +30,62 @@ # Enable 64-bit precision for JAX to match tensor types jax.config.update("jax_enable_x64", True) +USE_JAX_EXPORT = False + + +def jax_export( + is_variable: Callable[[Any], bool], flat_fn: Any, *args: Any, **kwargs: Any +) -> tuple[PFunction, list, PyTreeDef]: + """Compile JAX function to JAX export format for remote execution. + + Args: + is_variable: Predicate function to classify parameters as variables vs. constants. + Returns True for parameters that should be treated as PFunction inputs. + flat_fn: JAX function to be compiled into export format + *args: Positional arguments passed to the function during compilation + **kwargs: Keyword arguments passed to the function during compilation + Returns: + tuple[PFunction, list, PyTreeDef]: Compilation artifacts containing: + - PFunction: Serialized function with embedded export data and type metadata + - list: Extracted variable parameters (those satisfying is_variable predicate). + Non-variable parameters are captured as compile-time constants within + the PFunction body, while variables become runtime input parameters. + - PyTreeDef: Tree structure template for reconstructing nested output values + """ + # Flatten (args, kwargs) and capture immediates using the moved logic from primitive.py + normalized_fn, in_vars = normalize_fn(flat_fn, args, kwargs, is_variable) + + # Convert TensorType in_vars to ShapeDtypeStruct for JAX tracing + jax_params = [ + jax.ShapeDtypeStruct(arg.shape, jnp.dtype(arg.dtype.name)) for arg in in_vars + ] + + # Standard JAX serialization pipeline: jit → trace → lower → export + jitted_fn = jax.jit(normalized_fn) + traced = jitted_fn.trace(jax_params) + lowered = traced.lower() + + # Get JAX export representation - the portable format + exported = export.export(jitted_fn)(jax_params) + + # Get output info and tree structure for result reconstruction after remote execution + out_info_flat, out_tree = tree_flatten(lowered.out_info) + out_info_flat = [TensorType.from_obj(info) for info in out_info_flat] + + # This format tells JaxRT how to handle the compiled result + pfn_kwargs: dict[str, Any] = { + "fn_type": "jax.exec", + "ins_info": tuple(TensorType.from_obj(x) for x in in_vars), + "outs_info": tuple(out_info_flat), + "fn_name": get_fn_name(flat_fn), + "fn_text": base64.b64encode(exported.serialize()).decode( + "utf-8" + ), # Serialized export data, serializable for transmission + } + + pfn = PFunction(**pfn_kwargs) + return pfn, in_vars, out_tree + def jax2stablehlo( is_variable: Callable[[Any], bool], flat_fn: Any, *args: Any, **kwargs: Any @@ -171,7 +229,12 @@ def trace( def is_variable(arg: Any) -> bool: return isinstance(arg, MPObject) - pfunc, in_vars, out_tree = jax2stablehlo(is_variable, jax_fn, *args, **kwargs) + if USE_JAX_EXPORT: + pfunc, in_vars, out_tree = jax_export(is_variable, jax_fn, *args, **kwargs) + else: + pfunc, in_vars, out_tree = jax2stablehlo( + is_variable, jax_fn, *args, **kwargs + ) return pfunc, in_vars, out_tree diff --git a/pyproject.toml b/pyproject.toml index 2fd0460c..0e7c4811 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,20 +35,15 @@ dependencies = [ "fastapi", "uvicorn[standard]", "sqlglot>=23.0.0", + "absl-py>=2.3.1", + "flatbuffers>=25.2.10", ] [project.scripts] mplang-cli = "mplang.runtime.cli:main" [dependency-groups] -dev = [ - "pytest", - "pytest-cov", - "pytest-asyncio", - "ruff", - "mypy", - "pre-commit", -] +dev = ["pytest", "pytest-cov", "pytest-asyncio", "ruff", "mypy", "pre-commit"] examples = [ "scikit-learn", diff --git a/tests/kernels/test_jax_xla.py b/tests/kernels/test_jax_xla.py new file mode 100644 index 00000000..8788ca8f --- /dev/null +++ b/tests/kernels/test_jax_xla.py @@ -0,0 +1,572 @@ +# Copyright 2025 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import base64 + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from mplang.core import PFunction +from mplang.kernels.context import RuntimeContext +from mplang.kernels.jax_xla import _jax_exec +from mplang.kernels.value import TensorValue +from mplang.ops.jax_cc import jax_export + +# Enable 64-bit precision for testing +jax.config.update("jax_enable_x64", True) + + +class TestJaxXla: + """Test suite for JAX XLA kernel functionality.""" + + def setup_method(self): + """Initialize backend context for each test.""" + self.runtime = RuntimeContext(rank=0, world_size=1) + + def _create_test_pfunction(self, fn, *args, **kwargs) -> PFunction: + """Helper to create a PFunction from a JAX function.""" + is_variable = lambda obj: hasattr(obj, "dtype") and hasattr(obj, "shape") + pfunc, _, _ = jax_export(is_variable, fn, *args, **kwargs) + return pfunc + + def _tensor_value(self, arr: np.ndarray | jnp.ndarray) -> TensorValue: + """Helper to create TensorValue from array.""" + return TensorValue(np.array(arr)) + + def test_simple_function_execution(self): + """Test execution of a simple JAX function.""" + + def add_fn(x, y): + return x + y + + # Create test data + x = jnp.array([1.0, 2.0, 3.0]) + y = jnp.array([4.0, 5.0, 6.0]) + + # Create PFunction + pfunc = self._create_test_pfunction(add_fn, x, y) + + # Execute through kernel + result = _jax_exec(pfunc, self._tensor_value(x), self._tensor_value(y)) + + # Verify result + expected = np.array([5.0, 7.0, 9.0]) + assert isinstance(result, TensorValue) + np.testing.assert_array_equal(result.to_numpy(), expected) + + def test_scalar_function_execution(self): + """Test execution with scalar inputs and outputs.""" + + def multiply_fn(x, y): + return x * y + + # Create test data + x = jnp.array(3.0) + y = jnp.array(4.0) + + # Create PFunction + pfunc = self._create_test_pfunction(multiply_fn, x, y) + + # Execute through kernel + result = _jax_exec(pfunc, self._tensor_value(x), self._tensor_value(y)) + + # Verify result + expected = np.array(12.0) + assert isinstance(result, TensorValue) + np.testing.assert_array_equal(result.to_numpy(), expected) + + def test_matrix_operations(self): + """Test execution with matrix operations.""" + + def matrix_mul_fn(x, y): + return jnp.dot(x, y) + + # Create test matrices + x = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + y = jnp.array([[5.0, 6.0], [7.0, 8.0]]) + + # Create PFunction + pfunc = self._create_test_pfunction(matrix_mul_fn, x, y) + + # Execute through kernel + result = _jax_exec(pfunc, self._tensor_value(x), self._tensor_value(y)) + + # Verify result + expected = np.array([[19.0, 22.0], [43.0, 50.0]]) + assert isinstance(result, TensorValue) + np.testing.assert_array_equal(result.to_numpy(), expected) + + def test_multiple_outputs(self): + """Test execution with multiple outputs.""" + + def multi_output_fn(x, y): + return x + y, x - y, x * y + + # Create test data + x = jnp.array([1.0, 2.0]) + y = jnp.array([3.0, 4.0]) + + # Create PFunction + pfunc = self._create_test_pfunction(multi_output_fn, x, y) + + # Execute through kernel + result = _jax_exec(pfunc, self._tensor_value(x), self._tensor_value(y)) + + # Verify results + assert isinstance(result, tuple) + assert len(result) == 3 + + # Check each output + expected_sum = np.array([4.0, 6.0]) + expected_diff = np.array([-2.0, -2.0]) + expected_prod = np.array([3.0, 8.0]) + + np.testing.assert_array_equal(result[0].to_numpy(), expected_sum) + np.testing.assert_array_equal(result[1].to_numpy(), expected_diff) + np.testing.assert_array_equal(result[2].to_numpy(), expected_prod) + + def test_complex_operations(self): + """Test execution with complex JAX operations.""" + + def complex_fn(x): + return jnp.sin(x) * jnp.cos(x) + jnp.sum(x**2) + + # Create test data + x = jnp.array([0.5, 1.0, 1.5]) + + # Create PFunction + pfunc = self._create_test_pfunction(complex_fn, x) + + # Execute through kernel + result = _jax_exec(pfunc, self._tensor_value(x)) + + # Verify result + expected = np.sin(x) * np.cos(x) + np.sum(x**2) + assert isinstance(result, TensorValue) + np.testing.assert_allclose(result.to_numpy(), expected) + + def test_different_dtypes(self): + """Test execution with different data types.""" + + def mixed_dtype_fn(x, y): + return x.astype(jnp.float32) + y.astype(jnp.int32) + + # Create test data with different dtypes + x = jnp.array([1.5, 2.5], dtype=jnp.float64) + y = jnp.array([10, 20], dtype=jnp.int64) + + # Create PFunction + pfunc = self._create_test_pfunction(mixed_dtype_fn, x, y) + + # Execute through kernel + result = _jax_exec(pfunc, self._tensor_value(x), self._tensor_value(y)) + + # Verify result + expected = np.array([11.5, 22.5], dtype=np.float32) + assert isinstance(result, TensorValue) + np.testing.assert_allclose(result.to_numpy(), expected) + + def test_high_dimensional_arrays(self): + """Test execution with high-dimensional arrays.""" + + def reduce_fn(x): + return jnp.sum(x, axis=(1, 2)) + + # Create 4D tensor + x = jnp.ones((2, 3, 4, 5)) + + # Create PFunction + pfunc = self._create_test_pfunction(reduce_fn, x) + + # Execute through kernel + result = _jax_exec(pfunc, self._tensor_value(x)) + + # Verify result shape and values + assert isinstance(result, TensorValue) + assert result.to_numpy().shape == (2, 5) + expected = np.full((2, 5), 12.0) # 3*4 = 12 for each sum + np.testing.assert_array_equal(result.to_numpy(), expected) + + def test_error_handling_wrong_fn_type(self): + """Test error handling for wrong function type.""" + # Create PFunction with wrong type + pfunc = PFunction( + fn_type="wrong.type", + ins_info=(), + outs_info=(), + fn_name="test", + fn_text="dummy", + ) + + with pytest.raises(ValueError, match="jax exec kernel received wrong fn_type"): + _jax_exec(pfunc) + + def test_error_handling_missing_fn_text(self): + """Test error handling for missing function text.""" + pfunc = PFunction( + fn_type="jax.exec", ins_info=(), outs_info=(), fn_name="test", fn_text=None + ) + + with pytest.raises(ValueError, match="jax exec kernel missing fn_text"): + _jax_exec(pfunc) + + def test_error_handling_invalid_base64(self): + """Test error handling for invalid base64 data.""" + pfunc = PFunction( + fn_type="jax.exec", + ins_info=(), + outs_info=(), + fn_name="test", + fn_text="invalid_base64!", + ) + + with pytest.raises(ValueError, match="Failed to decode base64 export data"): + _jax_exec(pfunc) + + def test_error_handling_invalid_serialized_data(self): + """Test error handling for invalid serialized JAX data.""" + # Create valid base64 but invalid JAX data + invalid_data = base64.b64encode(b"invalid_jax_data").decode() + pfunc = PFunction( + fn_type="jax.exec", + ins_info=(), + outs_info=(), + fn_name="test", + fn_text=invalid_data, + ) + + with pytest.raises(ValueError, match="Failed to deserialize JAX export"): + _jax_exec(pfunc) + + def test_error_handling_invalid_argument_type(self): + """Test error handling for invalid argument types.""" + + def simple_fn(x): + return x + 1 + + x = jnp.array(1.0) + pfunc = self._create_test_pfunction(simple_fn, x) + + # Pass invalid argument + with pytest.raises(ValueError, match="Cannot convert argument 0 of type"): + _jax_exec(pfunc, "invalid_argument") + + def test_argument_conversion_numpy_array(self): + """Test that numpy arrays are properly converted.""" + + def add_fn(x, y): + return x + y + + # Create test data as numpy arrays + x = np.array([1.0, 2.0]) + y = np.array([3.0, 4.0]) + + # Create PFunction + pfunc = self._create_test_pfunction(add_fn, jnp.array(x), jnp.array(y)) + + # Execute with numpy arrays + result = _jax_exec(pfunc, self._tensor_value(x), self._tensor_value(y)) + + # Verify result + expected = np.array([4.0, 6.0]) + assert isinstance(result, TensorValue) + np.testing.assert_array_equal(result.to_numpy(), expected) + + def test_argument_conversion_jax_array(self): + """Test that JAX arrays are properly converted.""" + + def add_fn(x, y): + return x + y + + # Create test data as JAX arrays + x = jnp.array([1.0, 2.0]) + y = jnp.array([3.0, 4.0]) + + # Create PFunction + pfunc = self._create_test_pfunction(add_fn, x, y) + + # Execute with JAX arrays converted to TensorValue + result = _jax_exec(pfunc, self._tensor_value(x), self._tensor_value(y)) + + # Verify result + expected = np.array([4.0, 6.0]) + assert isinstance(result, TensorValue) + np.testing.assert_array_equal(result.to_numpy(), expected) + + def test_pfunction_properties(self): + """Test that PFunction has correct properties after creation.""" + + def add_fn(x, y): + return x + y + + x = jnp.array([1.0, 2.0]) + y = jnp.array([3.0, 4.0]) + + # Create PFunction + pfunc = self._create_test_pfunction(add_fn, x, y) + + # Verify properties + assert pfunc.fn_type == "jax.exec" + assert pfunc.fn_name == "add_fn" + assert len(pfunc.ins_info) == 2 + assert len(pfunc.outs_info) == 1 + assert pfunc.fn_text is not None + assert len(pfunc.fn_text) > 0 + + # Verify input tensor info + assert pfunc.ins_info[0].shape == x.shape + assert pfunc.ins_info[1].shape == y.shape + + @pytest.mark.parametrize("shape", [(1,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]) + def test_various_shapes(self, shape): + """Test execution with various tensor shapes.""" + + def identity_fn(x): + return x + + # Create test data with specified shape + x = jnp.ones(shape) + + # Create PFunction + pfunc = self._create_test_pfunction(identity_fn, x) + + # Execute through kernel + result = _jax_exec(pfunc, self._tensor_value(x)) + + # Verify result + assert isinstance(result, TensorValue) + assert result.to_numpy().shape == shape + np.testing.assert_array_equal(result.to_numpy(), np.ones(shape)) + + def test_empty_arrays(self): + """Test execution with empty arrays.""" + + def empty_fn(x): + return jnp.sum(x) + + # Create empty array + x = jnp.array([]) + + # Create PFunction + pfunc = self._create_test_pfunction(empty_fn, x) + + # Execute through kernel + result = _jax_exec(pfunc, self._tensor_value(x)) + + # Verify result (sum of empty array should be 0.0) + assert isinstance(result, TensorValue) + np.testing.assert_array_equal(result.to_numpy(), np.array(0.0)) + + def test_large_arrays(self): + """Test execution with large arrays.""" + + def large_sum_fn(x): + return jnp.sum(x) + + # Create large array + x = jnp.ones(10000) + + # Create PFunction + pfunc = self._create_test_pfunction(large_sum_fn, x) + + # Execute through kernel + result = _jax_exec(pfunc, self._tensor_value(x)) + + # Verify result + assert isinstance(result, TensorValue) + np.testing.assert_array_equal(result.to_numpy(), np.array(10000.0)) + + +class TestJaxExportFullPipeline: + """Test suite for complete frontend-to-backend pipeline.""" + + def test_full_pipeline_simple_function(self): + """Test complete pipeline with a simple function.""" + + def add_fn(x, y): + return x + y + + # Frontend: Generate PFunction using jax_export + x = jnp.array([1.0, 2.0, 3.0]) + y = jnp.array([4.0, 5.0, 6.0]) + + is_variable = lambda obj: hasattr(obj, "dtype") and hasattr(obj, "shape") + pfunc, _in_vars, _out_tree = jax_export(is_variable, add_fn, x, y) + + # Backend: Execute using _jax_exec + result = _jax_exec(pfunc, TensorValue(np.array(x)), TensorValue(np.array(y))) + + # Verify result + expected = np.array([5.0, 7.0, 9.0]) + assert isinstance(result, TensorValue) + np.testing.assert_array_equal(result.to_numpy(), expected) + + def test_full_pipeline_complex_function(self): + """Test complete pipeline with a complex function.""" + + def complex_fn(x, y): + return jnp.sin(x) * jnp.cos(x) + jnp.sum(x**2 + y**2) + + # Frontend: Generate PFunction + x = jnp.array([0.5, 1.0, 1.5]) + y = jnp.array([0.3, 0.7, 1.1]) + + is_variable = lambda obj: hasattr(obj, "dtype") and hasattr(obj, "shape") + pfunc, _in_vars, _out_tree = jax_export(is_variable, complex_fn, x, y) + + # Backend: Execute + result = _jax_exec(pfunc, TensorValue(np.array(x)), TensorValue(np.array(y))) + + # Verify result + expected = np.sin(x) * np.cos(x) + np.sum(x**2 + y**2) + assert isinstance(result, TensorValue) + np.testing.assert_allclose(result.to_numpy(), expected) + + def test_full_pipeline_multiple_outputs(self): + """Test complete pipeline with multiple outputs.""" + + def multi_output_fn(x, y): + return x + y, x - y, x * y + + # Frontend: Generate PFunction + x = jnp.array([1.0, 2.0]) + y = jnp.array([3.0, 4.0]) + + is_variable = lambda obj: hasattr(obj, "dtype") and hasattr(obj, "shape") + pfunc, _in_vars, _out_tree = jax_export(is_variable, multi_output_fn, x, y) + + # Backend: Execute + result = _jax_exec(pfunc, TensorValue(np.array(x)), TensorValue(np.array(y))) + + # Verify multiple outputs + assert isinstance(result, tuple) + assert len(result) == 3 + + expected_sum = np.array([4.0, 6.0]) + expected_diff = np.array([-2.0, -2.0]) + expected_prod = np.array([3.0, 8.0]) + + np.testing.assert_array_equal(result[0].to_numpy(), expected_sum) + np.testing.assert_array_equal(result[1].to_numpy(), expected_diff) + np.testing.assert_array_equal(result[2].to_numpy(), expected_prod) + + def test_full_pipeline_with_constants(self): + """Test complete pipeline with constants captured during compilation.""" + + def scale_and_shift(x, scale=2.0, shift=1.0): + return x * scale + shift + + # Frontend: Generate PFunction with constants + x = jnp.array([1.0, 2.0, 3.0]) + + is_variable = lambda obj: hasattr(obj, "dtype") and hasattr(obj, "shape") + pfunc, _in_vars, _out_tree = jax_export( + is_variable, scale_and_shift, x, scale=5.0, shift=10.0 + ) + + # Backend: Execute (only need to provide variable arguments) + result = _jax_exec(pfunc, TensorValue(np.array(x))) + + # Verify result with constants applied + expected = x * 5.0 + 10.0 + assert isinstance(result, TensorValue) + np.testing.assert_array_equal(result.to_numpy(), expected) + + def test_full_pipeline_high_dimensional(self): + """Test complete pipeline with high-dimensional arrays.""" + + def tensor_operation(x): + # Reduce along specific axes and compute statistics + mean_val = jnp.mean(x, axis=(1, 2)) + std_val = jnp.std(x, axis=(1, 2)) + return mean_val, std_val + + # Frontend: Generate PFunction + x = jnp.ones((2, 3, 4, 5)) + + is_variable = lambda obj: hasattr(obj, "dtype") and hasattr(obj, "shape") + pfunc, _in_vars, _out_tree = jax_export(is_variable, tensor_operation, x) + + # Backend: Execute + result = _jax_exec(pfunc, TensorValue(np.array(x))) + + # Verify results + assert isinstance(result, tuple) + assert len(result) == 2 + + # For ones array, mean should be 1.0, std should be 0.0 + expected_mean = np.ones((2, 5)) + expected_std = np.zeros((2, 5)) + + np.testing.assert_allclose(result[0].to_numpy(), expected_mean, atol=1e-6) + np.testing.assert_allclose(result[1].to_numpy(), expected_std, atol=1e-6) + + def test_full_pipeline_serialization_roundtrip(self): + """Test that PFunction properties survive basic attribute access.""" + + def simple_fn(x): + return jnp.sum(x) + + # Frontend: Generate PFunction + x = jnp.array([1.0, 2.0, 3.0]) + + is_variable = lambda obj: hasattr(obj, "dtype") and hasattr(obj, "shape") + pfunc, _, _ = jax_export(is_variable, simple_fn, x) + + # Test that key properties are accessible (partial serialization test) + fn_type = pfunc.fn_type + fn_name = pfunc.fn_name + fn_text = pfunc.fn_text + ins_info = pfunc.ins_info + outs_info = pfunc.outs_info + + # Verify properties are preserved + assert fn_type == "jax.exec" + assert fn_name == "simple_fn" + assert fn_text is not None + assert len(ins_info) == 1 + assert len(outs_info) == 1 + + # Backend: Execute with original PFunction + result = _jax_exec(pfunc, TensorValue(np.array(x))) + + # Verify result + expected = np.array(6.0) + assert isinstance(result, TensorValue) + np.testing.assert_array_equal(result.to_numpy(), expected) + + def test_full_pipeline_robustness(self): + """Test that the pipeline handles edge cases gracefully.""" + + def edge_case_fn(x): + # Test with operations that might be problematic + return jnp.sqrt(jnp.abs(x)) # sqrt of absolute value (always safe) + + # Frontend: Generate PFunction + x = jnp.array([-1.0, 0.0, 1.0]) + + is_variable = lambda obj: hasattr(obj, "dtype") and hasattr(obj, "shape") + pfunc, _, _ = jax_export(is_variable, edge_case_fn, x) + + # Backend: Execute should work fine + result = _jax_exec(pfunc, TensorValue(np.array(x))) + + # Verify result - sqrt of absolute values + expected = np.sqrt(np.abs([-1.0, 0.0, 1.0])) + assert isinstance(result, TensorValue) + np.testing.assert_allclose(result.to_numpy(), expected) diff --git a/tests/ops/test_jax_cc.py b/tests/ops/test_jax_cc.py index ac3771a0..cd2fd5b3 100644 --- a/tests/ops/test_jax_cc.py +++ b/tests/ops/test_jax_cc.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import base64 import jax import jax.numpy as jnp @@ -278,3 +279,122 @@ def func_with_unused(x, unused, z): assert isinstance(keep_map, list) assert len(keep_map) < 3 # Should be fewer than original 3 params assert 1 not in keep_map # Index 1 (unused) should not be kept + + +class TestJaxExportFrontend: + """Test suite for JAX export frontend generation.""" + + def test_jax_export_basic_function(self): + """Test jax_export with a basic function.""" + + def add_fn(x, y): + return x + y + + # Test data + x = jnp.array([1.0, 2.0, 3.0]) + y = jnp.array([4.0, 5.0, 6.0]) + + # Use jax_export directly + is_variable = lambda obj: hasattr(obj, "dtype") and hasattr(obj, "shape") + pfunc, in_vars, _out_tree = jax_cc.jax_export(is_variable, add_fn, x, y) + + # Verify PFunction properties + assert pfunc.fn_type == "jax.exec" + assert pfunc.fn_name == "add_fn" + assert len(pfunc.ins_info) == 2 + assert len(pfunc.outs_info) == 1 + assert pfunc.fn_text is not None + assert len(pfunc.fn_text) > 0 + + # Verify input info + assert pfunc.ins_info[0].shape == x.shape + assert pfunc.ins_info[1].shape == y.shape + + # Verify that fn_text is valid base64 + try: + base64.b64decode(pfunc.fn_text) + except Exception: + pytest.fail("fn_text should be valid base64") + + # Verify in_vars + assert len(in_vars) == 2 + assert in_vars[0].shape == x.shape + assert in_vars[1].shape == y.shape + + def test_jax_export_with_constants(self): + """Test jax_export with mixed variables and constants.""" + + def multiply_with_constant(x, factor=2.0): + return x * factor + + # Test data + x = jnp.array([1.0, 2.0, 3.0]) + + # Use jax_export with constant + is_variable = lambda obj: hasattr(obj, "dtype") and hasattr(obj, "shape") + pfunc, in_vars, _out_tree = jax_cc.jax_export( + is_variable, multiply_with_constant, x, factor=5.0 + ) + + # Only x should be a variable, factor should be captured as constant + assert len(in_vars) == 1 + assert len(pfunc.ins_info) == 1 + assert pfunc.ins_info[0].shape == x.shape + + def test_jax_export_complex_function(self): + """Test jax_export with a complex function.""" + + def complex_fn(x, y): + intermediate = jnp.sin(x) * jnp.cos(y) + result = jnp.sum(intermediate) + jnp.mean(x**2 + y**2) + return result, intermediate + + # Test data + x = jnp.array([0.5, 1.0, 1.5]) + y = jnp.array([0.3, 0.7, 1.1]) + + # Use jax_export + is_variable = lambda obj: hasattr(obj, "dtype") and hasattr(obj, "shape") + pfunc, _in_vars, _out_tree = jax_cc.jax_export(is_variable, complex_fn, x, y) + + # Verify multiple outputs + assert len(pfunc.outs_info) == 2 + assert pfunc.outs_info[0].shape == () # scalar result + assert pfunc.outs_info[1].shape == x.shape # intermediate array + + def test_jax_export_different_dtypes(self): + """Test jax_export with different data types.""" + + def mixed_dtype_fn(x, y): + return x.astype(jnp.float32) + y.astype(jnp.int32) + + # Test data with different dtypes + x = jnp.array([1.5, 2.5], dtype=jnp.float64) + y = jnp.array([10, 20], dtype=jnp.int64) + + # Use jax_export + is_variable = lambda obj: hasattr(obj, "dtype") and hasattr(obj, "shape") + pfunc, _in_vars, _out_tree = jax_cc.jax_export( + is_variable, mixed_dtype_fn, x, y + ) + + # Verify dtype preservation in input info + assert pfunc.ins_info[0].dtype.name == "float64" + assert pfunc.ins_info[1].dtype.name == "int64" + + def test_jax_export_high_dimensional(self): + """Test jax_export with high-dimensional arrays.""" + + def reduce_fn(x): + return jnp.sum(x, axis=(1, 2)) + + # Create 4D tensor + x = jnp.ones((2, 3, 4, 5)) + + # Use jax_export + is_variable = lambda obj: hasattr(obj, "dtype") and hasattr(obj, "shape") + pfunc, _in_vars, _out_tree = jax_cc.jax_export(is_variable, reduce_fn, x) + + # Verify shapes + assert pfunc.ins_info[0].shape == (2, 3, 4, 5) + assert pfunc.outs_info[0].shape == (2, 5) diff --git a/uv.lock b/uv.lock index e08ffd6b..17c3e3b7 100644 --- a/uv.lock +++ b/uv.lock @@ -868,8 +868,10 @@ wheels = [ name = "mplang" source = { editable = "." } dependencies = [ + { name = "absl-py" }, { name = "duckdb" }, { name = "fastapi" }, + { name = "flatbuffers" }, { name = "httpx" }, { name = "lightphe" }, { name = "pandas" }, @@ -899,8 +901,10 @@ examples = [ [package.metadata] requires-dist = [ + { name = "absl-py", specifier = ">=2.3.1" }, { name = "duckdb", specifier = ">=1.0.0" }, { name = "fastapi" }, + { name = "flatbuffers", specifier = ">=25.2.10" }, { name = "httpx", specifier = ">=0.27.0" }, { name = "lightphe", specifier = ">=0.0.15,<0.1.0" }, { name = "pandas", specifier = ">=2.0.0" }, From 33bd43c78ab53bc75bd12ef596f0e8848355253a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=89=BE=E5=84=92?= Date: Mon, 10 Nov 2025 11:32:07 +0800 Subject: [PATCH 2/3] lint --- tests/ops/test_jax_cc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ops/test_jax_cc.py b/tests/ops/test_jax_cc.py index cd2fd5b3..7ee6d062 100644 --- a/tests/ops/test_jax_cc.py +++ b/tests/ops/test_jax_cc.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations + import base64 import jax From e823442141ed1aa40774c78a912d22b1459b6c0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=89=BE=E5=84=92?= Date: Mon, 10 Nov 2025 11:35:16 +0800 Subject: [PATCH 3/3] fix cr --- mplang/kernels/jax_xla.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/mplang/kernels/jax_xla.py b/mplang/kernels/jax_xla.py index fa0eb962..301f1719 100644 --- a/mplang/kernels/jax_xla.py +++ b/mplang/kernels/jax_xla.py @@ -46,7 +46,7 @@ def _jax_exec(pfunc: PFunction, *args: Any) -> Any: try: export_bytes = base64.b64decode(export_text) - except Exception as e: + except ValueError as e: raise ValueError(f"Failed to decode base64 export data: {e}") from e try: @@ -57,21 +57,16 @@ def _jax_exec(pfunc: PFunction, *args: Any) -> Any: # Convert TensorValue arguments to JAX arrays jax_args = [] for i, arg in enumerate(args): + value_to_convert = arg if isinstance(arg, TensorValue): - # Convert TensorValue to JAX array - jax_array = jnp.array(arg.to_numpy()) - jax_args.append(jax_array) - elif isinstance(arg, (jnp.ndarray, np.ndarray)): - # Already a JAX/NumPy array - jax_args.append(jnp.array(arg)) - else: - # Try to convert to JAX array - try: - jax_args.append(jnp.array(arg)) - except Exception as e: - raise ValueError( - f"Cannot convert argument {i} of type {type(arg)} to JAX array: {e}" - ) from e + value_to_convert = arg.to_numpy() + + try: + jax_args.append(jnp.array(value_to_convert)) + except Exception as e: + raise ValueError( + f"Cannot convert argument {i} of type {type(arg)} to JAX array: {e}" + ) from e # Execute the exported function # The normalized function expects a single list argument containing all variables