""" Sharpness-Aware Minimization (SAM) optimizer wrapper. Seeks parameters in flatter minima for better OOD generalization. Reference: Foret et al., "Sharpness-Aware Minimization for Efficiently Improving Generalization" (ICLR 2021) """ import torch class SAM(torch.optim.Optimizer): def __init__(self, params, base_optimizer, rho=0.05, **kwargs): defaults = dict(rho=rho, **kwargs) super().__init__(params, defaults) self.base_optimizer = base_optimizer(self.param_groups, **kwargs) @torch.no_grad() def first_step(self): grad_norm = self._grad_norm() for group in self.param_groups: scale = group['rho'] / (grad_norm + 1e-12) for p in group['params']: if p.grad is None: continue e_w = p.grad * scale p.add_(e_w) self.state[p]['e_w'] = e_w @torch.no_grad() def second_step(self): for group in self.param_groups: for p in group['params']: if p.grad is None: continue p.sub_(self.state[p]['e_w']) self.base_optimizer.step() def _grad_norm(self): shared_device = self.param_groups[0]['params'][0].device norm = torch.norm( torch.stack([ p.grad.norm(p=2).to(shared_device) for group in self.param_groups for p in group['params'] if p.grad is not None ]), p=2, ) return norm def step(self, closure=None): raise NotImplementedError("SAM requires manual first_step() and second_step() calls") def zero_grad(self): self.base_optimizer.zero_grad()