path_arg = kwargs.pop("path", "greedy")
memory_limit = kwargs.pop("memory_limit", None)
tensordot = kwargs.pop("tensordot", True)
// Hidden option, only einsum should call this
einsum_call_arg = kwargs.pop("einsum_call", False)
// Python side parsing
input_subscripts, output_subscript, operands = parser.parse_einsum_input(operands)
subscripts = input_subscripts + "->" + output_subscript
// Build a few useful list and sets
input_list = input_subscripts.split(",")
input_sets = [set(x) for x in input_list]
output_set = set(output_subscript)
indices = set(input_subscripts.replace(",", ""))
// Get length of each unique dimension and ensure all dimension are correct
dimension_dict = {}
for tnum, term in enumerate(input_list):
sh = operands[tnum].shape
if len(sh) != len(term):
raise ValueError("Einstein sum subscript %s does not contain the "
"correct number of indices for operand %d.",
input_subscripts[tnum], tnum)
for cnum, char in enumerate(term):
dim = sh[cnum]
if char in dimension_dict.keys():
if dimension_dict[char] != dim:
raise ValueError("Size of label "%s" for operand %d does "
"not match previous terms.", char, tnum)
else:
dimension_dict[char] = dim
// Compute size of each input array plus the output array
if memory_limit is None:
size_list = []
for term in input_list + [output_subscript]:
size_list.append(paths.compute_size_by_dict(term, dimension_dict))
out_size = max(size_list)
memory_arg = out_size
else:
memory_arg = int(memory_limit)
// Compute naive cost
// This isnt quite right, need to look into exactly how einsum does this
naive_cost = paths.compute_size_by_dict(indices, dimension_dict)
indices_in_input = input_subscripts.replace(",", "")
mult = max(len(input_list) - 1, 1)
if (len(indices_in_input) - len(set(indices_in_input))):
mult *= 2
naive_cost *= mult
// Compute path
if not isinstance(path_arg, str):
path = path_arg
elif len(input_list) == 1:
path = [(0,)]
elif len(input_list) == 2:
path = [(0, 1)]
elif (indices == output_set):
// If no rank reduction leave it to einsum
path = [tuple(range(len(input_list)))]
elif (path_arg in ["greedy", "opportunistic"]):
// Maximum memory should be at most out_size for this algorithm
memory_arg = min(memory_arg, out_size)
path = paths.greedy(input_sets, output_set, dimension_dict, memory_arg)
elif path_arg == "optimal":
path = paths.optimal(input_sets, output_set, dimension_dict, memory_arg)
else:
raise KeyError("Path name %s not found", path_arg)
cost_list, scale_list, size_list = [], [], []
contraction_list = []
// Build contraction tuple (positions, gemm, einsum_str, remaining)
for cnum, contract_inds in enumerate(path):
// Make sure we remove inds from right to left
contract_inds = tuple(sorted(list(contract_inds), reverse=True))
contract = paths.find_contraction(contract_inds, input_sets, output_set)
out_inds, input_sets, idx_removed, idx_contract = contract
cost = paths.compute_size_by_dict(idx_contract, dimension_dict)
if idx_removed:
cost *= 2
cost_list.append(cost)
scale_list.append(len(idx_contract))
size_list.append(paths.compute_size_by_dict(out_inds, dimension_dict))
tmp_inputs = []
for x in contract_inds:
tmp_inputs.append(input_list.pop(x))
// Last contraction
if (cnum - len(path)) == -1:
idx_result = output_subscript
else:
sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
idx_result = "".join([x[1] for x in sorted(sort_result)])
input_list.append(idx_result)
einsum_str = ",".join(tmp_inputs) + "->" + idx_result
if tensordot:
can_gemm = blas.can_blas(tmp_inputs, idx_result, idx_removed)
// Dont want to deal with this quite yet
if can_gemm == "TDOT":
can_gemm = False
else:
can_gemm = False
contraction = (contract_inds, idx_removed, can_gemm, einsum_str, input_list[:])
After Change
// Make sure all keywords are valid
valid_contract_kwargs = ["path", "memory_limit", "einsum_call"]
unknown_kwargs = [k for (k, v) in kwargs.items() if k not in valid_contract_kwargs]
if len(unknown_kwargs):
raise TypeError("einsum_path: Did not understand the following kwargs: %s" % unknown_kwargs)
path_type = kwargs.pop("path_type", "greedy")