-
Notifications
You must be signed in to change notification settings - Fork 493
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
A graph of TiledHloInstruction represents an HLO graph with associated concrete tiles sizes. In the following changes I'll add code to build the graph from SymbolicTiledHloInstruction and use the tiled graph for Cost Model and Triton codegen. PiperOrigin-RevId: 621903701
- Loading branch information
1 parent
570a4d8
commit 19422af
Showing
4 changed files
with
429 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "xla/service/gpu/model/tiled_hlo_instruction.h" | ||
|
||
#include <cstddef> | ||
#include <cstdint> | ||
#include <memory> | ||
#include <sstream> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "absl/hash/hash.h" | ||
#include "absl/memory/memory.h" | ||
#include "absl/status/status.h" | ||
#include "absl/status/statusor.h" | ||
#include "absl/strings/str_cat.h" | ||
#include "absl/strings/str_join.h" | ||
#include "xla/hlo/ir/hlo_instruction.h" | ||
#include "xla/service/gpu/model/indexing_map.h" | ||
#include "xla/util.h" | ||
|
||
namespace xla { | ||
namespace gpu { | ||
|
||
size_t TiledHloInstruction::PtrHash::operator()( | ||
const TiledHloInstruction* tiled_hlo) const { | ||
return absl::HashOf(*tiled_hlo); | ||
} | ||
|
||
bool TiledHloInstruction::PtrEqual::operator()( | ||
const TiledHloInstruction* lhs, const TiledHloInstruction* rhs) const { | ||
return *lhs == *rhs; | ||
} | ||
|
||
bool operator==(const TiledHloInstruction& lhs, | ||
const TiledHloInstruction& rhs) { | ||
return lhs.hlo() == rhs.hlo() && lhs.tile_sizes() == rhs.tile_sizes() && | ||
lhs.tile_strides() == rhs.tile_strides() && | ||
lhs.block_id_to_tile_offsets_indexing() == | ||
rhs.block_id_to_tile_offsets_indexing(); | ||
} | ||
|
||
bool operator!=(const TiledHloInstruction& lhs, | ||
const TiledHloInstruction& rhs) { | ||
return !(lhs == rhs); | ||
} | ||
|
||
/*static*/ | ||
absl::StatusOr<std::unique_ptr<TiledHloInstruction>> | ||
TiledHloInstruction::Create(const HloInstruction* hlo, | ||
std::vector<int64_t> tile_sizes, | ||
std::vector<int64_t> tile_strides, | ||
IndexingMap block_id_to_tile_offsets_indexing) { | ||
int rank = hlo->shape().rank(); | ||
|
||
if (tile_sizes.size() != rank) { | ||
return absl::InvalidArgumentError( | ||
absl::StrCat("Number of tile sizes must be equal to the rank of the " | ||
"hlo shape. tile_sizes = ", | ||
tile_sizes.size(), ", hlo = ", hlo->ToString())); | ||
} | ||
|
||
if (tile_strides.size() != rank) { | ||
return absl::InvalidArgumentError( | ||
absl::StrCat("Number of tile strides must be equal to the rank of the " | ||
"hlo shape. tile_sizes = ", | ||
tile_strides.size(), ", hlo = ", hlo->ToString())); | ||
} | ||
|
||
if (block_id_to_tile_offsets_indexing.GetDimensionCount() != 1 || | ||
block_id_to_tile_offsets_indexing.GetSymbolCount() != 0) { | ||
return absl::InvalidArgumentError(absl::StrCat( | ||
"block_id_to_tile_offsets_indexing must have 1 dim and 0 symbols. " | ||
"block_id_to_tile_offsets_indexing = ", | ||
block_id_to_tile_offsets_indexing.ToString())); | ||
} | ||
|
||
if (block_id_to_tile_offsets_indexing.GetAffineMap().getNumResults() != | ||
rank) { | ||
return absl::InvalidArgumentError(absl::StrCat( | ||
"block_id_to_tile_offsets_indexing must have the same number of " | ||
"results as the rank of the hlo shape. " | ||
"block_id_to_tile_offsets_indexing = ", | ||
block_id_to_tile_offsets_indexing.ToString(), | ||
", hlo = ", hlo->ToString())); | ||
} | ||
|
||
return absl::WrapUnique(new TiledHloInstruction( | ||
hlo, std::move(tile_sizes), std::move(tile_strides), | ||
std::move(block_id_to_tile_offsets_indexing))); | ||
} | ||
|
||
std::string TiledHloInstruction::ToString() const { | ||
std::stringstream ss; | ||
ss << "hlo: " << hlo_->ToString() << "\n"; | ||
ss << "tile_sizes: {" << absl::StrJoin(tile_sizes_, ", ") << "}\n"; | ||
ss << "tile_strides: {" << absl::StrJoin(tile_strides_, ", ") << "}\n"; | ||
ss << "block_id_to_tile_offsets_indexing: " | ||
<< block_id_to_tile_offsets_indexing_; | ||
return ss.str(); | ||
} | ||
|
||
} // namespace gpu | ||
} // namespace xla |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#ifndef XLA_SERVICE_GPU_MODEL_TILED_HLO_INSTRUCTION_H_ | ||
#define XLA_SERVICE_GPU_MODEL_TILED_HLO_INSTRUCTION_H_ | ||
|
||
#include <cstddef> | ||
#include <cstdint> | ||
#include <memory> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "absl/status/statusor.h" | ||
#include "xla/hlo/ir/hlo_instruction.h" | ||
#include "xla/service/gpu/model/indexing_map.h" | ||
|
||
namespace xla { | ||
namespace gpu { | ||
|
||
// A wrapper around HloInstruction that represents a tiled HLO instruction. | ||
// | ||
// The class contains information required to emit this instruction in | ||
// block-level codegen. Tile sizes and strides are constants and do not depend | ||
// on the block id. Tile offsets are computed using an indexing map of form: | ||
// `(block_id) -> (tile_offset0, tile_offset1, ...)`. | ||
class TiledHloInstruction { | ||
public: | ||
// PtrHash and PtrEqual are helper classes to use in hash maps and sets that | ||
// compare values behind the pointers. For example, | ||
// absl::flat_hash_set<TiledHloInstruction*, PtrHash, PtrEqual> hlo_set; | ||
struct PtrHash { | ||
size_t operator()(const TiledHloInstruction* tiled_hlo) const; | ||
}; | ||
|
||
struct PtrEqual { | ||
bool operator()(const TiledHloInstruction* lhs, | ||
const TiledHloInstruction* rhs) const; | ||
}; | ||
|
||
// Creates an instance of TiledHloInstruction. Returns an error if any of the | ||
// following preconditions is not met: | ||
// * Number of tile sizes, strides should match HLO shape rank. | ||
// * Number of result of `block_id_to_tile_offsets_indexing` should match HLO | ||
// shape rank. | ||
// * `block_id_to_tile_offsets_indexing` should have only 1 dimension and 0 | ||
// symbols. | ||
static absl::StatusOr<std::unique_ptr<TiledHloInstruction>> Create( | ||
const HloInstruction* hlo, std::vector<int64_t> tile_sizes, | ||
std::vector<int64_t> tile_strides, | ||
IndexingMap block_id_to_tile_offsets_indexing); | ||
|
||
// Returns the original HLO instruction. | ||
const HloInstruction* hlo() const { return hlo_; } | ||
|
||
// Returns the tile sizes. The number of tile sizes is equal to the rank of | ||
// the output shape. | ||
const std::vector<int64_t>& tile_sizes() const { return tile_sizes_; } | ||
|
||
// Returns the tile strides. The number of tile strides is equal to the rank | ||
// of the output shape. | ||
const std::vector<int64_t>& tile_strides() const { return tile_strides_; } | ||
|
||
// Returns the indexing map from block_id to tile offsets. The map has a form | ||
// of `(block_id) -> (tile_offset0, tile_offset1, ...)`. The number of tile | ||
// offsets is equal to the rank of the output shape. | ||
const IndexingMap& block_id_to_tile_offsets_indexing() const { | ||
return block_id_to_tile_offsets_indexing_; | ||
} | ||
|
||
const TiledHloInstruction* operand(int64_t operand_id) const { | ||
return operands_[operand_id]; | ||
} | ||
|
||
const std::vector<TiledHloInstruction*>& operands() const { | ||
return operands_; | ||
} | ||
|
||
void AppendOperand(TiledHloInstruction* operand) { | ||
operands_.push_back(operand); | ||
} | ||
|
||
std::string ToString() const; | ||
|
||
private: | ||
TiledHloInstruction(const HloInstruction* hlo, | ||
std::vector<int64_t> tile_sizes, | ||
std::vector<int64_t> tile_strides, | ||
IndexingMap block_id_to_tile_offsets_indexing) | ||
: hlo_(hlo), | ||
tile_sizes_(std::move(tile_sizes)), | ||
tile_strides_(std::move(tile_strides)), | ||
block_id_to_tile_offsets_indexing_( | ||
std::move(block_id_to_tile_offsets_indexing)) {} | ||
|
||
// Pointer to the original HLO instruction. | ||
const HloInstruction* hlo_; | ||
|
||
// Tile sizes and strides. | ||
std::vector<int64_t> tile_sizes_; | ||
std::vector<int64_t> tile_strides_; | ||
|
||
// Indexing map from block_id to tile offsets. | ||
IndexingMap block_id_to_tile_offsets_indexing_; | ||
|
||
// Operands of the instruction in the tiled computation graph. | ||
std::vector<TiledHloInstruction*> operands_; | ||
}; | ||
|
||
bool operator==(const TiledHloInstruction& lhs, const TiledHloInstruction& rhs); | ||
bool operator!=(const TiledHloInstruction& lhs, const TiledHloInstruction& rhs); | ||
|
||
template <typename H> | ||
H AbslHashValue(H h, const TiledHloInstruction& tiled_hlo_instruction) { | ||
return H::combine(std::move(h), tiled_hlo_instruction.hlo(), | ||
tiled_hlo_instruction.tile_sizes(), | ||
tiled_hlo_instruction.tile_strides(), | ||
tiled_hlo_instruction.block_id_to_tile_offsets_indexing()); | ||
} | ||
|
||
} // namespace gpu | ||
} // namespace xla | ||
|
||
#endif // XLA_SERVICE_GPU_MODEL_TILED_HLO_INSTRUCTION_H_ |
Oops, something went wrong.