for i in range(total_nodes):
for j in range(total_nodes):
if binary_mask[i, j] == 1:
dg.add_edge(i, j)
pos[i] = 2. * np.array(number_to_type_layer(i, n_types))[::-1]
After Change
plt.figure(figsize=(12, 12))
values = [val_map.get(node, 0.25) for node in nodes]
nx.draw(dg, pos, cmap=plt.get_cmap("jet"), node_color=values, node_size=7000, alpha=0.3)
nx.draw_networkx_labels(dg, pos, nodes, font_size=18)