Best Practices

Training Workflow

1. Start Small, Scale Up

from tpuz import TPU

# Develop on v4-8 (single host, cheap)
dev = TPU("dev", accelerator="v4-8", preemptible=True)
dev.up()
dev.setup(extra_pip="your-package")
dev.repl()  # Interactive development

# Test a short run
dev.run("python train.py --steps=10", sync="./src")
dev.logs()

# When ready, scale up
dev.scale("v4-32")  # Now 4 workers
dev.run("python train.py --steps=50000", sync="./src")

2. Always Use GCS for Checkpoints

Without GCS, a preemption = lost training. With GCS:

from tpuz import TPU, GCS

gcs = GCS("gs://my-bucket")
tpu = TPU("train", accelerator="v4-8")
tpu.up()
tpu.setup()

# Auto-resumes from latest checkpoint
tpu.run_with_resume("python train.py", gcs=gcs, run_name="run-01", sync="./src")
tpu.watch("python train.py")  # Auto-recover on preemption

3. Use Cloud Secrets, Not env={}

# GOOD: secrets never leave GCP
tpu.run("python train.py", secrets=["WANDB_API_KEY", "HF_TOKEN"])

# OK for quick tests only
tpu.run("python train.py", env={"WANDB_API_KEY": os.environ["WANDB_API_KEY"]})

See secrets.md for full setup guide.

4. Verify Before Training

tpu.setup(extra_pip="flaxchat")
tpu.verify()  # Confirms JAX works on all workers
# worker 0: 4 devices
# worker 1: 4 devices
# ...

5. Use Notifications for Long Runs

tpu.watch_notify("python train.py",
    notify_url="https://hooks.slack.com/services/...",
    max_retries=5)
# You'll get Slack messages on: preemption, recovery, completion, failure

6. Budget with Scheduled Training

tpu.schedule("python train.py",
    start_after="22:00",    # Spot prices are lower at night
    max_cost=50.0,          # Auto-kill at $50
    sync="./src")

7. Collect Artifacts Before Teardown

tpu.collect(["model.pkl", "results.json", "report.md"], local_dir="./outputs")
tpu.cost_summary()  # Know what it cost
tpu.down()

Multi-Host Best Practices

Same Code, Different Scale

JAX SPMD means your code is identical on all workers. The only difference is the mesh size:

# This code works on v4-8 (1 worker) AND v4-32 (4 workers)
mesh = jax.sharding.Mesh(jax.devices(), ('data',))

Debug on Single Host First

Multi-host debugging is painful. Always verify on single host:

tpuz repl my-tpu              # Interactive Python
tpuz debug my-tpu "python train.py"  # VS Code debugger

Monitor All Workers

tpuz health my-tpu   # Dashboard view
tpuz logs-all my-tpu  # Color-coded per-worker logs

If a Worker Dies

tpuz health my-tpu
#   worker 0   running     step 1234
#   worker 1   stopped     (no log)    <-- problem!
#   worker 2   running     step 1234

# Option 1: Restart training (relies on checkpoint)
tpuz kill my-tpu
tpuz run my-tpu "python train.py --resume"

# Option 2: Recreate the whole pod
tpuz scale my-tpu v4-32  # Deletes and recreates

Cost Optimization

Use Preemptible/Spot VMs

Always preemptible=True (default). Spot prices are ~3x cheaper:

Accelerator On-demand Spot
v4-8 ~$6.18/hr $2.06/hr
v5litepod-8 ~$28.80/hr $9.60/hr

Use Queued Resources for Reliability

tpu.up_queued()  # Waits for capacity instead of failing

Check Before Creating

TPU.availability("v4-8", zone="us-central2-b")
# {"available": True, "spot_rate": 2.06}

Multi-Zone Failover

tpu = TPU.create_multi_zone("my-tpu", "v4-8",
    zones=["us-central2-b", "us-central1-a", "europe-west4-a"])

Track Costs

tpu.cost_summary()
# $12.36 (6.0h × $2.06/hr v4-8 spot)

File Organization

Recommended project structure for TPU training:

my-project/
├── src/
│   ├── model.py
│   ├── train.py        # Entry point
│   └── data.py
├── configs/
│   └── v4-8.yaml
├── outputs/            # tpu.collect() downloads here
│   ├── model.pkl
│   └── results.json
└── launch.py           # tpuz orchestration script
# launch.py
from tpuz import TPU, GCS

tpu = TPU("train", "v4-8")
gcs = GCS("gs://my-bucket")

tpu.run_once(
    "python src/train.py --config configs/v4-8.yaml",
    sync=".",
    secrets=["WANDB_API_KEY"],
    collect_files=["outputs/model.pkl", "outputs/results.json"],
    gcs=gcs,
    notify_url=os.environ.get("SLACK_WEBHOOK"),
)