Keras 3.0發佈!TF/PyTorch/Jax無縫混合使用,作者:歡迎來到多框架機器學習

2023-12-1 78 12/1

改變遊戲規則

Keras 3.0正式發佈,被譽爲改變了機器學習遊戲規則:

不僅支持TensorFlow、PyTorch、Jax
三大框架作爲後端
,還能在它們之間
無縫切換
,甚至
混合使用

Keras 3.0发布!TF/PyTorch/Jax无缝混合使用,作者:欢迎来到多框架机器学习

Keras之父
弗朗索瓦·喬萊
認爲,這樣至少可以獲得4大好處:

  • 始終讓模型獲得最佳性能:

JAX通常在GPU、CPU各種PU上都最快,但不使用XLA(加速線性代數)的Tensorflow在GPU上偶爾更快。

Keras 3.0能夠動態爲模型提供最佳性能的後端,而無需更改代碼,保證以最高效率運行。

  • 解鎖多個生態系統

任何Keras 3模型都可以作爲PyTorch模塊實例化,可以導出爲TF的SavedModel,或者可以實例化爲無狀態的 JAX 函數。

這意味着可以將Keras 3模型與PyTorch生態的包,TensorFlow中的部署工具或生產工具,以及JAX大規模TPU訓練基礎設施一起使用,獲得機器學習世界所提供的一切。

  • 在開源社區擴大影響力

如果使用純TensorFlow或PyTorch實現一個開源模型,都只有大約一半的人能使用。

但如果使用Keras 3,任何人無論偏好哪個框架,(即使不是 Keras 用戶)都能立刻使用。在不增加開發成本的情況下,使影響力翻倍。

  • 使用任何來源的數據管道

無論使用哪個後端,Keras 3 都能與tf.data.Dataset對象、PyTorch DataLoader對象、NumPy 數組、Pandas數據框兼容。

這意味着可以在PyTorch DataLoader上訓練Keras 3 + TensorFlow模型,或在 tf.data.Dataset上訓練Keras 3 + PyTorch模型。

Keras 3.0发布!TF/PyTorch/Jax无缝混合使用,作者:欢迎来到多框架机器学习

不少人都對這一進展表示祝賀,項目參與者、谷歌高級工程師
Aakash Kumar Nain
認爲:

Keras 3再次展示了心智模型的重要性。開發API 是一方面,而開發一個擁有出色心智模型的API則完全是另一個層次的工程實踐。

Keras 3.0发布!TF/PyTorch/Jax无缝混合使用,作者:欢迎来到多框架机器学习

也有開發者表示:

很高興能夠通過熟悉的Keras API獲得框架可選性,讓簡單的用例變得容易,複雜的用例也成爲可能。

Keras 3.0发布!TF/PyTorch/Jax无缝混合使用,作者:欢迎来到多框架机器学习

歡迎來到多框架機器學習

Keras 3.0發佈公告中開篇寫到,歡迎來到多框架機器學習。

Keras 3.0发布!TF/PyTorch/Jax无缝混合使用,作者:欢迎来到多框架机器学习

具體來說,
Keras 3.0完全重寫了框架API
,並使其可用於TensorFlow、JAX和PyTorch。

任何僅使用內置層的Keras模型都將立即與所有支持的後端配合使用。

Keras 3.0发布!TF/PyTorch/Jax无缝混合使用,作者:欢迎来到多框架机器学习

使用Keras 3可以
創建在任何框架中都能以相同方式工作的組件
,允許訪問跨所有後端運行的keras.ops命名空間。

只要僅使用keras.ops中的ops,自定義層、損失、指標和優化器等就可以使用相同的代碼與JAX、PyTorch和TensorFlow配合使用。這意味着只需維護一個組件實現,就可以在所有框架中使用完全相同的數值。

Keras 3.0发布!TF/PyTorch/Jax无缝混合使用,作者:欢迎来到多框架机器学习

除此之外,還發布了
用於大規模數據並行和模型並行的新分佈式API
,爲多設備模型分片問題提供Keras風格的解決方案。

爲此設計的API使模型定義、訓練邏輯和分片配置完全獨立,這意味可以像在單個設備上運行一樣編寫代碼,然後在訓練任意模型時將任意分片配置添加到任意模型中。

Keras 3.0发布!TF/PyTorch/Jax无缝混合使用,作者:欢迎来到多框架机器学习

不過新的分佈式API目前僅適用於JAX後端,TensorFlow和PyTorch支持即將推出。

Keras 3.0发布!TF/PyTorch/Jax无缝混合使用,作者:欢迎来到多框架机器学习

爲適配JAX,還發布了
用於層、模型、指標和優化器的新無狀態API
,添加了相關方法。

Keras 3.0发布!TF/PyTorch/Jax无缝混合使用,作者:欢迎来到多框架机器学习

這些方法沒有任何副作用,它們將目標對象的狀態變量的當前值作爲輸入,並返回更新值作爲其輸出的一部分。

用戶不用自己實現這些方法,只要實現了有狀態版本,它們就會自動可用。

如果
從Keras 2遷移到3
,使用tf.keras開發的代碼通常可以按原樣在Keras 3中使用Tensorflow後端運行。有限數量的不兼容之處也給出了遷移指南。

在舊版Keras 2中開發的預訓練模型通常也可以在Keras 3中使用TensorFlow後端開箱即用。

如果舊版模型僅使用了Keras內置層,那麼也可以在Keras 3中使用JAX和PyTorch後端開箱即用。

也有人敲警鐘

在迫不及待嘗試新版本的開發社區氛圍中,Cohere機器學習總監
尼爾斯·賴默斯
提出“真心希望歷史不要重演”,也獲得不少關注。

Keras 3.0发布!TF/PyTorch/Jax无缝混合使用,作者:欢迎来到多框架机器学习

Reimers認爲,Keras最初從支持單個後端(Theano)開始,陸續添加了Tensorflow、MXNet和CNTK等多後端。

這引發了一系列問題:

  • 某些功能只在特定後端可用
  • 各個後端的計算結果存在不一致:在一個後端上運行正常的代碼,在另一個後端可能產生不同結果
  • 對於開源軟件開發者來說體驗糟糕:你剛完成了一個自定義的 Keras層想要分享?你是否願意爲其他後端重新實現和優化它呢?
  • 調試問題:代碼在一個後端上表現完美,但在另一個後端的最新版本上卻頻繁出錯…

隨着時間推移,這些問題愈發嚴重:某些模塊只能在 Theano 上運行良好,某些只適用於Tensorflow,還有一些模塊可以在MXNet上進行推理,但無法訓練…

因此,2019年Keras轉向單一後端(Tensorflow),是保障這一偉大項目繼續存在的關鍵之舉。

我希望這一次的多後端能有更好的表現,但這無疑仍是一個挑戰。

您是否需要等到像FlashAttention v2這樣的重要特性在JAX、TensorFlow和PyTorch 上都可用後,才能在 Keras 中使用它?還是說您只能在某些特定後端中使用它?

對於未來,我們還面臨着許多未解決的挑戰。

參考鏈接:

[1]https://keras.io/keras_3/

[2]https://x.com/sampathweb/status/1729556960314339534

[3]https://twitter.com/Nils_Reimers/status/1729612017340657993