1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| def update(self): if len(self.memory) < self.batch_size: return state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample( self.batch_size) state_batch = torch.tensor(state_batch, device=self.device, dtype=torch.float) action_batch = torch.tensor(action_batch, device=self.device).unsqueeze(1) reward_batch = torch.tensor(reward_batch, device=self.device, dtype=torch.float) next_state_batch = torch.tensor(next_state_batch, device=self.device, dtype=torch.float) done_batch = torch.tensor(np.float32(done_batch), device=self.device) q_values = self.policy_net(state_batch).gather(dim=1, index=action_batch) next_q_values = self.target_net(next_state_batch).max(1)[0].detach() expected_q_values = reward_batch + self.gamma * next_q_values * (1-done_batch) loss = nn.MSELoss()(q_values, expected_q_values.unsqueeze(1)) self.optimizer.zero_grad() loss.backward() for param in self.policy_net.parameters(): param.grad.data.clamp_(-1, 1) self.optimizer.step()
|