Skip to content

Commit

Permalink
feat(examples): Add S3 auth by passing env variables
Browse files Browse the repository at this point in the history
This commit allows S3 authentication by passing in environmental
variables `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY` and
`AWS_ENDPOINT_URL_`. If `AWS_ENDPOINT_URL_` is not found in `env`,
it is given `default_s3_read_endpoint` for `rb` mode and
`default_s3_write_endpint` for `wb` mode.
  • Loading branch information
sangstar committed Jan 31, 2024
1 parent 71e7da3 commit 9126d09
Showing 1 changed file with 35 additions and 6 deletions.
41 changes: 35 additions & 6 deletions examples/hf_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@

from tensorizer import TensorDeserializer, TensorSerializer, stream_io, utils

s3_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID")
s3_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY")

if os.environ.get("AWS_ENDPOINT_URL"):
default_s3_read_endpoint = os.environ.get("AWS_ENDPOINT_URL")
default_s3_write_endpoint = os.environ.get("AWS_ENDPOINT_URL")
else:
default_s3_read_endpoint = "accel-object.ord1.coreweave.com"
default_s3_write_endpoint = "object.ord1.coreweave.com"

# Setup logger
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
Expand All @@ -54,7 +64,13 @@ def check_file_exists(
if os.path.exists(file):
return True
else:
with stream_io.open_stream(file, "rb") as f:
with stream_io.open_stream(
file,
"wb",
s3_access_key_id,
s3_secret_access_key,
default_s3_read_endpoint,
) as f:
if f.read(1) == b"":
return False
else:
Expand Down Expand Up @@ -101,18 +117,31 @@ def serialize_model(
config_path = f"{dir_prefix}-config.json"
if (not config_file_exists) or force:
logger.info(f"Writing config to {config_path}")
with stream_io.open_stream(config_path, "wb") as f:
with stream_io.open_stream(
config_path,
"wb",
s3_access_key_id,
s3_secret_access_key,
default_s3_write_endpoint,
) as f:
if hasattr(config, "to_dict"):
f.write(bytes(json.dumps(config.to_dict()), "utf-8"))
elif isinstance(config, dict):
f.write(bytes(json.dumps(config), "utf-8"))
f.close() ## Remove after PR
f.close() ## Remove after PR merged

if (not weights_file_exists) or force:
logger.info(f"Writing tensors to {dir_prefix}.tensors")
ts = TensorSerializer(f"{dir_prefix}.tensors")
ts.write_module(model)
ts.close()
with stream_io.open_stream(
f"{dir_prefix}.tensors",
"wb",
s3_access_key_id,
s3_secret_access_key,
default_s3_write_endpoint,
) as f:
ts = TensorSerializer(f)
ts.write_module(model)
ts.close()


def load_model(
Expand Down

0 comments on commit 9126d09

Please sign in to comment.