2dde998f8ffcc4c4c5520e7402302547278c6e91,opt_einsum/contract.py,,contract_path,#,14

Before Change



    // Get length of each unique dimension and ensure all dimensions are correct
    dimension_dict = {}
    bcast = set()
    for tnum, term in enumerate(input_list):
        sh = operands[tnum].shape

        if len(sh) != len(term):

After Change


    // Build a few useful list and sets
    input_list = input_subscripts.split(",")
    input_sets = [set(x) for x in input_list]
    input_shps = [x.shape for x in operands]
    output_set = set(output_subscript)
    indices = set(input_subscripts.replace(",", ""))

    // Get length of each unique dimension and ensure all dimensions are correct
    dimension_dict = {}
    for tnum, term in enumerate(input_list):
        sh = input_shps[tnum]

        if len(sh) != len(term):
            raise ValueError("Einstein sum subscript %s does not contain the "
                             "correct number of indices for operand %d." % (input_subscripts[tnum], tnum))
        for cnum, char in enumerate(term):
            dim = sh[cnum]

            if char in dimension_dict:
                // For broadcasting cases we always want the largest dim size
                if dimension_dict[char] == 1:
                    dimension_dict[char] = dim
                elif dim not in (1, dimension_dict[char]):
                    raise ValueError("Size of label "%s" for operand %d (%d) "
                                     "does not match previous terms (%d)." % (char, tnum, dimension_dict[char], dim))
            else:
                dimension_dict[char] = dim

    // Compute size of each input array plus the output array
    size_list = [helpers.compute_size_by_dict(term, dimension_dict) for term in input_list + [output_subscript]]
    out_size = max(size_list)

    if memory_limit is None:
        memory_arg = out_size
    else:
        if memory_limit < 1:
            if memory_limit == -1:
                memory_arg = int(1e20)
            else:
                raise ValueError("Memory limit must be larger than 0, or -1")
        else:
            memory_arg = int(memory_limit)

    // Compute naive cost
    // This isnt quite right, need to look into exactly how einsum does this
    // indices_in_input = input_subscripts.replace(",", "")
    // inne
    inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
    naive_cost = helpers.flop_count(indices, inner_product, len(input_list), dimension_dict)

    // Compute the path
    if not isinstance(path_type, str):
        path = path_type
    elif len(input_list) == 1:
        // Nothing to be optimized
        path = [(0, )]
    elif len(input_list) == 2:
        // Nothing to be optimized
        path = [(0, 1)]
    elif indices == output_set:
        // If no rank reduction leave it to einsum
        path = [tuple(range(len(input_list)))]
    elif path_type in ["greedy", "opportunistic"]:
        path = paths.greedy(input_sets, output_set, dimension_dict, memory_arg)
    elif path_type == "optimal":
        path = paths.optimal(input_sets, output_set, dimension_dict, memory_arg)
    else:
        raise KeyError("Path name %s not found" % path_type)

    cost_list = []
    scale_list = []
    size_list = []
    contraction_list = []

    // Build contraction tuple (positions, gemm, einsum_str, remaining)
    for cnum, contract_inds in enumerate(path):
        // Make sure we remove inds from right to left
        contract_inds = tuple(sorted(list(contract_inds), reverse=True))

        contract_tuple = helpers.find_contraction(contract_inds, input_sets, output_set)
        out_inds, input_sets, idx_removed, idx_contract = contract_tuple

        // Compute cost, scale, and size
        cost = helpers.flop_count(idx_contract, idx_removed, len(contract_inds), dimension_dict)
        cost_list.append(cost)
        scale_list.append(len(idx_contract))
        size_list.append(helpers.compute_size_by_dict(out_inds, dimension_dict))

        tmp_inputs = [input_list.pop(x) for x in contract_inds]
        tmp_shapes = [input_shps.pop(x) for x in contract_inds]

        if use_blas:
            do_blas = blas.can_blas(tmp_inputs, out_inds, idx_removed, tmp_shapes)
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 4

Non-data size: 4

Instances


Project Name: dgasmith/opt_einsum
Commit Name: 2dde998f8ffcc4c4c5520e7402302547278c6e91
Time: 2018-08-08
Author: johnniemcgray@gmail.com
File Name: opt_einsum/contract.py
Class Name:
Method Name: contract_path


Project Name: keras-team/keras
Commit Name: 57612707c1434ed3b18f9cad5cf2e6dd8b7b3b7d
Time: 2015-11-02
Author: roller@cs.utexas.edu
File Name: keras/layers/core.py
Class Name: Merge
Method Name: __init__


Project Name: ClimbsRocks/auto_ml
Commit Name: 9b9e491bb00be66b732d2f44b3e4375206940e61
Time: 2016-10-19
Author: climbsbytes@gmail.com
File Name: auto_ml/predictor.py
Class Name: Predictor
Method Name: make_sub_x_and_y_test


Project Name: ClimbsRocks/auto_ml
Commit Name: 33851dea0f6c75ca1e685037393ea7160506b53a
Time: 2016-10-08
Author: climbsbytes@gmail.com
File Name: auto_ml/predictor.py
Class Name: Predictor
Method Name: _prepare_for_training