Commit 07d92e31 authored by Andre Biedenkapp's avatar Andre Biedenkapp
Browse files

MAINT directly log the played policy

parent e7a421b2
Pipeline #17848 passed with stages
in 47 minutes and 1 second
......@@ -117,6 +117,7 @@ class SAT4JEnvSelHeur(Env):
self._inst_pointer = 0
self.__reward_type = reward_type
self.too_simple_inst = []
self._last_msg = None
@staticmethod
def _save_div(a, b):
......@@ -221,6 +222,7 @@ class SAT4JEnvSelHeur(Env):
self.close()
raise Exception('Connection unexpected closed')
self.conn.sendall(json.dumps(msg).encode('utf-8') + "\n".encode('utf-8'))
self._last_msg = msg
s, r, d = self._process_data()
info = {}
if d:
......@@ -238,6 +240,7 @@ class SAT4JEnvSelHeur(Env):
Initialize SAT4J
:return:
"""
self._last_msg = None
self.done = False
self._prev_state = None
self.__step = 0
......
......@@ -180,9 +180,9 @@ class DQN:
"""
print('Begin Evaluation')
if not only_checkpoint:
eval_s, eval_r, eval_d = self.eval(eval_eps, max_env_time_steps)
eval_s, eval_r, eval_d, eval_pols = self.eval(eval_eps, max_env_time_steps)
else:
eval_s, eval_r, eval_d = [-1], [-1], [-1]
eval_s, eval_r, eval_d, eval_pols = [-1], [-1], [-1], {}
checkpoint_path = os.path.join(out_dir, 'checkpoints', f'{total_steps:05d}')
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)
......@@ -198,7 +198,8 @@ class DQN:
eval_eps=eval_eps,
eval_insts=self._eval_env.instances,
reward_per_insts=eval_r,
steps_per_insts=eval_s
steps_per_insts=eval_s,
policies_per_insts=eval_pols
)
if only_checkpoint:
eval_stats['checkpoint_path'] = checkpoint_path
......@@ -211,9 +212,9 @@ class DQN:
# Do the same thing on the training data if required
if self._train_eval_env is not None:
if not only_checkpoint:
eval_s, eval_r, eval_d = self.eval(eval_eps, max_env_time_steps, train_set=True)
eval_s, eval_r, eval_d, eval_pols = self.eval(eval_eps, max_env_time_steps, train_set=True)
else:
eval_s, eval_r, eval_d = [-1], [-1], [-1]
eval_s, eval_r, eval_d, eval_pols = [-1], [-1], [-1], {}
checkpoint_path = os.path.join(out_dir, 'checkpoints', total_steps)
eval_stats = dict(
elapsed_time=time.time() - start_time,
......@@ -226,7 +227,8 @@ class DQN:
eval_eps=eval_eps,
eval_insts=self._train_eval_env.instances,
reward_per_insts=eval_r,
steps_per_insts=eval_s
steps_per_insts=eval_s,
policies_per_insts=eval_pols
)
if only_checkpoint:
eval_stats['checkpoint_path'] = checkpoint_path
......@@ -317,12 +319,18 @@ class DQN:
:returns (steps per episode), (reward per episode), (decisions per episode)
"""
steps, rewards, decisions = [], [], []
policies = {}
this_env = self._eval_env if not train_set else self._train_eval_env
with torch.no_grad():
for e in range(episodes):
ed, es, er = 0, 0, 0
s = this_env.reset()
if this_env.instance not in policies:
policies[this_env.instance] = [[]]
else:
policies[this_env.instance].append([])
policies[this_env.instance][-1].append(this_env._last_msg)
for _ in count():
a = self.get_action(s, 0)
if self._facts is not None:
......@@ -332,6 +340,7 @@ class DQN:
ed += 1
ns, r, d, _ = this_env.step(env_a)
policies[this_env.instance][-1].append(this_env._last_msg)
er += r
es += 1
if es >= max_env_time_steps or d:
......@@ -341,7 +350,7 @@ class DQN:
rewards.append(er)
decisions.append(ed)
return steps, rewards, decisions
return steps, rewards, decisions, policies
def save_model(self, path):
torch.save(self._q.state_dict(), os.path.join(path, 'Q'))
......@@ -475,14 +484,15 @@ if __name__ == "__main__":
else:
print(f'Validating {data["checkpoint_path"]}')
agent.load(data['checkpoint_path'])
eval_s, eval_r, eval_d = agent.eval(num_eval_episodes, max_env_time_steps,
train_set=args.validate_type == 'train')
eval_s, eval_r, eval_d, eval_p = agent.eval(num_eval_episodes, max_env_time_steps,
train_set=args.validate_type == 'train')
data['avg_num_steps_per_eval_ep'] = float(np.mean(eval_s))
data['avg_num_decs_per_eval_ep'] = float(np.mean(eval_d))
data['avg_rew_per_eval_ep'] = float(np.mean(eval_r))
data['std_rew_per_eval_ep'] = float(np.mean(eval_r))
data['reward_per_insts'] = eval_r
data['steps_per_insts'] = eval_s
data['policies_per_insts'] = eval_p
if args.validate_type != 'train':
insts_looked_at = []
counter = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment