From 3c5f987eec17dd56168a87a56e32f8880e397365 Mon Sep 17 00:00:00 2001 From: sicer <mansicer@qq.com> Date: Sun, 14 Nov 2021 21:53:18 +0800 Subject: [PATCH] update rnn agent --- src/modules/agents/__init__.py | 2 ++ src/modules/agents/rnn_agent.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 src/modules/agents/rnn_agent.py diff --git a/src/modules/agents/__init__.py b/src/modules/agents/__init__.py index a76ea1a..82d5319 100755 --- a/src/modules/agents/__init__.py +++ b/src/modules/agents/__init__.py @@ -1,5 +1,7 @@ REGISTRY = {} +from .rnn_agent import RNNAgent +REGISTRY['rnn'] = RNNAgent from .casec_rnn_agent import CASECRNNAgent from .casec_rnn_agent import CASECPairRNNAgent diff --git a/src/modules/agents/rnn_agent.py b/src/modules/agents/rnn_agent.py new file mode 100644 index 0000000..9f80003 --- /dev/null +++ b/src/modules/agents/rnn_agent.py @@ -0,0 +1,23 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class RNNAgent(nn.Module): + def __init__(self, input_shape, args): + super(RNNAgent, self).__init__() + self.args = args + + self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim) + self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim) + self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions) + + def init_hidden(self): + # make hidden states on same device as model + return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_() + + def forward(self, inputs, hidden_state): + x = F.relu(self.fc1(inputs)) + h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim) + h = self.rnn(x, h_in) + q = self.fc2(h) + return q, h -- GitLab