Skip to content

Commit

Permalink
updated.
Browse files Browse the repository at this point in the history
  • Loading branch information
tanyuqian committed Nov 5, 2023
1 parent 905bc82 commit 72e5bd5
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 29 deletions.
98 changes: 96 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

Redco is a lightweight and user-friendly tool designed to automate distributed training and inference for large models while simplifying the ML pipeline development process without necessitating MLSys expertise from users.

Check out our [Tech Report Here](https://arxiv.org/pdf/2310.16355.pdf) for details!

* Redco allows for the simple implementation of distributed training and inference, eliminating the need for additional coding efforts or complex configurations, but still exhibits efficiency comparable to the most advanced model parallel tools.
* Redco enables customization of arbitrary ML pipelines within three functions, eliminating repetitive ans boilerplate coding, such as multi-host related processing, etc. We demonstrate that this mechanism is widely applicable to various ML algorithms

Expand All @@ -19,8 +21,8 @@ Redco is a lightweight and user-friendly tool designed to automate distributed t

#### Install Jax & Flax
```
pip install --upgrade jax[cuda11_pip]==0.4.13 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade flax==0.7.0
pip install --upgrade jax[cuda11_pip]==0.4.13 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
Jax version (`==0.4.13`) and Flax version (`==0.7.0`) can be flexible, as long as they match your CUDA/CUDNN version.
Besides, the Flax modeling in the HuggingFace implementation sometimes doesn't support recent Jax & Flax versions.
Expand All @@ -31,7 +33,12 @@ If you are using TPU/CPU/AMD/Apple, see [here](https://github.com/google/jax#ins
```
pip install redco
```

For the most up-to-date version:
```
git clone https://github.com/tanyuqian/redco.git
cd redco
pip install -e .
```


### Examples
Expand All @@ -58,6 +65,93 @@ The table below shows runnable model LLM finetuning on different kinds of server

Go to [example/language_modeling](examples%2Flanguage_modeling) and [examples/text_to_text](examples%2Ftext_to_text) to try them out!


### Quick Tutorial

Below is a template code to customize an arbitrary distributed training pipeline with redco.

* No need to be a jax expert: `numpy` is pretty enough
* No need MLSys knowledge: only specify a number `n_model_shards` to split your model
* ONLY NEED to focus on your algorithm design!

After checking out our [text classification example (glue_main.py)](examples/classification_regression/glue_main.py), you'll be an expert of redco!

```python
def collate_fn(examples, ...):
# from raw examples to model inputs, e.g., tokenization
return {'input_ids': input_ids, 'labels': labels}


def loss_fn(train_rng, state, params, batch, is_training, ...):
# from model inputs defined in collate_fn, run the model and get the loss, e.g., cross_entropy
logits = model(input_ids=batch['input_ids'], params=params)
loss = cross_entropy(logits, batch['labels'])
return loss


def pred_fn(pred_rng, params, batch, model, gen_kwargs):
# from model inputs defined in collate_fn, run the model and get predictions, e.g., beam search
batch_preds = model.generate(input_ids=batch['input_ids'],params=params)
return batch_preds


def output_fn(batch_preds, tokenizer):
# (optional) post process of output tensors, e.g., decode output_ids to text
return tokenizer.batch_decode(batch_preds)


def eval_metric_fn(examples, preds):
# (optional) given test examples and predictions, calculate evaluation metrics, e.g., Rouge-L
return rouge_scorer.compute(
predictions=preds,
references=[example['target'] for example in examples],
rouge_types=['rouge1', 'rouge2', 'rougeL'])

# define seed, workdir, tensorboard, wandb, multi-host env, etc.
deployer = redco.Deployer(
jax_seed=jax_seed, # randomness control
n_model_shards=n_model_shards, # how many pieces to split the model (the only number needed for model parallelism)
workdir=workdir, run_tensorboard=True, run_wandb=True, # logging utils
host0_address='111.222.333.444', n_processes=2 # setup multi-host env
)

train_examples, valid_examples = load_dataset(...) # load dataset into python-list
model, params = FlaxModel() # a model defined in flax, e.g., transformers.FlaxT5ForConditionalGeneration()
optimizer = optax.adam(lr=0.001) # a optimizer defined in optax

# define redco.Trainer
trainer = redco.Trainer(
deployer=deployer,
collate_fn=collate_fn,
loss_fn=loss_fn,
params=params,
optimizer=optimizer,
params_sharding_rules=deployer.get_sharding_rules(params) # automatically generated model parallelism
)

# define redco.Predictor for prediction and evaluation during training
predictor = trainer.get_default_predictor(
pred_fn=pred_fn, output_fn=output_fn)

# pass in your training config and run the training
trainer.fit(
train_examples=train_examples,
per_device_batch_size=per_device_batch_size,
n_epochs=n_epochs,
eval_examples=valid_examples,
eval_per_device_batch_size=eval_per_device_batch_size,
eval_loss=True, # if compute loss on eval_examples after each epoch
eval_predictor=predictor, # run prediction on eval_examples after each epoch
eval_metric_fn=eval_metric_fn, # eval_metric_fn above
eval_sanity_check=True,
save_every_ckpt=False,
save_last_ckpt=True,
save_argmin_ckpt_by_metrics=None,
save_argmax_ckpt_by_metrics=['rouge-L'], # save the model with the best rouge-L score defined in eval_metric_fn
save_opt_states=True)
```


## Acknowledgement


Expand Down
9 changes: 9 additions & 0 deletions redco/deployers/deployer.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ def load_params(self, filepath):
def load_opt_state(self, ckpt_dir, desc, target):
if self._mesh is None:
filepath = f'{ckpt_dir}/opt_state_{desc}.msgpack'
self.log_info(f'Skip loading opt_state (No file {filepath})')
if not os.path.exists(filepath):
return None

opt_state = msgpack_restore(open(filepath, 'rb').read())
opt_state = from_state_dict(target=target, state=opt_state)
opt_state = replicate(opt_state)
Expand All @@ -255,9 +259,14 @@ def load_opt_state(self, ckpt_dir, desc, target):
ckpt_process_idx = jax.process_index() % n_processes_per_model
filepath = (f'{ckpt_dir}/opt_state_{desc}'
f'_process_{ckpt_process_idx}.msgpack')
if not os.path.exists(filepath):
self.log_info(f'Skip loading opt_state (No file {filepath})')
return None

opt_state = msgpack_restore(open(filepath, 'rb').read())
opt_state = from_state_dict(target=target, state=opt_state)

self.log_info(f'opt_state loaded from {filepath}.')
return opt_state

def save_params(self,
Expand Down
83 changes: 56 additions & 27 deletions redco/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,18 @@ def __init__(self,
optimizer=self._optimizer,
step=self._last_ckpt_info['last_step'])

self._state = self._state.replace(
opt_state=self._deployer.load_opt_state(
ckpt_dir=f'{self.workdir}/ckpts',
desc=desc,
target=self._state.opt_state))
opt_state = self._deployer.load_opt_state(
ckpt_dir=f'{self.workdir}/ckpts',
desc=desc,
target=self._state.opt_state)
if opt_state is not None:
self._state = self._state.replace(opt_state=opt_state)

self._deployer._rng = jnp.load(
f'{self.workdir}/ckpts/rng_{desc}.npy')

self._deployer.log_info(
'detect last_ckpt \"{last_desc}\",'
'loaded last_ckpt \"{last_desc}\",'
' last_step={last_step},'
' last_epoch_idx={last_epoch_idx}'.format(
**self._last_ckpt_info))
Expand Down Expand Up @@ -217,10 +218,12 @@ def fit(self,
eval_loss=True,
eval_predictor=None,
eval_metric_fn=None,
eval_sanity_check=True,
save_every_ckpt=False,
save_last_ckpt=False,
save_argmin_ckpt_by_metrics=None,
save_argmax_ckpt_by_metrics=None,
save_last_ckpt=False):
save_opt_states=True):
if save_argmax_ckpt_by_metrics is None:
save_argmax_ckpt_by_metrics = []
if save_argmin_ckpt_by_metrics is None:
Expand All @@ -237,6 +240,38 @@ def fit(self,
f'{self.workdir}/max_metrics.json'))
self._deployer.log_info(max_metrics, title='Detected max_metrics')

if eval_sanity_check:
rng_backup = self._deployer._rng
_, eval_global_batch_size = self._deployer.process_batch_size(
per_device_batch_size=eval_per_device_batch_size)

if eval_loss:
self.eval_loss(
examples=eval_examples[:eval_global_batch_size],
per_device_batch_size=eval_per_device_batch_size,
desc=f'Sanity check')
self._deployer.log_info(
'Sanity check (for evaluation loss) passed.')

if eval_predictor is not None:
preds = eval_predictor.predict(
examples=eval_examples[:eval_global_batch_size],
params=self.params,
params_meshed=(self._deployer.mesh is not None),
per_device_batch_size=eval_per_device_batch_size,
desc=f'Sanity check')
self._deployer.log_info(
'Sanity check (for prediction) passed.')

if eval_metric_fn is not None:
eval_metric_fn(
examples=eval_examples[:eval_global_batch_size],
preds=preds)
self._deployer.log_info(
'Sanity check (for prediction evaluation) passed.')

self._deployer._rng = rng_backup

for epoch_idx in range(
self._last_ckpt_info['last_epoch_idx'] + 1, n_epochs):
if isinstance(train_examples, list):
Expand Down Expand Up @@ -295,6 +330,11 @@ def fit(self,
for key, value in eval_metrics.items()
}, step=self.step)

save_ckpt_kwargs = {
'epoch_idx': epoch_idx,
'ckpt_dir': f'{self.workdir}/ckpts',
'save_opt_state': save_opt_states}

for key in save_argmin_ckpt_by_metrics:
assert self.workdir is not None
if eval_metrics[key] < min_metrics.get(key, float('inf')):
Expand All @@ -306,10 +346,7 @@ def fit(self,
json.dump(max_metrics, open(
f'{self.workdir}/max_metrics.json', 'w'))

self.save_ckpt(
desc=f'min_{key}',
epoch_idx=epoch_idx,
ckpt_dir=f'{self.workdir}/ckpts')
self.save_ckpt(desc=f'min_{key}', **save_ckpt_kwargs)

for key in save_argmax_ckpt_by_metrics:
assert self.workdir is not None
Expand All @@ -322,23 +359,14 @@ def fit(self,
json.dump(max_metrics, open(
f'{self.workdir}/max_metrics.json', 'w'))

self.save_ckpt(
desc=f'max_{key}',
epoch_idx=epoch_idx,
ckpt_dir=f'{self.workdir}/ckpts')
self.save_ckpt(desc=f'max_{key}', **save_ckpt_kwargs)

if save_every_ckpt:
self.save_ckpt(
desc=f'epoch_{epoch_idx}',
epoch_idx=epoch_idx,
ckpt_dir=f'{self.workdir}/ckpts')
self.save_ckpt(desc=f'epoch_{epoch_idx}', **save_ckpt_kwargs)
elif save_last_ckpt:
self.save_ckpt(
desc=f'last',
epoch_idx=epoch_idx,
ckpt_dir=f'{self.workdir}/ckpts')
self.save_ckpt(desc=f'last', **save_ckpt_kwargs)

def save_ckpt(self, epoch_idx, desc, ckpt_dir):
def save_ckpt(self, epoch_idx, desc, ckpt_dir, save_opt_state):
if jax.process_index() == 0:
os.makedirs(ckpt_dir, exist_ok=True)

Expand All @@ -347,17 +375,18 @@ def save_ckpt(self, epoch_idx, desc, ckpt_dir):
ckpt_dir=ckpt_dir,
desc=desc,
params_sharding_rules=self._params_sharding_rules)
self._deployer.save_opt_state(
opt_state=self._state.opt_state, ckpt_dir=ckpt_dir, desc=desc)
self._deployer.save_rng(ckpt_dir=ckpt_dir, desc=desc)

if save_opt_state:
self._deployer.save_opt_state(
opt_state=self._state.opt_state, ckpt_dir=ckpt_dir, desc=desc)

if jax.process_index() == 0:
last_ckpt_info = {
'last_desc': desc,
'last_step': self.step,
'last_epoch_idx': epoch_idx
}

json.dump(last_ckpt_info, open(
f'{ckpt_dir}/last_ckpt_info.json', 'w'), indent=4)
self._deployer.log_info(f'{ckpt_dir}/last_ckpt_info.json updated.')
Expand Down

0 comments on commit 72e5bd5

Please sign in to comment.