Skip to content

Conversation

Deep-unlearning
Copy link

@Deep-unlearning Deep-unlearning commented Apr 29, 2025

What does this PR do?

This PR adds support for XCodec2 a high fidelity general neural audio codec used in Llasa a Text-to-Speech model, to the Transformers library.

This model is composed of 5 components:

  • A Semantic Encoder
  • An Acoustic Encoder
  • A VectorQuantizer
  • A Semantic Decoder
  • An Acoustic Decoder

This is still a draft PR. Work done so far:

  • Adapted the model to Transformers format in modeling_xcodec2.py and modular_xcodec2.py.

Todo

  • Add the checkpoint conversion scripts and push to the hub
  • Support batch inference
  • Write Tests
  • Add documentation

Who can review?

cc: @ArthurZucker
cc: @eustlb @Vaibhavs10 for visibility

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker
Copy link
Collaborator

ff to ping me once this is ready!

@Deep-unlearning Deep-unlearning changed the title [WiP] Add xcodec model [WiP] Add xcodec2 model Jun 3, 2025
@Deep-unlearning Deep-unlearning marked this pull request as ready for review June 5, 2025 16:03
@ArthurZucker ArthurZucker removed the request for review from Rocketknight1 July 7, 2025 12:05
@ArthurZucker ArthurZucker removed their request for review July 7, 2025 12:05
@Deep-unlearning Deep-unlearning requested a review from eustlb July 7, 2025 13:53
@ebezzam ebezzam added the Audio label Jul 22, 2025
@ebezzam ebezzam self-requested a review July 24, 2025 15:56
Copy link
Contributor

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Deep-unlearning my first time reviewing a model addition so sorry if I'm nit-picky! I recently did a deep dive through DAC and EnCodec so most of my comments are about making things consistent with those models:

  • simplifying the configuration
  • whether we keep nn.Sequential and weight_norm. @eustlb will probably know better for that
  • similar integration tests

class XCodec2IntegrationTest(unittest.TestCase):
def test_integration(self):
expected_rmse = 0.07212554663419724
expected_codes = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrap around

# fmt: off
expected_codes = [...]
# fmt: on

to avoid make fixup from making new line for each element

audio_codes = model.encode(inputs["input_values"], return_dict=False)
codes = audio_codes.squeeze(0).squeeze(0).tolist()

self.assertEqual(codes, expected_codes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check new tests from EnCodec that @eustlb and I iterated on, namely:

  • checking with torch.testing.assert_close
  • making gist out of the script you used to compute the expected outputs



# Copied from transformers.tests.encodec.test_modeling_encodec.compute_rmse
def compute_rmse(arr1, arr2):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not being used? also there's a new version

arr_enc_dec = input_values_enc_dec[0].cpu().numpy()
arr_enc_dec_truncated = arr_enc_dec[:, : arr.shape[1]]
rmse = np.sqrt(((arr - arr_enc_dec_truncated) ** 2).mean())
self.assertTrue(np.abs(rmse - expected_rmse) < 1e-6)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding a batch test?

return x


class XCodec2DecoderLayer(LlamaDecoderLayer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for @eustlb
This component is meant to replace TransformerBlock from the original implementation

Copy link
Contributor

@eustlb eustlb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work, @Deep-unlearning! 🤗
For this first review pass, I focused on reacting to @ebezzam’s comments. I’ll take a broader look once those are addressed.

return weight_norm(nn.Conv1d(*args, **kwargs))


class XCodec2SnakeBeta(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both are ok IMO, here since it's specifically a modified snake I would actually keep XCodec2SnakeBeta

Comment on lines 556 to 573
class EncoderBlock(nn.Module):
def __init__(self, dim: int = 16, stride: int = 1, dilations=(1, 3, 9)):
super().__init__()
runits = [ResidualUnit(dim // 2, dilation=d) for d in dilations]
self.block = nn.Sequential(
*runits,
Activation1d(activation=XCodec2SnakeBeta(dim // 2, alpha_logscale=True)),
WNConv1d(
dim // 2,
dim,
kernel_size=2 * stride,
stride=stride,
padding=stride // 2 + stride % 2,
),
)

def forward(self, x):
return self.block(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If simple and clear, yes!

nn.init.constant_(m.bias, 0)


class XCodec2CodecEncoder_Transformer(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a naming for aligned with Transformers conventions (letting you check other modelling), we do not take _

Comment on lines 1518 to 1532
semantic_model_config = AutoConfig.from_pretrained("facebook/w2v-bert-2.0", output_hidden_states=True)
self.semantic_model = AutoModel.from_config(semantic_model_config)
self.semantic_model.eval()

self.SemanticEncoder_module = XCodec2SemanticEncoder(
config.semantic_hidden_size, config.semantic_hidden_size, config.semantic_hidden_size
)

self.CodecEnc = XCodec2CodecEncoder_Transformer()

self.generator = XCodec2CodecDecoderVocos(config=config)

self.fc_prior = nn.Linear(2048, 2048)
self.fc_post_a = nn.Linear(2048, 1024)
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eustlb should loading other checkpoints with from_pretrained be wrapped in something like a processor? or fine to use inside modeling code?

@ebezzam ebezzam requested a review from eustlb September 4, 2025 10:57
Copy link
Contributor

@eustlb eustlb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's iterate on the attention implementation. A lot of lines seems to come from handling non-causal attention, which should be simply handled by setting self.is_causal = False when inheriting and passing attention mask when required.

value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
is_causal=self.is_causal,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from what I understood, this is_causal is the only diff with LlamaAttention. I don't get why it's necessary though?
Normally, setting self.is_causal = False should be enough (see here)? Or of curse I am missing a specificity here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right. I didn't see where you pointed that there is a getattr(module, "is_causal", True), so I had kept Steven's version since it wasn't clear how/if self.is_causal is used in LlamaAttention because it isn't passed to attention_interface.

In my opinion, it could be made clear in LlamaAttention by having a similar comment also in LlamaAttention here.

return x


class Xcodec2DecoderLayer(LlamaDecoderLayer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get what we're doing here. Let's say we're doing non-causal attention. To get a specific non-causal attention, simply passing the correct attention mask to LLamaAttention is enough. When no attention_mask is providied, simply having self.is_causal=False in LLamaAttention is enough to have non-causal attention, otherwise it means there's a (BIG) bug.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking into it. The original author's implementation seems closer to LlamaDecoderLayer but that leads to some functional issues when passing attention_mask as is.

Comment on lines 307 to 308
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
Copy link
Contributor

@ebezzam ebezzam Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eustlb FYI original model has this set to True and it's never set to False at any point (usage). I will remove as I don't see why we would want these parameters to be trainable during inference?

Copy link
Contributor

github-actions bot commented Sep 5, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, xcodec, xcodec2

@@ -539,7 +539,7 @@ def forward(

>>> inputs = feature_extractor(raw_audio=audio_sample, return_tensors="pt")

>>> outputs = model(**inputs)
>>> outputs = model(inputs["input_values"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eustlb DAC, Xcodec, and Xcodec2 don't support model(**inputs) as padding_mask is not an accepted input. Is that fine? or should padding_mask be added as an input even if it isn't used?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants