if pae not in [pe.OUTP, pe.OFMP, pe.BATP]:
ofmap_pdims[pe.BATP] = [x * y for x, y in zip(ofmap_pdims[pe.BATP],
ofmap_pdims[pae])]
ofmap_pdims[pae] = [1, 1]
ofmap_part = PartitionScheme(order=ofmap_order, pdims=ofmap_pdims)
assert all(od <= omrd for od, omrd in zip(ofmap_part.dim(), dim_omr)), \
"Partition ofmap: ofmap partitioning {} is invalid within " \
After Change
// Ofmap dimension > computation dimension. Extend.
ext = od // pd
// Apply the extension to the top level.
top_pae = next(pae for pae in ofmap_order if pae in ofmap_paes)
ofmap_pdims[top_pae][di] *= ext
else:
// Computation dimension >= ofmap dimension, shrink.
// Go from bottom to top. Keep bottom (first) levels unchanged, and