Skip to content

Commit

Permalink
Fix skipping composition blocks in not applicable layers (#665)
Browse files Browse the repository at this point in the history
Fixes #664.

Changes in this PR:
- Avoid throwing `NotImplementedError` in an unsupported block if none
if the child adapters are part of the respective layer.
- Pass along "last" invoked adapter module name in LoRA & bottleneck
states to make sure "last" is actually existing in the respective layer.
  • Loading branch information
calpt authored Apr 15, 2024
1 parent 502eb91 commit 2652d27
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 15 deletions.
12 changes: 7 additions & 5 deletions src/adapters/methods/adapter_layer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,7 @@ def compose_stack(self, adapter_setup: Stack, state: NamedTuple, lvl: int = 0) -
state = self.pre_block(adapter_stack_layer, state)
state = self.compose_single(adapter_stack_layer, state, lvl=lvl + 1)
else:
raise ValueError(
"Invalid adapter setup: {} is not a valid adapter name or composition block.".format(
adapter_stack_layer.__class__.__name__
)
)
pass

return state

Expand All @@ -325,6 +321,9 @@ def compose_fuse(self, adapter_setup: Fuse, state: NamedTuple, lvl: int = 0):
For fusing multiple adapters using adapter fusion. NOTE: This method has no default implementation.
"""
# Fuse is currently only applicable to bottleneck adapters, thus don't provide a default implementation
# If the adapter setup does not contain any of the adapter modules, return without doing anything
if set(self.adapter_modules.keys()).isdisjoint(adapter_setup.flatten()):
return state
raise NotImplementedError()

def compose_split(self, adapter_setup: Split, state: NamedTuple, lvl: int = 0):
Expand All @@ -333,6 +332,9 @@ def compose_split(self, adapter_setup: Split, state: NamedTuple, lvl: int = 0):
implementation.
"""
# Split is currently only applicable to bottleneck adapters, thus don't provide a default implementation
# If the adapter setup does not contain any of the adapter modules, return without doing anything
if set(self.adapter_modules.keys()).isdisjoint(adapter_setup.flatten()):
return state
raise NotImplementedError()

def compose_batch_split(self, adapter_setup: BatchSplit, state: NamedTuple, lvl: int = 0):
Expand Down
22 changes: 16 additions & 6 deletions src/adapters/methods/bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ class BottleneckState(NamedTuple):
layer_norm (torch.nn.Module, optional): The Transformer layer norm module.
bottleneck_up (torch.Tensor, optional):
The up-projected bottleneck MLP output. This is only for Fuse compositions.
last (str, optional): Name of the last adapter applied in the composition.
"""

hidden_states: torch.Tensor
input_tensor: torch.Tensor
adapter_residual: torch.Tensor
layer_norm: Optional[torch.nn.Module]
bottleneck_up: Optional[torch.Tensor] = None
last: Optional[str] = None


class BottleneckLayer(ComposableAdapterLayerBase, nn.Module):
Expand Down Expand Up @@ -193,6 +195,7 @@ def vslice(self, state: BottleneckState, slice_obj: slice) -> BottleneckState:
state.adapter_residual[slice_obj],
state.layer_norm,
state.bottleneck_up[slice_obj] if state.bottleneck_up is not None else None,
state.last,
)

def pad_and_concat(self, states: List[BottleneckState]) -> BottleneckState:
Expand All @@ -204,6 +207,7 @@ def pad_and_concat(self, states: List[BottleneckState]) -> BottleneckState:
torch.cat([state.bottleneck_up for state in states], dim=0)
if states[0].bottleneck_up is not None
else None,
states[-1].last,
)

def repeat(self, state: BottleneckState, channels: int) -> BottleneckState:
Expand All @@ -213,6 +217,7 @@ def repeat(self, state: BottleneckState, channels: int) -> BottleneckState:
state.adapter_residual.repeat(channels, 1, 1),
state.layer_norm,
state.bottleneck_up.repeat(channels, 1, 1) if state.bottleneck_up is not None else None,
state.last,
)

def mean(self, states: List[BottleneckState], weights: torch.Tensor) -> BottleneckState:
Expand All @@ -222,6 +227,7 @@ def mean(self, states: List[BottleneckState], weights: torch.Tensor) -> Bottlene
states[0].adapter_residual,
states[0].layer_norm,
states[0].bottleneck_up,
states[-1].last,
)

def compose_single(self, adapter_setup: str, state: BottleneckState, lvl: int = 0) -> BottleneckState:
Expand All @@ -235,7 +241,7 @@ def compose_single(self, adapter_setup: str, state: BottleneckState, lvl: int =
hidden_states, up = layer_output[0], layer_output[2]
self._store_gating_score(adapter_setup, layer_output[-1])

return BottleneckState(hidden_states, state.input_tensor, state.adapter_residual, state.layer_norm, up)
return state._replace(hidden_states=hidden_states, bottleneck_up=up, last=adapter_setup)

def compose_fuse(self, adapter_setup: Fuse, state: BottleneckState, lvl: int = 0):
"""
Expand All @@ -245,7 +251,8 @@ def compose_fuse(self, adapter_setup: Fuse, state: BottleneckState, lvl: int = 0

# config of _last_ fused adapter is significant
fusion_config = self.adapters_config.get_fusion(adapter_setup.name)
last_adapter = self.adapters[adapter_setup.last()]
last = adapter_setup.last()
last_adapter = self.adapters[last]
hidden_states, query, residual = last_adapter.pre_forward(
state.hidden_states, state.input_tensor, state.layer_norm, fusion_config=fusion_config
)
Expand Down Expand Up @@ -281,7 +288,7 @@ def compose_fuse(self, adapter_setup: Fuse, state: BottleneckState, lvl: int = 0
else:
hidden_states = fusion_output

return state._replace(hidden_states=hidden_states)
return state._replace(hidden_states=hidden_states, last=last)

def compose_split(self, adapter_setup: Split, state: BottleneckState, lvl: int = 0):
"""
Expand All @@ -297,6 +304,7 @@ def compose_split(self, adapter_setup: Split, state: BottleneckState, lvl: int =
state = self.pre_block(adapter_setup, state)

children_states = []
last = None
for i, child in enumerate(adapter_setup):
batch_idx = (
sum(adapter_setup.splits[:i]),
Expand All @@ -314,14 +322,16 @@ def compose_split(self, adapter_setup: Split, state: BottleneckState, lvl: int =
composition_func = self._get_compose_func(type(child))
child_state = composition_func(child, child_state, lvl=lvl + 1)
children_states.append(child_state)
last = child_state.last or last
elif child in self.adapter_modules:
child_state = self.compose_single(child, child_state, lvl=lvl + 1)
children_states.append(child_state)
last = child_state.last or last
else:
pass

hidden_states = torch.cat([child.hidden_states for child in children_states], dim=1)
return state._replace(hidden_states=hidden_states)
return state._replace(hidden_states=hidden_states, last=last)

def bottleneck_layer_forward(self, hidden_states, residual_input, layer_norm):
"""Forward pass through the adapter layer.
Expand All @@ -346,9 +356,9 @@ def bottleneck_layer_forward(self, hidden_states, residual_input, layer_norm):

state = BottleneckState(hidden_states, residual_input, residual_input, layer_norm)
state = self.compose(adapter_setup, state)
hidden_states, residual_input, _, _, _ = state
hidden_states, residual_input, _, _, _, last = state

last_adapter = self.adapters[adapter_setup.last()]
last_adapter = self.adapters[last]
hidden_states = last_adapter.post_forward(hidden_states, input_hidden_states, residual_input, layer_norm)

elif layer_norm:
Expand Down
14 changes: 10 additions & 4 deletions src/adapters/methods/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,13 @@ class LoRAState(NamedTuple):
The hidden states of the adaptation module. These can be None before passing through the first LoRA/ IA3
module.
layer_output (torch.Tensor): The output states of the original layer without adaptation.
last (str, optional): Name of the last adapter applied in the composition.
"""

layer_input: torch.Tensor
hidden_states: Optional[torch.Tensor]
layer_output: torch.Tensor
last: Optional[str]


class LoRALinear(LoRALayer, ComposableAdapterLayerBase, nn.Linear):
Expand Down Expand Up @@ -395,20 +397,23 @@ def vslice(self, state: LoRAState, slice_obj: slice) -> LoRAState:
state.layer_input[slice_obj],
state.hidden_states[slice_obj] if state.hidden_states is not None else None,
state.layer_output[slice_obj],
state.last,
)

def pad_and_concat(self, states: List[LoRAState]) -> LoRAState:
return LoRAState(
torch.cat([s.layer_input for s in states], dim=0),
torch.cat([s.hidden_states for s in states], dim=0) if states[0].hidden_states is not None else None,
torch.cat([s.layer_output for s in states], dim=0),
states[-1].last,
)

def repeat(self, state: LoRAState, channels: int) -> LoRAState:
return LoRAState(
state.layer_input.repeat(channels, 1, 1),
state.hidden_states.repeat(channels, 1, 1) if state.hidden_states is not None else None,
state.layer_output.repeat(channels, 1, 1),
state.last,
)

def mean(self, states: List[LoRAState], weights: torch.Tensor) -> LoRAState:
Expand All @@ -418,6 +423,7 @@ def mean(self, states: List[LoRAState], weights: torch.Tensor) -> LoRAState:
if states[0].hidden_states is not None
else None,
states[0].layer_output,
states[-1].last,
)

def compose_single(self, adapter_setup: str, state: LoRAState, lvl: int = 0) -> LoRAState:
Expand All @@ -426,7 +432,7 @@ def compose_single(self, adapter_setup: str, state: LoRAState, lvl: int = 0) ->
if gate is not None:
self._store_gating_score(adapter_setup, gate)

return state._replace(hidden_states=hidden_states)
return state._replace(hidden_states=hidden_states, last=adapter_setup)

def forward(self, input_states: torch.Tensor):
weight = torch.transpose(self.weight, -2, -1) if self.fan_in_fan_out else self.weight
Expand All @@ -436,11 +442,11 @@ def forward(self, input_states: torch.Tensor):
if not self.merged:
adapter_setup = self.get_active_setup()
if adapter_setup is not None:
state = LoRAState(input_states, None, layer_output)
state = LoRAState(input_states, None, layer_output, None)
state = self.compose(adapter_setup, state)
_, hidden_states, layer_output = state
_, hidden_states, layer_output, last = state

last_lora = self.loras[adapter_setup.last()]
last_lora = self.loras[last]
layer_output = last_lora.com(
layer_output, hidden_states, scaling=1.0
) # scaling already applied in compose
Expand Down

0 comments on commit 2652d27

Please sign in to comment.