Anyone write an attention in ml-agents? i try to write it but it doesn't work

I try to make an attention mechanism into ml-agents trainer, it shows no error, but when I train, the agents just jump, and they can not be better.
here is my code: i make a class of attention in ml-agent trainer (just like the attention they make in MA-POCA)

class attentionmechanism(torch.nn.Module):
    def __init__(self,embedding_size: int):
        super().__init__()
        self.embedding_size: int = embedding_size
       self.q = linear_layer(
            embedding_size,
            embedding_size,
            kernel_init=Initialization.Normal,
            kernel_gain=(0.125 / embedding_size) ** 0.5,
        )
        self.k = linear_layer(
            embedding_size,
            embedding_size,
            kernel_init=Initialization.Normal,
            kernel_gain=(0.125 / embedding_size) ** 0.5,
        )
        self.v = linear_layer(
            embedding_size,
            embedding_size,
            kernel_init=Initialization.Normal,
            kernel_gain=(0.125 / embedding_size) ** 0.5,
        )
        self.embedding_norm = LayerNorm()

    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        inp = self.embedding_norm(inp)
        query = self.q(inp)
        key = self.k(inp)
        value = self.v(inp)

        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.embedding_size ** 0.5)
        attention_weights = torch.nn.functional.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, value)

        
        return output, attention_weights

and then I add the attention to networkbody, because ppo trainer and critic all use networkbody:

class NetworkBody(nn.Module):
    def __init__(
        self,
        observation_specs: List[ObservationSpec],
        network_settings: NetworkSettings,
        encoded_act_size: int = 0,
    ):
        super().__init__()
        self.normalize = network_settings.normalize
        self.use_lstm = network_settings.memory is not None
        self.h_size = network_settings.hidden_units
        self.m_size = (
            network_settings.memory.memory_size
            if network_settings.memory is not None
            else 0
        )
        self.observation_encoder = ObservationEncoder(
            observation_specs,
            self.h_size,
            network_settings.vis_encode_type,
            self.normalize,
        )
        self.processors = self.observation_encoder.processors
        total_enc_size = self.observation_encoder.total_enc_size
        total_enc_size += encoded_act_size

        if (
            self.observation_encoder.total_goal_enc_size > 0
            and network_settings.goal_conditioning_type == ConditioningType.HYPER
        ):
            self._body_endoder = ConditionalEncoder(
                total_enc_size,
                self.observation_encoder.total_goal_enc_size,
                self.h_size,
                network_settings.num_layers,
                1,
            )
        else:
            self._body_endoder = LinearEncoder(
                total_enc_size, network_settings.num_layers, self.h_size
            )

        self.attention = AttentionMechanism(self.h_size)  
        if self.use_lstm:
            self.lstm = LSTM(self.h_size, self.m_size)
        else:
            self.lstm = None  # type: ignore

    def update_normalization(self, buffer: AgentBuffer) -> None:
        self.observation_encoder.update_normalization(buffer)

    def copy_normalization(self, other_network: "NetworkBody") -> None:
        self.observation_encoder.copy_normalization(other_network.observation_encoder)

    @property
    def memory_size(self) -> int:
        return self.lstm.memory_size if self.use_lstm else 0

    def forward(
        self,
        inputs: List[torch.Tensor],
        actions: Optional[torch.Tensor] = None,
        memories: Optional[torch.Tensor] = None,
        sequence_length: int = 1,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        encoded_self = self.observation_encoder(inputs)
        if actions is not None:
            encoded_self = torch.cat([encoded_self, actions], dim=1)
        if isinstance(self._body_endoder, ConditionalEncoder):
            goal = self.observation_encoder.get_goal_encoding(inputs)
            encoding = self._body_endoder(encoded_self, goal)
        else:
            encoding = self._body_endoder(encoded_self)

         # attention
        attention_output, _ = self.attention(encoding)
        if self.use_lstm:
            # Resize to (batch, sequence length, encoding size)
            encoding = encoding.reshape([-1, sequence_length, self.h_size])
            encoding, memories = self.lstm(attention_output, memories)
            encoding = encoding.reshape([-1, self.m_size // 2])
        else:
            encoding = attention_output
        return encoding, memories

but i can not let them learn anything, I check the input,I think maybe this is reason,the q,k,v is [16,256]? anyway,I search but it seems no example that using attention in ml-agents? Does anyone have experience?I can’t find useful resources, most examples are nlp and cv, so if someone have experience,it really helps a lot.