2ae5a7afa6939b8bfd739bc8900c5920702585e2,opt_einsum/contract.py,,contract_path,#,6

Before Change



    path_arg = kwargs.pop("path", "greedy")
    memory_limit = kwargs.pop("memory_limit", None)
    tensordot = kwargs.pop("tensordot", True)

    // Hidden option, only einsum should call this
    einsum_call_arg = kwargs.pop("einsum_call", False)

    // Python side parsing
    input_subscripts, output_subscript, operands = parser.parse_einsum_input(operands)
    subscripts = input_subscripts + "->" + output_subscript

    // Build a few useful list and sets
    input_list = input_subscripts.split(",")
    input_sets = [set(x) for x in input_list]
    output_set = set(output_subscript)
    indices = set(input_subscripts.replace(",", ""))

    // Get length of each unique dimension and ensure all dimension are correct
    dimension_dict = {}
    for tnum, term in enumerate(input_list):
        sh = operands[tnum].shape
        if len(sh) != len(term):
            raise ValueError("Einstein sum subscript %s does not contain the "
                             "correct number of indices for operand %d.",
                             input_subscripts[tnum], tnum)
        for cnum, char in enumerate(term):
            dim = sh[cnum]
            if char in dimension_dict.keys():
                if dimension_dict[char] != dim:
                    raise ValueError("Size of label "%s" for operand %d does "
                                     "not match previous terms.", char, tnum)
            else:
                dimension_dict[char] = dim

    // Compute size of each input array plus the output array
    if memory_limit is None:
        size_list = []
        for term in input_list + [output_subscript]:
            size_list.append(paths.compute_size_by_dict(term, dimension_dict))
        out_size = max(size_list)
        memory_arg = out_size
    else:
        memory_arg = int(memory_limit)

    // Compute naive cost
    // This isnt quite right, need to look into exactly how einsum does this
    naive_cost = paths.compute_size_by_dict(indices, dimension_dict)
    indices_in_input = input_subscripts.replace(",", "")
    mult = max(len(input_list) - 1, 1)
    if (len(indices_in_input) - len(set(indices_in_input))):
        mult *= 2
    naive_cost *= mult

    // Compute path
    if not isinstance(path_arg, str):
        path = path_arg
    elif len(input_list) == 1:
        path = [(0,)]
    elif len(input_list) == 2:
        path = [(0, 1)]
    elif (indices == output_set):
        // If no rank reduction leave it to einsum
        path = [tuple(range(len(input_list)))]
    elif (path_arg in ["greedy", "opportunistic"]):
        // Maximum memory should be at most out_size for this algorithm
        memory_arg = min(memory_arg, out_size)
        path = paths.greedy(input_sets, output_set, dimension_dict, memory_arg)
    elif path_arg == "optimal":
        path = paths.optimal(input_sets, output_set, dimension_dict, memory_arg)
    else:
        raise KeyError("Path name %s not found", path_arg)

    cost_list, scale_list, size_list = [], [], []
    contraction_list = []

    // Build contraction tuple (positions, gemm, einsum_str, remaining)
    for cnum, contract_inds in enumerate(path):
        // Make sure we remove inds from right to left
        contract_inds = tuple(sorted(list(contract_inds), reverse=True))

        contract = paths.find_contraction(contract_inds, input_sets, output_set)
        out_inds, input_sets, idx_removed, idx_contract = contract

        cost = paths.compute_size_by_dict(idx_contract, dimension_dict)
        if idx_removed:
            cost *= 2
        cost_list.append(cost)
        scale_list.append(len(idx_contract))
        size_list.append(paths.compute_size_by_dict(out_inds, dimension_dict))

        tmp_inputs = []
        for x in contract_inds:
            tmp_inputs.append(input_list.pop(x))

        // Last contraction
        if (cnum - len(path)) == -1:
            idx_result = output_subscript
        else:
            sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
            idx_result = "".join([x[1] for x in sorted(sort_result)])

        input_list.append(idx_result)
        einsum_str = ",".join(tmp_inputs) + "->" + idx_result

        if tensordot:
            can_gemm = blas.can_blas(tmp_inputs, idx_result, idx_removed)
            // Dont want to deal with this quite yet
            if can_gemm == "TDOT":
                can_gemm = False
        else:
            can_gemm = False

        contraction = (contract_inds, idx_removed, can_gemm, einsum_str, input_list[:])

After Change



    // Make sure all keywords are valid
    valid_contract_kwargs = ["path", "memory_limit", "einsum_call"]
    unknown_kwargs = [k for (k, v) in kwargs.items() if k not in valid_contract_kwargs]
    if len(unknown_kwargs):
        raise TypeError("einsum_path: Did not understand the following kwargs: %s" % unknown_kwargs)

    path_type = kwargs.pop("path_type", "greedy")
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 5

Instances


Project Name: dgasmith/opt_einsum
Commit Name: 2ae5a7afa6939b8bfd739bc8900c5920702585e2
Time: 2017-02-13
Author: malorian@me.com
File Name: opt_einsum/contract.py
Class Name:
Method Name: contract_path


Project Name: catalyst-team/catalyst
Commit Name: c5c350cf9f9b576cc9de939e4dc308404eb48852
Time: 2019-05-28
Author: scitator@gmail.com
File Name: catalyst/dl/experiments/experiment.py
Class Name: ConfigExperiment
Method Name: get_optimizer_and_model