Basic io_callback implementation#1752
Conversation
polvalente
left a comment
There was a problem hiding this comment.
This generally looks good. Super exciting that we're progressing on this
| // Result buffers are aliased to input buffers via output_operand_aliases — | ||
| // we must NOT touch them here. Pass no output buffers to the bridge. | ||
| exla::callback_bridge::Result result = | ||
| exla::callback_bridge::InvokeIoCallback( |
There was a problem hiding this comment.
| exla::callback_bridge::InvokeIoCallback( | |
| exla::callback_bridge::InvokeIOCallback( |
nitpick on naming
I generally keep all letters in an acronym capital (HTTP vs Http, GRPC vs Grpc, etc)
| defp send_io_callback_reply(hooks, io_callbacks, callback_id, args_spec, reply_tag) do | ||
| reply = | ||
| try do | ||
| case Map.fetch(io_callbacks, callback_id) do | ||
| {:ok, {callback_spec, arg_template}} -> | ||
| case decode_callback_args(args_spec, arg_template) do | ||
| {:ok, tensor_args} -> | ||
| case resolve_io_callback_hook(callback_spec, hooks) do | ||
| nil -> | ||
| {:ok, []} | ||
|
|
||
| fun -> | ||
| try do | ||
| fun.(tensor_args) | ||
| {:ok, []} | ||
| rescue | ||
| exception -> | ||
| {:error, {:exception, Exception.message(exception)}} | ||
| catch | ||
| kind, reason -> | ||
| {:error, {kind, format_runtime_callback_reason(reason)}} | ||
| end | ||
| end | ||
|
|
||
| {:error, _} = error -> | ||
| error | ||
| end | ||
|
|
||
| :error -> | ||
| Logger.error( | ||
| "EXLA.Outfeed received io_callback id #{inspect(callback_id)} that is not registered" |
There was a problem hiding this comment.
Can we try to unnest this into a with?
| leaf_count = length(typespecs) | ||
|
|
||
| {callback_id_words, callback_id_size} = term_to_int64_list(callback_id) | ||
|
|
||
| # Each leaf at operand[i+1] (skipping PID at 0) is aliased to result[i]. | ||
| # The position in the list determines which result index is aliased. | ||
| aliases = | ||
| Enum.map(0..(leaf_count - 1)//1, fn i -> | ||
| attr_output_operand_alias(i + 1) | ||
| end) |
There was a problem hiding this comment.
| leaf_count = length(typespecs) | |
| {callback_id_words, callback_id_size} = term_to_int64_list(callback_id) | |
| # Each leaf at operand[i+1] (skipping PID at 0) is aliased to result[i]. | |
| # The position in the list determines which result index is aliased. | |
| aliases = | |
| Enum.map(0..(leaf_count - 1)//1, fn i -> | |
| attr_output_operand_alias(i + 1) | |
| end) | |
| {callback_id_words, callback_id_size} = term_to_int64_list(callback_id) | |
| # Each leaf at operand[i+1] (skipping PID at 0) is aliased to result[i]. | |
| # The position in the list determines which result index is aliased. | |
| aliases = | |
| Enum.with_index(typespecs, fn _typespec, index -> | |
| attr_output_operand_alias(i + 1) | |
| end) |
| # `operand_index` is the 0-based index into the custom_call's operand list. | ||
| # Note: operand 0 is the PID, so the first leaf is at operand_index 1. | ||
| # The output index is implicit: aliases[i] refers to result i. | ||
| # Both tuple-indices lists are empty because we operate on flat tensors. |
There was a problem hiding this comment.
| # Both tuple-indices lists are empty because we operate on flat tensors. | |
| # Both tuple-indices lists are empty because we operate on flat tensor lists. |
| tensor_exprs = Composite.flatten_list([tensor_expr]) | ||
|
|
||
| {arg_values, cache} = | ||
| Enum.map_reduce(tensor_exprs, cache, fn arg, cache -> | ||
| recur_operator(arg, state, cache) |> unwrap_single_tensor!() | ||
| end) |
There was a problem hiding this comment.
I think this can be done through a single Composite.reduce that accumulates both reverse_arg_values and cache
|
|
||
| defp io_callback_depend_pair(acc, hooked) do | ||
| zero = Nx.tensor(0, type: acc.type) | ||
| Nx.add(acc, Nx.multiply(zero, Nx.sum(hooked))) |
There was a problem hiding this comment.
I'm not sure if this is the best solution. Maybe we can add a metadata node or introduce a new expression node that can carry the nodes + attached callback?
@josevalim do you have suggestions?
| backend = Nx.Shared.list_impl!(tensors) | ||
|
|
||
| if backend == Nx.Defn.Expr do | ||
| backend.io_callback(tensor_or_container, callback) |
There was a problem hiding this comment.
| backend.io_callback(tensor_or_container, callback) | |
| Nx.Defn.Expr.io_callback(tensor_or_container, callback) |
| assert inspect(result, safe: false) == """ | ||
| #Nx.Tensor< | ||
| f32 | ||
| \s\s | ||
| Nx.Defn.Expr | ||
| parameter a:0 f32 | ||
| parameter b:1 f32 | ||
| c = multiply a, b f32 | ||
| d = add a, b f32 | ||
| e = token mult: c, add: d tuple2 | ||
| f = attach_token e, d f32 | ||
| g = attach_token e, c f32 | ||
| h = subtract f, g f32 | ||
| >\ | ||
| """ | ||
| assert rendered =~ "io_callback" | ||
| assert rendered =~ "subtract" | ||
| refute rendered =~ "token" | ||
| refute rendered =~ "attach_token" | ||
| end |
There was a problem hiding this comment.
Let's try to keep the same type of verbose assertions
| an explicit dependency edge. On EXLA, independent callbacks are additionally | ||
| serialized by threading the callback-server PID through each `io_callback` | ||
| custom call (operand 0 → result 0), so XLA cannot reorder side effects even | ||
| when their tensor inputs are unrelated. No StableHLO token machinery is | ||
| required at the Nx expression level. |
There was a problem hiding this comment.
We should not discuss StableHLO on user facing APIs, as it is an implementation detail.
I am also not sure the EXLA note is relevant here. We say they are serialised but that means they do not run concurrently, it is not ultimately about order. I think we should rather say: "Ordering between IO callbacks are not guarantee, unless they are chained. For example, if you have this expression, io_callback(a) + io_callback(b), there is no guarantee about order. But this one has: io_callback(a + io_callback(b)).
| > #### Backend transfers {: .warning} | ||
| > | ||
| > When executing inside `Nx.Defn.Evaluator`, do not transfer tensors with | ||
| > `Nx.backend_transfer/2` inside the callback because the values may still be | ||
| > used in the rest of the computation. Use `Nx.backend_copy/2` instead. |
There was a problem hiding this comment.
So maybe we should emit it as a binary reference pointing to device? This way you are forced to copy it (and it solves both the lock and backend transfer issue above).
| token = create_token() | ||
| {token, _} = hook_token(token, fun.(expr), &IO.inspect(&1, opts)) | ||
| attach_token(token, expr) | ||
| Nx.io_callback(expr, fn t -> IO.inspect(fun.(t), opts) end) |
There was a problem hiding this comment.
I am wondering if we should call it io_call for consisatency with runtime_call?
| defp normalize_io_callback_spec(fun) when is_function(fun, 1), do: {:fn, fun} | ||
|
|
||
| defp normalize_io_callback_spec(name) when is_atom(name), do: {:hook, name, nil} | ||
|
|
||
| defp normalize_io_callback_spec({:hook, name, callback}) | ||
| when is_atom(name) and (is_function(callback, 1) or is_nil(callback)), | ||
| do: {:hook, name, callback} |
There was a problem hiding this comment.
Why are we normalizing it here? Shouldn't we normalize in the callers of Expr.io_callback, since the callers already have the exact shape?
| The user **must** reassign the result. Nx-level DCE prunes unreachable nodes | ||
| before any compiler sees `has_side_effect`, so discarding the return value | ||
| silently drops the callback from the graph. |
There was a problem hiding this comment.
This seems to be talking about implementation details of EXLA?
| end | ||
|
|
||
| defp io_callback_spec_inspect({:hook, name, _}), do: inspect({:hook, name, nil}) | ||
| defp io_callback_spec_inspect({:fn, _fun}), do: "&fn" |
There was a problem hiding this comment.
We can inspect the whole fun, no?
|
|
||
| compatible? = | ||
| case mode do | ||
| :io_callback -> true |
There was a problem hiding this comment.
Why don't we need to check for compatibility here?
| {:error, {:exception, Exception.message(exception)}} | ||
| catch | ||
| kind, reason -> | ||
| send(self(), :stop) |
There was a problem hiding this comment.
We are sending stop, but will this effectively stop the execution on the callback too? Can we have tests for this scenario?
| catch | ||
| kind, reason -> | ||
| send(self(), :stop) | ||
| {:error, {kind, format_runtime_callback_reason(reason)}} |
There was a problem hiding this comment.
This should likely be: Exception.format(kind, reason, __STACKTRACE__). We need to update runtime_call accordingly.
| defp reset_token(%{__MODULE__ => outfeed}, token), | ||
| do: %{__MODULE__ => Outfeed.with_token(outfeed, token)} | ||
|
|
||
| defp reset_callback_pid(%{__MODULE__ => outfeed}, callback_pid), |
There was a problem hiding this comment.
I don't think we need this functionality. This is necessary for tokens because we had to use different tokens on each branch of a conditional. But a callback_pid is a static value across all conditionals. Basically, we need a simple helper, which is "get_or_set_callback_pid" which returns the callback_pid (or creates one if it does not exist). We could even use the process dictionary if we want to, and forget this cache craziness.
closes #1672
First iteration of
Nx.io_callback/2: side-effect host callbacks with passthrough inputs/outputs, lowered tostablehlo.custom_call(exla_io_callback). Hooks andprint_valuenow useio_callbackinstead of the outfeed/token hook path.:io_callbackexpr op ({:fn, fun}/{:hook, name, callback}), publicNx.io_callback/2, evaluator, tree/gradhook/3,print_value, andattach_tokenlowered viaio_callback;hook_token+Nx.Defn.Tokenkept only for the shared-token /attach_tokenpatternValue.io_callback/3, compiler lowering (host + CUDA), Outfeed{:exla_io_callback, ...}handler, C++ FFI (io_callback.cc/io_callback_cuda.cc):token/:attach_tokenfor hooks (raises if legacy nodes appear);stablehlo.create_tokenonly when infeeds need itoutput_operand_aliasesper leaf (PID + tensors:output[0]→operand[0],output[i+1]→operand[i+1]; operand0= callback server PID)x = Nx.io_callback(x, ...)); Nx DCE runs before XLA seeshas_side_effectio_callbacks — each custom call returns the PID (operand 0 → result 0) and the compiler chains it into the next call, so XLA cannot reorder independent side effects;if/condbranches reset to a dominating PID (no StableHLO tokens, no evaluator changes)Not in this PR (yet)
Nx.Defn.Token,create_token,hook_token, and theattach_tokenAPI