Skip to content

Commit

Permalink
Delete enable_memories code in C++ since that flag is always True and…
Browse files Browse the repository at this point in the history
… cannot be turned off now.

PiperOrigin-RevId: 707272597
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Dec 17, 2024
1 parent 4728062 commit 052f17f
Show file tree
Hide file tree
Showing 9 changed files with 2 additions and 42 deletions.
18 changes: 1 addition & 17 deletions xla/python/jax_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,6 @@ static std::string OptionalDebugString(
}
}

bool FetchMemoriesFlag() {
auto& global_state = GlobalJitState();
auto& thread_local_state = ThreadLocalJitState();
CHECK(global_state.enable_memories.has_value());
return thread_local_state.enable_memories.value_or(
*global_state.enable_memories);
}

std::string ArgumentSignature::DebugString() const {
auto py_object_formatter = [](std::string* out, const nb::object& o) {
out->append(nb::cast<absl::string_view>(nb::str(o)));
Expand Down Expand Up @@ -224,7 +216,6 @@ std::string CallSignature::DebugString() const {
"device: %s\n"
"default_device: %s\n"
"jax_enable_x64: %d\n"
"jax_enable_memories: %d\n"
"global_extra_jit_context: %s\n"
"thread_local_extra_jit_context: %s\n"
"configs: %s\n",
Expand All @@ -234,7 +225,7 @@ std::string CallSignature::DebugString() const {
absl::StrJoin(dynamic_arg_layouts, ", ", layout_formatter),
absl::StrJoin(committed_args, ",", bool_formatter),
device != nullptr ? device->DebugString() : "nullptr",
OptionalDebugString(default_device), jax_enable_x64, jax_enable_memories,
OptionalDebugString(default_device), jax_enable_x64,
OptionalDebugString(global_extra_jit_context),
OptionalDebugString(thread_local_extra_jit_context),
absl::StrJoin(configs, ", ", py_object_formatter));
Expand All @@ -253,9 +244,6 @@ bool CallSignature::operator==(const CallSignature& other) const {
if (jax_enable_x64 != other.jax_enable_x64) {
return false;
}
if (jax_enable_memories != other.jax_enable_memories) {
return false;
}
if (committed_args != other.committed_args) {
return false;
}
Expand Down Expand Up @@ -387,16 +375,12 @@ void BuildJaxjitSubmodule(nb::module_& m) {
nb::class_<JitState> jit_state_(jitlib, "JitState");
jit_state_.def_rw("disable_jit", &JitState::disable_jit, nb::arg().none());
jit_state_.def_rw("enable_x64", &JitState::enable_x64, nb::arg().none());
jit_state_.def_rw("enable_memories", &JitState::enable_memories,
nb::arg().none());
jit_state_.def_rw("default_device", &JitState::default_device,
nb::arg().none());
jit_state_.def_rw("extra_jit_context", &JitState::extra_jit_context,
nb::arg().none());
jit_state_.def_rw("post_hook", &JitState::post_hook, nb::arg().none());

GetEnableMemories = +[] { return FetchMemoriesFlag(); };

jitlib.def(
"global_state", [&]() { return &GlobalJitState(); },
nb::rv_policy::reference);
Expand Down
2 changes: 0 additions & 2 deletions xla/python/jax_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ struct JitState {

std::optional<bool> disable_jit;
std::optional<bool> enable_x64;
std::optional<bool> enable_memories;

// Used to manually set the default device jax should use. May be unset even
// in global state, indicating there is no manual override.
Expand Down Expand Up @@ -205,7 +204,6 @@ struct CallSignature {
// This is not the case for PMAP, and is set to `nullptr`.
xla::PjRtDevice* device = nullptr;
bool jax_enable_x64;
bool jax_enable_memories = false;

// For JIT on PJIT, we need to fallback to python whenever default_device
// changes.
Expand Down
1 change: 0 additions & 1 deletion xla/python/pjit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,6 @@ absl::Status PjitFunction::ComputeCallSignature(

signature.default_device = GetDefaultDevice();
signature.jax_enable_x64 = jax_enable_x64;
signature.jax_enable_memories = GetEnableMemories();

auto& dynamic_arg_signatures = signature.dynamic_arg_signatures;
dynamic_arg_signatures.reserve(flat_dynamic_args.size());
Expand Down
2 changes: 1 addition & 1 deletion xla/python/py_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ PyArray PyArray::MakeFromSingleDeviceArray(
auto dtype = IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value();
const ifrt::MemoryKind memory_kind = ifrt_array->sharding().memory_kind();
nb::object py_memory_kind =
(jax::GetEnableMemories() && memory_kind.memory_kind().has_value())
(memory_kind.memory_kind().has_value())
? nb::object(nb::str(memory_kind.memory_kind()->data(),
memory_kind.memory_kind()->size()))
: nb::none();
Expand Down
6 changes: 0 additions & 6 deletions xla/python/py_device_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,6 @@ void PyDeviceList::PopulateMemoryKindInfoForDuckTypedDevices() {
}

absl::StatusOr<nb::tuple> PyDeviceList::MemoryKinds() {
if (!GetEnableMemories()) {
return nb::tuple();
}
if (!memory_kind_info_.has_value()) {
PopulateMemoryKindInfo();
}
Expand All @@ -409,9 +406,6 @@ absl::StatusOr<nb::tuple> PyDeviceList::MemoryKinds() {
}

absl::StatusOr<nb::object> PyDeviceList::DefaultMemoryKind() {
if (!GetEnableMemories()) {
return nb::none();
}
if (!memory_kind_info_.has_value()) {
PopulateMemoryKindInfo();
}
Expand Down
11 changes: 0 additions & 11 deletions xla/python/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,6 @@ namespace jax {

namespace nb = nanobind;

bool (*GetEnableMemories)() = +[] {
static bool fetch_memory_kind_on_executable = [] {
char* v = getenv("JAX_ENABLE_MEMORIES");
if (v == nullptr || *v == '\0') {
return false;
}
return true;
}();
return fetch_memory_kind_on_executable;
};

nb::object CheckAndCanonicalizeMemoryKind(
nb::object memory_kind,
const xla::nb_class_ptr<PyDeviceList>& device_list) {
Expand Down
2 changes: 0 additions & 2 deletions xla/python/sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ class Sharding {
std::optional<int> num_devices_;
};

extern bool (*GetEnableMemories)();

// Checks if the memory kind is valid, and canonicalizes the
// memory kind to default memory on backends that support memories.
nanobind::object CheckAndCanonicalizeMemoryKind(
Expand Down
1 change: 0 additions & 1 deletion xla/python/xla_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
xla_client._xla.jax_jit.set_thread_local_state_initialization_callback(
lambda: None
)
xla_client._xla.jax_jit.global_state().enable_memories = False

bfloat16 = xla_client.bfloat16
# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
Expand Down
1 change: 0 additions & 1 deletion xla/python/xla_extension/jax_jit.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ Device = xla_extension.Device
class JitState:
disable_jit: Optional[bool]
enable_x64: Optional[bool]
enable_memories: Optional[bool]
default_device: Optional[Any]
extra_jit_context: Optional[Any]
post_hook: Optional[Callable[..., Any]]
Expand Down

0 comments on commit 052f17f

Please sign in to comment.