Skip to content

Commit

Permalink
Fix L1-norm case in weights_norm
Browse files Browse the repository at this point in the history
`norm_fn=np.abs` would compute L1 norm as: `np.sqrt(np.sum(np.abs(x)))`, which is incorrect; the sqrt is redundant. `norm_fn=np.abs` will now work correctly. L2-norm case always worked correctly.

For L2-norm, set `norm_fn = (np.sqrt, np.square) = (outer_fn, inner_fn)`, which will compute `outer_fn(sum(inner_fn(x)))`. Note that `norm_fn=np.square` will **no longer compute L2-norm correctly**.

Pardon the mishap.
  • Loading branch information
OverLordGoldDragon authored Jun 10, 2020
1 parent 76e440a commit 41a114b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
2 changes: 1 addition & 1 deletion see_rnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
from .inspect_gen import *
from .inspect_rnn import *

__version__ = '1.14.3'
__version__ = '1.14.4'
17 changes: 12 additions & 5 deletions see_rnn/inspect_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,8 @@ def _get_txt(perc, data, name):


def weights_norm(model, _id, _dict=None, stat_fns=(np.max, np.mean),
norm_fn=np.square, omit_names=None, axis=-1, verbose=0):
norm_fn=(np.sqrt, np.square), omit_names=None, axis=-1,
verbose=0):
"""Retrieves model layer weight matrix norms, as specified by `norm_fn`.
Arguments:
model: keras.Model/tf.keras.Model.
Expand All @@ -526,9 +527,11 @@ def weights_norm(model, _id, _dict=None, stat_fns=(np.max, np.mean),
_dict: dict/None. If None, returns new dict. If dict, appends to it.
stat_fns: functions list/tuple. Aggregate statistic to compute from
normed weights.
norm_fn: function. Norm transform to apply to weights. Ex:
- np.square (l2 norm)
- np.abs (l1 norm)
norm_fn: inner function / (outer function, inner function). Norm
transform to apply to weights. Ex:
- (np.sqrt, np.square) (l2 norm)
- np.abs (l1 norm)
Computed as: `outer_fn(sum(inner_fn(x) for x in data))`.
omit_names: str/str list. List of names (can be substring) of weights
to omit from fetching.
axis: int. Axis w.r.t. which compute the norm (collapsing all others).
Expand Down Expand Up @@ -584,7 +587,11 @@ def _compute_norm(w, norm_fn, axis=-1):
axis = axis if axis != -1 else len(w.shape) - 1
reduction_axes = tuple([ax for ax in range(len(w.shape))
if ax != axis])
return np.sqrt(np.sum(norm_fn(w), axis=reduction_axes))
if isinstance(norm_fn, (tuple, list)):
outer_fn, inner_fn = norm_fn
return outer_fn(np.sum(inner_fn(w), axis=reduction_axes))
else:
return np.sum(norm_fn(w), axis=reduction_axes)

def _append(stats_all, l2_stats, w_idx, l_name):
if len(stats_all[l_name]) < w_idx + 1:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_misc(): # test miscellaneous functionalities
model.train_on_batch(x, y, sw)

weights_norm(model, 'gru', omit_names='bias', verbose=1)
weights_norm(model, ['gru', 1, (1, 1)])
weights_norm(model, ['gru', 1, (1, 1)], norm_fn=np.abs)
stats = weights_norm(model, 'gru')
weights_norm(model, 'gru', _dict=stats)

Expand Down

0 comments on commit 41a114b

Please sign in to comment.