Skip to content

Add GPU Acceleration for tabular model using pytorch MPS (Metal Performance Shaders) for macOS #62

@radurogojanumai

Description

@radurogojanumai

Description

Currently, mostlyai-engine does not leverage GPU acceleration on macOS. While CUDA-based acceleration is available for Linux, macOS users are limited to CPU-bound operations. To improve performance for macOS users, we should add support for PyTorch’s Metal Performance Shaders (MPS) backend, which enables GPU acceleration on Apple Silicon (eg. M1/M2/M3 or future Macs).

Proposed Solution

  1. Enable PyTorch MPS backend detection
  • Check if PyTorch is installed with MPS support (torch.backends.mps.is_available()).

  • If MPS is available, ensure models and tensors are correctly moved to the MPS device

  1. Update Installation & Dependencies
  • Ensure torch>=1.13 is installed, as MPS support is available from this version onwards.

  • Add documentation for macOS users on installing PyTorch with MPS support.

  1. Modify Training & Inference Pipelines
  • Adapt existing PyTorch calls to dynamically select the best available backend (mps, cuda, or cpu).

  • Ensure compatibility with QLoRA and bitsandbytes (fallback to CPU if MPS does not support certain operations).

  1. Performance Benchmarking & Validation
  • Compare training/inference speeds using MPS vs. CPU.

  • Identify any limitations or unsupported operations within MPS that may require fallbacks.

Questions

  1. Should we introduce an extra[mps] option for macOS users to explicitly enable MPS-related dependencies?
    Answer: We would want to keep a simple set of extras (eg. [gpu] for both Linux + CUDA and Darwin + MPS)
  2. How well does bitsandbytes integrate with Darwin?
    Answer: We'll ensure the required version has the necessary wheels.

Acceptance Criteria

  • Mac users can utilize MPS acceleration via PyTorch without modifying code manually.
  • Performance improvements over CPU-only execution are verified.
  • No breaking changes for Linux users.

Metadata

Metadata

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions