I try to make an attention mechanism into ml-agents trainer, it shows no error, but when I train, the agents just jump, and they can not be better.
here is my code: i make a class of attention in ml-agent trainer (just like the attention they make in MA-POCA)
class attentionmechanism(torch.nn.Module):
def __init__(self,embedding_size: int):
super().__init__()
self.embedding_size: int = embedding_size
self.q = linear_layer(
embedding_size,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.k = linear_layer(
embedding_size,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.v = linear_layer(
embedding_size,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.embedding_norm = LayerNorm()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
inp = self.embedding_norm(inp)
query = self.q(inp)
key = self.k(inp)
value = self.v(inp)
scores = torch.matmul(query, key.transpose(-2, -1)) / (self.embedding_size ** 0.5)
attention_weights = torch.nn.functional.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, value)
return output, attention_weights
and then I add the attention to networkbody, because ppo trainer and critic all use networkbody:
class NetworkBody(nn.Module):
def __init__(
self,
observation_specs: List[ObservationSpec],
network_settings: NetworkSettings,
encoded_act_size: int = 0,
):
super().__init__()
self.normalize = network_settings.normalize
self.use_lstm = network_settings.memory is not None
self.h_size = network_settings.hidden_units
self.m_size = (
network_settings.memory.memory_size
if network_settings.memory is not None
else 0
)
self.observation_encoder = ObservationEncoder(
observation_specs,
self.h_size,
network_settings.vis_encode_type,
self.normalize,
)
self.processors = self.observation_encoder.processors
total_enc_size = self.observation_encoder.total_enc_size
total_enc_size += encoded_act_size
if (
self.observation_encoder.total_goal_enc_size > 0
and network_settings.goal_conditioning_type == ConditioningType.HYPER
):
self._body_endoder = ConditionalEncoder(
total_enc_size,
self.observation_encoder.total_goal_enc_size,
self.h_size,
network_settings.num_layers,
1,
)
else:
self._body_endoder = LinearEncoder(
total_enc_size, network_settings.num_layers, self.h_size
)
self.attention = AttentionMechanism(self.h_size)
if self.use_lstm:
self.lstm = LSTM(self.h_size, self.m_size)
else:
self.lstm = None # type: ignore
def update_normalization(self, buffer: AgentBuffer) -> None:
self.observation_encoder.update_normalization(buffer)
def copy_normalization(self, other_network: "NetworkBody") -> None:
self.observation_encoder.copy_normalization(other_network.observation_encoder)
@property
def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0
def forward(
self,
inputs: List[torch.Tensor],
actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
encoded_self = self.observation_encoder(inputs)
if actions is not None:
encoded_self = torch.cat([encoded_self, actions], dim=1)
if isinstance(self._body_endoder, ConditionalEncoder):
goal = self.observation_encoder.get_goal_encoding(inputs)
encoding = self._body_endoder(encoded_self, goal)
else:
encoding = self._body_endoder(encoded_self)
# attention
attention_output, _ = self.attention(encoding)
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)
encoding = encoding.reshape([-1, sequence_length, self.h_size])
encoding, memories = self.lstm(attention_output, memories)
encoding = encoding.reshape([-1, self.m_size // 2])
else:
encoding = attention_output
return encoding, memories
but i can not let them learn anything, I check the input,I think maybe this is reason,the q,k,v is [16,256]? anyway,I search but it seems no example that using attention in ml-agents? Does anyone have experience?I can’t find useful resources, most examples are nlp and cv, so if someone have experience,it really helps a lot.