tpuz
Manage GCP TPU & GPU VMs from your terminal. Create, train, debug, recover, teardown — one command.
pip install tpuz
from tpuz import TPU
tpu = TPU("my-tpu", accelerator="v4-8")
tpu.up()
tpu.setup()
tpu.run("python train.py", secrets=["WANDB_API_KEY"], sync="./src")
tpu.logs()
tpu.cost_summary()
tpu.down()
Or one command:
tpuz train my-tpu "python train.py" -a v4-8 --sync=. --recover --teardown
Getting Started
Prerequisites
- gcloud CLI:
brew install google-cloud-sdk(macOS) or cloud.google.com/sdk - Authenticate:
gcloud auth login - Set project:
gcloud config set project YOUR_PROJECT
Step 1: Check Setup
from tpuz import TPU
tpu = TPU("test", accelerator="v4-8")
tpu.preflight()
# gcloud account: you@gmail.com
# gcloud project: my-project
Step 2: Create VM
tpu.up() # Standard creation
tpu.up_queued() # Queued Resources (waits for capacity — better for spot)
Step 3: Install Dependencies
tpu.setup() # JAX[TPU] + common deps
tpu.setup(extra_pip="flaxchat") # + custom packages
tpu.setup(python_version="3.11") # Install Python 3.11 if missing
tpu.verify() # Confirm JAX works on all workers
setup() uses uv (10-50x faster than pip) and automatically waits for SSH readiness on fresh VMs — no need to time.sleep() after up().
Step 4: Upload Code & Run
tpu.run("python train.py",
sync="./src", # Upload local directory
secrets=["WANDB_API_KEY", "HF_TOKEN"], # From Cloud Secret Manager
)
Step 5: Monitor
tpu.logs() # Stream logs (Ctrl-C to detach)
tpu.logs_all(lines=20) # All workers color-coded
tpu.health_check() # Full dashboard
tpu.training_progress() # Parsed step/loss/lr from logs
tpu.cost_summary() # Running cost
Step 6: Collect & Cleanup
tpu.collect(["model.pkl", "results.json"], local_dir="./outputs")
tpu.cost_summary() # $4.12 (2.0h × $2.06/hr v4-8 spot)
tpu.down()
Complete API Reference
Lifecycle
tpu.preflight() # Verify gcloud config
tpu.up() # Create VM (idempotent)
tpu.up_queued(timeout_hours=2) # Queued Resources API
tpu.down() # Delete VM
tpu.down_queued() # Delete QR + VM
tpu.info() # TPUInfo(state, accelerator, ips)
tpu.describe() # Alias for info()
tpu.wait_for_ssh() # Wait until SSH is ready
tpu.setup(extra_pip="pkg") # Install JAX[TPU] + deps
tpu.verify() # Check JAX on all workers
SSH & Files
# SSH — returns stdout string by default
output = tpu.ssh("echo hello")
# Custom timeout (default: ssh_timeout from __init__, default 120s)
output = tpu.ssh("long-command", timeout=600)
# Disable timeout entirely
output = tpu.ssh("very-long-command", timeout=None)
# Set default SSH timeout at init
tpu = TPU("my-tpu", accelerator="v4-8", ssh_timeout=300)
# Structured result with stderr and returncode
result = tpu.ssh("cat /tmp/log", structured=True)
result.stdout # "log content..."
result.stderr # ""
result.returncode # 0
result.ok # True
# Wait for SSH readiness (called automatically by setup())
tpu.wait_for_ssh(timeout=180)
# All workers in parallel (with per-worker retries)
outputs = tpu.ssh_all("hostname", retries=3)
# File transfer
tpu.scp_to("./src", "/home/user/src")
tpu.scp_from("/home/user/model.pkl", "./model.pkl")
# Aliases
tpu.push("./config.json", "/home/user/config.json")
tpu.pull("/home/user/metrics.db", "./metrics.db")
Training
tpu.run("python train.py",
sync="./src", # Upload code
secrets=["WANDB_API_KEY"], # Cloud Secret Manager
env={"BATCH_SIZE": "32"}, # Extra env vars (.env file)
)
tpu.logs() # Stream (Ctrl-C to detach)
tpu.logs_all() # All workers color-coded
tpu.is_running() # Check if alive
tpu.kill() # Stop training
tpu.wait() # Poll for COMPLETE/FAILED sentinel
tpu.collect(["model.pkl"]) # Download artifacts
Git Repo Setup
tpu.clone_repo(
url="https://github.com/org/repo.git",
branch="main",
install=True, # pip install -e . after clone
github_token="ghp_...", # For private repos
)
# Clones/pulls on ALL workers in parallel
GCS Checkpoints
from tpuz import GCS
gcs = GCS("gs://my-bucket")
gcs.upload_checkpoint("./ckpt", "run-01", step=1000)
gcs.download_checkpoint("run-01", step=1000)
gcs.latest_step("run-01") # 5000
gcs.checkpoint_path("run-01") # "gs://my-bucket/checkpoints/run-01"
gcs.list_runs() # ["run-01", "run-02"]
# Auto-resume from latest checkpoint
tpu.run_with_resume("python train.py", gcs=gcs, run_name="run-01")
# Finds step 5000 → appends --resume-from-step=5000
Secrets (Cloud Secret Manager)
Always prefer secrets= over env=. Secrets never leave GCP.
from tpuz import SecretManager
# One-time setup
sm = SecretManager()
sm.create("WANDB_API_KEY", "your-key") # Idempotent: creates or updates
sm.ensure("HF_TOKEN", "hf_...") # Create-or-skip (no update)
sm.grant_tpu_access_all() # Grant VM access via IAM
# Use in training
tpu.run("python train.py", secrets=["WANDB_API_KEY", "HF_TOKEN"])
# Manage
sm.list() # ["WANDB_API_KEY", "HF_TOKEN"]
sm.get("WANDB_API_KEY") # "your-key"
sm.exists("WANDB_API_KEY") # True
sm.delete("OLD_KEY")
Security Comparison
env={} |
secrets=[] |
|
|---|---|---|
| Secret on your machine | Yes | No |
| In transit | SCP (encrypted) | Never |
| On VM disk | .env file | Memory only |
| Survives preemption | Must re-send | Automatic |
Full guide: Secrets & Security
Preemption Recovery
# Basic — auto recreate + restart
tpu.watch("python train.py", max_retries=5)
# With Slack notifications
tpu.watch_notify("python train.py",
notify_url="https://hooks.slack.com/...",
max_retries=5)
Flow: Poll every 60s → PREEMPTED → delete → recreate → setup → restart from checkpoint.
Debugging
tpu.repl() # Interactive Python on worker 0
tpu.debug("python train.py", port=5678) # VS Code debugpy attach
tpu.tunnel(6006) # TensorBoard: localhost:6006
tpu.tunnel(8888, 9999) # Jupyter: localhost:9999
Health Dashboard
tpu.health_check()
# Health Check for 'my-tpu'
# ==================================================
# Process: running
# Heartbeat: fresh (12s ago)
# Disk: 45% (90/200 GB)
# GPU: 85% utilization
# Training: step 1234/5000 | loss 2.31 | 56,000 tok/s
# ETA: ~35m
# Just training metrics
tpu.training_progress()
# {"step": 1234, "total_steps": 5000, "loss": 2.31, "tok_per_sec": 56000}
# Worker-level health
tpu.health_pretty()
# worker 0 running step 1234 | loss 2.31
# worker 1 running step 1234 | loss 2.31
# worker 2 stopped (no log)
Cost Tracking
tpu.cost_summary() # "$4.12 (2.0h × $2.06/hr v4-8 spot)"
# Budget enforcement
tpu.set_budget(max_usd=50,
notify_url="https://hooks.slack.com/...")
# Alerts at $40, kills training at $50
# Check before creating
TPU.availability("v4-8", zone="us-central2-b")
# {"available": True, "spot_rate": 2.06, "on_demand_rate": 6.18}
Scaling
tpu.scale("v4-32") # Delete → recreate with v4-32 → re-setup
# Same code works — JAX handles the mesh
Multi-Zone Failover
tpu = TPU.create_multi_zone("my-tpu", "v4-8",
zones=["us-central2-b", "us-central1-a", "europe-west4-a"])
Scheduled Training
tpu.schedule("python train.py",
start_after="22:00", # Cheaper spot at night
max_cost=10.0, # Kill if exceeds $10
sync="./src")
Run-Once (Docker-like)
tpu.run_once("python train.py",
sync="./src",
collect_files=["model.pkl", "results.json"],
gcs=gcs,
notify_url="https://hooks.slack.com/...")
# up → setup → resume → run → wait → collect → notify → down
One-Liner
from tpuz import Launcher
Launcher("my-tpu", "v4-8").train(
command="python train.py",
sync="./src",
setup_pip="flaxchat",
auto_recover=True,
teardown_after=True,
)
Environment Snapshot
tpu.snapshot_env(gcs=gcs) # pip freeze → GCS
tpu.restore_env(gcs=gcs) # Restore after preemption
Profiles
tpu.save_profile("big-run") # Save config
tpu = TPU.from_profile("big-run", "new-tpu") # Reuse
Dry Run
tpu.dry_run("python train.py", sync="./src", secrets=["KEY"])
# Prints all gcloud commands without executing
Config Upload (Multi-Host)
tpu.upload_config({"model": {...}, "training": {...}})
# JSON → worker 0 → internal network copy to all workers
Worker Count
TPU.num_workers_for("v4-8") # 1
TPU.num_workers_for("v4-32") # 4
TPU.num_workers_for("v6e-64") # 8
GPU VMs
Same API for GCE GPU instances:
from tpuz import GCE
vm = GCE.gpu("my-vm", gpu="a100") # 1x A100 40GB
vm = GCE.gpu("my-vm", gpu="a100x4") # 4x A100
vm = GCE.gpu("my-vm", gpu="h100x8") # 8x H100
vm = GCE.gpu("my-vm", gpu="t4") # 1x T4 (cheapest)
vm.up()
vm.setup() # JAX[CUDA] + deps
vm.run("python train.py", sync="./src")
vm.logs()
vm.stop() # Pause (keeps disk)
vm.start() # Resume
vm.down() # Delete
GPU Shorthands
| Shorthand | GPU | Count | VRAM |
|---|---|---|---|
t4 |
Tesla T4 | 1 | 16 GB |
l4 |
L4 | 1 | 24 GB |
a100 |
A100 40GB | 1 | 40 GB |
a100x4 |
A100 40GB | 4 | 160 GB |
a100x8 |
A100 40GB | 8 | 320 GB |
h100x8 |
H100 80GB | 8 | 640 GB |
Secrets & Security
Three Approaches
1. Cloud Secret Manager (recommended)
from tpuz import SecretManager
sm = SecretManager()
sm.create("WANDB_API_KEY", "key")
sm.grant_tpu_access_all()
tpu.run("cmd", secrets=["WANDB_API_KEY"])
Secrets never leave GCP. VM reads them via its service account.
2. env dict (fallback)
tpu.run("cmd", env={"KEY": "val"})
Written to .env file via SCP. Not in command line or ps aux, but transits your machine.
Setup
# Enable API
gcloud services enable secretmanager.googleapis.com
# Store secret
echo -n "your-key" | gcloud secrets create WANDB_API_KEY --data-file=-
# Grant TPU access
gcloud secrets add-iam-policy-binding WANDB_API_KEY \
--member="serviceAccount:$(gcloud iam service-accounts list --format='value(email)' --filter='Compute Engine')" \
--role="roles/secretmanager.secretAccessor"
Multi-Host (TPU Pods)
Worker count auto-detected. All SSH parallel with retries:
| Accelerator | Chips | Workers | Spot $/hr |
|---|---|---|---|
v4-8 |
4 | 1 | $2.06 |
v4-32 |
16 | 4 | $8.24 |
v5litepod-8 |
8 | 1 | $9.60 |
v5litepod-64 |
64 | 8 | $76.80 |
v6e-8 |
8 | 1 | $9.60 |
CLI Reference
# Lifecycle
tpuz up NAME -a v4-8 tpuz preflight
tpuz down NAME tpuz avail v4-8
tpuz status NAME tpuz runtimes
tpuz list
# Training
tpuz setup NAME --pip=PKG tpuz logs NAME
tpuz verify NAME tpuz logs-all NAME
tpuz run NAME "cmd" --sync=. tpuz kill NAME
tpuz wait NAME tpuz collect NAME files...
# Debugging
tpuz repl NAME tpuz tunnel NAME 6006
tpuz debug NAME "cmd" tpuz scale NAME v4-32
tpuz health NAME tpuz cost NAME
# Recovery & All-in-one
tpuz watch NAME "cmd"
tpuz train NAME "cmd" -a v4-8 --sync=. --recover --teardown
tpuz run-once NAME "cmd" --sync=. --collect model.pkl
Best Practices
- Start small, scale up: Develop on v4-8, scale to v4-32 when ready
- Always use GCS: Preemption = lost training without checkpoint persistence
- Use Cloud Secrets: Never pass API keys via
env={} - Verify before training:
tpu.verify()catches silent multi-host failures - Use notifications:
watch_notify()sends Slack on preemption/completion - Budget your runs:
set_budget(50)prevents surprise bills - Collect before teardown:
tpu.collect()thentpu.down() - Use Queued Resources:
up_queued()is more reliable thanup()for spot
Acknowledgments
Cloud TPU resources provided by Google’s TPU Research Cloud (TRC) program.