JAXは、Googleが公開する数値計算ライブラリです。NumPyに近い配列APIに自動微分とJITコンパイルを載せ、研究コードをGPU・TPU上で速く・微分可能に実行する——CUDAやPyTorchのような「学習フレームワーク全体」ではなく、配列プログラムの実行基盤——本記事はAPIの暗記より、スタックのどこに位置するかに焦点を当てます。
JAXの三つの力
| 機能 | 何ができるか | 試験向けの一言 |
|---|---|---|
| 配列API | NumPy風にテンソル演算を記述 | 科学計算・DLの土台 |
| 自動微分 | grad等で関数の勾配を自動計算 | 誤差逆伝播の実装基盤 |
| JIT | 関数をコンパイルして高速実行 | GPU・TPUでの推論・学習加速 |
試験では「JAX=モデル名」ではなく、数値計算+微分+高速化のライブラリ——とラベル付けするのが安全です。
ソフトウェアスタックでの位置
CUDAの記事と同様、層の分離が得点の近道です。
| 層 | 例 | 役割 |
|---|---|---|
| ハードウェア | GPU、TPU | 物理チップ(G-202) |
| 実行基盤 | CUDA、XLA | 並列演算の実行環境 |
| 数値計算ライブラリ | JAX | 配列・微分・JIT |
| 高水準フレームワーク | PyTorch、TensorFlow、Flax | モデル定義・学習ループ |
JAXは真ん中の「微分可能な配列プログラム」層。その上にFlaxなどを載せてTransformerを組み立てる、という使い方が研究で多いです。
自動微分と誤差逆伝播
深層学習の学習は、損失をパラメータで微分した勾配で重みを更新します(G-171、TF-072)。
JAXの自動微分は、Pythonで書いた関数を解析し、連鎖律に沿った勾配計算を自動生成します——手で逆伝播式を書く代わりに、grad(loss_fn)のように宣言する発想です(TF-391:誤差逆伝播は人手修正ではない)。
勾配ブースティング(G-063)の「勾配」とは別物——JAXの文脈ではニューラルネットのパラメータ微分を指します。
JITで速くする理由
JIT(Just-In-Time compilation)は、Python関数を一度コンパイルし、以降の呼び出しを最適化された機械コードとしてGPU・TPU上で走らせます。
- 素のPython — ループが遅く、デバイス転送が多い
- JIT後 — 演算が融合・並列化され、学習・推論が速くなる
- 効果の限界 — 計算を速くするだけで、データ品質や精度は保証しない(G-202)
モデル軽量化(量子化・蒸留など)とは別の実行効率化——G-406の「推論の高速化」は複数の手段の組み合わせであり、JAXはそのうち実装・実行基盤の選択肢の一つです。
試験で押さえるポイント
- 定義 — 自動微分とJITを備えた数値計算ライブラリ
- 開発元 — Google(研究・TPU文脈で出やすい)
- 対比 — PyTorch/TF=フレームワーク、CUDA=GPU基盤、JAX=配列+微分層
- すり替え回避 — モデル名・GPUチップ・勾配ブースティングではない
すり替えに注意
| 誤った説明 | 正しい理解 |
|---|---|
| JAX=PyTorch | 数値計算ライブラリ vs 学習フレームワーク |
| JAX=CUDA | 配列・微分のAPI vs GPU実行基盤 |
| JAX=GPU | ソフトウェア vs ハードウェア |
| JAX=LLM | 実行基盤 vs 言語モデル |
| JAX=勾配ブースティング | 自動微分の勾配 vs アンサンブル手法(G-063) |
よくある質問
JAXは何をするライブラリですか?
配列演算を記述するPythonライブラリで、関数に自動微分(grad等)を適用して勾配を求めたり、JITコンパイルでGPU・TPU上の実行を高速化したりできます。深層学習の研究や科学計算で、微分可能なプログラムを効率よく書くための基盤として使われます。
JAXとPyTorchは同じですか?
同じではありません。PyTorchはモデル定義・学習ループまで含む深層学習フレームワークです。JAXは数値計算と自動微分・JITに特化したライブラリで、上にFlaxやHaikuなどのニューラルネット用ラッパーを載せて使うことが多いです。層が異なります。
JAXとCUDAは同じですか?
同じではありません。CUDAはNVIDIA GPU上で並列計算を動かすためのソフトウェア基盤です。JAXはPythonから配列演算と自動微分を書くためのライブラリであり、GPU・TPUでの実行はXLAなどのコンパイル基盤を通じて行われます。ハード基盤と数値計算ライブラリは別物です。