O que é Google JAX? Tudo o que você precisa saber

Google JAX ou J apenas A depois de Ex ecution é uma estrutura desenvolvida pelo Google para acelerar as tarefas de aprendizado de máquina.

Você pode considerá-la uma biblioteca para Python, que ajuda na execução mais rápida de tarefas, computação científica, transformações de função, aprendizado profundo, redes neurais e muito mais.

Sobre o Google JAX

O pacote de computação mais fundamental em Python é o pacote NumPy, que possui todas as funções como agregações, operações vetoriais, álgebra linear, matriz n-dimensional e manipulações de matrizes e muitas outras funções avançadas.

E se pudéssemos acelerar ainda mais os cálculos realizados usando o NumPy – particularmente para grandes conjuntos de dados?

Temos algo que poderia funcionar igualmente bem em diferentes tipos de processadores, como GPU ou TPU, sem nenhuma alteração de código?

Que tal se o sistema pudesse executar transformações de funções combináveis ​​de forma automática e mais eficiente?

O Google JAX é uma biblioteca (ou estrutura, como diz a Wikipedia) que faz exatamente isso e talvez muito mais. Ele foi desenvolvido para otimizar o desempenho e executar com eficiência tarefas de aprendizado de máquina (ML) e aprendizado profundo. O Google JAX fornece os seguintes recursos de transformação que o tornam exclusivo de outras bibliotecas de ML e ajudam na computação científica avançada para aprendizado profundo e redes neurais:

  • Diferenciação automática
  • Vetorização automática
  • Paralelização automática
  • Compilação Just-in-time (JIT)
O que e Google JAX Tudo o que voce precisaRecursos exclusivos do Google JAX

Todas as transformações usam XLA (Accelerated Linear Algebra) para maior desempenho e otimização de memória. O XLA é um mecanismo de compilador de otimização específico de domínio que executa álgebra linear e acelera modelos do TensorFlow. Usar o XLA em cima do seu código Python não requer alterações significativas no código!

Vamos explorar em detalhes cada um desses recursos.

Recursos do Google JAX

O Google JAX vem com importantes funções de transformação combináveis ​​para melhorar o desempenho e realizar tarefas de aprendizado profundo com mais eficiência. Por exemplo, diferenciação automática para obter o gradiente de uma função e encontrar derivadas de qualquer ordem. Da mesma forma, paralelização automática e JIT para executar várias tarefas paralelamente. Essas transformações são essenciais para aplicações como robótica, jogos e até pesquisas.

A função de transformação combinável é um puro função que transforma um conjunto de dados em outra forma. Elas são chamadas de combináveis ​​porque são independentes (ou seja, essas funções não têm dependências com o restante do programa) e não têm estado (ou seja, a mesma entrada sempre resultará na mesma saída).

Y(x) = T: (f(x))

Na equação acima, f(x) é a função original na qual uma transformação é aplicada. Y(x) é a função resultante após a aplicação da transformação.

Por exemplo, se você tem uma função chamada ‘total_bill_amt’ e deseja o resultado como uma transformação de função, pode simplesmente usar a transformação que deseja, digamos gradiente (grad):

grad_total_bill = grad(total_bill_amt)

Ao transformar funções numéricas usando funções como grad(), podemos obter facilmente suas derivadas de ordem superior, que podemos usar extensivamente em algoritmos de otimização de aprendizado profundo, como gradiente descendente, tornando os algoritmos mais rápidos e eficientes. Da mesma forma, usando jit(), podemos compilar programas Python just-in-time (preguiçosamente).

#1. Diferenciação automática

O Python usa a função autograd para diferenciar automaticamente o código NumPy e o código Python nativo. JAX usa uma versão modificada de autograd (ou seja, grad) e combina XLA (Accelerated Linear Algebra) para realizar diferenciação automática e encontrar derivados de qualquer ordem para GPU (Graphic Processing Units) e TPU (Tensor Processing Units).)

Nota rápida sobre TPU, GPU e CPU: CPU ou Unidade Central de Processamento gerencia todas as operações no computador. A GPU é um processador adicional que aumenta o poder de computação e executa operações de ponta. A TPU é uma unidade poderosa desenvolvida especificamente para cargas de trabalho complexas e pesadas, como IA e algoritmos de aprendizado profundo.

Na mesma linha da função autograd, que pode diferenciar por meio de loops, recursões, ramificações e assim por diante, JAX usa a função grad() para gradientes de modo reverso (backpropagation). Além disso, podemos diferenciar uma função em qualquer ordem usando grad:

grad(grad(grad(sin θ))) (1.0)

Diferenciação automática de ordem superior

Como mencionamos anteriormente, grad é bastante útil para encontrar as derivadas parciais de uma função. Podemos usar uma derivada parcial para calcular a descida do gradiente de uma função de custo em relação aos parâmetros da rede neural no aprendizado profundo para minimizar as perdas.

Calculando a derivada parcial

Suponha que uma função tenha múltiplas variáveis, x, y e z. Encontrar a derivada de uma variável mantendo as outras variáveis ​​constantes é chamado de derivada parcial. Suponha que temos uma função,

f(x,y,z) = x + 2y + z2

Exemplo para mostrar a derivada parcial

A derivada parcial de x será ∂f/∂x, que nos diz como uma função muda para uma variável quando outras são constantes. Se fizermos isso manualmente, devemos escrever um programa para diferenciar, aplicá-lo para cada variável e, em seguida, calcular a descida do gradiente. Isso se tornaria um assunto complexo e demorado para múltiplas variáveis.

A diferenciação automática divide a função em um conjunto de operações elementares, como +, -, *, / ou sin, cos, tan, exp, etc., e então aplica a regra da cadeia para calcular a derivada. Podemos fazer isso no modo direto e reverso.

Isso é não isto! Todos esses cálculos acontecem muito rápido (bem, pense em um milhão de cálculos semelhantes aos anteriores e no tempo que pode levar!). XLA cuida da velocidade e desempenho.

#2. Álgebra Linear Acelerada

Vamos pegar a equação anterior. Sem o XLA, a computação levará três (ou mais) kernels, onde cada kernel executará uma tarefa menor. Por exemplo,

Kernel k1 –> x * 2y (multiplicação)

k2 –> x * 2y + z (adição)

k3 -> Redução

Se a mesma tarefa for executada pelo XLA, um único kernel cuidará de todas as operações intermediárias, fundindo-as. Os resultados intermediários das operações elementares são transmitidos em vez de armazenados na memória, economizando memória e aumentando a velocidade.

#3. Compilação Just-in-time

JAX usa internamente o compilador XLA para aumentar a velocidade de execução. O XLA pode aumentar a velocidade da CPU, GPU e TPU. Tudo isso é possível usando a execução do código JIT. Para usar isso, podemos usar o jit via import:

  from jax import jit
def my_function(x):
	…………some lines of code
my_function_jit = jit(my_function)

Outra maneira é decorar o jit sobre a definição da função:

  @jit
def my_function(x):
	…………some lines of code

Esse código é muito mais rápido porque a transformação retornará a versão compilada do código para o chamador em vez de usar o interpretador Python. Isso é particularmente útil para entradas de vetores, como arrays e matrizes.

O mesmo vale para todas as funções existentes do python. Por exemplo, funções do pacote NumPy. Nesse caso, devemos importar jax.numpy como jnp em vez de NumPy:

  import jax
import jax.numpy as jnp

x = jnp.array(((1,2,3,4), (5,6,7,8)))

Depois de fazer isso, o objeto principal da matriz JAX chamado DeviceArray substitui a matriz NumPy padrão. DeviceArray é preguiçoso – os valores são mantidos no acelerador até serem necessários. Isso também significa que o programa JAX não espera que os resultados retornem ao programa chamador (Python), seguindo assim um despacho assíncrono.

#4. Vetorização automática (vmap)

Em um mundo típico de aprendizado de máquina, temos conjuntos de dados com um milhão ou mais pontos de dados. Muito provavelmente, realizaríamos alguns cálculos ou manipulações em cada um ou na maioria desses pontos de dados – o que é uma tarefa que consome muito tempo e memória! Por exemplo, se você quiser encontrar o quadrado de cada um dos pontos de dados no conjunto de dados, a primeira coisa em que pensará é criar um loop e obter o quadrado um por um – argh!

Se criarmos esses pontos como vetores, poderíamos fazer todos os quadrados de uma só vez realizando manipulações de vetores ou matrizes nos pontos de dados com nosso NumPy favorito. E se o seu programa pudesse fazer isso automaticamente – você pode pedir mais alguma coisa? Isso é exatamente o que o JAX faz! Ele pode vetorizar automaticamente todos os seus pontos de dados para que você possa executar facilmente qualquer operação neles – tornando seus algoritmos muito mais rápidos e eficientes.

JAX usa a função vmap para autovetorização. Considere a seguinte matriz:

  x = jnp.array((1,2,3,4,5,6,7,8,9,10))
y = jnp.square(x)

Fazendo exatamente o que foi dito acima, o método square será executado para cada ponto no array. Mas se você fizer o seguinte:

  vmap(jnp.square(x))

O método quadrado será executado apenas uma vez porque os pontos de dados agora são vetorizados automaticamente usando o método vmap antes de executar a função, e o loop é empurrado para o nível elementar de operação – resultando em uma multiplicação de matriz em vez de multiplicação escalar, proporcionando melhor desempenho.

#5. Programação SPMD (pmap)

SPMD – ou S inglês P programa M múltiplo D a programação ata é essencial em contextos de aprendizado profundo – você frequentemente aplica as mesmas funções em diferentes conjuntos de dados que residem em várias GPUs ou TPUs. O JAX tem uma função chamada pump, que permite a programação paralela em várias GPUs ou qualquer acelerador. Como o JIT, os programas que usam pmap serão compilados pelo XLA e executados simultaneamente nos sistemas. Essa paralelização automática funciona para cálculos diretos e reversos.

1676153787 77 O que e Google JAX Tudo o que voce precisaComo funciona o pmap

Também podemos aplicar múltiplas transformações de uma só vez em qualquer ordem em qualquer função como:

pmap(vmap(jit(grad (f(x)))))

Múltiplas transformações combináveis

Limitações do Google JAX

Os desenvolvedores do Google JAX pensaram bem em acelerar aprendizagem profunda algoritmos ao introduzir todas essas transformações impressionantes. As funções e pacotes de computação científica estão nas linhas do NumPy, então você não precisa se preocupar com a curva de aprendizado. No entanto, JAX tem as seguintes limitações:

  • O Google JAX ainda está nos estágios iniciais de desenvolvimento e, embora seu objetivo principal seja a otimização de desempenho, ele não oferece muitos benefícios para a computação da CPU. O NumPy parece ter um desempenho melhor e o uso do JAX pode apenas aumentar a sobrecarga.
  • O JAX ainda está em pesquisa ou em estágio inicial e precisa de mais ajustes para atingir os padrões de infraestrutura de frameworks como o TensorFlow, que são mais estabelecidos e têm mais modelos predefinidos, projetos de código aberto e material de aprendizagem.
  • A partir de agora, JAX não suporta Windows Sistema operacional – você precisaria de uma máquina virtual para fazê-lo funcionar.
  • JAX funciona apenas em funções puras – aquelas que não têm nenhum efeito colateral. Para funções com efeitos colaterais, JAX pode não ser uma boa opção.

Como instalar o JAX em seu ambiente Python

Se você tiver a configuração do python em seu sistema e quiser executar o JAX em sua máquina local (CPU), use os seguintes comandos:

  pip install --upgrade pip
pip install --upgrade "jax(cpu)"

Se você deseja executar o Google JAX em uma GPU ou TPU, siga as instruções fornecidas em GitHub JAX página. Para configurar o Python, visite o downloads oficiais do python página.

Conclusão

O Google JAX é ótimo para escrever algoritmos eficientes de aprendizado profundo, robótica e pesquisa. Apesar das limitações, ele é usado extensivamente com outros frameworks como Haiku, Flax e muitos outros. Você poderá apreciar o que o JAX faz ao executar programas e ver as diferenças de tempo na execução do código com e sem JAX. Você pode começar lendo o documentação oficial do Google JAX que é bastante abrangente.