// printed out if another exception happens.
// NB(jerry): added a flush to mitigate this
print(msg, file=sys.stderr)
if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
for device_idx in range(torch.cuda.device_count()):
print(torch.cuda.memory_summary(device=device_idx),
file=sys.stderr)
sys.stderr.flush()
After Change
if "out of memory" in str(e):
self._log_oom(e)
if raise_oom:
raise e
print("| WARNING: attempting to recover from OOM in forward/backward pass",
file=sys.stderr)
ooms += 1
self.zero_grad()