Skip to content

Commit

Permalink
fix(examples): Fix all stream_io.open_stream to include S3 auth
Browse files Browse the repository at this point in the history
This commit ensures that all invocations to `stream_io.open_stream`
uses the provided credentials to ensure the code is independent
of `.s3cfg` to run.
  • Loading branch information
sangstar committed Jan 31, 2024
1 parent 700a41a commit bf23565
Showing 1 changed file with 28 additions and 21 deletions.
49 changes: 28 additions & 21 deletions examples/hf_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@
default_s3_read_endpoint = "accel-object.ord1.coreweave.com"
default_s3_write_endpoint = "object.ord1.coreweave.com"

s3_read_credentials = (
s3_access_key_id,
s3_secret_access_key,
default_s3_read_endpoint,
)

s3_write_credentials = (
s3_access_key_id,
s3_secret_access_key,
default_s3_write_endpoint,
)

# Setup logger
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
Expand Down Expand Up @@ -68,9 +80,7 @@ def check_file_exists(
with stream_io.open_stream(
file,
"rb",
s3_access_key_id,
s3_secret_access_key,
default_s3_read_endpoint,
*s3_read_credentials,
) as f:
f.read(1)
return True
Expand Down Expand Up @@ -119,11 +129,7 @@ def serialize_model(
if (not config_file_exists) or force:
logger.info(f"Writing config to {config_path}")
with stream_io.open_stream(
config_path,
"wb+",
s3_access_key_id,
s3_secret_access_key,
default_s3_write_endpoint,
config_path, "wb+", *s3_write_credentials
) as f:
if hasattr(config, "to_dict"):
f.write(bytes(json.dumps(config.to_dict()), "utf-8"))
Expand All @@ -133,11 +139,7 @@ def serialize_model(
if (not weights_file_exists) or force:
logger.info(f"Writing tensors to {dir_prefix}.tensors")
with stream_io.open_stream(
f"{dir_prefix}.tensors",
"wb+",
s3_access_key_id,
s3_secret_access_key,
default_s3_write_endpoint,
f"{dir_prefix}.tensors", "wb+", *s3_write_credentials
) as f:
ts = TensorSerializer(f)
ts.write_module(model)
Expand Down Expand Up @@ -182,11 +184,7 @@ def load_model(
logger.info(f"Loading {tensors_uri}, {ram_usage}")

tensor_stream = stream_io.open_stream(
tensors_uri,
"rb",
s3_access_key_id,
s3_secret_access_key,
default_s3_read_endpoint,
tensors_uri, "rb", *s3_read_credentials
)

tensor_deserializer = TensorDeserializer(
Expand All @@ -199,7 +197,11 @@ def load_model(
temp_config_path = os.path.join(temp_dir, "config.json")
with open(temp_config_path, "wb") as temp_config:
logger.info(f"Loading {config_uri}, {ram_usage}")
temp_config.write(stream_io.open_stream(config_uri).read())
temp_config.write(
stream_io.open_stream(
config_uri, "rb", *s3_read_credentials
).read()
)
config = config_class.from_pretrained(temp_dir)
config.gradient_checkpointing = True
except ValueError:
Expand All @@ -212,7 +214,9 @@ def load_model(
else:
try:
config = json.loads(
stream_io.open_stream(config_uri).read().decode("utf-8")
stream_io.open_stream(config_uri, "rb", *s3_read_credentials)
.read()
.decode("utf-8")
)
except ValueError:
with open(config_uri, "r") as f:
Expand Down Expand Up @@ -385,7 +389,10 @@ def main():
parser.add_argument(
"--force",
action="store_true",
help="Force upload serialized tensors to output_prefix even if they already exist",
help=(
"Force upload serialized tensors to output_prefix even if they"
" already exist"
),
)
args = parser.parse_args()

Expand Down

0 comments on commit bf23565

Please sign in to comment.