)
if self._mask.shape != w.shape[:self._mask.shape.ndims]:
raise base.IncompatibleShapeError(
"Invalid mask shape: {}. Weight shape: {}".format(
self._mask.shape, w.shape
)
)
// TF broadcasting is a bit fragile.
// Expand the shape of self._mask by one dim at a time to the right
// until the rank matches `weight_shape`.
After Change
match on shape.
w = self._w
w_shape = w.get_shape()
mask_shape = self._mask.get_shape()
if mask_shape.ndims > w_shape.ndims:
raise base.IncompatibleShapeError(
"Invalid mask shape: {}. Max shape: {}".format(
mask_shape.ndims, len(self._data_format)
)
)
if mask_shape != w_shape[:mask_shape.ndims]:
raise base.IncompatibleShapeError(
"Invalid mask shape: {}. Weight shape: {}".format(
mask_shape, w_shape