Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mplang/kernels/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _ensure_impl_imported() -> None:
from mplang.kernels import basic as _impl_basic # noqa: F401
from mplang.kernels import crypto as _impl_crypto # noqa: F401
from mplang.kernels import fhe as _impl_fhe # noqa: F401
from mplang.kernels import jax_xla as _impl_jax_xla # noqa: F401
from mplang.kernels import mock_tee as _impl_tee # noqa: F401
from mplang.kernels import phe as _impl_phe # noqa: F401
from mplang.kernels import spu as _impl_spu # noqa: F401
Expand Down Expand Up @@ -99,6 +100,7 @@ def _ensure_impl_imported() -> None:
"spu.run_pphlo": "spu.run_pphlo",
# stablehlo
"mlir.stablehlo": "mlir.stablehlo",
"jax.exec": "jax.exec",
# sql
# generic SQL op; backend-specific kernel id for duckdb
"sql.run": "duckdb.run_sql",
Expand Down
91 changes: 91 additions & 0 deletions mplang/kernels/jax_xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2025 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import base64
from typing import Any

import jax.export as jax_export
import jax.numpy as jnp
import numpy as np

from mplang.core import PFunction
from mplang.kernels.base import kernel_def
from mplang.kernels.value import TensorValue


@kernel_def("jax.exec")
def _jax_exec(pfunc: PFunction, *args: Any) -> Any:
"""Execute a JAX exported function.

Args:
pfunc: PFunction containing serialized JAX export data
*args: Input arguments for the function execution

Returns:
The result of executing the JAX function with the provided arguments
"""
if pfunc.fn_type != "jax.exec":
raise ValueError(f"jax exec kernel received wrong fn_type: {pfunc.fn_type}")

export_text = pfunc.fn_text
if export_text is None:
raise ValueError("jax exec kernel missing fn_text")

try:
export_bytes = base64.b64decode(export_text)
except ValueError as e:
raise ValueError(f"Failed to decode base64 export data: {e}") from e

try:
exported = jax_export.deserialize(bytearray(export_bytes))
except Exception as e:
raise ValueError(f"Failed to deserialize JAX export: {e}") from e

# Convert TensorValue arguments to JAX arrays
jax_args = []
for i, arg in enumerate(args):
value_to_convert = arg
if isinstance(arg, TensorValue):
value_to_convert = arg.to_numpy()

try:
jax_args.append(jnp.array(value_to_convert))
except Exception as e:
raise ValueError(
f"Cannot convert argument {i} of type {type(arg)} to JAX array: {e}"
) from e

# Execute the exported function
# The normalized function expects a single list argument containing all variables
try:
result = exported.call(jax_args)
except Exception as e:
raise RuntimeError(f"Failed to execute JAX exported function: {e}") from e

# Convert result back to TensorValue if it's a JAX array
if isinstance(result, (jnp.ndarray, np.ndarray)):
return TensorValue(np.array(result))
elif isinstance(result, (tuple, list)):
# Handle multiple outputs
converted_result = []
for item in result:
if isinstance(item, (jnp.ndarray, np.ndarray)):
converted_result.append(TensorValue(np.array(item)))
else:
converted_result.append(item)
return type(result)(converted_result)
else:
return result
Comment thread
oeqqwq marked this conversation as resolved.
65 changes: 64 additions & 1 deletion mplang/ops/jax_cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

from __future__ import annotations

import base64
from collections.abc import Callable
from typing import Any

import jax
import jax.numpy as jnp
from jax import export
from jax.tree_util import PyTreeDef, tree_flatten

from mplang.core import MPObject, PFunction, TensorType, get_fn_name
Expand All @@ -28,6 +30,62 @@
# Enable 64-bit precision for JAX to match tensor types
jax.config.update("jax_enable_x64", True)

USE_JAX_EXPORT = False


def jax_export(
is_variable: Callable[[Any], bool], flat_fn: Any, *args: Any, **kwargs: Any
) -> tuple[PFunction, list, PyTreeDef]:
"""Compile JAX function to JAX export format for remote execution.

Args:
is_variable: Predicate function to classify parameters as variables vs. constants.
Returns True for parameters that should be treated as PFunction inputs.
flat_fn: JAX function to be compiled into export format
*args: Positional arguments passed to the function during compilation
**kwargs: Keyword arguments passed to the function during compilation
Returns:
tuple[PFunction, list, PyTreeDef]: Compilation artifacts containing:
- PFunction: Serialized function with embedded export data and type metadata
- list: Extracted variable parameters (those satisfying is_variable predicate).
Non-variable parameters are captured as compile-time constants within
the PFunction body, while variables become runtime input parameters.
- PyTreeDef: Tree structure template for reconstructing nested output values
"""
# Flatten (args, kwargs) and capture immediates using the moved logic from primitive.py
normalized_fn, in_vars = normalize_fn(flat_fn, args, kwargs, is_variable)

# Convert TensorType in_vars to ShapeDtypeStruct for JAX tracing
jax_params = [
jax.ShapeDtypeStruct(arg.shape, jnp.dtype(arg.dtype.name)) for arg in in_vars
]

# Standard JAX serialization pipeline: jit → trace → lower → export
jitted_fn = jax.jit(normalized_fn)
traced = jitted_fn.trace(jax_params)
lowered = traced.lower()

# Get JAX export representation - the portable format
exported = export.export(jitted_fn)(jax_params)

# Get output info and tree structure for result reconstruction after remote execution
out_info_flat, out_tree = tree_flatten(lowered.out_info)
out_info_flat = [TensorType.from_obj(info) for info in out_info_flat]

# This format tells JaxRT how to handle the compiled result
pfn_kwargs: dict[str, Any] = {
"fn_type": "jax.exec",
"ins_info": tuple(TensorType.from_obj(x) for x in in_vars),
"outs_info": tuple(out_info_flat),
"fn_name": get_fn_name(flat_fn),
"fn_text": base64.b64encode(exported.serialize()).decode(
"utf-8"
), # Serialized export data, serializable for transmission
}

pfn = PFunction(**pfn_kwargs)
return pfn, in_vars, out_tree


def jax2stablehlo(
is_variable: Callable[[Any], bool], flat_fn: Any, *args: Any, **kwargs: Any
Expand Down Expand Up @@ -171,7 +229,12 @@ def trace(
def is_variable(arg: Any) -> bool:
return isinstance(arg, MPObject)

pfunc, in_vars, out_tree = jax2stablehlo(is_variable, jax_fn, *args, **kwargs)
if USE_JAX_EXPORT:
pfunc, in_vars, out_tree = jax_export(is_variable, jax_fn, *args, **kwargs)
else:
pfunc, in_vars, out_tree = jax2stablehlo(
is_variable, jax_fn, *args, **kwargs
)
return pfunc, in_vars, out_tree


Expand Down
11 changes: 3 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,15 @@ dependencies = [
"fastapi",
"uvicorn[standard]",
"sqlglot>=23.0.0",
"absl-py>=2.3.1",
"flatbuffers>=25.2.10",
]

[project.scripts]
mplang-cli = "mplang.runtime.cli:main"

[dependency-groups]
dev = [
"pytest",
"pytest-cov",
"pytest-asyncio",
"ruff",
"mypy",
"pre-commit",
]
dev = ["pytest", "pytest-cov", "pytest-asyncio", "ruff", "mypy", "pre-commit"]

examples = [
"scikit-learn",
Expand Down
Loading
Loading