Skip to content

Commit

Permalink
Add --tf32 device arg for transparent float32 acceleration (#1066)
Browse files Browse the repository at this point in the history
- --tf32 0|1 bool device (torch.backends.cuda.matmul.allow_tf32)
 enabling 10-bit precision (19 bit total) transparent float32
 acceleration. default true for backward compat with torch < 1.12.
 allow different --tf32 training continuation

- device.init_device called by train, translate, and score

- allow torch 1.12 in requirements.txt
  • Loading branch information
graehl authored Nov 5, 2022
1 parent c09a168 commit f852fbd
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 19 deletions.
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.1.26]

### Added

- `--tf32 0|1` bool device (`torch.backends.cuda.matmul.allow_tf32`)
enabling 10-bit precision (19 bit total) transparent float32
acceleration. default true for backward compat with torch < 1.12.
allow different `--tf32` training continuation

### Changed

- device.init_device called by train, translate, and score

- allow torch 1.12 in requirements.txt

## [3.1.25]

## Changed
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torch>=1.10.0,<1.12.0
torch>=1.10.0,<1.13.0
pyyaml>=5.1
numpy>1.16.0,<2.0.0
sacrebleu>=2.3.0
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.1.25'
__version__ = '3.1.26'
4 changes: 4 additions & 0 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,10 @@ def add_device_args(params):
device_params.add_argument('--env',
help='List of environment variables to be set before importing PyTorch. Separated by '
'",", e.g. --env=OMP_NUM_THREADS=1,PYTORCH_JIT=0 etc.')
device_params.add_argument('--tf32',
type=bool_str(),
default=True,
help='Globally enable transparent tf32 acceleration of float32 at the cost of reducing precision to 10 bits')


def add_vocab_args(params):
Expand Down
4 changes: 3 additions & 1 deletion sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@
# Arguments that may differ and still resume training
ARGS_MAY_DIFFER = ["device_id", "device_ids", "overwrite_output", "use_tensorboard", "quiet", "align_plot_prefix",
"sure_align_threshold", "keep_last_params", "seed", "max_updates", "min_updates", "max_num_epochs",
"min_num_epochs", "max_samples", "min_samples", "max_checkpoints", "max_seconds", "local_rank"]
"min_num_epochs", "max_samples", "min_samples", "max_checkpoints", "max_seconds", "local_rank", "tf32"]

# Other argument constants
TRAINING_ARG_SOURCE = "--source"
Expand Down Expand Up @@ -285,6 +285,7 @@
DTYPE_BF16 = 'bfloat16'
DTYPE_FP16 = 'float16'
DTYPE_FP32 = 'float32'
DTYPE_TF32 = 'tf32'
DTYPE_INT8 = 'int8'
DTYPE_INT32 = 'int32'
LARGE_POSITIVE_VALUE = 99999999.
Expand All @@ -305,6 +306,7 @@
DTYPE_FP32: LARGE_POSITIVE_VALUE,
np.float32: LARGE_POSITIVE_VALUE,
pt.float32: LARGE_POSITIVE_VALUE,
# with --tf32, rounds to 0b1011111011 * 1024 * 128 (10 bits precision) = 1.00007e8

# Rounds to 1.0014e+08
DTYPE_BF16: LARGE_POSITIVE_VALUE,
Expand Down
25 changes: 25 additions & 0 deletions sockeye/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
import argparse
from typing import Optional

def init_device(args: argparse.Namespace, logger=None, local_rank : Optional[int] = None):
"""
return requested torch device, optionally enabling tf32
:param args "Device Parameters". args.use_cpu will be set if cuda is not available
:param logger optional logger.info(msg)
:param local_rank optional int LOCAL_RANK env for multiple GPU training
"""
if not torch.cuda.is_available():
if logger is not None:
logger.info("CUDA not available, using cpu")
args.use_cpu = True
device = torch.device('cpu') if args.use_cpu else torch.device('cuda', args.device_id if local_rank is None else local_rank)
if not args.use_cpu:
# Ensure that GPU operations use the correct device by default
torch.cuda.set_device(device)
if args.tf32:
if logger is not None:
logger.info("CUDA: allow tf32 (float32 but with 10 bits precision)")
torch.backends.cuda.matmul.allow_tf32 = True
return device
7 changes: 2 additions & 5 deletions sockeye/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from . import data_io
from . import utils
from .beam_search import CandidateScorer
from .device import init_device
from .log import setup_main_logger
from .model import load_model
from .output_handler import get_output_handler
Expand All @@ -50,11 +51,7 @@ def score(args: argparse.Namespace):

utils.log_basic_info(args)

use_cpu = args.use_cpu
if not pt.cuda.is_available():
logger.info("CUDA not available, using cpu")
use_cpu = True
device = pt.device('cpu') if use_cpu else pt.device('cuda', args.device_id)
device = init_device(args, logger)
logger.info(f"Scoring device: {device}")

model, source_vocabs, target_vocabs = load_model(args.model, device=device, dtype=args.dtype)
Expand Down
9 changes: 3 additions & 6 deletions sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
except ImportError:
pass


from . import arguments
from . import checkpoint_decoder
from . import constants as C
Expand All @@ -56,6 +57,7 @@
from . import utils
from . import vocab
from .config import Config
from .device import init_device
from .log import setup_main_logger
from .utils import check_condition

Expand Down Expand Up @@ -996,12 +998,7 @@ def train(args: argparse.Namespace, custom_metrics_logger: Optional[Callable] =
logger.info("Adjusting maximum length to reserve space for a BOS/EOS marker. New maximum length: (%d, %d)",
max_seq_len_source, max_seq_len_target)

device = torch.device('cpu') if args.use_cpu \
else torch.device('cuda', utils.get_local_rank()) if utils.is_distributed() \
else torch.device('cuda', args.device_id)
if not args.use_cpu:
# Ensure that GPU operations use the correct device by default
torch.cuda.set_device(device)
device = init_device(args, logger, utils.get_local_rank() if utils.is_distributed() else None)
logger.info(f'Training Device: {device}')
utils.seed_rngs(args.seed)

Expand Down
7 changes: 2 additions & 5 deletions sockeye/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import torch as pt

from .device import init_device
from sockeye.lexicon import load_restrict_lexicon, RestrictLexicon
from sockeye.log import setup_main_logger
from sockeye.model import load_models
Expand Down Expand Up @@ -66,11 +67,7 @@ def run_translate(args: argparse.Namespace):
output_handler = get_output_handler(args.output_type,
args.output)

use_cpu = args.use_cpu
if not pt.cuda.is_available():
logger.info("CUDA not available, using cpu")
use_cpu = True
device = pt.device('cpu') if use_cpu else pt.device('cuda', args.device_id)
device = init_device(args, logger)
logger.info(f"Translate Device: {device}")
models, source_vocabs, target_vocabs = load_models(device=device,
model_folders=args.models,
Expand Down
2 changes: 2 additions & 0 deletions test/unit/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,12 @@ def test_quantize_args(test_params, expected_params):
@pytest.mark.parametrize("test_params, expected_params", [
('', dict(device_id=0,
use_cpu=False,
tf32=True,
env=None)),
('--device-id 1 --use-cpu ',
dict(device_id=1,
use_cpu=True,
tf32=True,
env=None))
])
def test_device_args(test_params, expected_params):
Expand Down

0 comments on commit f852fbd

Please sign in to comment.