}, minibatch.count)))[policy_id]
for k, v in batch_fetches.get(LEARNER_STATS_KEY, {}).items():
iter_extra_fetches[k].append(v)
logger.debug("{} {}".format(i, averaged(iter_extra_fetches)))
fetches[policy_id] = averaged(iter_extra_fetches)
return fetches
After Change
if isinstance(samples, SampleBatch):
samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count)
fetches = defaultdict(dict)
for policy_id in policies.keys():
if policy_id not in samples.policy_batches:
continue