Curriculum Learning
Humans learn better when concepts are introduced in a specific order: simple addition before multiplication, multiplication before calculus. Curriculum Learning applies this principle to machine learning: instead of sampling batches randomly (), we sample from a distribution that changes over time to present easy examples first, then gradually harder ones.
The Theory of Curriculum
A curriculum is defined by a Difficulty Scorer and a Pacing Function .
1. Difficulty Scoring
How do we measure difficulty? Common heuristics include:
- Sentence Length (NLP): Longer sentences are harder.
- Noise Level: Examples with lower signal-to-noise ratio.
- Teacher Uncertainty: A pre-trained model has high entropy/loss on the example.
- Transfer Scoring: Loss from a model trained on a generic dataset.
2. Pacing Functions
The pacing function determines the fraction of the dataset available at training step .
Linear Pacing:
Root Pacing (More time on hard examples):
Geometric Pacing (More time on easy examples):
Where is the initial data fraction and is the number of steps to reach full dataset.
Implementation: Dynamic Data Sampling
We can implement curriculum learning by dynamically filtering the dataset during training.
import jax
import jax.numpy as jnp
class CurriculumScheduler:
"""
Manages the curriculum pacing and data sampling.
"""
def __init__(self, num_stages=10, growth_steps=10000, function='linear'):
self.num_stages = num_stages
self.growth_steps = growth_steps
self.function = function
def get_pacing_rate(self, step):
"""
Compute available data fraction (lambda) based on step.
"""
# Linear pacing example
if self.function == 'linear':
rate = min(1.0, 0.1 + 0.9 * (step / self.growth_steps))
return rate
elif self.function == 'root':
rate = min(1.0, jnp.sqrt(0.1**2 + (1 - 0.1**2) * (step / self.growth_steps)))
return rate
return 1.0
def sample_batch(self, dataset, step, batch_size, rng_key):
"""
Sample a batch from the 'available' slice of the dataset.
Assumes dataset is pre-sorted by difficulty!
"""
pacing_rate = self.get_pacing_rate(step)
# Determine how many examples are 'unlocked'
num_examples = len(dataset)
max_index = int(pacing_rate * num_examples)
max_index = max(max_index, batch_size) # Ensure minimum data
# Consider only the available slice
available_data = dataset[:max_index]
# Randomly sample indices from this slice
indices = jax.random.choice(
rng_key,
len(available_data),
shape=(batch_size,),
replace=False
)
return available_data[indices]
Advanced: Self-Paced Learning (SPL)
Pre-defining a curriculum is rigid. Self-Paced Learning learns the curriculum jointly with the model parameters .
Objective function:
Where:
- indicates if example is selected.
- is a regularization term (the "age" of the curriculum).
Optimization:
- Fix , min : Standard SGD training on selected examples.
- Fix , min : Closed-form solution:
This means the model trains on all examples with loss smaller than . We gradually increase to include harder (higher loss) examples.
def self_paced_mask(losses, lambda_threshold):
"""
Generate mask for Self-Paced Learning.
Selects examples where loss < lambda.
"""
return losses < lambda_threshold
@jax.jit
def spl_train_step(state, batch, lambda_threshold):
"""
Training step with Self-Paced weighting.
"""
def loss_fn(params):
logits = state.apply_fn({'params': params}, batch['image'])
# Compute individual losses
losses = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']
)
# Compute V matrix (selection)
# Note: We detach selection from gradient!
v_mask = jax.lax.stop_gradient(losses < lambda_threshold)
# Weighted loss (only train on selected examples)
# Avoid division by zero
mean_loss = jnp.sum(losses * v_mask) / (jnp.sum(v_mask) + 1e-6)
return mean_loss, jnp.mean(v_mask) # Return fraction used
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, fraction_used), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state, fraction_used
Mentorship (Teacher-Student)
Another variation involves a "Teacher" model helping the main model:
- Train a large Teacher model on the dataset.
- Teacher scores difficulty of all examples.
- Student trains using curriculum derived from Teacher's scores.
This is robust because the Teacher's "difficulty" acts as a proxy for the Student's expected error.
References
- Curriculum Learning (Bengio et al., 2009)
- Self-Paced Learning for Long-Term Tracking (Supancic et al., 2013)
- Automated Curriculum Learning for Neural Networks (Graves et al., 2017)