if varnames is None:
varnames = get_default_varnames(trace.varnames, include_transformed)
R = gelman_rubin(trace)R = {v: R[v] for v in varnames}
ax.set_title(title)
// Set x range
ax.set_xlim(0.9, 2.1)
// X axis labels
ax.set_xticks((1.0, 1.5, 2.0), ("1", "1.5", "2+"))
ax.set_yticks([-(l + 1) for l in range(len(labels))], "")
i = 1
for varname in varnames:
chain = trace.chains[0]
value = trace.get_values(varname, chains=[chain])[0]
k = np.size(value)
if k > 1:
ax.plot([min(r, 2) for r in R[varname]],
[-(j + i) for j in range(k)], "bo", markersize=4)
else:
ax.plot(min(R[varname], 2), -i, "bo", markersize=4)
After Change
chain = trace.chains[0]
value = trace.get_values(varname, chains=[chain])[0]
k = np.size(value)
R = gelman_rubin(trace, varnames=[varname])
if k > 1:
Rval = dict2pd(R, "rhat").values
ax.plot([min(r, 2) for r in Rval],
[-(j + i) for j in range(k)], "bo", markersize=4)
else:
ax.plot(min(R[varname], 2), -i, "bo", markersize=4)