Skip to content

Commit

Permalink
updated.
Browse files Browse the repository at this point in the history
  • Loading branch information
tanyuqian committed Nov 8, 2023
1 parent 72e5bd5 commit f36488c
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Redco is a lightweight and user-friendly tool designed to automate distributed training and inference for large models while simplifying the ML pipeline development process without necessitating MLSys expertise from users.

Check out our [Tech Report Here](https://arxiv.org/pdf/2310.16355.pdf) for details!
Check out our [Tech Report Here](https://arxiv.org/pdf/2310.16355.pdf) for details! There is also a [quick tutorial](#quick-tutorial) below.

* Redco allows for the simple implementation of distributed training and inference, eliminating the need for additional coding efforts or complex configurations, but still exhibits efficiency comparable to the most advanced model parallel tools.
* Redco enables customization of arbitrary ML pipelines within three functions, eliminating repetitive ans boilerplate coding, such as multi-host related processing, etc. We demonstrate that this mechanism is widely applicable to various ML algorithms
Expand All @@ -19,16 +19,6 @@ Check out our [Tech Report Here](https://arxiv.org/pdf/2310.16355.pdf) for detai

### Installation

#### Install Jax & Flax
```
pip install --upgrade flax==0.7.0
pip install --upgrade jax[cuda11_pip]==0.4.13 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
Jax version (`==0.4.13`) and Flax version (`==0.7.0`) can be flexible, as long as they match your CUDA/CUDNN version.
Besides, the Flax modeling in the HuggingFace implementation sometimes doesn't support recent Jax & Flax versions.

If you are using TPU/CPU/AMD/Apple, see [here](https://github.com/google/jax#installation) for corresponding installation commands.

#### Install Redco
```
pip install redco
Expand All @@ -40,6 +30,19 @@ cd redco
pip install -e .
```

#### Adjust Jax & Flax versions
The command above would automatically install cpu version of jax, so the version of Jax need to be adjusted based on your device.

For example,
```
pip install --upgrade flax==0.7.0
pip install --upgrade jax[cuda11_pip]==0.4.13 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
Jax version (`==0.4.13`) and Flax version (`==0.7.0`) can be flexible, as long as they match your CUDA/CUDNN/NCCL version.
Besides, the Flax modeling in the HuggingFace implementation sometimes doesn't support the most recent Jax & Flax versions.

If you are using TPU/CPU/AMD/Apple, see [here](https://github.com/google/jax#installation) for corresponding installation commands.


### Examples

Expand Down Expand Up @@ -117,7 +120,7 @@ deployer = redco.Deployer(

train_examples, valid_examples = load_dataset(...) # load dataset into python-list
model, params = FlaxModel() # a model defined in flax, e.g., transformers.FlaxT5ForConditionalGeneration()
optimizer = optax.adam(lr=0.001) # a optimizer defined in optax
optimizer = adam(lr=0.001) # a optimizer defined in optax

# define redco.Trainer
trainer = redco.Trainer(
Expand Down

0 comments on commit f36488c

Please sign in to comment.