class BaseLayerWithLoRA(nn.Module):
@overload
def slice_lora_a(
self, lora_a: list[torch.Tensor | None]
) -> list[torch.Tensor | None]: ...
@overload
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: ...
def slice_lora_a(
self, lora_a: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
"""Slice lora a if splitting for tensor parallelism."""
...
@overload
def slice_lora_b(
self, lora_b: list[torch.Tensor | None]
) -> list[torch.Tensor | None]: ...
@overload
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: ...
def slice_lora_b(
self, lora_b: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
"""Slice lora b if splitting with tensor parallelism."""
...
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
...
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
...
def set_lora(
self,
index: int,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
...
def set_mapping(
self,
punica_wrapper,
):
self.punica_wrapper: PunicaWrapperBase = punica_wrapper
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
raise NotImplementedError