Google JAX ou J apenas A depois de Ex ecution é uma estrutura desenvolvida pelo Google para acelerar as tarefas de aprendizado de máquina.
Você pode considerá-la uma biblioteca para Python, que ajuda na execução mais rápida de tarefas, computação científica, transformações de função, aprendizado profundo, redes neurais e muito mais.
Sobre o Google JAX
O pacote de computação mais fundamental em Python é o pacote NumPy, que possui todas as funções como agregações, operações vetoriais, álgebra linear, matriz n-dimensional e manipulações de matrizes e muitas outras funções avançadas.
E se pudéssemos acelerar ainda mais os cálculos realizados usando o NumPy – particularmente para grandes conjuntos de dados?
Temos algo que poderia funcionar igualmente bem em diferentes tipos de processadores, como GPU ou TPU, sem nenhuma alteração de código?
Que tal se o sistema pudesse executar transformações de funções combináveis de forma automática e mais eficiente?
O Google JAX é uma biblioteca (ou estrutura, como diz a Wikipedia) que faz exatamente isso e talvez muito mais. Ele foi desenvolvido para otimizar o desempenho e executar com eficiência tarefas de aprendizado de máquina (ML) e aprendizado profundo. O Google JAX fornece os seguintes recursos de transformação que o tornam exclusivo de outras bibliotecas de ML e ajudam na computação científica avançada para aprendizado profundo e redes neurais:
Diferenciação automática
Vetorização automática
Paralelização automática
Compilação Just-in-time (JIT)
Todas as transformações usam XLA (Accelerated Linear Algebra) para maior desempenho e otimização de memória. O XLA é um mecanismo de compilador de otimização específico de domínio que executa álgebra linear e acelera modelos do TensorFlow. Usar o XLA em cima do seu código Python não requer alterações significativas no código!
Vamos explorar em detalhes cada um desses recursos.
Recursos do Google JAX
O Google JAX vem com importantes funções de transformação combináveis para melhorar o desempenho e realizar tarefas de aprendizado profundo com mais eficiência. Por exemplo, diferenciação automática para obter o gradiente de uma função e encontrar derivadas de qualquer ordem. Da mesma forma, paralelização automática e JIT para executar várias tarefas paralelamente. Essas transformações são essenciais para aplicações como robótica, jogos e até pesquisas.
A função de transformação combinável é um puro função que transforma um conjunto de dados em outra forma. Elas são chamadas de combináveis porque são independentes (ou seja, essas funções não têm dependências com o restante do programa) e não têm estado (ou seja, a mesma entrada sempre resultará na mesma saída).
Y(x) = T: (f(x))
Na equação acima, f(x) é a função original na qual uma transformação é aplicada. Y(x) é a função resultante após a aplicação da transformação.
Por exemplo, se você tem uma função chamada ‘total_bill_amt’ e deseja o resultado como uma transformação de função, pode simplesmente usar a transformação que deseja, digamos gradiente (grad):
grad_total_bill = grad(total_bill_amt)
Ao transformar funções numéricas usando funções como grad(), podemos obter facilmente suas derivadas de ordem superior, que podemos usar extensivamente em algoritmos de otimização de aprendizado profundo, como gradiente descendente, tornando os algoritmos mais rápidos e eficientes. Da mesma forma, usando jit(), podemos compilar programas Python just-in-time (preguiçosamente).
#1. Diferenciação automática
O Python usa a função autograd para diferenciar automaticamente o código NumPy e o código Python nativo. JAX usa uma versão modificada de autograd (ou seja, grad) e combina XLA (Accelerated Linear Algebra) para realizar diferenciação automática e encontrar derivados de qualquer ordem para GPU (Graphic Processing Units) e TPU (Tensor Processing Units).)
Nota rápida sobre TPU, GPU e CPU: CPU ou Unidade Central de Processamento gerencia todas as operações no computador. A GPU é um processador adicional que aumenta o poder de computação e executa operações de ponta. A TPU é uma unidade poderosa desenvolvida especificamente para cargas de trabalho complexas e pesadas, como IA e algoritmos de aprendizado profundo.
Na mesma linha da função autograd, que pode diferenciar por meio de loops, recursões, ramificações e assim por diante, JAX usa a função grad() para gradientes de modo reverso (backpropagation). Além disso, podemos diferenciar uma função em qualquer ordem usando grad:
grad(grad(grad(sin θ))) (1.0)
Diferenciação automática de ordem superior
Como mencionamos anteriormente, grad é bastante útil para encontrar as derivadas parciais de uma função. Podemos usar uma derivada parcial para calcular a descida do gradiente de uma função de custo em relação aos parâmetros da rede neural no aprendizado profundo para minimizar as perdas.
Calculando a derivada parcial
Suponha que uma função tenha múltiplas variáveis, x, y e z. Encontrar a derivada de uma variável mantendo as outras variáveis constantes é chamado de derivada parcial. Suponha que temos uma função,
f(x,y,z) = x + 2y + z2
Exemplo para mostrar a derivada parcial
A derivada parcial de x será ∂f/∂x, que nos diz como uma função muda para uma variável quando outras são constantes. Se fizermos isso manualmente, devemos escrever um programa para diferenciar, aplicá-lo para cada variável e, em seguida, calcular a descida do gradiente. Isso se tornaria um assunto complexo e demorado para múltiplas variáveis.
A diferenciação automática divide a função em um conjunto de operações elementares, como +, -, *, / ou sin, cos, tan, exp, etc., e então aplica a regra da cadeia para calcular a derivada. Podemos fazer isso no modo direto e reverso.
Para fornecer as melhores experiências, usamos tecnologias como cookies para armazenar e/ou aceder a informações do dispositivo. Consentir com essas tecnologias nos permitirá processar dados, como comportamento de navegação ou IDs exclusivos neste site. Não consentir ou retirar o consentimento pode afetar adversamente certos recursos e funções.
Funcional
Sempre ativo
O armazenamento ou acesso técnico é estritamente necessário para o fim legítimo de permitir a utilização de um determinado serviço expressamente solicitado pelo assinante ou utilizador, ou para o fim exclusivo de efetuar a transmissão de uma comunicação numa rede de comunicações eletrónicas.
Preferências
O armazenamento ou acesso técnico é necessário para o propósito legítimo de armazenamento de preferências não solicitadas pelo assinante ou usuário.
Estatísticas
O armazenamento técnico ou acesso que é usado exclusivamente para fins estatísticos.O armazenamento técnico ou acesso que é usado exclusivamente para fins estatísticos anónimos. Sem uma intimação, conformidade voluntária por parte de seu Provedor de Serviços de Internet ou registos adicionais de terceiros, as informações armazenadas ou recuperadas apenas para esse fim geralmente não podem ser usadas para identificá-lo.
Marketing
O armazenamento ou acesso técnico é necessário para criar perfis de usuário para enviar publicidade ou para rastrear o usuário em um site ou em vários sites para fins de marketing semelhantes.