rocm/jax-community本镜像基于 JAX 官方发布版本构建,集成 AMD ROCm (Radeon Open Compute Platform) 开源 GPU 计算栈,提供开箱即用的 JAX 运行环境,专为在 AMD GPU 硬件上高效运行 JAX 应用程序设计。
核心用途:支持基于 JAX 的机器学习研究、深度学习训练/推理及高性能计算任务,充分利用 AMD GPU 的计算能力加速数值计算和并行处理。
docker run)bashdocker run -it --rm \ --device=/dev/kfd \ --device=/dev/dri \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ rocm-jax:latest \ bash
bashdocker run -it --rm \ --device=/dev/kfd \ --device=/dev/dri \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ -v /本地路径:/workspace \ # 挂载宿主机目录到容器内/workspace -w /workspace \ # 设置工作目录 rocm-jax:latest \ python your_jax_script.py
创建 docker-compose.yml 文件:
yamlversion: '3.8' services: jax-rocm: image: rocm-jax:latest container_name: jax-rocm-container devices: - /dev/kfd:/dev/kfd - /dev/dri:/dev/dri group_add: - video cap_add: - SYS_PTRACE security_opt: - seccomp:unconfined volumes: - ./local_data:/workspace # 挂载本地数据目录 working_dir: /workspace environment: - JAX_PLATFORMS=rocm # 强制使用ROCm后端 - TF_CPP_MIN_LOG_LEVEL=2 # 抑制TensorFlow日志(JAX依赖) command: python train.py # 启动命令(示例:运行训练脚本)
启动服务:
bashdocker-compose up -d
| 环境变量名 | 说明 | 默认值 |
|---|---|---|
JAX_PLATFORMS | 指定 JAX 计算平台,可选值:rocm(GPU)、cpu(CPU)、auto(自动检测) | auto |
JAX_ENABLE_X64 | 是否启用64位浮点数支持(True/False) | False |
TF_CPP_MIN_LOG_LEVEL | TensorFlow 日志级别(0=全部,1=警告,2=错误,3=无输出) | 1 |
HIP_VISIBLE_DEVICES | 指定可见 GPU 设备(如 0,1 表示使用第1、2块GPU) | 全部设备 |
bashdocker pull rocm-jax:latest # 或指定版本标签,如 rocm-jax:jax-v0.4.23-rocm5.7
bash# 示例:运行JAX设备检测脚本 docker run -it --rm \ --device=/dev/kfd --device=/dev/dri \ --group-add video \ rocm-jax:latest \ python -c "import jax; print('GPU设备数量:', jax.device_count()); print('默认设备:', jax.devices()[0])"
若输出类似以下内容,表明 ROCm 后端正常加载:
GPU设备数量: 1 默认设备: rocm:0 (AMD Radeon RX 7900 XTX)
--shm-size=16g 调整共享内存大小,避免 OOM 错误video 和 render 组,以访问 GPU 设备文件rocm-jax:jax-<jax版本>-rocm<rocm版本>(如 jax-v0.4.23-rocm5.7)manifest unknown 错误
TLS 证书验证失败
DNS 解析超时
410 错误:版本过低
402 错误:流量耗尽
身份认证失败错误
429 限流错误
凭证保存错误
来自真实用户的反馈,见证轩辕镜像的优质服务