Skip to content

Commit

Permalink
Wrong line
Browse files Browse the repository at this point in the history
  • Loading branch information
gokulavasan committed Mar 26, 2024
1 parent 50adbeb commit 606ab23
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchdata/datapipes/iter/util/randomsplitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ def __init__(
self._rng = random.Random(self._seed)
self._lengths: List[int] = []

def draw(self) -> T:
def draw(self) -> T: # type: ignore
selected_key = self._rng.choices(self.keys, self.weights)[0]
index = self.key_to_index[selected_key]
self.weights[index] -= 1
self.remaining_length -= 1
if self.weights[index] < 0:
self.weights[index] = 0
self.weights = self.normalize_weights(self.weights, self.remaining_length)
return selected_key # type: ignore
return selected_key

@staticmethod
def normalize_weights(weights: List[float], total_length: int) -> List[float]:
Expand Down

0 comments on commit 606ab23

Please sign in to comment.