Skip to content

Commit 447f577

Browse files
[XLA][HLO][Analysis] Fix dataflow analysis for async ops with excluded threads.
PiperOrigin-RevId: 941872692
1 parent 620497e commit 447f577

1 file changed

Lines changed: 139 additions & 0 deletions

File tree

third_party/xla/docs/async_ops.md

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,142 @@ have explicit first-class opcodes defined with the `-start` and/or `-done`
115115
suffixes (e.g., `copy-start`/`copy-done`,
116116
`collective-permute-start`/`collective-permute-done`). These will continue to
117117
use their respective first-class opcodes.
118+
119+
## Late Binding
120+
121+
In some cases, the operands (inputs) or outputs of an asynchronous
122+
operation are not all available or allocated when the operation starts.
123+
XLA supports *late binding*, which allows operands to be incrementally
124+
bound during `async-update` steps, and outputs to be bound during either
125+
`async-update` or `async-done` steps.
126+
127+
### Representation in HLO
128+
129+
For a called computation that expects $N$ parameters, we can start the
130+
asynchronous execution with fewer than $N$ operands. The remaining
131+
operands are passed in subsequent `async-update` instructions.
132+
133+
* `async-start` binds the first $K$ operands ($K < N$).
134+
* `async-update` instructions bind the remaining $N - K$ operands.
135+
136+
Operand bindings must happen in left-to-right order. That is, if a
137+
computation expects parameters $P_0, P_1, \dots, P_{N-1}$, they must be bound in
138+
that order across the async chain.
139+
140+
The `async-start` and `async-update` shapes reflect the incrementally
141+
bound parameters. Specifically, the first element of the tuple shape
142+
(the operand shapes) grows as more operands are bound.
143+
144+
Output binding is independent of operand binding and can happen at any
145+
step in the async chain (either in an `async-update` or at the final
146+
`async-done`).
147+
148+
### Example with `kCall`
149+
150+
Consider a called computation `%foo` that takes two parameters:
151+
152+
```
153+
%foo {
154+
%p0 = f32[] parameter(0)
155+
%p1 = f32[] parameter(1)
156+
ROOT %add = f32[] add(%p0, %p1)
157+
}
158+
```
159+
160+
We can call this computation asynchronously, binding `%p0` at start and
161+
`%p1` at update:
162+
163+
```
164+
%call-start = ((f32[]), (), s32[]) call-start(%operand0), to_apply=%foo
165+
%call-update = ((f32[], f32[]), f32[], s32[]) call-update(%call-start, %operand1)
166+
%result = f32[] call-done(%call-update)
167+
```
168+
169+
The parser desugars this into the following HLO:
170+
171+
```
172+
%async-start = ((f32[]), (), s32[]) async-start(%operand0), calls=%foo
173+
%async-update = ((f32[], f32[]), f32[], s32[]) async-update(%async-start, %operand1)
174+
%result = f32[] async-done(%async-update)
175+
```
176+
177+
### Late-Bound Outputs
178+
179+
In addition to operands (inputs), the **outputs** of an asynchronous
180+
operation can also be bound late. This is useful when the output
181+
buffers are not known or allocated at the start of the operation.
182+
183+
To represent late-bound outputs:
184+
1. The `async-start` (or `call-start`) instruction is defined with an
185+
empty tuple `()` at index 1 of its output shape (the result slot).
186+
2. A subsequent `async-update` (or `call-update`) instruction
187+
specifies the actual output shape at index 1, replacing the empty
188+
tuple.
189+
3. Alternatively, the output can be bound at the end of the chain by
190+
the `async-done` (or `call-done`) instruction, which returns the
191+
final output shape. This can be done regardless of whether there are
192+
intermediate `async-update` steps in the chain.
193+
194+
#### Example with `async-update`
195+
196+
```
197+
// Output is not bound at start (index 1 is ())
198+
%call-start = ((f32[1024]), (), s32[]) call-start(%input_buffer), to_apply=%foo
199+
200+
// Output is bound at update (index 1 becomes (f32[1024]))
201+
%call-update = ((f32[1024]), (f32[1024]), s32[]) call-update(%call-start, %output_buffer)
202+
203+
%result = (f32[1024]) call-done(%call-update)
204+
```
205+
206+
The parser desugars this into:
207+
208+
```
209+
%async-start = ((f32[1024]), (), s32[]) async-start(%input_buffer), calls=%foo
210+
%async-update = ((f32[1024]), (f32[1024]), s32[]) async-update(%async-start, %output_buffer)
211+
%result = (f32[1024]) async-done(%async-update)
212+
```
213+
214+
#### Example with `async-done` (without `async-update`)
215+
216+
If there are no intermediate update steps, the output can be bound
217+
directly at `async-done`:
218+
219+
```
220+
// Output is not bound at start (index 1 is ())
221+
%call-start = ((f32[1024]), (), s32[]) call-start(%input_buffer), to_apply=%foo
222+
223+
// Output is bound at done
224+
%result = (f32[1024]) call-done(%call-start)
225+
```
226+
227+
The parser desugars this into:
228+
229+
```
230+
%async-start = ((f32[1024]), (), s32[]) async-start(%input_buffer), calls=%foo
231+
%result = (f32[1024]) async-done(%async-start)
232+
```
233+
234+
#### Example with intermediate `async-update` and output bound at `async-done`
235+
236+
If there are intermediate update steps to bind operands, but the output
237+
is still bound at the very end:
238+
239+
```
240+
// Output is not bound at start, no operands bound
241+
%call-start = ((), (), s32[]) call-start(), to_apply=%foo
242+
243+
// Operands are bound at update, but output remains unbound (index 1 is ())
244+
%call-update = ((f32[], f32[]), (), s32[]) call-update(%call-start, %operand0, %operand1)
245+
246+
// Output is bound at done
247+
%result = f32[] call-done(%call-update)
248+
```
249+
250+
The parser desugars this into:
251+
252+
```
253+
%async-start = ((), (), s32[]) async-start(), calls=%foo
254+
%async-update = ((f32[], f32[]), (), s32[]) async-update(%async-start, %operand0, %operand1)
255+
%result = f32[] async-done(%async-update)
256+
```

0 commit comments

Comments
 (0)