Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 18, 2024
1 parent bf405f2 commit be97992
Showing 1 changed file with 42 additions and 6 deletions.
48 changes: 42 additions & 6 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,21 +422,49 @@ def log_prob(
if dist.aggregate_probabilities is not None:
aggregate_probabilities_inp = dist.aggregate_probabilities
else:
# TODO: warning
warnings.warn(
f"aggregate_probabilities wasn't defined in the {type(self).__name__} instance. "
f"It couldn't be retrieved from the CompositeDistribution object either. "
f"Currently, the aggregate_probability will be `True` in this case but in a future release "
f"(v0.9) this will change and `aggregate_probabilities` will default to ``False`` such "
f"that log_prob will return a tensordict with the log-prob values. To silence this warning, "
f"pass `aggregate_probabilities` to the {type(self).__name__} constructor, to the distribution kwargs "
f"or to the log-prob method.",
category=DeprecationWarning,
)
aggregate_probabilities_inp = False
else:
aggregate_probabilities_inp = aggregate_probabilities
if inplace is None:
if dist.inplace is not None:
inplace = dist.inplace
else:
# TODO: warning
warnings.warn(
f"inplace wasn't defined in the {type(self).__name__} instance. "
f"It couldn't be retrieved from the CompositeDistribution object either. "
f"Currently, the `inplace` will be `True` in this case but in a future release "
f"(v0.9) this will change and `inplace` will default to ``False`` such "
f"that log_prob will return a new tensordict containing only the log-prob values. To silence this warning, "
f"pass `inplace` to the {type(self).__name__} constructor, to the distribution kwargs "
f"or to the log-prob method.",
category=DeprecationWarning,
)
inplace = True
if include_sum is None:
if dist.include_sum is not None:
include_sum = dist.include_sum
else:
# TODO: warning
warnings.warn(
f"include_sum wasn't defined in the {type(self).__name__} instance. "
f"It couldn't be retrieved from the CompositeDistribution object either. "
f"Currently, the `include_sum` will be `True` in this case but in a future release "
f"(v0.9) this will change and `include_sum` will default to ``False`` such "
f"that log_prob will return a new tensordict containing only the leaf log-prob values. "
f"To silence this warning, "
f"pass `include_sum` to the {type(self).__name__} constructor, to the distribution kwargs "
f"or to the log-prob method.",
category=DeprecationWarning,
)
include_sum = True
lp = dist.log_prob(
tensordict,
Expand All @@ -446,6 +474,7 @@ def log_prob(
)
if is_tensor_collection(lp) and aggregate_probabilities is None:
return lp.get(dist.log_prob_key)
return lp
else:
return dist.log_prob(tensordict.get(self.out_keys[0]))

Expand Down Expand Up @@ -1027,8 +1056,9 @@ def log_prob(
):
"""Returns the log-probability of the input tensordict.
If `return_composite` is ``True`` and the distribution is a :class:`~tensordict.nn.CompositeDistribution`,
this method will return the log-probability of the entire composite distribution.
If `self.return_composite` is ``True`` and the distribution is a :class:`~tensordict.nn.CompositeDistribution`,
or if any of :attr:`aggregate_probabilities`, :attr:`inplace` or :attr:`include_sum` this method will return
the log-probability of the entire composite distribution.
Otherwise, it will only consider the last probabilistic module in the sequence.
Expand Down Expand Up @@ -1069,7 +1099,13 @@ def log_prob(
tensordict_inp = tensordict
if dist is None:
dist = self.get_dist(tensordict_inp)
if self.return_composite and isinstance(dist, CompositeDistribution):
return_composite = (
self.return_composite
or (aggregate_probabilities is not None)
or (inplace is not None)
or (include_sum is not None)
)
if return_composite and isinstance(dist, CompositeDistribution):
# Check the values within the dist - if not set, choose defaults
if aggregate_probabilities is None:
if self.aggregate_probabilities is not None:
Expand Down

0 comments on commit be97992

Please sign in to comment.