// END INIT
// Chunk elements in the args
args = [[_.chunk(group_size) for _ in args_] for args_ in args] // arg_name, model_name, group_name
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] // group_name, arg_name, model_name
for t in range(self.seq_length + group_size - 1):
After Change
// END INIT
// Chunk elements in the args
args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] // arg_name, model_name, group_name
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] // group_name, arg_name, model_name
for t in range(self.seq_length + group_size - 1):