def __call__(self, module, inputs):
callable hook for forward pass
sys.stdout.flush()
_w = getattr(self.module, self.name)
if not self.evaluated or _w is None:
setattr(self.module, self.name, self.compute_weight(self.module))
After Change
def __call__(self, module, inputs):
callable hook for forward pass
module2use, name2use = Reparameterization.get_module_and_name(module, self.name)
_w = getattr(module2use, name2use)
if not self.evaluated or _w is None:
setattr(module2use, name2use, self.compute_weight(module2use, name2use))