diff --git a/src/modules/agents/__init__.py b/src/modules/agents/__init__.py index a76ea1aa6ae88617d2825d187f39f3d32dbafbcc..82d5319b52a2aaa8381d78a1e3d1073f9b5577ee 100755 --- a/src/modules/agents/__init__.py +++ b/src/modules/agents/__init__.py @@ -1,5 +1,7 @@ REGISTRY = {} +from .rnn_agent import RNNAgent +REGISTRY['rnn'] = RNNAgent from .casec_rnn_agent import CASECRNNAgent from .casec_rnn_agent import CASECPairRNNAgent diff --git a/src/modules/agents/rnn_agent.py b/src/modules/agents/rnn_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..9f80003a6522017b894c21911e8c7ce14983580b --- /dev/null +++ b/src/modules/agents/rnn_agent.py @@ -0,0 +1,23 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class RNNAgent(nn.Module): + def __init__(self, input_shape, args): + super(RNNAgent, self).__init__() + self.args = args + + self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim) + self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim) + self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions) + + def init_hidden(self): + # make hidden states on same device as model + return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_() + + def forward(self, inputs, hidden_state): + x = F.relu(self.fc1(inputs)) + h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim) + h = self.rnn(x, h_in) + q = self.fc2(h) + return q, h