if fp32_ops is not None:
assert isinstance(fp32_ops, list), "fp32_ops should be a list of strs"
else:
fp32_ops = lists.symbol.FP32_FUNCS
if conditional_fp32_ops is not None:
assert isinstance(conditional_fp32_ops, list), "conditional_fp32_ops should be a list"
else:
conditional_fp32_ops = lists.symbol.CONDITIONAL_FP32_FUNCS
original_conditional_op_names = []
conditional_op_names = []
param_names = []
param_vals = []
indptr = [0]
for conditional_fp32_op in conditional_fp32_ops:
assert isinstance(conditional_fp32_op[0], str) and isinstance(conditional_fp32_op[1], str) \
and isinstance(conditional_fp32_op[2], list), "conditional_fp32_ops should be a list of " \
"(str, str, list of str)"
param_vals += conditional_fp32_op[2]
indptr.append(len(param_vals))
param_names.append(conditional_fp32_op[1])
conditional_op_names.append(conditional_fp32_op[0])
if excluded_sym_names is not None:
assert isinstance(excluded_sym_names, list), "excluded_sym_names should be a list of strs"
else:
excluded_sym_names = []
for original_conditional_fp32_op in lists.symbol.CONDITIONAL_FP32_FUNCS:
original_conditional_op_names.append(original_conditional_fp32_op[0])
// Op lists should not have intersection
common_ops = set(target_dtype_ops) & set(fp32_ops)
assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
"Common ops in target_dtype_ops and fp32_ops {}".format(common_ops)
common_ops = set(target_dtype_ops) & set(conditional_op_names)
assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
"Common ops in target_dtype_ops and conditional_fp32_ops {}".format(common_ops)
common_ops = set(conditional_op_names) & set(fp32_ops)
assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
"Common ops in fp32_ops and conditional_fp32_ops {}".format(common_ops)
combined_ops = set(target_dtype_ops + fp32_ops + conditional_op_names)
all_fp16_fp32_ops = set(lists.symbol.FP16_FUNCS + lists.symbol.FP32_FUNCS
+ lists.symbol.FP16_FP32_FUNCS + original_conditional_op_names)
illegal_ops = combined_ops - all_fp16_fp32_ops
assert not illegal_ops, """Can only choose ops from one of the three lists
for fp16_ops and fp32_ops
1. amp.list_fp16_ops()
2. amp.list_fp32_ops()
3. amp.list_fp16_fp32_ops()
4. amp.list_conditional_fp32_ops()
Op %s not in any of them""" % (illegal_ops)
widest_dtype_ops = lists.symbol.WIDEST_TYPE_CASTS
target_dtype = _DTYPE_NP_TO_MX[np.dtype(target_dtype).type]
// Prepare a data_names list based on list_inputs if its not provided
// Add all names in list for the nodes in the symbol which don"t have