JAX:有望取代Tensorflow,谷歌出品的又一超高性能机器学习框架

2023-06-01 0 604

序言

在机器学习架构各方面,JAX是两个信息时代——虽然Tensorflow的竞争者换句话讲早已在2018年后早已很完整,但直至前段时间JAX才已经开始在更广为的机器学习科学研究街道社区中赢得诱惑力。

JAX究竟是甚么?依照JAX非官方如是说:

JAX是NumPy在CPU、GPU和TPU上的版,具备高效能机器学习科学研究的强悍手动二阶(automatic differentiation)潜能。

接下去,他们会具体内容重新认识JAX。

JAX:有望取代Tensorflow,谷歌出品的又一超高性能机器学习框架

此基础如是说

就像下面说的,JAX是高能全力支持的numpy和绝大部分scipy机能,暗含许多通用型机器学习操作方式的便捷表达式。

他们举个范例

import jax import jax.numpy as np def gpu_backed_hidden_layer(x): return jax.nn.relu(np.dot(W, x) + b)

您能获得numpy精心安排的API,它从2006年就已经开始采用了,具备Tensorflow和PyTorch等当代ML辅助工具的操控性特点。

JAX还包括透过jax.scipy来全力支持相当大一小部分scipy工程项目:

from jax.scipy.linalg importsvd singular_vectors, singular_values = svd(x)

虽然有高能全力支持的numpy + scipy版早已十分管用,但JAX除了许多其它的妙招。具体来说让他们看一看JAX对手动二阶的广为全力支持。

自动二阶·Autograd

Autograd是两个用于在numpy和原生python代码上高效计算梯度的库。Autograd恰好也是JAX的前身。虽然最初的autograd存储库不再被积极开发,但是在autograd上工作的绝大部分核心团队早已已经开始全职从事JAX工程项目。

就像autograd, JAX允许对两个python表达式的输出求导,只需调用grad:

from jax import grad def hidden_layer(x): return jax.nn.relu(np.dot(W, x) + b) grad_hidden_layer = grad(hidden_layer)

您还能透过本机的python控制结构进行区分——而不需要采用tf.cond:

def absolute_value(x) if x >= 0: return x else: return -x grad_absolute_value = grad(absolute_value) from jax.nn import tanh # grads all the way down print(grad(grad(grad(tanh)))(1.0))

默认情况下,grad 为您提供了逆向模式梯度——这是计算梯度最常用的模式,它依赖于缓存激活来提高向后传递的效率。反模式差分是计算参数更新最有效的方法。但是,特别是在实现依赖于高阶派生的优化方法时,它并不总是最佳选择。JAX透过jacfwd和jacrev为逆向模式手动差分和正向模式手动差分提供了一流的全力支持:

from jax importjacfwd, jacrev hessian_fn = jacfwd(jacrev(fn))

除了grad、jacfwd和jacrev之外,JAX还提供了许多实用程序,用于计算表达式的线性逼近、定义自定义梯度操作方式,和作为其手动二阶全力支持的一小部分。

加速线性代数·XLA

XLA (Accelerated Linear Algebra)是两个特定域的线性代数代码编译器,它是JAX将python和numpy表达式转换成高能全力支持的操作方式的此基础。

除了允许JAX将python + numpy代码转换为能在高能上运行的操作方式之外(就像他们在第两个示例中看到的那样),XLA全力支持还允许JAX将多个操作方式融合到两个内核中。它在计算图中寻找节点簇,这些节点簇能被重写以减少计算或中间变量的存储。Tensorflow关于XLA的文档采用以下示例来解释问题能从XLA编译中受益的实例类型。

def unoptimized_fn(x, y, z): return np.sum(x + y * z)

在没有XLA的情况下运行,这将作为3个独立的内核运行——两个乘法、两个加法和两个加法减法。采用XLA运行时,这变成了两个负责所有这三个各方面的内核,不需要存储中间变量,从而节省了时间和内存。

向量化和并行性

虽然Autograd和XLA构成了JAX库的核心,但是除了两个JAX表达式脱颖而出。你能采用jax.vmap和jax.pmap用于向量化和基于spmd(单程序多数据)并行的pmap。

为了说明vmap的优点,他们将返回到他们的简单稠密层的示例,它操作方式两个由向量x表示的示例。

# convention to distinguish between # jax.numpy and numpy import numpy as onp def hidden_layer(x): returnjax.nn.relu(np.dot(W, x + b) print(hidden_layer(np.random.randn(128)).shape) # (128,)

您能采用任何接受单个输入的表达式,并允许它采用JAX .vmap接受一批输入:

batch_hidden_layer = vmap(hidden_layer) print(batch_hidden_layer(onp.random.randn(32, 128)).shape)# (32, 128)

它的美妙之处在于,它意味着你或多或少地忽略了模型表达式中的批处理维数,并且在你构造模型的时候,在你的头脑中少了两个张量维数。

如果您有几个输入都应该向量化,或者您想沿着轴向量化而不是沿着轴0,您能采用in_axes参数来指定。

batch_hidden_layer= vmap(hidden_layer, in_axes=(0,))

JAX用于SPMD paralellism的实用程序,遵循十分类似的API。如果你有一台4-gpu机器和4个范例,你能采用pmap在每个设备上运行两个范例。

# first dimension must align with number of XLA-enabled devices spmd_hidden_layer = pmap(hidden_layer)

和往常一样,你能随心所欲地编写表达式:

# hypothetical setup for high-throughput inference outputs = pmap(vmap(hidden_layer))(onp.random.randn(4, 32, 128)) print(outputs.shape) # (4, 32, 128)

为甚么是JAX?

JAX不是因为它都比现有的机器学习架构更加干净,或者因为它是比Tensorflow PyTorch更好地设计的东西,而是因为它能让他们更容易尝试更多的想法和探索更广为的空间。

相关文章

发表评论
暂无评论
官方客服团队

为您解决烦忧 - 24小时在线 专业服务