@@ -115,3 +115,142 @@ have explicit first-class opcodes defined with the `-start` and/or `-done`
115115suffixes (e.g., ` copy-start ` /` copy-done ` ,
116116` collective-permute-start ` /` collective-permute-done ` ). These will continue to
117117use their respective first-class opcodes.
118+
119+ ## Late Binding
120+
121+ In some cases, the operands (inputs) or outputs of an asynchronous
122+ operation are not all available or allocated when the operation starts.
123+ XLA supports * late binding* , which allows operands to be incrementally
124+ bound during ` async-update ` steps, and outputs to be bound during either
125+ ` async-update ` or ` async-done ` steps.
126+
127+ ### Representation in HLO
128+
129+ For a called computation that expects $N$ parameters, we can start the
130+ asynchronous execution with fewer than $N$ operands. The remaining
131+ operands are passed in subsequent ` async-update ` instructions.
132+
133+ * ` async-start ` binds the first $K$ operands ($K < N$).
134+ * ` async-update ` instructions bind the remaining $N - K$ operands.
135+
136+ Operand bindings must happen in left-to-right order. That is, if a
137+ computation expects parameters $P_0, P_1, \dots, P_ {N-1}$, they must be bound in
138+ that order across the async chain.
139+
140+ The ` async-start ` and ` async-update ` shapes reflect the incrementally
141+ bound parameters. Specifically, the first element of the tuple shape
142+ (the operand shapes) grows as more operands are bound.
143+
144+ Output binding is independent of operand binding and can happen at any
145+ step in the async chain (either in an ` async-update ` or at the final
146+ ` async-done ` ).
147+
148+ ### Example with ` kCall `
149+
150+ Consider a called computation ` %foo ` that takes two parameters:
151+
152+ ```
153+ %foo {
154+ %p0 = f32[] parameter(0)
155+ %p1 = f32[] parameter(1)
156+ ROOT %add = f32[] add(%p0, %p1)
157+ }
158+ ```
159+
160+ We can call this computation asynchronously, binding ` %p0 ` at start and
161+ ` %p1 ` at update:
162+
163+ ```
164+ %call-start = ((f32[]), (), s32[]) call-start(%operand0), to_apply=%foo
165+ %call-update = ((f32[], f32[]), f32[], s32[]) call-update(%call-start, %operand1)
166+ %result = f32[] call-done(%call-update)
167+ ```
168+
169+ The parser desugars this into the following HLO:
170+
171+ ```
172+ %async-start = ((f32[]), (), s32[]) async-start(%operand0), calls=%foo
173+ %async-update = ((f32[], f32[]), f32[], s32[]) async-update(%async-start, %operand1)
174+ %result = f32[] async-done(%async-update)
175+ ```
176+
177+ ### Late-Bound Outputs
178+
179+ In addition to operands (inputs), the ** outputs** of an asynchronous
180+ operation can also be bound late. This is useful when the output
181+ buffers are not known or allocated at the start of the operation.
182+
183+ To represent late-bound outputs:
184+ 1 . The ` async-start ` (or ` call-start ` ) instruction is defined with an
185+ empty tuple ` () ` at index 1 of its output shape (the result slot).
186+ 2 . A subsequent ` async-update ` (or ` call-update ` ) instruction
187+ specifies the actual output shape at index 1, replacing the empty
188+ tuple.
189+ 3 . Alternatively, the output can be bound at the end of the chain by
190+ the ` async-done ` (or ` call-done ` ) instruction, which returns the
191+ final output shape. This can be done regardless of whether there are
192+ intermediate ` async-update ` steps in the chain.
193+
194+ #### Example with ` async-update `
195+
196+ ```
197+ // Output is not bound at start (index 1 is ())
198+ %call-start = ((f32[1024]), (), s32[]) call-start(%input_buffer), to_apply=%foo
199+
200+ // Output is bound at update (index 1 becomes (f32[1024]))
201+ %call-update = ((f32[1024]), (f32[1024]), s32[]) call-update(%call-start, %output_buffer)
202+
203+ %result = (f32[1024]) call-done(%call-update)
204+ ```
205+
206+ The parser desugars this into:
207+
208+ ```
209+ %async-start = ((f32[1024]), (), s32[]) async-start(%input_buffer), calls=%foo
210+ %async-update = ((f32[1024]), (f32[1024]), s32[]) async-update(%async-start, %output_buffer)
211+ %result = (f32[1024]) async-done(%async-update)
212+ ```
213+
214+ #### Example with ` async-done ` (without ` async-update ` )
215+
216+ If there are no intermediate update steps, the output can be bound
217+ directly at ` async-done ` :
218+
219+ ```
220+ // Output is not bound at start (index 1 is ())
221+ %call-start = ((f32[1024]), (), s32[]) call-start(%input_buffer), to_apply=%foo
222+
223+ // Output is bound at done
224+ %result = (f32[1024]) call-done(%call-start)
225+ ```
226+
227+ The parser desugars this into:
228+
229+ ```
230+ %async-start = ((f32[1024]), (), s32[]) async-start(%input_buffer), calls=%foo
231+ %result = (f32[1024]) async-done(%async-start)
232+ ```
233+
234+ #### Example with intermediate ` async-update ` and output bound at ` async-done `
235+
236+ If there are intermediate update steps to bind operands, but the output
237+ is still bound at the very end:
238+
239+ ```
240+ // Output is not bound at start, no operands bound
241+ %call-start = ((), (), s32[]) call-start(), to_apply=%foo
242+
243+ // Operands are bound at update, but output remains unbound (index 1 is ())
244+ %call-update = ((f32[], f32[]), (), s32[]) call-update(%call-start, %operand0, %operand1)
245+
246+ // Output is bound at done
247+ %result = f32[] call-done(%call-update)
248+ ```
249+
250+ The parser desugars this into:
251+
252+ ```
253+ %async-start = ((), (), s32[]) async-start(), calls=%foo
254+ %async-update = ((f32[], f32[]), (), s32[]) async-update(%async-start, %operand0, %operand1)
255+ %result = f32[] async-done(%async-update)
256+ ```
0 commit comments