// Get length of each unique dimension and ensure all dimensions are correct
dimension_dict = {}
bcast = set()
for tnum, term in enumerate(input_list):
sh = operands[tnum].shape
if len(sh) != len(term):
After Change
// Build a few useful list and sets
input_list = input_subscripts.split(",")
input_sets = [set(x) for x in input_list]
input_shps = [x.shape for x in operands]
output_set = set(output_subscript)
indices = set(input_subscripts.replace(",", ""))
// Get length of each unique dimension and ensure all dimensions are correct
dimension_dict = {}
for tnum, term in enumerate(input_list):
sh = input_shps[tnum]
if len(sh) != len(term):
raise ValueError("Einstein sum subscript %s does not contain the "
"correct number of indices for operand %d." % (input_subscripts[tnum], tnum))
for cnum, char in enumerate(term):
dim = sh[cnum]
if char in dimension_dict:
// For broadcasting cases we always want the largest dim size
if dimension_dict[char] == 1:
dimension_dict[char] = dim
elif dim not in (1, dimension_dict[char]):
raise ValueError("Size of label "%s" for operand %d (%d) "
"does not match previous terms (%d)." % (char, tnum, dimension_dict[char], dim))
else:
dimension_dict[char] = dim
// Compute size of each input array plus the output array
size_list = [helpers.compute_size_by_dict(term, dimension_dict) for term in input_list + [output_subscript]]
out_size = max(size_list)
if memory_limit is None:
memory_arg = out_size
else:
if memory_limit < 1:
if memory_limit == -1:
memory_arg = int(1e20)
else:
raise ValueError("Memory limit must be larger than 0, or -1")
else:
memory_arg = int(memory_limit)
// Compute naive cost
// This isnt quite right, need to look into exactly how einsum does this
// indices_in_input = input_subscripts.replace(",", "")
// inne
inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
naive_cost = helpers.flop_count(indices, inner_product, len(input_list), dimension_dict)
// Compute the path
if not isinstance(path_type, str):
path = path_type
elif len(input_list) == 1:
// Nothing to be optimized
path = [(0, )]
elif len(input_list) == 2:
// Nothing to be optimized
path = [(0, 1)]
elif indices == output_set:
// If no rank reduction leave it to einsum
path = [tuple(range(len(input_list)))]
elif path_type in ["greedy", "opportunistic"]:
path = paths.greedy(input_sets, output_set, dimension_dict, memory_arg)
elif path_type == "optimal":
path = paths.optimal(input_sets, output_set, dimension_dict, memory_arg)
else:
raise KeyError("Path name %s not found" % path_type)
cost_list = []
scale_list = []
size_list = []
contraction_list = []
// Build contraction tuple (positions, gemm, einsum_str, remaining)
for cnum, contract_inds in enumerate(path):
// Make sure we remove inds from right to left
contract_inds = tuple(sorted(list(contract_inds), reverse=True))
contract_tuple = helpers.find_contraction(contract_inds, input_sets, output_set)
out_inds, input_sets, idx_removed, idx_contract = contract_tuple
// Compute cost, scale, and size
cost = helpers.flop_count(idx_contract, idx_removed, len(contract_inds), dimension_dict)
cost_list.append(cost)
scale_list.append(len(idx_contract))
size_list.append(helpers.compute_size_by_dict(out_inds, dimension_dict))
tmp_inputs = [input_list.pop(x) for x in contract_inds]
tmp_shapes = [input_shps.pop(x) for x in contract_inds]
if use_blas:
do_blas = blas.can_blas(tmp_inputs, out_inds, idx_removed, tmp_shapes)