// ensure n sum up to min_tr
idx = np.argmax(n)
n[idx] = n[idx] + min_tr - np.sum(n)
trace = np.concatenate([np.random.choice(traces[i], j)
for i, j in enumerate(n)])
obs = [x for m in models for x in m.observed_RVs]
variables = np.repeat(obs, n)
After Change
n[idx] = n[idx] + min_tr - np.sum(n)
trace = []
for i, j in enumerate(n):
tr = traces[i]
len_trace = len(tr)
nchain = tr.nchains
indices = np.random.randint(0, nchain*len_trace, j)
chain_idx, point_idx = np.divmod(indices, len_trace)
for idx in zip(chain_idx, point_idx):
trace.append(tr._straces[idx[0]].point(idx[1]))
obs = [x for m in models for x in m.observed_RVs]
variables = np.repeat(obs, n)