Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def add_lora(
model: Optional[nn.Module] = None,
watch: bool = True,
clear: bool = True,
lora_state: Dict[str, Any] = None,
) -> None:
"""
Adds LoRA for gradient compression.
Expand All @@ -140,6 +141,7 @@ def add_lora(
model=model,
type_filter=self.type_filter,
name_filter=self.name_filter,
lora_state=lora_state,
)

# Clear state and logger
Expand Down Expand Up @@ -319,9 +321,9 @@ def initialize_from_log(self) -> None:
# Load LoRA state
lora_dir = os.path.join(self.log_dir, "lora")
if os.path.exists(lora_dir):
if not is_lora(self.model):
self.add_lora()
lora_state = torch.load(os.path.join(lora_dir, "lora_state_dict.pt"))
if not is_lora(self.model):
self.add_lora(lora_state=lora_state)
for name in lora_state:
assert name in self.model.state_dict(), f"{name} not in model!"
self.model.load_state_dict(lora_state, strict=False)
Expand Down
7 changes: 6 additions & 1 deletion analog/logging/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,16 @@ def _sanity_check(self):
)
self._log["grad"] = True

def eval(self):
def eval(self, log="grad"):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having "grad" as a default value, what do you think about having None as a default value, and when it's None we set it to "grad" with a warning message like:

def eval(self, log=None):
    if log is None:
        get_logger().warning("we automatically set 'log' to 'grad'. if this is not a desired behavior, please explicitly set your 'log' value.")
        log = "grad"

    if isinstance(log, str):
        ...

"""
Enable the evaluation mode. This will turn of saving and updating
statistic.
"""
if isinstance(log, str):
self._log[log] = True
else:
raise ValueError(f"Unsupported log type for eval: {type(log)}")

self.clear(log=False, save=True, statistic=True)

def clear(self, log=True, save=True, statistic=True):
Expand Down
73 changes: 68 additions & 5 deletions analog/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@

import torch.nn as nn

from analog.constants import FORWARD, BACKWARD
from analog.state import StatisticState
from analog.lora.modules import LoraLinear, LoraConv2d, LoraEmbedding
from analog.lora.utils import (
find_parameter_sharing_group,
_get_submodules,
find_rank_pca_compression,
find_rank_pca_covariance,
pca_rank_by_weight_shape,
)
from analog.lora.utils import find_parameter_sharing_group, _get_submodules
from analog.utils import get_logger, module_check

Expand All @@ -23,17 +31,25 @@ def __init__(self, config: Dict[str, Any], state: StatisticState):

def parse_config(self):
self.init_strategy = self.config.get("init", "random")
self.rank = self.config.get("rank", 64)
self.rank_default = self.config.get("rank", 64)
self.compression_ratio_by_covariance = self.config.get(
"compression_ratio_by_covariance", None
)
self.compression_ratio_by_memory = self.config.get(
"compression_ratio_by_memory", None
)
self.parameter_sharing = self.config.get("parameter_sharing", False)
self.parameter_sharing_groups = self.config.get(
"parameter_sharing_groups", None
)
self._sanity_check()

def add_lora(
self,
model: nn.Module,
type_filter: List[nn.Module],
name_filter: List[str],
lora_state: Dict[str, Any] = None,
):
"""
Add LoRA modules to a model.
Expand Down Expand Up @@ -69,23 +85,70 @@ def add_lora(
lora_cls = LoraEmbedding

psg = find_parameter_sharing_group(name, self.parameter_sharing_groups)

rank_forward = rank_backward = self.rank_default # default rank

if lora_state is not None: # add lora matching the rank of the lora_state
rank_forward, rank_backward = pca_rank_by_weight_shape(
lora_state[name + ".analog_lora_B.weight"].shape, module
)
elif (
self.init_strategy == "pca"
and self.compression_ratio_by_covariance is not None
):
rank_forward = find_rank_pca_covariance(
covariance_state[name][FORWARD],
self.compression_ratio_by_covariance,
)
rank_backward = find_rank_pca_covariance(
covariance_state[name][BACKWARD],
self.compression_ratio_by_covariance,
)
get_logger().info(
f"using adaptive rank_forward = {rank_forward}, rank_backward = {rank_backward} for {name}\n"
)
elif (
self.init_strategy == "pca"
and self.compression_ratio_by_memory is not None
):
rank_forward = rank_backward = find_rank_pca_compression(
module,
self.compression_ratio_by_memory,
)
get_logger().info(
f"using adaptive rank_forward = {rank_forward}, rank_backward = {rank_backward} for {name}\n"
)

if self.parameter_sharing and psg not in shared_modules:
if isinstance(module, nn.Linear):
shared_module = nn.Linear(self.rank, self.rank, bias=False)
shared_module = nn.Linear(rank_forward, rank_backward, bias=False)
elif isinstance(module, nn.Conv1d):
shared_module = nn.Conv1d(
self.rank, self.rank, kernel_size=1, bias=False
rank_forward, rank_backward, kernel_size=1, bias=False
)
elif isinstance(module, nn.Conv2d):
shared_module = nn.Conv2d(
self.rank, self.rank, kernel_size=1, bias=False
rank_forward, rank_backward, kernel_size=1, bias=False
)
shared_modules[psg] = shared_module

lora_module = lora_cls(self.rank, module, shared_modules.get(psg, None))
lora_module = lora_cls(
rank_forward, rank_backward, module, shared_modules.get(psg, None)
)
if self.init_strategy == "pca":
lora_module.pca_init_weight(covariance_state[name])
lora_module.to(device)

parent, target, target_name = _get_submodules(model, name)
setattr(parent, target_name, lora_module)

def _sanity_check(self):
if (
self.init_strategy == "pca"
and self.compression_ratio_by_covariance is not None
and self.compression_ratio_by_memory is not None
):
get_logger().warning(
"compression_ratio_by_covariance and compression_ratio_by_memory are both set. "
+ "compression_ratio_by_covariance will be used."
)
58 changes: 40 additions & 18 deletions analog/lora/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@


class LoraLinear(nn.Linear):
def __init__(self, rank: int, linear: nn.Linear, shared_module: nn.Linear = None):
def __init__(
self,
rank_forward: int,
rank_backward: int,
linear: nn.Linear,
shared_module: nn.Linear = None,
):
"""Transforms a linear layer into a LoraLinear layer.

Args:
Expand All @@ -19,13 +25,14 @@ def __init__(self, rank: int, linear: nn.Linear, shared_module: nn.Linear = None
out_features = linear.out_features

super().__init__(in_features, out_features)
self.rank = min(rank, in_features, out_features)
self.rank_forward = min(rank_forward, in_features)
self.rank_backward = min(rank_backward, out_features)

self.analog_lora_A = nn.Linear(in_features, self.rank, bias=False)
self.analog_lora_A = nn.Linear(in_features, self.rank_forward, bias=False)
self.analog_lora_B = shared_module or nn.Linear(
self.rank, self.rank, bias=False
self.rank_forward, self.rank_backward, bias=False
)
self.analog_lora_C = nn.Linear(self.rank, out_features, bias=False)
self.analog_lora_C = nn.Linear(self.rank_backward, out_features, bias=False)

nn.init.kaiming_uniform_(self.analog_lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.analog_lora_B.weight)
Expand All @@ -49,17 +56,23 @@ def pca_init_weight(self, covariance=None):
(
top_r_singular_vector_forward,
top_r_singular_value_forward,
) = compute_top_k_singular_vectors(covariance[FORWARD], self.rank)
) = compute_top_k_singular_vectors(covariance[FORWARD], self.rank_forward)
(
top_r_singular_vector_backward,
top_r_singular_value_backward,
) = compute_top_k_singular_vectors(covariance[BACKWARD], self.rank)
) = compute_top_k_singular_vectors(covariance[BACKWARD], self.rank_backward)
self.analog_lora_A.weight.data.copy_(top_r_singular_vector_forward.T)
self.analog_lora_C.weight.data.copy_(top_r_singular_vector_backward)


class LoraConv2d(nn.Conv2d):
def __init__(self, rank: int, conv: nn.Conv2d, shared_module: nn.Conv2d = None):
def __init__(
self,
rank_forward: int,
rank_backward: int,
conv: nn.Conv2d,
shared_module: nn.Conv2d = None,
):
"""Transforms a conv2d layer into a LoraConv2d layer.

Args:
Expand All @@ -76,15 +89,23 @@ def __init__(self, rank: int, conv: nn.Conv2d, shared_module: nn.Conv2d = None):
in_channels, out_channels, kernel_size, stride, padding, bias=False
)

self.rank = min(rank, self.in_channels, self.out_channels)
self.rank_forward = min(rank_forward, in_channels)
self.rank_backward = min(rank_backward, out_channels)

self.analog_lora_A = nn.Conv2d(
self.in_channels, self.rank, kernel_size, stride, padding, bias=False
self.in_channels,
self.rank_forward,
kernel_size,
stride,
padding,
bias=False,
)
self.analog_lora_B = shared_module or nn.Conv2d(
self.rank, self.rank, 1, bias=False
self.rank_forward, self.rank_backward, 1, bias=False
)
self.analog_lora_C = nn.Conv2d(
self.rank_backward, self.out_channels, 1, bias=False
)
self.analog_lora_C = nn.Conv2d(self.rank, self.out_channels, 1, bias=False)

nn.init.kaiming_uniform_(self.analog_lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.analog_lora_B.weight)
Expand All @@ -108,11 +129,11 @@ def pca_init_weight(self, covariance):
(
top_r_singular_vector_forward,
top_r_singular_value_forward,
) = compute_top_k_singular_vectors(covariance[FORWARD], self.rank)
) = compute_top_k_singular_vectors(covariance[FORWARD], self.rank_forward)
(
top_r_singular_vector_backward,
top_r_singular_value_backward,
) = compute_top_k_singular_vectors(covariance[BACKWARD], self.rank)
) = compute_top_k_singular_vectors(covariance[BACKWARD], self.rank_backward)
shape_A = self.analog_lora_A.weight.shape
shape_C = self.analog_lora_C.weight.shape
self.analog_lora_A.weight.data.copy_(
Expand All @@ -137,13 +158,14 @@ def __init__(
embedding_dim = embedding.embedding_dim

super().__init__(num_embeddings, embedding_dim)
self.rank = min(rank, num_embeddings, embedding_dim)
self.rank_forward = min(rank, num_embeddings)
self.rank_backward = min(rank, embedding_dim)

self.analog_lora_A = nn.Embedding(num_embeddings, self.rank)
self.analog_lora_A = nn.Embedding(num_embeddings, self.rank_forward)
self.analog_lora_B = shared_module or nn.Linear(
self.rank, self.rank, bias=False
self.rank_forward, self.rank_backward, bias=False
)
self.analog_lora_C = nn.Linear(self.rank, embedding_dim, bias=False)
self.analog_lora_C = nn.Linear(self.rank_backward, embedding_dim, bias=False)

nn.init.kaiming_uniform_(self.analog_lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.analog_lora_B.weight)
Expand Down
56 changes: 56 additions & 0 deletions analog/lora/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,62 @@
from typing import List

import math
import torch
import torch.nn as nn


def find_rank_pca_covariance(matrix, threshold):
"""
Calculate the minimum principal component analysis (PCA) rank required
to explain at least the specified percentage (threshold) of the total covariance.
"""
U, S, Vh = torch.linalg.svd(matrix)
rank = 0
cur, total = 0, sum(S)
while rank < len(S) and (cur / total) < threshold:
cur += S[rank]
rank += 1

return rank


def find_rank_pca_compression(module, ratio):
"""
Calculate the minimum principal component analysis (PCA) rank required
to reach threshold compression ratio.
"""
weight = module.weight.detach().cpu().numpy()
if isinstance(module, nn.Linear):
# r * r = m * n * ratio
in_features, out_features = weight.shape
rank = math.ceil(math.sqrt(in_features * out_features * ratio))
elif isinstance(module, nn.Conv2d):
# r * r * 1 * 1 = in_channels * out_channels * kernel_size[0] * kernel_size[1] * ratio
in_channels, out_channels, kernel_size0, kernel_size1 = weight.shape
rank = math.ceil(
math.sqrt(in_channels * out_channels * kernel_size0 * kernel_size1 * ratio)
)
return rank
elif isinstance(module, nn.Embedding):
# r * r = m * n * ratio
num_embeddings, embedding_dim = weight.shape
rank = math.ceil(math.sqrt(num_embeddings * embedding_dim * ratio))
else:
raise NotImplementedError

return rank


def pca_rank_by_weight_shape(shape, module):
if isinstance(module, nn.Linear):
assert len(shape) == 2
return shape[1], shape[0]
elif isinstance(module, nn.Conv2d):
assert len(shape) == 4
return shape[1], shape[0]
elif isinstance(module, nn.Embedding):
assert len(shape) == 2
return shape[1], shape[0]


def is_lora(model):
Expand Down
8 changes: 5 additions & 3 deletions examples/cifar_influence/compute_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@
# Gradient & Hessian logging
analog.watch(model)
analog.setup({"log": "grad", "save": "grad", "statistic": "kfac"})

id_gen = DataIDGenerator()
if not args.resume:
id_gen = DataIDGenerator()
for inputs, targets in train_loader:
data_id = id_gen(inputs)
with analog(data_id=data_id):
Expand All @@ -62,7 +61,10 @@

analog.add_analysis({"influence": InfluenceFunction})
query_iter = iter(query_loader)
with analog(log=["grad"]) as al:
test_input, test_target = next(query_iter)
test_id = id_gen(test_input)
analog.eval()
with analog(data_id=test_id) as al:
test_input, test_target = next(query_iter)
test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE)
model.zero_grad()
Expand Down
Loading