Skip to content

Basic io_callback implementation#1752

Closed
Chapaman wants to merge 11 commits into
elixir-nx:mainfrom
Chapaman:io_callback-implementation
Closed

Basic io_callback implementation#1752
Chapaman wants to merge 11 commits into
elixir-nx:mainfrom
Chapaman:io_callback-implementation

Conversation

@Chapaman

@Chapaman Chapaman commented May 21, 2026

Copy link
Copy Markdown
Contributor

closes #1672

First iteration of Nx.io_callback/2: side-effect host callbacks with passthrough inputs/outputs, lowered to stablehlo.custom_call (exla_io_callback). Hooks and print_value now use io_callback instead of the outfeed/token hook path.

  • Nx: new :io_callback expr op ({:fn, fun} / {:hook, name, callback}), public Nx.io_callback/2, evaluator, tree/grad
  • Hooks: hook/3, print_value, and attach_token lowered via io_callback; hook_token + Nx.Defn.Token kept only for the shared-token / attach_token pattern
  • EXLA: Value.io_callback/3, compiler lowering (host + CUDA), Outfeed {:exla_io_callback, ...} handler, C++ FFI (io_callback.cc / io_callback_cuda.cc)
  • Hook outfeed removed: EXLA no longer lowers :token / :attach_token for hooks (raises if legacy nodes appear); stablehlo.create_token only when infeeds need it
  • Zero-copy: output_operand_aliases per leaf (PID + tensors: output[0]operand[0], output[i+1]operand[i+1]; operand 0 = callback server PID)
  • Graph retention: reassign the result (x = Nx.io_callback(x, ...)); Nx DCE runs before XLA sees has_side_effect
  • Ordered execution: PID tensor threading between io_callbacks — each custom call returns the PID (operand 0 → result 0) and the compiler chains it into the next call, so XLA cannot reorder independent side effects; if/cond branches reset to a dominating PID (no StableHLO tokens, no evaluator changes)

Not in this PR (yet)

  • Removing Nx.Defn.Token, create_token, hook_token, and the attach_token API

@polvalente polvalente left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generally looks good. Super exciting that we're progressing on this

// Result buffers are aliased to input buffers via output_operand_aliases —
// we must NOT touch them here. Pass no output buffers to the bridge.
exla::callback_bridge::Result result =
exla::callback_bridge::InvokeIoCallback(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
exla::callback_bridge::InvokeIoCallback(
exla::callback_bridge::InvokeIOCallback(

nitpick on naming
I generally keep all letters in an acronym capital (HTTP vs Http, GRPC vs Grpc, etc)

Comment on lines +486 to +516
defp send_io_callback_reply(hooks, io_callbacks, callback_id, args_spec, reply_tag) do
reply =
try do
case Map.fetch(io_callbacks, callback_id) do
{:ok, {callback_spec, arg_template}} ->
case decode_callback_args(args_spec, arg_template) do
{:ok, tensor_args} ->
case resolve_io_callback_hook(callback_spec, hooks) do
nil ->
{:ok, []}

fun ->
try do
fun.(tensor_args)
{:ok, []}
rescue
exception ->
{:error, {:exception, Exception.message(exception)}}
catch
kind, reason ->
{:error, {kind, format_runtime_callback_reason(reason)}}
end
end

{:error, _} = error ->
error
end

:error ->
Logger.error(
"EXLA.Outfeed received io_callback id #{inspect(callback_id)} that is not registered"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we try to unnest this into a with?

Comment thread exla/lib/exla/mlir/value.ex Outdated
Comment on lines +848 to +857
leaf_count = length(typespecs)

{callback_id_words, callback_id_size} = term_to_int64_list(callback_id)

# Each leaf at operand[i+1] (skipping PID at 0) is aliased to result[i].
# The position in the list determines which result index is aliased.
aliases =
Enum.map(0..(leaf_count - 1)//1, fn i ->
attr_output_operand_alias(i + 1)
end)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
leaf_count = length(typespecs)
{callback_id_words, callback_id_size} = term_to_int64_list(callback_id)
# Each leaf at operand[i+1] (skipping PID at 0) is aliased to result[i].
# The position in the list determines which result index is aliased.
aliases =
Enum.map(0..(leaf_count - 1)//1, fn i ->
attr_output_operand_alias(i + 1)
end)
{callback_id_words, callback_id_size} = term_to_int64_list(callback_id)
# Each leaf at operand[i+1] (skipping PID at 0) is aliased to result[i].
# The position in the list determines which result index is aliased.
aliases =
Enum.with_index(typespecs, fn _typespec, index ->
attr_output_operand_alias(i + 1)
end)

Comment thread exla/lib/exla/mlir/value.ex Outdated
# `operand_index` is the 0-based index into the custom_call's operand list.
# Note: operand 0 is the PID, so the first leaf is at operand_index 1.
# The output index is implicit: aliases[i] refers to result i.
# Both tuple-indices lists are empty because we operate on flat tensors.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Both tuple-indices lists are empty because we operate on flat tensors.
# Both tuple-indices lists are empty because we operate on flat tensor lists.

Comment thread exla/lib/exla/defn.ex Outdated
Comment on lines +859 to +864
tensor_exprs = Composite.flatten_list([tensor_expr])

{arg_values, cache} =
Enum.map_reduce(tensor_exprs, cache, fn arg, cache ->
recur_operator(arg, state, cache) |> unwrap_single_tensor!()
end)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be done through a single Composite.reduce that accumulates both reverse_arg_values and cache

Comment thread nx/lib/nx/defn/expr.ex

defp io_callback_depend_pair(acc, hooked) do
zero = Nx.tensor(0, type: acc.type)
Nx.add(acc, Nx.multiply(zero, Nx.sum(hooked)))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is the best solution. Maybe we can add a metadata node or introduce a new expression node that can carry the nodes + attached callback?

@josevalim do you have suggestions?

Comment thread nx/lib/nx.ex
backend = Nx.Shared.list_impl!(tensors)

if backend == Nx.Defn.Expr do
backend.io_callback(tensor_or_container, callback)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
backend.io_callback(tensor_or_container, callback)
Nx.Defn.Expr.io_callback(tensor_or_container, callback)

Comment on lines -431 to 436
assert inspect(result, safe: false) == """
#Nx.Tensor<
f32
\s\s
Nx.Defn.Expr
parameter a:0 f32
parameter b:1 f32
c = multiply a, b f32
d = add a, b f32
e = token mult: c, add: d tuple2
f = attach_token e, d f32
g = attach_token e, c f32
h = subtract f, g f32
>\
"""
assert rendered =~ "io_callback"
assert rendered =~ "subtract"
refute rendered =~ "token"
refute rendered =~ "attach_token"
end

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to keep the same type of verbose assertions

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

Comment thread nx/lib/nx.ex
Comment on lines +2240 to +2244
an explicit dependency edge. On EXLA, independent callbacks are additionally
serialized by threading the callback-server PID through each `io_callback`
custom call (operand 0 → result 0), so XLA cannot reorder side effects even
when their tensor inputs are unrelated. No StableHLO token machinery is
required at the Nx expression level.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not discuss StableHLO on user facing APIs, as it is an implementation detail.

I am also not sure the EXLA note is relevant here. We say they are serialised but that means they do not run concurrently, it is not ultimately about order. I think we should rather say: "Ordering between IO callbacks are not guarantee, unless they are chained. For example, if you have this expression, io_callback(a) + io_callback(b), there is no guarantee about order. But this one has: io_callback(a + io_callback(b)).

Comment thread nx/lib/nx.ex
Comment on lines +2252 to +2256
> #### Backend transfers {: .warning}
>
> When executing inside `Nx.Defn.Evaluator`, do not transfer tensors with
> `Nx.backend_transfer/2` inside the callback because the values may still be
> used in the rest of the computation. Use `Nx.backend_copy/2` instead.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So maybe we should emit it as a binary reference pointing to device? This way you are forced to copy it (and it solves both the lock and backend transfer issue above).

Comment thread nx/lib/nx/defn/kernel.ex
token = create_token()
{token, _} = hook_token(token, fun.(expr), &IO.inspect(&1, opts))
attach_token(token, expr)
Nx.io_callback(expr, fn t -> IO.inspect(fun.(t), opts) end)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if we should call it io_call for consisatency with runtime_call?

Comment thread nx/lib/nx/defn/expr.ex
Comment on lines +1537 to +1543
defp normalize_io_callback_spec(fun) when is_function(fun, 1), do: {:fn, fun}

defp normalize_io_callback_spec(name) when is_atom(name), do: {:hook, name, nil}

defp normalize_io_callback_spec({:hook, name, callback})
when is_atom(name) and (is_function(callback, 1) or is_nil(callback)),
do: {:hook, name, callback}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we normalizing it here? Shouldn't we normalize in the callers of Expr.io_callback, since the callers already have the exact shape?

Comment thread nx/lib/nx/defn/expr.ex
Comment on lines +1473 to +1475
The user **must** reassign the result. Nx-level DCE prunes unreachable nodes
before any compiler sees `has_side_effect`, so discarding the return value
silently drops the callback from the graph.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be talking about implementation details of EXLA?

Comment thread nx/lib/nx/defn/expr.ex
end

defp io_callback_spec_inspect({:hook, name, _}), do: inspect({:hook, name, nil})
defp io_callback_spec_inspect({:fn, _fun}), do: "&fn"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can inspect the whole fun, no?


compatible? =
case mode do
:io_callback -> true

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we need to check for compatibility here?

{:error, {:exception, Exception.message(exception)}}
catch
kind, reason ->
send(self(), :stop)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are sending stop, but will this effectively stop the execution on the callback too? Can we have tests for this scenario?

catch
kind, reason ->
send(self(), :stop)
{:error, {kind, format_runtime_callback_reason(reason)}}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should likely be: Exception.format(kind, reason, __STACKTRACE__). We need to update runtime_call accordingly.

Comment thread exla/lib/exla/defn.ex
defp reset_token(%{__MODULE__ => outfeed}, token),
do: %{__MODULE__ => Outfeed.with_token(outfeed, token)}

defp reset_callback_pid(%{__MODULE__ => outfeed}, callback_pid),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this functionality. This is necessary for tokens because we had to use different tokens on each branch of a conditional. But a callback_pid is a static value across all conditionals. Basically, we need a simple helper, which is "get_or_set_callback_pid" which returns the callback_pid (or creates one if it does not exist). We could even use the process dictionary if we want to, and forget this cache craziness.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Introduce Nx.io_callback as a replacement for token-based hooks and EXLA's outfeed

4 participants