Skip to content

Commit

Permalink
updated.
Browse files Browse the repository at this point in the history
  • Loading branch information
tanyuqian committed Aug 21, 2024
1 parent 89df0ca commit 692996e
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 7 deletions.
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

**Red Coast** (redco) is a lightweight and user-friendly tool designed to automate distributed training and inference for large models while simplifying the ML pipeline development process without necessitating MLSys expertise from users.

RedCoast supports *Large Models* + *Complex Algorithms*, in a *lightweight* and *user-friendly* manner:
RedCoast supports *Large Models* + *Complex Algorithms*, in a *lightweight* and *user-friendly* way:

* Large Models beyond Transformers, e.g, [Stable Diffusion](https://github.com/tanyuqian/redco/tree/master/examples/text_to_image), etc.
* Complex algorithms beyond cross entropy, e.g., [Meta Learning](https://github.com/tanyuqian/redco/tree/master/examples/meta_learning), etc.
* Complex algorithms beyond cross entropy, e.g., [Meta Learning](https://github.com/tanyuqian/redco/tree/master/examples/meta_learning), [DP Training](https://github.com/tanyuqian/redco/tree/master/examples/differential_private_training), etc.

With RedCoast, to define a ML pipeline, only three functions are needed:

Expand Down
6 changes: 3 additions & 3 deletions docs/mnist.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
This is a trivial MNIST example with RedCoast. Runnable by
This is a trivial MNIST example with RedCoast (`pip install redco==0.4.22`). Runnable by
```
python main.py
```
Expand Down Expand Up @@ -47,14 +47,14 @@ def collate_fn(examples):


# Loss function converting model inputs to a scalar loss
def loss_fn(train_rng, state, params, batch, is_training):
def loss_fn(rng, state, params, batch, is_training):
logits = state.apply_fn({'params': params}, batch['images'])
return optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['labels']).mean()


# Predict function converting model inputs to the model outputs
def pred_fn(pred_rng, params, batch, model):
def pred_fn(rng, params, batch, model):
accs = model.apply({'params': params}, batch['images']).argmax(axis=-1)
return {'acc': accs}

Expand Down
2 changes: 1 addition & 1 deletion redco/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = '0.4.21'
__version__ = '0.4.22'

from .deployers import *
from .trainers import *
Expand Down
3 changes: 3 additions & 0 deletions redco/deployers/deployer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ def gen_rng(self):
return new_rng

def gen_model_step_rng(self):
"""Get a new random number generator key for distributed model step and
update the random state.
"""
rng = self.gen_rng()
if self.mesh is None:
rng = jax.random.split(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

setup(
name="redco",
version="0.4.21",
version="0.4.22",
author="Bowen Tan",
packages=find_packages(),
install_requires=['jax', 'flax', 'optax'],
Expand Down

0 comments on commit 692996e

Please sign in to comment.