ctx.num_bwd_passes = num_bwd_passes
with torch.no_grad():
x = input_t.detach() // Makes a detached copy which shares the storage
output = ctx.fn(x)
detached_output = output.detach_() // Detaches y in-place (inbetween computations can now be discarded)
// store these tensor nodes for backward pass
ctx.input_t = [input_t] * num_bwd_passes
After Change
ctx.num_bwd_passes = num_bwd_passes
ctx.num_inputs = num_inputs
input_t = inputs_and_weights[:num_inputs]
ctx.input_requires_grad = [element.requires_grad for element in input_t]
with torch.no_grad():
// Makes a detached copy which shares the storage
x = [element.detach() for element in input_t]
output = ctx.fn(*x)
if not isinstance(output, tuple):
output = (output,)
// Detaches y in-place (inbetween computations can now be discarded)
detached_output = tuple([element.detach_() for element in output])
// store these tensor nodes for backward pass
ctx.input_t = [input_t] * num_bwd_passes
ctx.output_t = [detached_output] * num_bwd_passes