Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#include "xla/codegen/xtile/codegen/experimental_fusion_emitter.h"

#include <cstdint>
#include <iostream>
#include <optional>
#include <utility>
#include <vector>
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/ir_emission_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ bool IsMosaicWithNvshmem(const HloInstruction& hlo) {
bool IsMosaicWithMultimem(const HloInstruction& hlo) {
return IsCustomCallToMosaicGpu(hlo) &&
absl::StrContains(hlo.raw_backend_config_string(),
"xla_multimem_parameters");
"multimem_parameters");
}

bool IsCollectiveMosaicGpuInstruction(const HloInstruction& hlo) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,13 @@ StringAttr createFuncOp(NamedComputationOp namedComputationOp,
std::optional<TensorShardingPerValueAttr> inShardings,
std::optional<TensorShardingPerValueAttr> outShardings,
ManualAxesAttr manualAxesAttr) {
auto funcOp = FuncOp::create(
FuncOp funcOp = FuncOp::create(
rewriter, namedComputationOp.getLoc(), namedComputationOp.getName(),
rewriter.getFunctionType(namedComputationOp.getBody().getArgumentTypes(),
namedComputationOp.getResultTypes()),
rewriter.getStringAttr("private"),
/*argAttrs=*/ArrayAttr(), /*resultAttrs=*/ArrayAttr());
funcOp->setAttr(kOriginalFuncName, namedComputationOp.getNameAttr());
if (manualAxesAttr) {
funcOp->setAttr(kManualAxes, manualAxesAttr);
}
Expand Down Expand Up @@ -122,15 +123,11 @@ StringAttr createFuncOpOrGetFromCache(
mlir::IRRewriter& rewriter, SymbolTable& symbolTable,
ManualAxesAttr manualAxesAttr,
std::optional<TensorShardingPerValueAttr> inShardings,
std::optional<TensorShardingPerValueAttr> outShardings,
bool dedupFunctionsFully) {
ComputationKey key = {
namedComputationOp.getName(),
dedupFunctionsFully ? TensorShardingPerValueAttr()
: inShardings.value_or(TensorShardingPerValueAttr()),
dedupFunctionsFully ? TensorShardingPerValueAttr()
: outShardings.value_or(TensorShardingPerValueAttr()),
manualAxesAttr};
std::optional<TensorShardingPerValueAttr> outShardings) {
ComputationKey key = {namedComputationOp.getName(),
inShardings.value_or(TensorShardingPerValueAttr()),
outShardings.value_or(TensorShardingPerValueAttr()),
manualAxesAttr};
if (auto it = funcCache.find(key); it != funcCache.end()) {
return it->second;
}
Expand All @@ -141,64 +138,10 @@ StringAttr createFuncOpOrGetFromCache(
return funcSymName;
}

void exportNamedComputations(ModuleOp moduleOp, SymbolTable& symbolTable,
bool dedupFunctionsFully) {
void exportNamedComputations(ModuleOp moduleOp, SymbolTable& symbolTable) {
mlir::Block& moduleBlock = moduleOp.getRegion().front();
llvm::SmallDenseMap<ComputationKey, StringAttr> funcCache;

if (dedupFunctionsFully) {
using FuncNameKey = std::pair<StringRef, ManualAxesAttr>;
llvm::SmallDenseMap<ComputationKey, int64_t> funcCallSiteCounts;
llvm::SmallDenseMap<FuncNameKey, NamedComputationWithCount>
funcToNamedComputations;
// TODO(enver): Instead of a SmallDenseMap and a separate SmallVector to
// guarantee a deterministic iteration order, consider using
// llvm::MapVector.
// Required to iterate on functions in a deterministic order.
llvm::SmallVector<FuncNameKey> funcNames;
moduleOp.walk([&](NamedComputationOp namedComputationOp) {
ManualAxesAttr manualAxesAttr =
namedComputationOp->getAttrOfType<ManualAxesAttr>(kManualAxes);
ComputationKey key = {namedComputationOp.getName(),
namedComputationOp.getInShardings().value_or(
TensorShardingPerValueAttr()),
namedComputationOp.getOutShardings().value_or(
TensorShardingPerValueAttr()),
manualAxesAttr};
const int64_t callSiteCount = funcCallSiteCounts[key]++;
FuncNameKey funcNameKey =
std::pair(namedComputationOp.getName(), manualAxesAttr);
if (auto [it, inserted] = funcToNamedComputations.try_emplace(
funcNameKey,
NamedComputationWithCount{namedComputationOp, callSiteCount});
!inserted) {
NamedComputationWithCount& cached = it->second;
if (callSiteCount > cached.callSiteCount) {
cached.namedComputationOp = namedComputationOp;
cached.callSiteCount = callSiteCount;
}
} else { // inserted is true.
funcNames.push_back(funcNameKey);
}
});

for (FuncNameKey funcNameKey : funcNames) {
NamedComputationOp namedComputationOp =
funcToNamedComputations.at(funcNameKey).namedComputationOp;
mlir::IRRewriter rewriter(namedComputationOp);
rewriter.setInsertionPointToEnd(&moduleBlock);
ManualAxesAttr manualAxesAttr = funcNameKey.second;
StringAttr funcSymName =
createFuncOp(namedComputationOp, rewriter, symbolTable,
namedComputationOp.getInShardings(),
namedComputationOp.getOutShardings(), manualAxesAttr);
funcCache.try_emplace(
ComputationKey{namedComputationOp.getName(),
TensorShardingPerValueAttr(),
TensorShardingPerValueAttr(), manualAxesAttr},
funcSymName);
}
}

// NOTE: The walk needs to be in post order, which is the default order, to
// account for nested named computations.
Expand All @@ -219,7 +162,7 @@ void exportNamedComputations(ModuleOp moduleOp, SymbolTable& symbolTable,
}
StringAttr funcSymName = createFuncOpOrGetFromCache(
namedComputationOp, funcCache, rewriter, symbolTable, manualAxesAttr,
inShardings, outShardings, dedupFunctionsFully);
inShardings, outShardings);

// Replace the `NamedComputationOp` with a `CallOp`.
rewriter.setInsertionPoint(namedComputationOp);
Expand All @@ -234,14 +177,13 @@ void exportNamedComputations(ModuleOp moduleOp, SymbolTable& symbolTable,
}

FuncOp funcOp = symbolTable.lookup<FuncOp>(funcSymName);
maybeInsertReshardsOnFuncArguments(funcOp, callOp, symbolTable, rewriter);
// Copy the func output shardings to the call op.
if (TensorShardingPerValueAttr funcResultShardings =
getFuncResultShardings(funcOp, symbolTable)) {
getFuncResultShardings(funcOp, symbolTable);
funcResultShardings || outShardings) {
mlir::sdy::setShardings(
callOp, outShardings ? *outShardings
: getFullyClosedLike(funcResultShardings));
insertReshardsOnFuncResults(funcResultShardings, callOp, rewriter);
}
});
}
Expand All @@ -253,23 +195,12 @@ class ExportNamedComputationsPass
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ExportNamedComputationsPass)

explicit ExportNamedComputationsPass(bool dedupFunctionsFully) {
this->dedupFunctionsFully = dedupFunctionsFully;
}

ExportNamedComputationsPass() = default;

explicit ExportNamedComputationsPass(
const ExportNamedComputationsPass& other) {
this->dedupFunctionsFully = other.dedupFunctionsFully;
}

void runOnOperation() final {
ModuleOp moduleOp = getOperation();
SymbolTableCollection symbolTableCollection;

SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(moduleOp);
exportNamedComputations(moduleOp, symbolTable, dedupFunctionsFully);
exportNamedComputations(moduleOp, symbolTable);
}

StringRef getArgument() const override {
Expand All @@ -286,22 +217,12 @@ class ExportNamedComputationsPass
void getDependentDialects(mlir::DialectRegistry& registry) const final {
registry.insert<mlir::sdy::SdyDialect, mlir::mhlo::MhloDialect>();
}

Option<bool> dedupFunctionsFully{
*this, "dedup-functions-fully",
llvm::cl::desc(
"If true, regardless of the input and output shardings of functions, "
"it keeps one callee function for each caller function. The default "
"is false, meaning it will deduplicate only if the input and output "
"shardings are the same."),
llvm::cl::init(false)};
};

} // namespace

std::unique_ptr<mlir::Pass> createExportNamedComputationsPass(
bool dedupFunctionsFully) {
return std::make_unique<ExportNamedComputationsPass>(dedupFunctionsFully);
std::unique_ptr<mlir::Pass> createExportNamedComputationsPass() {
return std::make_unique<ExportNamedComputationsPass>();
}

void registerExportNamedComputationsPass() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,12 @@ namespace sdy {
// and `CallOp` have the same shardings as the original `NamedComputationOp`s
// operands/results.
//
// Deduplicates functions with the same input and output shardings if
// `deduplicateFunctionsFully` is false. Otherwise, it deduplicates functions of
// the same name regardless of their input and output shardings.
// Deduplicates functions with the same input and output shardings.
//
// Based on the deduplication logic as described, if there is a function with
// the same name as the `NamedComputationOp` in the module, the MLIR symbol
// table will change it to `{name}_#`.
std::unique_ptr<mlir::Pass> createExportNamedComputationsPass(
bool dedupFunctionsFully);
std::unique_ptr<mlir::Pass> createExportNamedComputationsPass();

// Register the xla-sdy-export-named-computations pass.
void registerExportNamedComputationsPass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void addStablehloExportPipeline(mlir::OpPassManager& pm,
pm.addPass(createExportOpsPass(options.keepHloShardingConstraints));
pm.addPass(createStablehloRoundTripShardMapExportPass(
options.keepHloShardingConstraints));
pm.addPass(createExportNamedComputationsPass(options.dedupFunctionsFully));
pm.addPass(createExportNamedComputationsPass());
// NOTE: It is currently a literal no-op.
pm.addPass(createUnflattenCallGraphPass(options.dedupFunctionsFully));
pm.addPass(mlir::createSymbolDCEPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,8 @@ class UnflattenCallGraphPass
callOp, symbolTable, /*ignoreShardings=*/dedupFunctionsFully);
FuncOp funcOp = funcCache[funcCacheKey];
callOp.setCallee(funcOp.getName());
maybeInsertReshardsOnFuncArguments(funcOp, callOp, symbolTable, rewriter);
if (TensorShardingPerValueAttr funcResultShardings =
sdy::getFuncResultShardings(funcOp, symbolTable)) {
insertReshardsOnFuncResults(funcResultShardings, callOp, rewriter);
}
insertReshardsOnFuncArguments(funcOp, callOp, symbolTable, rewriter);
insertReshardsOnFuncResults(funcOp, callOp, symbolTable, rewriter);
});

moduleOp.walk([&](FuncOp funcOp) {
Expand Down
Loading
Loading