Associate ML Infrastructure Engineer
Role details
Job location
Tech stack
Job description
We are seeking a specialized Machine Learning Engineer with deep expertise in the high-performance AI stack. This role isn't just about "translating" code; it s about re-architecting Large Language Models (LLMs) to thrive in a JAX-native environment, specifically targeting TPU and GPU clusters at scale. You will bridge the gap between high-level PyTorch research implementations and thefunctional, XLA-optimized world of JAX/XLA, ensuring that our models achieve maximum throughput and hardware efficiency.
- Core Framework Migration Structural Porting:
Manually migrate complex PyTorch LLM architectures (Transformers, MoE, SSMs) into JAX-based frameworks (Equinox, Flax, or Pax). State Management: Transition imperative PyTorch state management to JAX s purely functional paradigm, handling PRNGKey management and immutable state updates with precision. Weight Translation: Develop robust pipelines for checkpoint conversion, ensuring numerical parity between frameworks via rigorous unit testing and error tolerance checks.
- Advanced Profiling & Numerical Stability Bottleneck Analysis:
Use the NVIDIA Nsight and TensorBoard Profiler to identify XLA compilation overheads, excessive rematerialization, or un-fused kernels. Numerical Debugging: Implement precision-tracking tools to ensure that $BF16$ or $FP8$ training runs remain stable during the transition, preventing gradient divergence.
- Scaling & Distributed Training Parallelism Strategies:
Implement and optimize Fully Sharded Data Parallelism (FSDP) equivalents in JAX (using pjit or sharding APIs). Hybrid Parallelism: Design 3D parallelism strategies (Data, Pipeline, and Tensor) tailored for the interconnect topology (e.g., NVLink or TPU IC) of the target hardware.
- Hardware-Aware Optimization XLA Mastery:
Understand and influence the XLA (Accelerated Linear Algebra) compiler behavior. You will optimize HLO (High-Level Optimizer) graphs to minimize "jit-time" and maximize "run-time" efficiency. Memory Management: Apply optimizations like Selective Activation Checkpointing and memory-efficient attention (FlashAttention-2 JAX implementations) based on the specific HBM (High Bandwidth Memory) constraints of the hardware.
- Fine-Tuning & Adaptation Efficient Fine-Tuning
Requirements
Do you have experience in Unit testing?, Port PyTorch-based PEFT (LoRA, DoRA) methods into the JAX stack. Architectural Evolution: Stay ahead of the curve by adapting JAX implementations for newer primitives like Mamba/SSMs, Grouped-Query Attention (GQA), and Linear Attention as they emerge in the research space. Familiarity with the following technical Stack & Tooling:
-
Core Frameworks & Libraries: JAX Ecosystem: Expertise in Flax or Equinox (for model definition), Optax (for optimization/schedules), and Orbax (for checkpointing). PyTorch Ecosystem: Deep knowledge of PyTorch 2.x, including torch.compile, DistributedDataParallel (DDP), and FSDP. Intermediate Representations: Proficiency in HLO (High-Level Optimizer) and MLIR to understand how JAX code translates to hardware instructions. Data Loaders: Experience migrating from torch.utils.data to Grain or tf.data for high-throughput JAX pipelines.
-
Profiling & Observability: device memory traffic. JAX Profiler / TensorBoard: For identifying XLA compilation bottlenecks and tracing NVIDIA Nsight Systems: To analyze GPU utilization, SM occupancy, and NVLink Perfetto: For deep-dive trace analysis across multi-node TPU/GPU clusters.
-
Infrastructure & Hardware Accelerator Hardware: Strong understanding of NVIDIA H100/A100 (Hopper/Ampere) architecture and Google TPU (v4/v5p) topology. Orchestration: Experience with Slurm or Kubernetes (K8s) for managing large-scale training jobs. Cloud Providers: Proficiency in Google Cloud (Google Cloud Platform) for TPUs or AWS/Azure for high-end GPU instances. Core Skills & Competencies:
-
Software Engineering Excellence Functional Programming: A shift in mindset from OOP (Object-Oriented) to pure functions, immutability, and stateless logic. Asynchronous Programming: Understanding JAX s asynchronous dispatch model and how to avoid "host-sync" bottlenecks. Testing Rigor: Ability to write property-based tests for numerical stability.
-
Distributed Systems Knowledge Collective Communications: Deep understanding of All-Reduce, All-Gather, and Reduce-Scatter primitives. Network Topology: Understanding how rack-level interconnects (e.g., InfiniBand) affect the choice of 3D parallelism strategies.
-
Mathematical & AI Domain Expertise (Desirable) Linear Algebra: Mastery of tensor contractions, Einstein summation (einsum), and matrix decomposition. Mixed Precision Training: Expert-level knowledge of Stochastic Rounding, Loss Scaling, and the nuances of BF16 vs. FP8 training. Architecture Insight: Ability to decompose modern LLM components (KV Caches, Rotary Embeddings, Gated Linear Units) into their primitive mathematical operations.