AlloGen / code /utils /sam.py
chq1155's picture
AlloGen public release: Q_theta scorer + PXDesign guidance + Colab demo
ad9572d
"""
Sharpness-Aware Minimization (SAM) optimizer wrapper.
Seeks parameters in flatter minima for better OOD generalization.
Reference: Foret et al., "Sharpness-Aware Minimization for Efficiently Improving Generalization" (ICLR 2021)
"""
import torch
class SAM(torch.optim.Optimizer):
def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
defaults = dict(rho=rho, **kwargs)
super().__init__(params, defaults)
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
@torch.no_grad()
def first_step(self):
grad_norm = self._grad_norm()
for group in self.param_groups:
scale = group['rho'] / (grad_norm + 1e-12)
for p in group['params']:
if p.grad is None:
continue
e_w = p.grad * scale
p.add_(e_w)
self.state[p]['e_w'] = e_w
@torch.no_grad()
def second_step(self):
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
p.sub_(self.state[p]['e_w'])
self.base_optimizer.step()
def _grad_norm(self):
shared_device = self.param_groups[0]['params'][0].device
norm = torch.norm(
torch.stack([
p.grad.norm(p=2).to(shared_device)
for group in self.param_groups
for p in group['params']
if p.grad is not None
]),
p=2,
)
return norm
def step(self, closure=None):
raise NotImplementedError("SAM requires manual first_step() and second_step() calls")
def zero_grad(self):
self.base_optimizer.zero_grad()