Patch Mpt ❲Instant ✰❳

If you meant something else (ECU patch, firmware, audio plugin), let me know. Context: MPT (Modified Transformer) uses ALiBi or Rotary embeddings. This patch fixes rotary position cache invalidation and attention mask expansion for variable-length sequences in a custom MPT block.

batch = attention_mask.size(0)

def _update_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype): if seq_len == self._cached_seq_len: return inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self._cached_cos = emb.cos().to(dtype) self._cached_sin = emb.sin().to(dtype) self._cached_seq_len = seq_len patch mpt

# If already 4D, assume correct if attention_mask.dim() == 4: return attention_mask.to(dtype) If you meant something else (ECU patch, firmware,

def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: self._update_cache(seq_len, x.device, x.dtype) return self._cached_cos[:seq_len], self._cached_sin[:seq_len] 2. Patch Attention Mask Expansion (for cross-attention) ---------------------------------------------------------------------- def patch_attention_mask( attention_mask: torch.Tensor, query_length: int, key_length: int, dtype: torch.dtype, ) -> torch.Tensor: """ Expand mask from (batch, 1, key_len) or (batch, seq_len) to (batch, 1, query_len, key_len) for MPT attention. """ if attention_mask is None: return None batch = attention_mask

# Monkey-patch attention mask expansion function if model has it if hasattr(model, "_expand_attention_mask"): model._expand_attention_mask = patch_attention_mask print("[PATCH] Replaced _expand_attention_mask") Usage example ---------------------------------------------------------------------- if name == " main ": # Assume you have an MPT model loaded # from transformers import AutoModel # model = AutoModel.from_pretrained("mosaicml/mpt-7b", trust_remote_code=True) # apply_mpt_patches(model)

# Test rotary cache fix rotary = PatchedRotaryEmbedding(dim=64, max_seq_len=512) x = torch.randn(1, 10, 64) cos1, sin1 = rotary(x, seq_len=10) cos2, sin2 = rotary(x, seq_len=20) # seqlen changes → recalc cache assert cos1.shape[0] == 10 assert cos2.shape[0] == 20 print("Rotary cache patch: OK")