base_class = kwargs.pop("base")
steps, downsample, block_args = cls.pop(["num_stages", "downsample", "blocks"], kwargs)
reverse_order = kwargs.pop("reverse_order")
print("REVERSE ORDER", reverse_order)
if base_class is not None:
encoder_outputs = base_class.make_encoder(inputs, name=name, **kwargs)
After Change
base_class = kwargs.pop("base")
steps, downsample, block_args, order = cls.pop(["num_stages", "downsample", "blocks", "order"], kwargs)
order = "".join([item[0] for item in order])
if base_class is not None:
encoder_outputs = base_class.make_encoder(inputs, name=name, **kwargs)
else:
base_block = block_args.get("base")
with tf.variable_scope(name):
x = inputs
encoder_outputs = [x]
for i in range(steps):
with tf.variable_scope("encoder-"+str(i)):
args = {**kwargs, **block_args, **unpack_args(block_args, i, steps)} // enforce priority of keys
downsample_args = {**kwargs, **downsample, **unpack_args(downsample, i, steps)}
if order in ["bd", "bp"]: // block -> downsample
x = base_block(x, name="pre", **args)
if downsample.get("layout") is not None:
x = conv_block(x, name="downsample-{}".format(i), **downsample_args)
elif order in ["db", "pb"]: // downsample -> block
if downsample.get("layout") is not None:
x = conv_block(x, name="downsample-{}".format(i), **downsample_args)
x = base_block(x, name="pre", **args)
else:
raise ValueError("Unknown order, use one of {"bd", "db"}")
encoder_outputs.append(x)
return encoder_outputs
@classmethod