Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 41 additions & 25 deletions allo/backend/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def _process_function_streams(
module: Module,
func: func_d.FuncOp,
processed_funcs: set,
all_pe_calls_by_func: dict = None,
):
"""
Process streams and PE calls within a single function.
Expand Down Expand Up @@ -109,7 +110,9 @@ def _process_function_streams(
if callee_name == str(mod_op.sym_name).strip('"'):
pe_call_define_ops[op] = mod_op
# Recursively process the callee function first
_process_function_streams(module, mod_op, processed_funcs)
_process_function_streams(
module, mod_op, processed_funcs, all_pe_calls_by_func
)
break
elif isinstance(op, allo_d.StreamConstructOp):
stream_name = str(op.attributes["name"]).strip('"')
Expand Down Expand Up @@ -765,6 +768,10 @@ def _process_function_streams(
for op in stream_construct_ops.values():
op.operation.erase()

# Accumulate PE calls keyed by function for recursive OMP injection
if all_pe_calls_by_func is not None and pe_call_define_ops:
all_pe_calls_by_func[func_name] = pe_call_define_ops
Comment on lines +772 to +773
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always not None?


return (
stream_struct_table,
stream_type_table,
Expand All @@ -773,6 +780,32 @@ def _process_function_streams(
)


def _inject_omp_parallel_sections(pe_call_define_ops):
"""Wrap a set of func.call ops in omp.parallel > omp.sections > omp.section blocks."""
assert len(pe_call_define_ops) > 0
omp_ip = InsertionPoint(beforeOperation=list(pe_call_define_ops.keys())[0])
omp_parallel_op = openmp_d.ParallelOp([], [], [], [], ip=omp_ip)
assert isinstance(omp_parallel_op.region, Region)
omp_parallel_block = Block.create_at_start(omp_parallel_op.region, [])

# Add `omp.sections`
ip_omp_parallel = InsertionPoint(omp_parallel_block)
omp_sections_op = openmp_d.SectionsOp([], [], [], [], ip=ip_omp_parallel)
omp_sections_block = Block.create_at_start(omp_sections_op.region, [])
openmp_d.TerminatorOp(ip=ip_omp_parallel)

# Add `omp.section`s for PE calls
ip_omp_sections = InsertionPoint(omp_sections_block)
for call_op in pe_call_define_ops:
assert isinstance(call_op, OpView)
omp_section_op = openmp_d.SectionOp(ip=ip_omp_sections)
omp_section_block = Block.create_at_start(omp_section_op.region, [])
ip_omp_section = InsertionPoint(omp_section_block)
omp_term_op = openmp_d.TerminatorOp(ip=ip_omp_section)
call_op.operation.move_before(omp_term_op.operation)
openmp_d.TerminatorOp(ip=ip_omp_sections)


def build_dataflow_simulator(module: Module, top_func_name: str):
with module.context, Location.unknown():
# Declare usleep for spinloop yielding
Expand All @@ -799,12 +832,13 @@ def build_dataflow_simulator(module: Module, top_func_name: str):

# Process all functions with streams recursively, starting from top
processed_funcs: set = set()
all_pe_calls_by_func: dict = {}
func = find_func_in_module(module, top_func_name)
assert isinstance(func.body, Region)

# Recursively process the top function and all its callees
_, _, pe_call_define_ops, _ = _process_function_streams(
module, func, processed_funcs
module, func, processed_funcs, all_pe_calls_by_func
)

# If no PE calls were found in top function, collect them again from the processed functions
Expand All @@ -819,30 +853,12 @@ def build_dataflow_simulator(module: Module, top_func_name: str):
if callee_name == str(mod_op.sym_name).strip('"'):
pe_call_define_ops[op] = mod_op
break
all_pe_calls_by_func[top_func_name] = pe_call_define_ops

# Add the outmost `omp.parallel`
assert len(pe_call_define_ops) > 0
omp_ip = InsertionPoint(beforeOperation=list(pe_call_define_ops.keys())[0])
omp_parallel_op = openmp_d.ParallelOp([], [], [], [], ip=omp_ip)
assert isinstance(omp_parallel_op.region, Region)
omp_parallel_block = Block.create_at_start(omp_parallel_op.region, [])

# Add `omp.sections`
ip_omp_parallel = InsertionPoint(omp_parallel_block)
omp_sections_op = openmp_d.SectionsOp([], [], [], [], ip=ip_omp_parallel)
omp_sections_block = Block.create_at_start(omp_sections_op.region, [])
openmp_d.TerminatorOp(ip=ip_omp_parallel)

# Add `omp.section`s for PE calls
ip_omp_sections = InsertionPoint(omp_sections_block)
for call_op in pe_call_define_ops:
assert isinstance(call_op, OpView)
omp_section_op = openmp_d.SectionOp(ip=ip_omp_sections)
omp_section_block = Block.create_at_start(omp_section_op.region, [])
ip_omp_section = InsertionPoint(omp_section_block)
omp_term_op = openmp_d.TerminatorOp(ip=ip_omp_section)
call_op.operation.move_before(omp_term_op.operation)
openmp_d.TerminatorOp(ip=ip_omp_sections)
# Inject omp.parallel/sections into every function that has PE calls
for func_pe_calls in all_pe_calls_by_func.values():
if func_pe_calls:
_inject_omp_parallel_sections(func_pe_calls)


# This pass is only meant to run on fully lowered MLIR code
Expand Down
33 changes: 26 additions & 7 deletions allo/ir/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2211,10 +2211,20 @@ def build_FunctionDef(ctx: ASTContext, node: ast.FunctionDef):
args_kw = get_kwarg_value(decorator.keywords, "args")
if args_kw is not None:
args = build_stmts(ctx, args_kw)
arg_values = [
ASTTransformer.get_mlir_op_result(ctx, arg)
for arg in args
]
arg_values = []
for arg in args:
res = ASTTransformer.get_mlir_op_result(
ctx, arg
)
# If it's a 0D memref (scalar), load it to get the value
if (
isinstance(res.type, MemRefType)
and len(res.type.shape) == 0
):
op_ = ASTTransformer.build_scalar(ctx, arg)
arg_values.append(op_.result)
Comment on lines +2219 to +2225
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would pass the scalar argument by value.
Also, I think args should only contain tensors (please correct me if I'm missing something). In that case, we should raise an error in the type infer phase if the user use a scalar here.

else:
arg_values.append(res)
else:
arg_values = []
# Insert calls
Expand Down Expand Up @@ -2729,13 +2739,22 @@ def build_Call(ctx: ASTContext, node: ast.Call, out_buffer: OpView = None):
new_ctx.func_suffix = inst_suffix

func_op = ASTTransformer.build_FunctionDef(new_ctx, func_def)
func_op.attributes["dataflow"] = UnitAttr.get()
if ctx.top_func is not None:
func_op.operation.move_before(ctx.top_func)
Comment on lines +2743 to +2744
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is also added to fix foward reference. Forward reference is valid in mlir but invalide in HLS backend. We'd better fix the issue in HLS backend codegen.
Could you please check whether existing fix (introduced in #557 ) already works?


# Now insert the call
# Parse arguments
new_args = build_stmts(ctx, node.args)
arg_values = [
ASTTransformer.get_mlir_op_result(ctx, arg) for arg in new_args
]
arg_values = []
for arg in new_args:
res = ASTTransformer.get_mlir_op_result(ctx, arg)
if isinstance(res.type, MemRefType) and len(res.type.shape) == 0:
op_ = ASTTransformer.build_scalar(ctx, arg)
arg_values.append(op_.result)
Comment on lines +2752 to +2754
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would pass the scalar argument by value. Is this the intended behavior?

else:
arg_values.append(res)

call_op = func_d.CallOp(
[],
FlatSymbolRefAttr.get(func_def.name),
Expand Down
2 changes: 2 additions & 0 deletions allo/ir/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
self.mapping = None
# track the current AST node being visited for error reporting
self.current_node = None
self.global_op_cache = {}

def copy(self):
ctx = ASTContext(
Expand All @@ -147,6 +148,7 @@ def copy(self):
ctx.current_node = self.current_node
if hasattr(self, "func_suffix"):
ctx.func_suffix = self.func_suffix
ctx.global_op_cache = self.global_op_cache
return ctx

def set_ip(self, ip):
Expand Down
24 changes: 19 additions & 5 deletions mlir/lib/Translation/EmitVivadoHLS.cpp
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does #557 still fail to fix the forward reference issue?

Original file line number Diff line number Diff line change
Expand Up @@ -2727,7 +2727,7 @@ allo::hls::VhlsModuleEmitter::emitFunctionSignature(func::FuncOp func) {
os << "void " << func.getName() << "(\n";
addIndent();

// This vector is to record all ports of the function.
// This vector records all ports of the function (args + return operands).
SmallVector<Value, 8> portList;

// Emit input arguments.
Expand All @@ -2754,6 +2754,7 @@ allo::hls::VhlsModuleEmitter::emitFunctionSignature(func::FuncOp func) {
itypes += "x";
}
for (auto &arg : func.getArguments()) {
portList.push_back(arg);
indent();
fixUnsignedType(arg, itypes[argIdx] == 'u');
if (llvm::isa<ShapedType>(arg.getType())) {
Expand Down Expand Up @@ -2781,7 +2782,6 @@ allo::hls::VhlsModuleEmitter::emitFunctionSignature(func::FuncOp func) {
}
}

portList.push_back(arg);
if (argIdx++ != func.getNumArguments() - 1)
os << ",\n";
}
Expand All @@ -2801,6 +2801,7 @@ allo::hls::VhlsModuleEmitter::emitFunctionSignature(func::FuncOp func) {
unsigned idx = 0;
for (auto result : funcReturn.getOperands()) {
if (std::find(args.begin(), args.end(), result) == args.end()) {
portList.push_back(result);
if (func.getArguments().size() > 0)
os << ",\n";
indent();
Expand All @@ -2820,8 +2821,6 @@ allo::hls::VhlsModuleEmitter::emitFunctionSignature(func::FuncOp func) {
else
emitValue(result, /*rank=*/0, /*isPtr=*/true, output_names);
}

portList.push_back(result);
}
idx += 1;
}
Expand Down Expand Up @@ -3178,7 +3177,22 @@ using namespace std;
}
}

// Third pass: emit function definitions and non-stateful globals
// Third pass: emit forward declarations for all functions
for (auto &op : *module.getBody()) {
if (auto func = dyn_cast<func::FuncOp>(op)) {
if (!func->hasAttr("top") && !func.getBlocks().empty()) {
emitFunctionSignature(func);
os << "\n);\n\n";
}
}
}

// Clear nameTable and nameConflictCnt to ensure that Pass 4 can re-emit
// function signatures with full types.
state.nameTable.clear();
state.nameConflictCnt.clear();

// Fourth pass: emit functions and non-stateful globals
for (auto &op : *module.getBody()) {
if (auto func = dyn_cast<func::FuncOp>(op)) {
emitFunction(func);
Expand Down
Loading