if is_layer(layer, "Flatten"):
return Flatten()
if is_layer(layer, "GlobalAveragePooling"):
return GlobalAveragePooling2D()
def to_stub_layer(layer, input_id, output_id):
if is_conv_layer(layer):
After Change
if is_layer(layer, "Dropout"):
return torch.nn.Dropout2d(layer.rate)
if is_layer(layer, "ReLU"):
return torch.nn.ReLU()
if is_layer(layer, "Softmax"):
return torch.nn.Softmax()
if is_layer(layer, "Flatten"):
return TorchFlatten()