What if We Treated Speech Like Language?

Traditional text-to-speech systems use specialized neural architectures designed specifically for audio generation. But what if we approached speech synthesis the same way we approach language modeling?

That's the core idea behind Orpheus-3B—a 3 billion parameter language model adapted for speech synthesis. Instead of treating TTS as a unique problem requiring bespoke architectures, Orpheus treats it as sequence-to-sequence translation: text tokens in, audio tokens out.

This tutorial follows Unsloth AI's Orpheus-3B approach, adapted for Welsh using the BU-TTS Welsh-English dataset. You'll see how their optimized training framework makes LLM-based TTS accessible for low-resource languages like Welsh.

Why Orpheus-3B?

The LLM-Based TTS Revolution

For years, TTS systems relied on architectures like Tacotron, VITS, and WaveNet—neural networks specifically designed to convert text to speech. These work well, but they require building and training complex, specialized pipelines.

Language models like GPT, Claude, and Llama have shown that transformer architectures can handle virtually any sequence-to-sequence task. Orpheus-3B from Unsloth AI applies this insight to speech synthesis.

The key innovation? Treating audio as discrete tokens, just like words.

Enter SNAC: Audio as Tokens

SNAC (SN Audio Codec) is a neural audio codec that compresses audio waveforms into discrete token sequences. Think of it as tokenization for sound—the same way "Hello, world!" becomes [15496, 11, 1917, 0] in text, audio waveforms become sequences of integer codes.

This lets us feed audio into language models the same way we feed text. The model learns to predict audio token sequences given text token sequences—standard next-token prediction, just with a twist.

Benefits for Welsh:

  • Better prosody and naturalness (LLMs excel at long-range dependencies)
  • Handles code-switching naturally (Welsh speakers often mix Welsh and English)
  • Transfer learning from massive pre-trained models
  • Parameter-efficient fine-tuning with LoRA

The Dataset

I used the BU-TTS dataset I detailed in a previous post—approximately 10 hours of bilingual Welsh-English recordings from 4 speakers (2 male, 2 female, representing north and south Welsh accents).

The dataset is publicly available on Hugging Face, making this tutorial fully reproducible.

Why is this dataset size sufficient when LLMs typically need massive amounts of data? Because we're fine-tuning, not training from scratch. Orpheus-3B starts with English TTS capabilities learned from much larger datasets. We're teaching it Welsh, leveraging what it already knows about speech synthesis.

Technical Approach

LoRA: Parameter-Efficient Fine-Tuning

Training a 3B parameter model from scratch would require enormous computational resources. Unsloth's approach uses LoRA (Low-Rank Adaptation)—a technique that freezes most of the model and only trains small "adapter" layers.

LoRA works by inserting trainable rank decomposition matrices into the model's attention layers. Instead of updating all 3 billion parameters, we train just a few million additional parameters. This makes fine-tuning:

  • Faster: Less computation per training step
  • Cheaper: Fits on consumer GPUs
  • More stable: Less risk of catastrophic forgetting

SNAC Tokenization

Before training, we need to convert our audio files into tokens. SNAC operates at 24kHz and uses a hierarchical encoding scheme:

  • 7 tokens per time step: One coarse token + 6 fine-grained tokens
  • Interleaved encoding: Balances quality and compression
  • Discrete codes: Audio becomes integer sequences the LLM can process

Code Walkthrough

Here's how Unsloth's implementation works, with adaptations for Welsh. The code follows their notebook structure—you can adapt this approach for any language with a speech corpus.

Step 1: Load the Model

First, we load the pre-trained Orpheus-3B model using Unsloth's FastLanguageModel wrapper:

from unsloth import FastLanguageModel
import os

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/orpheus-3b-0.1-ft",
    max_seq_length=4096,
    dtype=None,  # Auto-detect best dtype for your GPU
    load_in_4bit=True,  # Use 4-bit quantization to save memory
    token=os.environ["HF_TOKEN"],
)

The load_in_4bit parameter is crucial if you're training on consumer hardware. It reduces memory usage significantly with minimal quality loss.

Step 2: Configure LoRA

Next, we wrap the model with LoRA adapters:

model = FastLanguageModel.get_peft_model(
    model,
    r=128,  # LoRA rank (higher = more parameters, better quality)
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=128,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
)

Key parameters:

  • r=128: The rank of the LoRA matrices. Higher ranks mean more trainable parameters and potentially better quality, but also more memory usage.
  • target_modules: Which parts of the transformer to adapt. We're targeting all the projection layers in the attention mechanism and feed-forward network.
  • lora_alpha: Scaling factor for the LoRA updates.

Step 3: Prepare the Dataset

Load the BU-TTS dataset from Hugging Face and prepare it for training:

from datasets import load_dataset, Audio
from snac import SNAC

# Load the public BU-TTS dataset
dataset = load_dataset("techiaith/bu-tts-cy-en", split="train")

# Rename columns to match Orpheus format
dataset = dataset.rename_column("speaker", "source")
dataset = dataset.rename_column("sentence", "text")

# Resample audio to 24kHz (SNAC's expected rate)
dataset = dataset.cast_column("audio", Audio(sampling_rate=24000))

# Load SNAC model for audio tokenization
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
snac_model = snac_model.to("cuda")

Step 4: Text Preprocessing

Welsh has special characters (ŵ, ŷ, etc.) that need handling:

def cleanup_text(inputs):
    """Normalize Welsh-specific characters"""
    replacements = [
        ("–", ""),  # Remove em dashes
        ("ŵ", "w"),  # Normalize circumflex w
        ("Ŵ", "w"),
        ("ŷ", "y"),  # Normalize circumflex y
        ("'", "'"),  # Normalize apostrophes
    ]
    for src, dst in replacements:
        inputs["text"] = inputs["text"].replace(src, dst)
    return inputs

dataset = dataset.map(cleanup_text)

Why normalize? The base model was trained on English, which doesn't include Welsh diacritics. Normalizing to ASCII equivalents improves tokenization and reduces out-of-vocabulary tokens.

Step 5: Audio Tokenization

This is where SNAC converts audio waveforms into discrete tokens:

import torch
import torchaudio.transforms as T

def tokenise_audio(waveform, sample_rate, snac_model):
    """Convert audio waveform to SNAC tokens"""
    waveform = torch.from_numpy(waveform).unsqueeze(0)
    waveform = waveform.to(dtype=torch.float32)

    # Resample if needed
    if sample_rate != 24000:
        resample_transform = T.Resample(orig_freq=sample_rate, new_freq=24000)
        waveform = resample_transform(waveform)

    waveform = waveform.unsqueeze(0).to("cuda")

    # Generate SNAC codes (7 tokens per timestep)
    with torch.inference_mode():
        codes = snac_model.encode(waveform)

    # Interleave the hierarchical codes
    all_codes = []
    for i in range(codes[0].shape[1]):
        all_codes.append(codes[0][0][i].item() + 128266)  # Coarse code
        all_codes.append(codes[1][0][2 * i].item() + 128266 + 4096)  # Fine 1
        all_codes.append(codes[2][0][4 * i].item() + 128266 + (2 * 4096))  # Fine 2
        all_codes.append(codes[2][0][(4 * i) + 1].item() + 128266 + (3 * 4096))  # Fine 3
        all_codes.append(codes[1][0][(2 * i) + 1].item() + 128266 + (4 * 4096))  # Fine 4
        all_codes.append(codes[2][0][(4 * i) + 2].item() + 128266 + (5 * 4096))  # Fine 5
        all_codes.append(codes[2][0][(4 * i) + 3].item() + 128266 + (6 * 4096))  # Fine 6

    return all_codes

def add_codes(example, sample_rate, snac_model):
    """Add SNAC codes to dataset example"""
    try:
        audio_array = example["audio"]["array"]
        example["codes_list"] = tokenise_audio(audio_array, sample_rate, snac_model)
    except Exception as e:
        print(f"Skipping row due to error: {e}")
        example["codes_list"] = None
    return example

# Apply tokenization to entire dataset
dataset = dataset.map(
    lambda example: add_codes(example, 24000, snac_model),
    remove_columns=["audio"],
)

# Filter out any failed tokenizations
dataset = dataset.filter(lambda x: x["codes_list"] is not None)
dataset = dataset.filter(lambda x: len(x["codes_list"]) > 0)

The offsets (128266, +4096, etc.) map SNAC codes into the model's vocabulary space, ensuring they don't collide with text tokens.

Step 6: Format for Training

Create the final input format Orpheus expects:

def create_input_ids(example, tokenizer):
    """Format text and audio codes for training"""
    # Include speaker ID if present
    text_prompt = f"{example['source']}: {example['text']}" if "source" in example else example["text"]

    # Tokenize text
    text_ids = tokenizer.encode(text_prompt, add_special_tokens=True)
    text_ids.append(128009)  # End of text token

    # Build full sequence: <human> text </human> <ai> <speech> audio_codes </speech> </ai>
    input_ids = (
        [128259]  # Start of human
        + text_ids
        + [128260]  # End of human
        + [128261]  # Start of AI
        + [128257]  # Start of speech
        + example["codes_list"]
        + [128258]  # End of speech
        + [128262]  # End of AI
    )

    return {
        "input_ids": input_ids,
        "labels": input_ids,  # For causal LM, labels = inputs
        "attention_mask": [1] * len(input_ids),
    }

dataset = dataset.map(
    lambda example: create_input_ids(example, tokenizer),
    remove_columns=["text", "codes_list", "source"],
)

Step 7: Train

Now we're ready to train:

from transformers import TrainingArguments, Trainer
from unsloth import is_bfloat16_supported

trainer = Trainer(
    model=model,
    train_dataset=dataset,
    args=TrainingArguments(
        per_device_train_batch_size=1,  # Small batch for memory efficiency
        gradient_accumulation_steps=128,  # Effective batch size = 128
        warmup_steps=500,
        num_train_epochs=10,
        learning_rate=2e-4,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=1,
        report_to=["tensorboard"],
        optim="adamw_8bit",  # Memory-efficient optimizer
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir="outputs",
        push_to_hub=True,
        hub_model_id="your-username/orpheus3b-cy-en",
    ),
)

trainer.train()

Training tips:

  • gradient_accumulation_steps=128: Simulates larger batch sizes without the memory cost
  • adamw_8bit: 8-bit quantized optimizer saves GPU memory
  • Expect ~3-5 days on a single consumer GPU (RTX 3090/4090)

Step 8: Save and Push

After training, save the model and push to Hugging Face:

# Save LoRA adapters
model.save_pretrained("./lora_model")
tokenizer.save_pretrained("./lora_model")

# Push to Hugging Face Hub
model.push_to_hub("your-username/orpheus3b-cy-en", token=os.environ["HF_TOKEN"])
tokenizer.push_to_hub("your-username/orpheus3b-cy-en", token=os.environ["HF_TOKEN"])

# Optionally merge and push full 16-bit model
model.push_to_hub_merged(
    "your-username/orpheus3b-cy-en-16-bit",
    tokenizer,
    save_method="merged_16bit",
    token=os.environ["HF_TOKEN"],
)

Inference: Generating Speech

Now for the fun part—making your model talk! Here's how to generate Welsh speech:

from snac import SNAC
import soundfile as sf

# Load trained model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="str20tbl/orpheus3b-cy-en-16-bit",
    max_seq_length=4096,
    dtype=None,
    load_in_4bit=False,
)
FastLanguageModel.for_inference(model)  # Enable faster inference mode

# Load SNAC for decoding
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to("cuda")

# Prepare your text
prompt = "Mae yna wall yn y ffor mae nhw'n deud *wall*"

# Encode with special tokens
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
modified_input_ids = torch.cat([
    torch.tensor([[128259]]),  # Start of human
    input_ids,
    torch.tensor([[128009, 128260]]),  # End of text, end of human
], dim=1).to("cuda")

# Generate audio tokens
generated_ids = model.generate(
    input_ids=modified_input_ids,
    max_new_tokens=4096,
    do_sample=True,
    temperature=0.6,  # Controls randomness (lower = more consistent)
    top_p=0.95,  # Nucleus sampling threshold
    repetition_penalty=1.1,  # Discourage repetition
    eos_token_id=128258,  # End of speech token
)

# Extract audio tokens (after start-of-speech marker)
audio_tokens = generated_ids[0][generated_ids[0] == 128257][0] + 1:]  # After <speech>
audio_tokens = audio_tokens[audio_tokens != 128258]  # Before </speech>

# Convert back to SNAC codes (undo the offset mapping)
# ... (redistribution logic from training, reversed)

# Decode with SNAC
audio_waveform = snac_model.decode(codes)

# Save to file
sf.write("output.wav", audio_waveform.squeeze().cpu().numpy(), 24000)

Code-switching example:

prompts = [
    "Dwi'n hoffi Python programming",  # "I like Python programming"
    "Mae Google Translate yn dda iawn",  # "Google Translate is very good"
    "Bydd hi'n braf yn Barcelona heddiw",  # "It'll be nice in Barcelona today"
]

The model handles Welsh-English mixing naturally because we trained on bilingual data with speaker IDs preserved.

Results

Using Unsloth's training approach, I trained both 16-bit and 4-bit versions on the Welsh dataset:

Quality:

  • Noticeably more natural prosody compared to traditional VITS models
  • Handles code-switching smoothly
  • Better intonation on questions and exclamations
  • Occasional artifacts on very short phrases (works best with full sentences)

Inference speed:

  • 16-bit: ~2-3x real-time on RTX 4090
  • 4-bit: ~4-5x real-time on RTX 4090
  • Much slower than VITS, but quality improvements often worth it

Challenges & Learnings

Welsh Character Handling

Welsh uses diacritics (ŵ, ŷ, â, etc.) that aren't in the base model's vocabulary. For this adaptation, I normalized them to ASCII equivalents (ŵ→w) which worked well enough. Future work could expand the tokenizer vocabulary to include these natively.

Memory Management

Even with LoRA and 4-bit quantization, a 3B model pushes consumer GPU limits:

  • Batch size of 1 was necessary
  • Gradient accumulation (steps=128) simulated larger batches
  • 8-bit AdamW optimizer saved crucial memory

Duplicate Frame Removal

SNAC occasionally generates repeated frames, creating stutter effects. The Unsloth notebook includes filtering logic to detect and remove consecutive duplicate frames:

def remove_duplicate_frames(codes_list):
    """Remove stuttering from duplicate SNAC frames"""
    result = codes_list[:7]  # Keep first frame
    for i in range(7, len(codes_list), 7):
        current_first = codes_list[i]
        previous_first = result[-7]
        if current_first != previous_first:
            result.extend(codes_list[i:i+7])
    return result

4-bit vs 16-bit

The 4-bit model is noticeably faster and uses less memory, but has subtle quality degradation:

  • Slightly less natural prosody
  • More pronunciation artifacts
  • Still quite usable for many applications

For production use, I'd recommend the 16-bit model if resources allow.

Try It Yourself

Want to experiment with Welsh TTS or train on your own language?

Resources:

To adapt this for your language:

  1. Prepare a dataset (even 5-10 hours can work with transfer learning)
  2. Follow the code walkthrough above
  3. Adjust text preprocessing for your language's character set
  4. Fine-tune for 5-10 epochs
  5. Experiment with temperature and sampling parameters

Acknowledgments

This work is based on Unsloth AI's Orpheus-3B notebook. The LoRA configuration, SNAC tokenization approach, and training methodology all come from their implementation. I've adapted their approach for Welsh using the BU-TTS dataset, with language-specific preprocessing for Welsh characters.

Special thanks to the Unsloth team for making state-of-the-art TTS training accessible and efficient.

Closing Thoughts

LLM-based TTS represents a shift in how we approach speech synthesis for low-resource languages. Instead of building specialized architectures from scratch, we can leverage massive pre-trained models and adapt them with relatively small datasets.

For Welsh—and other lesser-resourced languages—this is huge. It means we don't need tens of thousands of hours of audio to build competitive TTS systems. We can start with what we have (in this case, the ~10 hours from BU-TTS) and still achieve natural-sounding results.

This work builds directly on the BU-TTS foundation. Creating that dataset during a pandemic was challenging, but it's enabled this kind of modern TTS research. Open datasets compound—each new technique can build on them without redoing the collection work.

What's next?

  • Multi-speaker control (choose different voices)
  • Emotion and style transfer
  • Longer context (currently limited to 4096 tokens)
  • Better handling of Welsh diacritics
  • Streaming inference for real-time applications

Welsh language technology is advancing rapidly. Tools like this make it easier for Welsh speakers to interact with technology in their native language—code-switching and all.


Have questions about the implementation? Training your own model? Share your experiences in the comments or reach out—I'd love to hear how you adapt this for other languages.