Skip to content
Snippets Groups Projects
Commit 3c5f987e authored by sicer's avatar sicer
Browse files

update rnn agent

parent b8ae8428
No related branches found
No related tags found
No related merge requests found
Pipeline #9646 failed
REGISTRY = {}
from .rnn_agent import RNNAgent
REGISTRY['rnn'] = RNNAgent
from .casec_rnn_agent import CASECRNNAgent
from .casec_rnn_agent import CASECPairRNNAgent
......
import torch.nn as nn
import torch.nn.functional as F
class RNNAgent(nn.Module):
def __init__(self, input_shape, args):
super(RNNAgent, self).__init__()
self.args = args
self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)
def init_hidden(self):
# make hidden states on same device as model
return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()
def forward(self, inputs, hidden_state):
x = F.relu(self.fc1(inputs))
h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim)
h = self.rnn(x, h_in)
q = self.fc2(h)
return q, h
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment