Associate ML Infrastructure Engineer

OpenKyber LLC
1 month ago

Role details

Contract type
Permanent contract
Employment type
Full-time (> 32 hours)
Working hours
Regular working hours
Languages
English

Job location

Remote

Tech stack

API
Artificial Intelligence
Amazon Web Services (AWS)
Unit Testing
Azure
Cloud Computing
Profiling
Software Debugging
Distributed Systems
Memory Management
Equinox (OSGi)
Network Topologies
InfiniBand
Machine Learning
Object-Oriented Software Development
Software Engineering
Google Cloud Platform
PyTorch
Large Language Models
Kubernetes
SAP Ariba
Slurm
SQL Server Management Studio (SSMS)
Asynchronous Programming
Functional Programming

Job description

We are seeking a specialized Machine Learning Engineer with deep expertise in the high-performance AI stack. This role isn't just about "translating" code; it s about re-architecting Large Language Models (LLMs) to thrive in a JAX-native environment, specifically targeting TPU and GPU clusters at scale. You will bridge the gap between high-level PyTorch research implementations and thefunctional, XLA-optimized world of JAX/XLA, ensuring that our models achieve maximum throughput and hardware efficiency.

  1. Core Framework Migration Structural Porting:

Manually migrate complex PyTorch LLM architectures (Transformers, MoE, SSMs) into JAX-based frameworks (Equinox, Flax, or Pax). State Management: Transition imperative PyTorch state management to JAX s purely functional paradigm, handling PRNGKey management and immutable state updates with precision. Weight Translation: Develop robust pipelines for checkpoint conversion, ensuring numerical parity between frameworks via rigorous unit testing and error tolerance checks.

  1. Advanced Profiling & Numerical Stability Bottleneck Analysis:

Use the NVIDIA Nsight and TensorBoard Profiler to identify XLA compilation overheads, excessive rematerialization, or un-fused kernels. Numerical Debugging: Implement precision-tracking tools to ensure that $BF16$ or $FP8$ training runs remain stable during the transition, preventing gradient divergence.

  1. Scaling & Distributed Training Parallelism Strategies:

Implement and optimize Fully Sharded Data Parallelism (FSDP) equivalents in JAX (using pjit or sharding APIs). Hybrid Parallelism: Design 3D parallelism strategies (Data, Pipeline, and Tensor) tailored for the interconnect topology (e.g., NVLink or TPU IC) of the target hardware.

  1. Hardware-Aware Optimization XLA Mastery:

Understand and influence the XLA (Accelerated Linear Algebra) compiler behavior. You will optimize HLO (High-Level Optimizer) graphs to minimize "jit-time" and maximize "run-time" efficiency. Memory Management: Apply optimizations like Selective Activation Checkpointing and memory-efficient attention (FlashAttention-2 JAX implementations) based on the specific HBM (High Bandwidth Memory) constraints of the hardware.

  1. Fine-Tuning & Adaptation Efficient Fine-Tuning

Requirements

Do you have experience in Unit testing?, Port PyTorch-based PEFT (LoRA, DoRA) methods into the JAX stack. Architectural Evolution: Stay ahead of the curve by adapting JAX implementations for newer primitives like Mamba/SSMs, Grouped-Query Attention (GQA), and Linear Attention as they emerge in the research space. Familiarity with the following technical Stack & Tooling:

  1. Core Frameworks & Libraries: JAX Ecosystem: Expertise in Flax or Equinox (for model definition), Optax (for optimization/schedules), and Orbax (for checkpointing). PyTorch Ecosystem: Deep knowledge of PyTorch 2.x, including torch.compile, DistributedDataParallel (DDP), and FSDP. Intermediate Representations: Proficiency in HLO (High-Level Optimizer) and MLIR to understand how JAX code translates to hardware instructions. Data Loaders: Experience migrating from torch.utils.data to Grain or tf.data for high-throughput JAX pipelines.

  2. Profiling & Observability: device memory traffic. JAX Profiler / TensorBoard: For identifying XLA compilation bottlenecks and tracing NVIDIA Nsight Systems: To analyze GPU utilization, SM occupancy, and NVLink Perfetto: For deep-dive trace analysis across multi-node TPU/GPU clusters.

  3. Infrastructure & Hardware Accelerator Hardware: Strong understanding of NVIDIA H100/A100 (Hopper/Ampere) architecture and Google TPU (v4/v5p) topology. Orchestration: Experience with Slurm or Kubernetes (K8s) for managing large-scale training jobs. Cloud Providers: Proficiency in Google Cloud (Google Cloud Platform) for TPUs or AWS/Azure for high-end GPU instances. Core Skills & Competencies:

  4. Software Engineering Excellence Functional Programming: A shift in mindset from OOP (Object-Oriented) to pure functions, immutability, and stateless logic. Asynchronous Programming: Understanding JAX s asynchronous dispatch model and how to avoid "host-sync" bottlenecks. Testing Rigor: Ability to write property-based tests for numerical stability.

  5. Distributed Systems Knowledge Collective Communications: Deep understanding of All-Reduce, All-Gather, and Reduce-Scatter primitives. Network Topology: Understanding how rack-level interconnects (e.g., InfiniBand) affect the choice of 3D parallelism strategies.

  6. Mathematical & AI Domain Expertise (Desirable) Linear Algebra: Mastery of tensor contractions, Einstein summation (einsum), and matrix decomposition. Mixed Precision Training: Expert-level knowledge of Stochastic Rounding, Loss Scaling, and the nuances of BF16 vs. FP8 training. Architecture Insight: Ability to decompose modern LLM components (KV Caches, Rotary Embeddings, Gated Linear Units) into their primitive mathematical operations.

Apply for this position