diff --git a/src/modules/agents/__init__.py b/src/modules/agents/__init__.py
index a76ea1aa6ae88617d2825d187f39f3d32dbafbcc..82d5319b52a2aaa8381d78a1e3d1073f9b5577ee 100755
--- a/src/modules/agents/__init__.py
+++ b/src/modules/agents/__init__.py
@@ -1,5 +1,7 @@
 REGISTRY = {}
 
+from .rnn_agent import RNNAgent
+REGISTRY['rnn'] = RNNAgent
 
 from .casec_rnn_agent import CASECRNNAgent
 from .casec_rnn_agent import CASECPairRNNAgent
diff --git a/src/modules/agents/rnn_agent.py b/src/modules/agents/rnn_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f80003a6522017b894c21911e8c7ce14983580b
--- /dev/null
+++ b/src/modules/agents/rnn_agent.py
@@ -0,0 +1,23 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class RNNAgent(nn.Module):
+    def __init__(self, input_shape, args):
+        super(RNNAgent, self).__init__()
+        self.args = args
+
+        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
+        self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
+        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)
+
+    def init_hidden(self):
+        # make hidden states on same device as model
+        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()
+
+    def forward(self, inputs, hidden_state):
+        x = F.relu(self.fc1(inputs))
+        h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim)
+        h = self.rnn(x, h_in)
+        q = self.fc2(h)
+        return q, h