Skip to content

Commit

Permalink
#13554: Weaken symbols using elf loader (#13863)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-TT authored Oct 18, 2024
1 parent aac9ebe commit 3423b44
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 48 deletions.
23 changes: 12 additions & 11 deletions tt_metal/jit_build/build.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <fstream>
#include <iomanip>
#include <iostream>
#include <span>
#include <sstream>
#include <string>
#include <thread>
Expand All @@ -18,8 +19,9 @@
#include "jit_build/kernel_args.hpp"
#include "tools/profiler/common.hpp"
#include "tools/profiler/profiler_state.hpp"
#include "tt_metal/impl/kernels/kernel.hpp"
#include "tt_metal/impl/dispatch/command_queue_interface.hpp"
#include "tt_metal/impl/kernels/kernel.hpp"
#include "tt_metal/llrt/tt_elffile.hpp"

namespace fs = std::filesystem;

Expand Down Expand Up @@ -521,7 +523,7 @@ void JitBuildState::link(const string& log_file, const string& out_dir) const {
if (!this->is_fw_) {
string weakened_elf_name =
env_.out_firmware_root_ + this->target_name_ + "/" + this->target_name_ + "_weakened.elf";
cmd += " -Xlinker \"--just-symbols=" + weakened_elf_name + "\" ";
cmd += "-Wl,--just-symbols=" + weakened_elf_name + " ";
}

cmd += "-o " + out_dir + this->target_name_ + ".elf";
Expand All @@ -537,16 +539,15 @@ void JitBuildState::link(const string& log_file, const string& out_dir) const {
// strong so to propogate link addresses
void JitBuildState::weaken(const string& log_file, const string& out_dir) const {
ZoneScoped;
string cmd;
cmd = "cd " + out_dir + " && ";
cmd += env_.objcopy_;
cmd += " --wildcard --weaken-symbol \"*\" --weaken-symbol \"!__fw_export_*\" " + this->target_name_ + ".elf " +
this->target_name_ + "_weakened.elf";

log_debug(tt::LogBuildKernels, " objcopy cmd: {}", cmd);
if (!tt::utils::run_command(cmd, log_file, false)) {
build_failure(this->target_name_, "objcopy weaken", cmd, log_file);
}
std::string pathname_in = out_dir + target_name_ + ".elf";
std::string pathname_out = out_dir + target_name_ + "_weakened.elf";

ll_api::ElfFile elf;
elf.ReadImage(pathname_in);
static std::string_view const strong_names[] = {"__fw_export_*"};
elf.WeakenDataSymbols(strong_names);
elf.WriteImage(pathname_out);
}

void JitBuildState::extract_zone_src_locations(const string& log_file) const {
Expand Down
179 changes: 142 additions & 37 deletions tt_metal/llrt/tt_elffile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

#include "tt_elffile.hpp"

#include <algorithm>
#include <array>

#include "common/assert.hpp"
// C
#include <errno.h>
Expand Down Expand Up @@ -45,84 +48,83 @@ using namespace ll_api;

class ElfFile::Impl {
private:
std::span<Elf32_Phdr const> phdrs_;
std::span<Elf32_Shdr const> shdrs_;
std::span<Elf32_Phdr> phdrs_;
std::span<Elf32_Shdr> shdrs_;
std::string const &path_;
ElfFile &owner_;

private:
ElfFile &owner_;
class Weakener;

public:
Impl(ElfFile &owner, std::string const &path) : owner_(owner), path_(path) {}
~Impl() = default;

public:
void LoadImage();
void WeakenDataSymbols(std::span<std::string_view const> strong_names);

private:
Elf32_Ehdr const &GetHeader() const { return *reinterpret_cast<Elf32_Ehdr const *>(GetContents().data()); }
std::span<Elf32_Phdr const> GetPhdrs() const { return phdrs_; }
std::span<Elf32_Shdr const> GetShdrs() const { return shdrs_; }
Elf32_Shdr const &GetShdr(unsigned ix) const { return shdrs_[ix]; }
std::vector<Segment> &GetSegments() const { return owner_.segments_; }
std::span<std::byte> &GetContents() const { return owner_.contents_; }
std::span<std::byte> GetContents(Elf32_Phdr const &phdr) {
[[nodiscard]] auto GetHeader() const -> Elf32_Ehdr const & { return *ByteOffset<Elf32_Ehdr>(GetContents().data()); }
[[nodiscard]] auto GetPhdrs() const -> std::span<Elf32_Phdr const> { return phdrs_; }
[[nodiscard]] auto GetShdrs() const -> std::span<Elf32_Shdr const> { return shdrs_; }
[[nodiscard]] auto GetShdr(unsigned ix) const -> Elf32_Shdr const & { return shdrs_[ix]; }
[[nodiscard]] auto GetSegments() const -> std::vector<Segment> & { return owner_.segments_; }
[[nodiscard]] auto GetContents() const -> std::span<std::byte> & { return owner_.contents_; }
[[nodiscard]] auto GetContents(Elf32_Phdr const &phdr) const -> std::span<std::byte> {
return GetContents().subspan(phdr.p_offset, phdr.p_filesz);
}
std::span<std::byte> GetContents(Elf32_Shdr const &shdr) {
[[nodiscard]] auto GetContents(Elf32_Shdr const &shdr) const -> std::span<std::byte> {
return GetContents().subspan(shdr.sh_offset, shdr.sh_size);
}
char const *GetString(size_t offset, unsigned ix) {
if (ix >= GetShdrs().size())
bad:
return "*bad*";
auto &shdr = GetShdr(ix);
if (shdr.sh_type != SHT_STRTAB)
goto bad;
auto strings = GetContents(GetShdr(ix));
if (offset >= strings.size())
goto bad;
return ByteOffset<char const>(strings.data(), offset);
}
char const *GetName(Elf32_Shdr const &shdr) { return GetString(shdr.sh_name, GetHeader().e_shstrndx); }
std::span<Elf32_Sym> GetSymbols(Elf32_Shdr const &shdr) {
[[nodiscard]] auto GetString(size_t offset, Elf32_Shdr const &shdr) const -> char const * {
return ByteOffset<char const>(GetContents(shdr).data(), offset);
}
[[nodiscard]] auto GetName(Elf32_Shdr const &shdr) const -> char const * {
return GetString(shdr.sh_name, GetShdr(GetHeader().e_shstrndx));
}
[[nodiscard]] auto GetSymbols(Elf32_Shdr const &shdr) const -> std::span<Elf32_Sym> {
auto section = GetContents(shdr);
return std::span(ByteOffset<Elf32_Sym>(section.data()), section.size() / shdr.sh_entsize);
}
char const *GetName(Elf32_Sym const &sym, unsigned lk) { return GetString(sym.st_name, lk); }
std::span<Elf32_Rela> GetRelocations(Elf32_Shdr const &shdr) {
[[nodiscard]] auto GetName(Elf32_Sym const &sym, unsigned link) const -> char const * {
return GetString(sym.st_name, GetShdr(link));
}
[[nodiscard]] auto GetRelocations(Elf32_Shdr const &shdr) const -> std::span<Elf32_Rela> {
auto section = GetContents(shdr);
return std::span(ByteOffset<Elf32_Rela>(section.data()), section.size() / shdr.sh_entsize);
}

static bool IsInSegment(Segment const &segment, Elf32_Shdr const &shdr) {
[[nodiscard]] static bool IsInSegment(Segment const &segment, Elf32_Shdr const &shdr) {
// Remember, Segments use word_t sizes
return shdr.sh_flags & SHF_ALLOC && shdr.sh_addr >= segment.address &&
shdr.sh_addr + shdr.sh_size <=
segment.address + (segment.contents.size() + segment.bss) * sizeof (word_t);
}
bool IsInSegment(unsigned ix, Elf32_Shdr const &shdr) const { return IsInSegment(GetSegments()[ix], shdr); }
bool IsInText(Elf32_Shdr const &shdr) const { return IsInSegment(GetSegments().front(), shdr); };
int GetSegmentIx(Elf32_Shdr const &shdr) const {
[[nodiscard]] bool IsInSegment(unsigned _ix, Elf32_Shdr const &shdr) const {
return IsInSegment(GetSegments()[_ix], shdr);
}
[[nodiscard]] bool IsInText(Elf32_Shdr const &shdr) const { return IsInSegment(GetSegments().front(), shdr); };
[[nodiscard]] int GetSegmentIx(Elf32_Shdr const &shdr) const {
for (unsigned ix = GetSegments().size(); ix--;)
if (IsInSegment(ix, shdr))
return ix;
return -1;
};
bool IsTextSymbol(Elf32_Sym const &symbol) const {
[[nodiscard]] bool IsTextSymbol(Elf32_Sym const &symbol) const {
return symbol.st_shndx < GetShdrs().size() && IsInText(GetShdr(symbol.st_shndx));
}
bool IsDataSymbol(Elf32_Sym const &symbol) const {
[[nodiscard]] bool IsDataSymbol(Elf32_Sym const &symbol) const {
return symbol.st_shndx < GetShdrs().size() && GetSegmentIx(GetShdr(symbol.st_shndx)) > 0;
}

private:
template <typename T = std::byte>
static T *ByteOffset(std::byte *base, size_t offset = 0) {
[[nodiscard]] static T *ByteOffset(std::byte *base, size_t offset = 0) {
return reinterpret_cast<T *>(base + offset);
}
template <typename T = std::byte>
static T const *ByteOffset(std::byte const *base, size_t offset = 0) {
[[nodiscard]] static T const *ByteOffset(std::byte const *base, size_t offset = 0) {
return reinterpret_cast<T const *>(base + offset);
}
};
Expand Down Expand Up @@ -156,6 +158,23 @@ void ElfFile::ReadImage(std::string const &path) {
pimpl_->LoadImage();
}

void ElfFile::WriteImage(std::string const &path) {
// open is an os-defined varadic function, it the API to use.
int file_descriptor = open(
path.c_str(),
O_WRONLY | O_CLOEXEC | O_CREAT | O_TRUNC,
S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP | S_IROTH | S_IWOTH);
bool failed = file_descriptor < 0;
if (!failed) {
failed = write(file_descriptor, contents_.data(), contents_.size()) != ssize_t(contents_.size());
close(file_descriptor);
}
if (failed)
TT_THROW("{}: cannot map elf file into memory: {}", path, strerror(errno));
}

void ElfFile::WeakenDataSymbols(std::span<std::string_view const> strong) { pimpl_->WeakenDataSymbols(strong); }

void ElfFile::Impl::LoadImage() {
auto &hdr = GetHeader();

Expand All @@ -181,11 +200,11 @@ void ElfFile::Impl::LoadImage() {
if (!hdr.e_phoff || hdr.e_phoff & (sizeof(address_t) - 1) || hdr.e_phentsize != sizeof(Elf32_Phdr) ||
(hdr.e_phoff + hdr.e_phnum * sizeof(Elf32_Phdr) > GetContents().size()))
TT_THROW("{}: PHDRS are missing or malformed", path_);
phdrs_ = std::span(ByteOffset<Elf32_Phdr const>(GetContents().data(), hdr.e_phoff), hdr.e_phnum);
phdrs_ = std::span(ByteOffset<Elf32_Phdr>(GetContents().data(), hdr.e_phoff), hdr.e_phnum);
if (!hdr.e_shoff || hdr.e_shoff & (sizeof(address_t) - 1) || hdr.e_shentsize != sizeof(Elf32_Shdr) ||
(hdr.e_shoff + hdr.e_shnum * sizeof(Elf32_Shdr) > GetContents().size()))
TT_THROW("{}: sections are missing or malformed", path_);
shdrs_ = std::span(ByteOffset<Elf32_Shdr const>(GetContents().data(), hdr.e_shoff), hdr.e_shnum);
shdrs_ = std::span(ByteOffset<Elf32_Shdr>(GetContents().data(), hdr.e_shoff), hdr.e_shnum);
if (!hdr.e_shstrndx || hdr.e_shstrndx >= GetShdrs().size())
TT_THROW("{}: string table is missing or malformed", path_);

Expand Down Expand Up @@ -236,3 +255,89 @@ void ElfFile::Impl::LoadImage() {
std::rotate(segments.begin(), text, std::next(text, 1));
}
}

class ElfFile::Impl::Weakener {
enum { LOCAL, GLOBAL, HWM };

Elf32_Shdr const &shdr_;
std::span<Elf32_Sym> syms_in_;
std::vector<unsigned> remap_;
std::vector<Elf32_Sym> syms_out_[HWM];

public:
Weakener(Elf32_Shdr const &shdr, std::span<Elf32_Sym> symbols) :
shdr_(shdr), syms_in_(symbols.subspan(shdr.sh_info)) {
unsigned reserve = syms_in_.size() - shdr_.sh_info;
remap_.reserve(reserve);
std::ranges::for_each(syms_out_, [=](std::vector<Elf32_Sym> &syms) { syms.reserve(reserve); });
}

void WeakenOrLocalizeSymbols(Impl &impl, std::span<std::string_view const> strong) {
auto name_matches = [](std::string_view name, std::span<std::string_view const> list) {
return std::ranges::any_of(list, [&](std::string_view pattern) {
return pattern.back() == '*' ? name.starts_with(pattern.substr(0, pattern.size() - 1))
: name == pattern;
});
};

// Weaken or hide globals
for (auto &sym : syms_in_) {
auto kind = GLOBAL;
if ((ELF32_ST_BIND(sym.st_info) == STB_GLOBAL || ELF32_ST_BIND(sym.st_info) == STB_WEAK) &&
!name_matches(impl.GetName(sym, shdr_.sh_link), strong)) {
unsigned bind = impl.IsDataSymbol(sym) ? STB_WEAK : STB_LOCAL;
sym.st_info = ELF32_ST_INFO(bind, ELF32_ST_TYPE(sym.st_info));
if (bind == STB_LOCAL)
kind = LOCAL;
}
remap_.push_back(syms_out_[kind].size() ^ (kind == GLOBAL ? ~0U : 0U));
syms_out_[kind].push_back(sym);
}
}

void UpdateRelocations(std::span<Elf32_Rela> relocs) {
// Adjust relocs using remap array.
const unsigned num_locals = shdr_.sh_info;
for (auto &reloc : relocs) {
unsigned sym_ix = ELF32_R_SYM(reloc.r_info);
if (sym_ix < num_locals)
continue;

sym_ix = remap_[sym_ix - num_locals];
if (bool(sym_ix & (~0U ^ (~0U >> 1))))
sym_ix = ~sym_ix + syms_out_[LOCAL].size();
reloc.r_info = ELF32_R_INFO(ELF32_R_TYPE(reloc.r_info), sym_ix + num_locals);
}
}

void RewriteSymbols() {
// Rewrite the symbols
std::copy(syms_out_[LOCAL].begin(), syms_out_[LOCAL].end(), syms_in_.begin());
const_cast<Elf32_Shdr &>(shdr_).sh_info += syms_out_[LOCAL].size();

std::copy(
syms_out_[GLOBAL].begin(),
syms_out_[GLOBAL].end(),
std::next(syms_in_.begin(), ssize_t(syms_out_[LOCAL].size())));
}
};

// Any global symbol matching STRONG is preserved.
// Any global symbol in a data-segment section is weakened
// Any other global symbol is made local
void ElfFile::Impl::WeakenDataSymbols(std::span<std::string_view const> strong) {
for (unsigned ix = GetShdrs().size(); bool(ix--);) {
auto &shdr = GetShdr(ix);
if (shdr.sh_type != SHT_SYMTAB || bool(shdr.sh_flags & SHF_ALLOC))
continue;

Weakener weakener(shdr, GetSymbols(shdr));
weakener.WeakenOrLocalizeSymbols(*this, strong);

for (auto const &relhdr : GetShdrs())
if (relhdr.sh_type == SHT_RELA && relhdr.sh_link == ix)
weakener.UpdateRelocations(GetRelocations(relhdr));

weakener.RewriteSymbols();
}
}
8 changes: 8 additions & 0 deletions tt_metal/llrt/tt_elffile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ class ElfFile {
// Path must remain live throughout processing.
void ReadImage(std::string const &path);

// Write the (now-processed) elf file.
void WriteImage(std::string const &path);

// Weaken data symbols, remove all others. Keep STRONG_NAMES
// strong (can be non-data symbols). Names can be exact or simple
// globs ending in '*'.
void WeakenDataSymbols(std::span<std::string_view const> strong_names);

private:
class Impl;
// We can't use unique_ptr here, because the above move semantics
Expand Down

0 comments on commit 3423b44

Please sign in to comment.