{"id":50718,"date":"2025-11-11T08:01:38","date_gmt":"2025-11-11T08:01:38","guid":{"rendered":"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/"},"modified":"2025-11-11T08:01:38","modified_gmt":"2025-11-11T08:01:38","slug":"a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax","status":"publish","type":"post","link":"https:\/\/youzum.net\/th\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/","title":{"rendered":"A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax"},"content":{"rendered":"<p>In this tutorial, we explore how to build and train an advanced neural network using JAX, Flax, and Optax in an efficient and modular way. We begin by designing a deep architecture that integrates residual connections and self-attention mechanisms for expressive feature learning. As we progress, we implement sophisticated optimization strategies with learning rate scheduling, gradient clipping, and adaptive weight decay. Throughout the process, we leverage JAX transformations such as jit, grad, and vmap to accelerate computation and ensure smooth training performance across devices. Check out the\u00a0<strong><a href=\"https:\/\/github.com\/Marktechpost\/AI-Tutorial-Codes-Included\/blob\/main\/ML%20Project%20Codes\/advanced_jax_flax_optax_training_pipeline_Marktechpost.ipynb\" target=\"_blank\" rel=\"noreferrer noopener\">FULL CODES here<\/a><\/strong>.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\"no-line-numbers\"><code class=\"no-wrap language-php\">!pip install jax jaxlib flax optax matplotlib\n\n\nimport jax\nimport jax.numpy as jnp\nfrom jax import random, jit, vmap, grad\nimport flax.linen as nn\nfrom flax.training import train_state\nimport optax\nimport matplotlib.pyplot as plt\nfrom typing import Any, Callable\n\n\nprint(f\"JAX version: {jax.__version__}\")\nprint(f\"Devices: {jax.devices()}\")<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We begin by installing and importing JAX, Flax, and Optax, along with essential utilities for numerical operations and visualization. We check our device setup to ensure that JAX is running efficiently on available hardware. This setup forms the foundation for the entire training pipeline. Check out the\u00a0<strong><a href=\"https:\/\/github.com\/Marktechpost\/AI-Tutorial-Codes-Included\/blob\/main\/ML%20Project%20Codes\/advanced_jax_flax_optax_training_pipeline_Marktechpost.ipynb\" target=\"_blank\" rel=\"noreferrer noopener\">FULL CODES here<\/a><\/strong>.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\"no-line-numbers\"><code class=\"no-wrap language-php\">class SelfAttention(nn.Module):\n   num_heads: int\n   dim: int\n   @nn.compact\n   def __call__(self, x):\n       B, L, D = x.shape\n       head_dim = D \/\/ self.num_heads\n       qkv = nn.Dense(3 * D)(x)\n       qkv = qkv.reshape(B, L, 3, self.num_heads, head_dim)\n       q, k, v = jnp.split(qkv, 3, axis=2)\n       q, k, v = q.squeeze(2), k.squeeze(2), v.squeeze(2)\n       attn_scores = jnp.einsum('bhqd,bhkd-&gt;bhqk', q, k) \/ jnp.sqrt(head_dim)\n       attn_weights = jax.nn.softmax(attn_scores, axis=-1)\n       attn_output = jnp.einsum('bhqk,bhvd-&gt;bhqd', attn_weights, v)\n       attn_output = attn_output.reshape(B, L, D)\n       return nn.Dense(D)(attn_output)\n\n\nclass ResidualBlock(nn.Module):\n   features: int\n   @nn.compact\n   def __call__(self, x, training: bool = True):\n       residual = x\n       x = nn.Conv(self.features, (3, 3), padding='SAME')(x)\n       x = nn.BatchNorm(use_running_average=not training)(x)\n       x = nn.relu(x)\n       x = nn.Conv(self.features, (3, 3), padding='SAME')(x)\n       x = nn.BatchNorm(use_running_average=not training)(x)\n       if residual.shape[-1] != self.features:\n           residual = nn.Conv(self.features, (1, 1))(residual)\n       return nn.relu(x + residual)\n\n\nclass AdvancedCNN(nn.Module):\n   num_classes: int = 10\n   @nn.compact\n   def __call__(self, x, training: bool = True):\n       x = nn.Conv(32, (3, 3), padding='SAME')(x)\n       x = nn.relu(x)\n       x = ResidualBlock(64)(x, training)\n       x = ResidualBlock(64)(x, training)\n       x = nn.max_pool(x, (2, 2), strides=(2, 2))\n       x = ResidualBlock(128)(x, training)\n       x = ResidualBlock(128)(x, training)\n       x = jnp.mean(x, axis=(1, 2))\n       x = x[:, None, :]\n       x = SelfAttention(num_heads=4, dim=128)(x)\n       x = x.squeeze(1)\n       x = nn.Dense(256)(x)\n       x = nn.relu(x)\n       x = nn.Dropout(0.5, deterministic=not training)(x)\n       x = nn.Dense(self.num_classes)(x)\n       return x<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We define a deep neural network that combines residual blocks and a self-attention mechanism for enhanced feature learning. We construct the layers modularly, ensuring that the model can capture both spatial and contextual relationships. This design enables the network to generalize effectively across various types of input data. Check out the\u00a0<strong><a href=\"https:\/\/github.com\/Marktechpost\/AI-Tutorial-Codes-Included\/blob\/main\/ML%20Project%20Codes\/advanced_jax_flax_optax_training_pipeline_Marktechpost.ipynb\" target=\"_blank\" rel=\"noreferrer noopener\">FULL CODES here<\/a><\/strong>.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\"no-line-numbers\"><code class=\"no-wrap language-php\">class TrainState(train_state.TrainState):\n   batch_stats: Any\n\n\ndef create_learning_rate_schedule(base_lr: float = 1e-3, warmup_steps: int = 100, decay_steps: int = 1000) -&gt; optax.Schedule:\n   warmup_fn = optax.linear_schedule(init_value=0.0, end_value=base_lr, transition_steps=warmup_steps)\n   decay_fn = optax.cosine_decay_schedule(init_value=base_lr, decay_steps=decay_steps, alpha=0.1)\n   return optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps])\n\n\ndef create_optimizer(learning_rate_schedule: optax.Schedule) -&gt; optax.GradientTransformation:\n   return optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=learning_rate_schedule, weight_decay=1e-4))<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We create a custom training state that tracks model parameters and batch statistics. We also define a learning rate schedule with warmup and cosine decay, paired with an AdamW optimizer that includes gradient clipping and weight decay. This combination ensures stable and adaptive training. Check out the\u00a0<strong><a href=\"https:\/\/github.com\/Marktechpost\/AI-Tutorial-Codes-Included\/blob\/main\/ML%20Project%20Codes\/advanced_jax_flax_optax_training_pipeline_Marktechpost.ipynb\" target=\"_blank\" rel=\"noreferrer noopener\">FULL CODES here<\/a><\/strong>.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\"no-line-numbers\"><code class=\"no-wrap language-php\">@jit\ndef compute_metrics(logits, labels):\n   loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()\n   accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)\n   return {'loss': loss, 'accuracy': accuracy}\n\n\ndef create_train_state(rng, model, input_shape, learning_rate_schedule):\n   variables = model.init(rng, jnp.ones(input_shape), training=False)\n   params = variables['params']\n   batch_stats = variables.get('batch_stats', {})\n   tx = create_optimizer(learning_rate_schedule)\n   return TrainState.create(apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats)\n\n\n@jit\ndef train_step(state, batch, dropout_rng):\n   images, labels = batch\n   def loss_fn(params):\n       variables = {'params': params, 'batch_stats': state.batch_stats}\n       logits, new_model_state = state.apply_fn(variables, images, training=True, mutable=['batch_stats'], rngs={'dropout': dropout_rng})\n       loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()\n       return loss, (logits, new_model_state)\n   grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n   (loss, (logits, new_model_state)), grads = grad_fn(state.params)\n   state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])\n   metrics = compute_metrics(logits, labels)\n   return state, metrics\n\n\n@jit\ndef eval_step(state, batch):\n   images, labels = batch\n   variables = {'params': state.params, 'batch_stats': state.batch_stats}\n   logits = state.apply_fn(variables, images, training=False)\n   return compute_metrics(logits, labels)<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We implement JIT-compiled training and evaluation functions to achieve efficient execution. The training step computes gradients, updates parameters, and dynamically maintains batch statistics. We also define evaluation metrics that help us monitor loss and accuracy throughout the training process. Check out the\u00a0<strong><a href=\"https:\/\/github.com\/Marktechpost\/AI-Tutorial-Codes-Included\/blob\/main\/ML%20Project%20Codes\/advanced_jax_flax_optax_training_pipeline_Marktechpost.ipynb\" target=\"_blank\" rel=\"noreferrer noopener\">FULL CODES here<\/a><\/strong>.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\"no-line-numbers\"><code class=\"no-wrap language-php\">def generate_synthetic_data(rng, num_samples=1000, img_size=32):\n   rng_x, rng_y = random.split(rng)\n   images = random.normal(rng_x, (num_samples, img_size, img_size, 3))\n   labels = random.randint(rng_y, (num_samples,), 0, 10)\n   return images, labels\n\n\ndef create_batches(images, labels, batch_size=32):\n   num_batches = len(images) \/\/ batch_size\n   for i in range(num_batches):\n       idx = slice(i * batch_size, (i + 1) * batch_size)\n       yield images[idx], labels[idx]<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We generate synthetic data to simulate an image classification task, enabling us to train the model without relying on external datasets. We then batch the data efficiently for iterative updates. This approach allows us to test the full pipeline quickly and verify that all components function correctly. Check out the\u00a0<strong><a href=\"https:\/\/github.com\/Marktechpost\/AI-Tutorial-Codes-Included\/blob\/main\/ML%20Project%20Codes\/advanced_jax_flax_optax_training_pipeline_Marktechpost.ipynb\" target=\"_blank\" rel=\"noreferrer noopener\">FULL CODES here<\/a><\/strong>.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\"no-line-numbers\"><code class=\"no-wrap language-php\">def train_model(num_epochs=5, batch_size=32):\n   rng = random.PRNGKey(0)\n   rng, data_rng, model_rng = random.split(rng, 3)\n   train_images, train_labels = generate_synthetic_data(data_rng, num_samples=1000)\n   test_images, test_labels = generate_synthetic_data(data_rng, num_samples=200)\n   model = AdvancedCNN(num_classes=10)\n   lr_schedule = create_learning_rate_schedule(base_lr=1e-3, warmup_steps=50, decay_steps=500)\n   state = create_train_state(model_rng, model, (1, 32, 32, 3), lr_schedule)\n   history = {'train_loss': [], 'train_acc': [], 'test_acc': []}\n   print(\"Starting training...\")\n   for epoch in range(num_epochs):\n       train_metrics = []\n       for batch in create_batches(train_images, train_labels, batch_size):\n           rng, dropout_rng = random.split(rng)\n           state, metrics = train_step(state, batch, dropout_rng)\n           train_metrics.append(metrics)\n       train_loss = jnp.mean(jnp.array([m['loss'] for m in train_metrics]))\n       train_acc = jnp.mean(jnp.array([m['accuracy'] for m in train_metrics]))\n       test_metrics = [eval_step(state, batch) for batch in create_batches(test_images, test_labels, batch_size)]\n       test_acc = jnp.mean(jnp.array([m['accuracy'] for m in test_metrics]))\n       history['train_loss'].append(float(train_loss))\n       history['train_acc'].append(float(train_acc))\n       history['test_acc'].append(float(test_acc))\n       print(f\"Epoch {epoch + 1}\/{num_epochs}: Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\")\n   return history, state\n\n\nhistory, trained_state = train_model(num_epochs=5)\n\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\nax1.plot(history['train_loss'], label='Train Loss')\nax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss'); ax1.set_title('Training Loss'); ax1.legend(); ax1.grid(True)\nax2.plot(history['train_acc'], label='Train Accuracy')\nax2.plot(history['test_acc'], label='Test Accuracy')\nax2.set_xlabel('Epoch'); ax2.set_ylabel('Accuracy'); ax2.set_title('Model Accuracy'); ax2.legend(); ax2.grid(True)\nplt.tight_layout(); plt.show()\n\n\nprint(\"n<img decoding=\"async\" src=\"https:\/\/s.w.org\/images\/core\/emoji\/16.0.1\/72x72\/2705.png\" alt=\"\u2705\" class=\"wp-smiley\" \/> Tutorial complete! This covers:\")\nprint(\"- Custom Flax modules (ResNet blocks, Self-Attention)\")\nprint(\"- Advanced Optax optimizers (AdamW with gradient clipping)\")\nprint(\"- Learning rate schedules (warmup + cosine decay)\")\nprint(\"- JAX transformations (@jit for performance)\")\nprint(\"- Proper state management (batch normalization statistics)\")\nprint(\"- Complete training pipeline with evaluation\")<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We bring all components together to train the model over several epochs, track performance metrics, and visualize the trends in loss and accuracy. We monitor the model\u2019s learning progress and validate its performance on test data. Ultimately, we confirm the stability and effectiveness of our JAX-based training workflow.<\/p>\n<p>In conclusion, we implemented a comprehensive training pipeline utilizing JAX, Flax, and Optax, which demonstrates both flexibility and computational efficiency. We observed how custom architectures, advanced optimization strategies, and precise state management can come together to form a high-performance deep learning workflow. Through this exercise, we gain a deeper understanding of how to structure scalable experiments in JAX and prepare ourselves to adapt these techniques to real-world machine learning research and production tasks.<\/p>\n<hr class=\"wp-block-separator has-alpha-channel-opacity\" \/>\n<p>Check out the\u00a0<strong><a href=\"https:\/\/github.com\/Marktechpost\/AI-Tutorial-Codes-Included\/blob\/main\/ML%20Project%20Codes\/advanced_jax_flax_optax_training_pipeline_Marktechpost.ipynb\" target=\"_blank\" rel=\"noreferrer noopener\">FULL CODES here<\/a><\/strong>.\u00a0Feel free to check out our\u00a0<strong><mark><a href=\"https:\/\/github.com\/Marktechpost\/AI-Tutorial-Codes-Included\" target=\"_blank\" rel=\"noreferrer noopener\">GitHub Page for Tutorials, Codes and Notebooks<\/a><\/mark><\/strong>.\u00a0Also,\u00a0feel free to follow us on\u00a0<strong><a href=\"https:\/\/x.com\/intent\/follow?screen_name=marktechpost\" target=\"_blank\" rel=\"noreferrer noopener\"><mark>Twitter<\/mark><\/a><\/strong>\u00a0and don\u2019t forget to join our\u00a0<strong><a href=\"https:\/\/www.reddit.com\/r\/machinelearningnews\/\" target=\"_blank\" rel=\"noreferrer noopener\">100k+ ML SubReddit<\/a><\/strong>\u00a0and Subscribe to\u00a0<strong><a href=\"https:\/\/www.aidevsignals.com\/\" target=\"_blank\" rel=\"noreferrer noopener\">our Newsletter<\/a><\/strong>. Wait! are you on telegram?\u00a0<strong><a href=\"https:\/\/t.me\/machinelearningresearchnews\" target=\"_blank\" rel=\"noreferrer noopener\">now you can join us on telegram as well.<\/a><\/strong><\/p>\n<p>The post <a href=\"https:\/\/www.marktechpost.com\/2025\/11\/10\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/\">A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax<\/a> appeared first on <a href=\"https:\/\/www.marktechpost.com\/\">MarkTechPost<\/a>.<\/p>","protected":false},"excerpt":{"rendered":"<p>In this tutorial, we explore how to build and train an advanced neural network using JAX, Flax, and Optax in an efficient and modular way. We begin by designing a deep architecture that integrates residual connections and self-attention mechanisms for expressive feature learning. As we progress, we implement sophisticated optimization strategies with learning rate scheduling, gradient clipping, and adaptive weight decay. Throughout the process, we leverage JAX transformations such as jit, grad, and vmap to accelerate computation and ensure smooth training performance across devices. Check out the\u00a0FULL CODES here. Copy CodeCopiedUse a different Browser !pip install jax jaxlib flax optax matplotlib import jax import jax.numpy as jnp from jax import random, jit, vmap, grad import flax.linen as nn from flax.training import train_state import optax import matplotlib.pyplot as plt from typing import Any, Callable print(f&#8221;JAX version: {jax.__version__}&#8221;) print(f&#8221;Devices: {jax.devices()}&#8221;) We begin by installing and importing JAX, Flax, and Optax, along with essential utilities for numerical operations and visualization. We check our device setup to ensure that JAX is running efficiently on available hardware. This setup forms the foundation for the entire training pipeline. Check out the\u00a0FULL CODES here. Copy CodeCopiedUse a different Browser class SelfAttention(nn.Module): num_heads: int dim: int @nn.compact def __call__(self, x): B, L, D = x.shape head_dim = D \/\/ self.num_heads qkv = nn.Dense(3 * D)(x) qkv = qkv.reshape(B, L, 3, self.num_heads, head_dim) q, k, v = jnp.split(qkv, 3, axis=2) q, k, v = q.squeeze(2), k.squeeze(2), v.squeeze(2) attn_scores = jnp.einsum(&#8216;bhqd,bhkd-&gt;bhqk&#8217;, q, k) \/ jnp.sqrt(head_dim) attn_weights = jax.nn.softmax(attn_scores, axis=-1) attn_output = jnp.einsum(&#8216;bhqk,bhvd-&gt;bhqd&#8217;, attn_weights, v) attn_output = attn_output.reshape(B, L, D) return nn.Dense(D)(attn_output) class ResidualBlock(nn.Module): features: int @nn.compact def __call__(self, x, training: bool = True): residual = x x = nn.Conv(self.features, (3, 3), padding=&#8217;SAME&#8217;)(x) x = nn.BatchNorm(use_running_average=not training)(x) x = nn.relu(x) x = nn.Conv(self.features, (3, 3), padding=&#8217;SAME&#8217;)(x) x = nn.BatchNorm(use_running_average=not training)(x) if residual.shape[-1] != self.features: residual = nn.Conv(self.features, (1, 1))(residual) return nn.relu(x + residual) class AdvancedCNN(nn.Module): num_classes: int = 10 @nn.compact def __call__(self, x, training: bool = True): x = nn.Conv(32, (3, 3), padding=&#8217;SAME&#8217;)(x) x = nn.relu(x) x = ResidualBlock(64)(x, training) x = ResidualBlock(64)(x, training) x = nn.max_pool(x, (2, 2), strides=(2, 2)) x = ResidualBlock(128)(x, training) x = ResidualBlock(128)(x, training) x = jnp.mean(x, axis=(1, 2)) x = x[:, None, :] x = SelfAttention(num_heads=4, dim=128)(x) x = x.squeeze(1) x = nn.Dense(256)(x) x = nn.relu(x) x = nn.Dropout(0.5, deterministic=not training)(x) x = nn.Dense(self.num_classes)(x) return x We define a deep neural network that combines residual blocks and a self-attention mechanism for enhanced feature learning. We construct the layers modularly, ensuring that the model can capture both spatial and contextual relationships. This design enables the network to generalize effectively across various types of input data. Check out the\u00a0FULL CODES here. Copy CodeCopiedUse a different Browser class TrainState(train_state.TrainState): batch_stats: Any def create_learning_rate_schedule(base_lr: float = 1e-3, warmup_steps: int = 100, decay_steps: int = 1000) -&gt; optax.Schedule: warmup_fn = optax.linear_schedule(init_value=0.0, end_value=base_lr, transition_steps=warmup_steps) decay_fn = optax.cosine_decay_schedule(init_value=base_lr, decay_steps=decay_steps, alpha=0.1) return optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps]) def create_optimizer(learning_rate_schedule: optax.Schedule) -&gt; optax.GradientTransformation: return optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=learning_rate_schedule, weight_decay=1e-4)) We create a custom training state that tracks model parameters and batch statistics. We also define a learning rate schedule with warmup and cosine decay, paired with an AdamW optimizer that includes gradient clipping and weight decay. This combination ensures stable and adaptive training. Check out the\u00a0FULL CODES here. Copy CodeCopiedUse a different Browser @jit def compute_metrics(logits, labels): loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) return {&#8216;loss&#8217;: loss, &#8216;accuracy&#8217;: accuracy} def create_train_state(rng, model, input_shape, learning_rate_schedule): variables = model.init(rng, jnp.ones(input_shape), training=False) params = variables[&#8216;params&#8217;] batch_stats = variables.get(&#8216;batch_stats&#8217;, {}) tx = create_optimizer(learning_rate_schedule) return TrainState.create(apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats) @jit def train_step(state, batch, dropout_rng): images, labels = batch def loss_fn(params): variables = {&#8216;params&#8217;: params, &#8216;batch_stats&#8217;: state.batch_stats} logits, new_model_state = state.apply_fn(variables, images, training=True, mutable=[&#8216;batch_stats&#8217;], rngs={&#8216;dropout&#8217;: dropout_rng}) loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() return loss, (logits, new_model_state) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, (logits, new_model_state)), grads = grad_fn(state.params) state = state.apply_gradients(grads=grads, batch_stats=new_model_state[&#8216;batch_stats&#8217;]) metrics = compute_metrics(logits, labels) return state, metrics @jit def eval_step(state, batch): images, labels = batch variables = {&#8216;params&#8217;: state.params, &#8216;batch_stats&#8217;: state.batch_stats} logits = state.apply_fn(variables, images, training=False) return compute_metrics(logits, labels) We implement JIT-compiled training and evaluation functions to achieve efficient execution. The training step computes gradients, updates parameters, and dynamically maintains batch statistics. We also define evaluation metrics that help us monitor loss and accuracy throughout the training process. Check out the\u00a0FULL CODES here. Copy CodeCopiedUse a different Browser def generate_synthetic_data(rng, num_samples=1000, img_size=32): rng_x, rng_y = random.split(rng) images = random.normal(rng_x, (num_samples, img_size, img_size, 3)) labels = random.randint(rng_y, (num_samples,), 0, 10) return images, labels def create_batches(images, labels, batch_size=32): num_batches = len(images) \/\/ batch_size for i in range(num_batches): idx = slice(i * batch_size, (i + 1) * batch_size) yield images[idx], labels[idx] We generate synthetic data to simulate an image classification task, enabling us to train the model without relying on external datasets. We then batch the data efficiently for iterative updates. This approach allows us to test the full pipeline quickly and verify that all components function correctly. Check out the\u00a0FULL CODES here. Copy CodeCopiedUse a different Browser def train_model(num_epochs=5, batch_size=32): rng = random.PRNGKey(0) rng, data_rng, model_rng = random.split(rng, 3) train_images, train_labels = generate_synthetic_data(data_rng, num_samples=1000) test_images, test_labels = generate_synthetic_data(data_rng, num_samples=200) model = AdvancedCNN(num_classes=10) lr_schedule = create_learning_rate_schedule(base_lr=1e-3, warmup_steps=50, decay_steps=500) state = create_train_state(model_rng, model, (1, 32, 32, 3), lr_schedule) history = {&#8216;train_loss&#8217;: [], &#8216;train_acc&#8217;: [], &#8216;test_acc&#8217;: []} print(&#8220;Starting training&#8230;&#8221;) for epoch in range(num_epochs): train_metrics = [] for batch in create_batches(train_images, train_labels, batch_size): rng, dropout_rng = random.split(rng) state, metrics = train_step(state, batch, dropout_rng) train_metrics.append(metrics) train_loss = jnp.mean(jnp.array([m[&#8216;loss&#8217;] for m in train_metrics])) train_acc = jnp.mean(jnp.array([m[&#8216;accuracy&#8217;] for m in train_metrics])) test_metrics = [eval_step(state, batch) for batch in create_batches(test_images, test_labels, batch_size)] test_acc = jnp.mean(jnp.array([m[&#8216;accuracy&#8217;] for m in test_metrics])) history[&#8216;train_loss&#8217;].append(float(train_loss)) history[&#8216;train_acc&#8217;].append(float(train_acc)) history[&#8216;test_acc&#8217;].append(float(test_acc)) print(f&#8221;Epoch {epoch + 1}\/{num_epochs}: Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}&#8221;) return history, state history, trained_state = train_model(num_epochs=5) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) ax1.plot(history[&#8216;train_loss&#8217;], label=&#8217;Train Loss&#8217;) ax1.set_xlabel(&#8216;Epoch&#8217;); ax1.set_ylabel(&#8216;Loss&#8217;); ax1.set_title(&#8216;Training Loss&#8217;); ax1.legend(); ax1.grid(True) ax2.plot(history[&#8216;train_acc&#8217;], label=&#8217;Train Accuracy&#8217;) ax2.plot(history[&#8216;test_acc&#8217;], label=&#8217;Test Accuracy&#8217;) ax2.set_xlabel(&#8216;Epoch&#8217;); ax2.set_ylabel(&#8216;Accuracy&#8217;); ax2.set_title(&#8216;Model Accuracy&#8217;); ax2.legend(); ax2.grid(True) plt.tight_layout(); plt.show() print(&#8220;n Tutorial complete! This covers:&#8221;) print(&#8220;- Custom Flax modules (ResNet blocks, Self-Attention)&#8221;) print(&#8220;- Advanced Optax optimizers (AdamW with gradient clipping)&#8221;)<\/p>","protected":false},"author":2,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"_acf_changed":false,"pmpro_default_level":"","site-sidebar-layout":"default","site-content-layout":"","ast-site-content-layout":"","site-content-style":"default","site-sidebar-style":"default","ast-global-header-display":"","ast-banner-title-visibility":"","ast-main-header-display":"","ast-hfb-above-header-display":"","ast-hfb-below-header-display":"","ast-hfb-mobile-header-display":"","site-post-title":"","ast-breadcrumbs-content":"","ast-featured-img":"","footer-sml-layout":"","theme-transparent-header-meta":"","adv-header-id-meta":"","stick-header-meta":"","header-above-stick-meta":"","header-main-stick-meta":"","header-below-stick-meta":"","astra-migrate-meta-layouts":"default","ast-page-background-enabled":"default","ast-page-background-meta":{"desktop":{"background-color":"var(--ast-global-color-4)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"ast-content-background-meta":{"desktop":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"tablet":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""},"mobile":{"background-color":"var(--ast-global-color-5)","background-image":"","background-repeat":"repeat","background-position":"center center","background-size":"auto","background-attachment":"scroll","background-type":"","background-media":"","overlay-type":"","overlay-color":"","overlay-opacity":"","overlay-gradient":""}},"_pvb_checkbox_block_on_post":false,"footnotes":""},"categories":[52,5,7,1],"tags":[],"class_list":["post-50718","post","type-post","status-publish","format-standard","hentry","category-ai-club","category-committee","category-news","category-uncategorized","pmpro-has-access"],"acf":[],"yoast_head":"<!-- This site is optimized with the Yoast SEO plugin v25.3 - https:\/\/yoast.com\/wordpress\/plugins\/seo\/ -->\n<title>A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax - YouZum<\/title>\n<meta name=\"description\" content=\"\u0e01\u0e34\u0e08\u0e01\u0e23\u0e23\u0e21\u0e40\u0e01\u0e35\u0e48\u0e22\u0e27\u0e01\u0e31\u0e1a\u0e42\u0e14\u0e23\u0e19\" \/>\n<meta name=\"robots\" content=\"index, follow, max-snippet:-1, max-image-preview:large, max-video-preview:-1\" \/>\n<link rel=\"canonical\" href=\"https:\/\/youzum.net\/th\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/\" \/>\n<meta property=\"og:locale\" content=\"th_TH\" \/>\n<meta property=\"og:type\" content=\"article\" \/>\n<meta property=\"og:title\" content=\"A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax - YouZum\" \/>\n<meta property=\"og:description\" content=\"\u0e01\u0e34\u0e08\u0e01\u0e23\u0e23\u0e21\u0e40\u0e01\u0e35\u0e48\u0e22\u0e27\u0e01\u0e31\u0e1a\u0e42\u0e14\u0e23\u0e19\" \/>\n<meta property=\"og:url\" content=\"https:\/\/youzum.net\/th\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/\" \/>\n<meta property=\"og:site_name\" content=\"YouZum\" \/>\n<meta property=\"article:publisher\" content=\"https:\/\/www.facebook.com\/DroneAssociationTH\/\" \/>\n<meta property=\"article:published_time\" content=\"2025-11-11T08:01:38+00:00\" \/>\n<meta property=\"og:image\" content=\"https:\/\/s.w.org\/images\/core\/emoji\/16.0.1\/72x72\/2705.png\" \/>\n<meta name=\"author\" content=\"admin NU\" \/>\n<meta name=\"twitter:card\" content=\"summary_large_image\" \/>\n<meta name=\"twitter:label1\" content=\"Written by\" \/>\n\t<meta name=\"twitter:data1\" content=\"admin NU\" \/>\n\t<meta name=\"twitter:label2\" content=\"Est. reading time\" \/>\n\t<meta name=\"twitter:data2\" content=\"8 \u0e19\u0e32\u0e17\u0e35\" \/>\n<script type=\"application\/ld+json\" class=\"yoast-schema-graph\">{\"@context\":\"https:\/\/schema.org\",\"@graph\":[{\"@type\":\"Article\",\"@id\":\"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/#article\",\"isPartOf\":{\"@id\":\"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/\"},\"author\":{\"name\":\"admin NU\",\"@id\":\"https:\/\/yousum.gpucore.co\/#\/schema\/person\/97fa48242daf3908e4d9a5f26f4a059c\"},\"headline\":\"A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax\",\"datePublished\":\"2025-11-11T08:01:38+00:00\",\"mainEntityOfPage\":{\"@id\":\"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/\"},\"wordCount\":610,\"commentCount\":0,\"publisher\":{\"@id\":\"https:\/\/yousum.gpucore.co\/#organization\"},\"image\":{\"@id\":\"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/#primaryimage\"},\"thumbnailUrl\":\"https:\/\/s.w.org\/images\/core\/emoji\/16.0.1\/72x72\/2705.png\",\"articleSection\":[\"AI\",\"Committee\",\"News\",\"Uncategorized\"],\"inLanguage\":\"th\",\"potentialAction\":[{\"@type\":\"CommentAction\",\"name\":\"Comment\",\"target\":[\"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/#respond\"]}]},{\"@type\":\"WebPage\",\"@id\":\"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/\",\"url\":\"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/\",\"name\":\"A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax - YouZum\",\"isPartOf\":{\"@id\":\"https:\/\/yousum.gpucore.co\/#website\"},\"primaryImageOfPage\":{\"@id\":\"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/#primaryimage\"},\"image\":{\"@id\":\"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/#primaryimage\"},\"thumbnailUrl\":\"https:\/\/s.w.org\/images\/core\/emoji\/16.0.1\/72x72\/2705.png\",\"datePublished\":\"2025-11-11T08:01:38+00:00\",\"description\":\"\u0e01\u0e34\u0e08\u0e01\u0e23\u0e23\u0e21\u0e40\u0e01\u0e35\u0e48\u0e22\u0e27\u0e01\u0e31\u0e1a\u0e42\u0e14\u0e23\u0e19\",\"breadcrumb\":{\"@id\":\"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/#breadcrumb\"},\"inLanguage\":\"th\",\"potentialAction\":[{\"@type\":\"ReadAction\",\"target\":[\"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/\"]}]},{\"@type\":\"ImageObject\",\"inLanguage\":\"th\",\"@id\":\"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/#primaryimage\",\"url\":\"https:\/\/s.w.org\/images\/core\/emoji\/16.0.1\/72x72\/2705.png\",\"contentUrl\":\"https:\/\/s.w.org\/images\/core\/emoji\/16.0.1\/72x72\/2705.png\"},{\"@type\":\"BreadcrumbList\",\"@id\":\"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/#breadcrumb\",\"itemListElement\":[{\"@type\":\"ListItem\",\"position\":1,\"name\":\"Home\",\"item\":\"https:\/\/youzum.net\/\"},{\"@type\":\"ListItem\",\"position\":2,\"name\":\"A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax\"}]},{\"@type\":\"WebSite\",\"@id\":\"https:\/\/yousum.gpucore.co\/#website\",\"url\":\"https:\/\/yousum.gpucore.co\/\",\"name\":\"YouSum\",\"description\":\"\",\"publisher\":{\"@id\":\"https:\/\/yousum.gpucore.co\/#organization\"},\"potentialAction\":[{\"@type\":\"SearchAction\",\"target\":{\"@type\":\"EntryPoint\",\"urlTemplate\":\"https:\/\/yousum.gpucore.co\/?s={search_term_string}\"},\"query-input\":{\"@type\":\"PropertyValueSpecification\",\"valueRequired\":true,\"valueName\":\"search_term_string\"}}],\"inLanguage\":\"th\"},{\"@type\":\"Organization\",\"@id\":\"https:\/\/yousum.gpucore.co\/#organization\",\"name\":\"Drone Association Thailand\",\"url\":\"https:\/\/yousum.gpucore.co\/\",\"logo\":{\"@type\":\"ImageObject\",\"inLanguage\":\"th\",\"@id\":\"https:\/\/yousum.gpucore.co\/#\/schema\/logo\/image\/\",\"url\":\"https:\/\/youzum.net\/wp-content\/uploads\/2024\/11\/tranparent-logo.png\",\"contentUrl\":\"https:\/\/youzum.net\/wp-content\/uploads\/2024\/11\/tranparent-logo.png\",\"width\":300,\"height\":300,\"caption\":\"Drone Association Thailand\"},\"image\":{\"@id\":\"https:\/\/yousum.gpucore.co\/#\/schema\/logo\/image\/\"},\"sameAs\":[\"https:\/\/www.facebook.com\/DroneAssociationTH\/\"]},{\"@type\":\"Person\",\"@id\":\"https:\/\/yousum.gpucore.co\/#\/schema\/person\/97fa48242daf3908e4d9a5f26f4a059c\",\"name\":\"admin NU\",\"image\":{\"@type\":\"ImageObject\",\"inLanguage\":\"th\",\"@id\":\"https:\/\/yousum.gpucore.co\/#\/schema\/person\/image\/\",\"url\":\"https:\/\/youzum.net\/wp-content\/uploads\/avatars\/2\/1746849356-bpfull.png\",\"contentUrl\":\"https:\/\/youzum.net\/wp-content\/uploads\/avatars\/2\/1746849356-bpfull.png\",\"caption\":\"admin NU\"},\"url\":\"https:\/\/youzum.net\/th\/members\/adminnu\/\"}]}<\/script>\n<!-- \/ Yoast SEO plugin. -->","yoast_head_json":{"title":"A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax - YouZum","description":"\u0e01\u0e34\u0e08\u0e01\u0e23\u0e23\u0e21\u0e40\u0e01\u0e35\u0e48\u0e22\u0e27\u0e01\u0e31\u0e1a\u0e42\u0e14\u0e23\u0e19","robots":{"index":"index","follow":"follow","max-snippet":"max-snippet:-1","max-image-preview":"max-image-preview:large","max-video-preview":"max-video-preview:-1"},"canonical":"https:\/\/youzum.net\/th\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/","og_locale":"th_TH","og_type":"article","og_title":"A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax - YouZum","og_description":"\u0e01\u0e34\u0e08\u0e01\u0e23\u0e23\u0e21\u0e40\u0e01\u0e35\u0e48\u0e22\u0e27\u0e01\u0e31\u0e1a\u0e42\u0e14\u0e23\u0e19","og_url":"https:\/\/youzum.net\/th\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/","og_site_name":"YouZum","article_publisher":"https:\/\/www.facebook.com\/DroneAssociationTH\/","article_published_time":"2025-11-11T08:01:38+00:00","og_image":[{"url":"https:\/\/s.w.org\/images\/core\/emoji\/16.0.1\/72x72\/2705.png","type":"","width":"","height":""}],"author":"admin NU","twitter_card":"summary_large_image","twitter_misc":{"Written by":"admin NU","Est. reading time":"8 \u0e19\u0e32\u0e17\u0e35"},"schema":{"@context":"https:\/\/schema.org","@graph":[{"@type":"Article","@id":"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/#article","isPartOf":{"@id":"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/"},"author":{"name":"admin NU","@id":"https:\/\/yousum.gpucore.co\/#\/schema\/person\/97fa48242daf3908e4d9a5f26f4a059c"},"headline":"A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax","datePublished":"2025-11-11T08:01:38+00:00","mainEntityOfPage":{"@id":"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/"},"wordCount":610,"commentCount":0,"publisher":{"@id":"https:\/\/yousum.gpucore.co\/#organization"},"image":{"@id":"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/#primaryimage"},"thumbnailUrl":"https:\/\/s.w.org\/images\/core\/emoji\/16.0.1\/72x72\/2705.png","articleSection":["AI","Committee","News","Uncategorized"],"inLanguage":"th","potentialAction":[{"@type":"CommentAction","name":"Comment","target":["https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/#respond"]}]},{"@type":"WebPage","@id":"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/","url":"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/","name":"A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax - YouZum","isPartOf":{"@id":"https:\/\/yousum.gpucore.co\/#website"},"primaryImageOfPage":{"@id":"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/#primaryimage"},"image":{"@id":"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/#primaryimage"},"thumbnailUrl":"https:\/\/s.w.org\/images\/core\/emoji\/16.0.1\/72x72\/2705.png","datePublished":"2025-11-11T08:01:38+00:00","description":"\u0e01\u0e34\u0e08\u0e01\u0e23\u0e23\u0e21\u0e40\u0e01\u0e35\u0e48\u0e22\u0e27\u0e01\u0e31\u0e1a\u0e42\u0e14\u0e23\u0e19","breadcrumb":{"@id":"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/#breadcrumb"},"inLanguage":"th","potentialAction":[{"@type":"ReadAction","target":["https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/"]}]},{"@type":"ImageObject","inLanguage":"th","@id":"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/#primaryimage","url":"https:\/\/s.w.org\/images\/core\/emoji\/16.0.1\/72x72\/2705.png","contentUrl":"https:\/\/s.w.org\/images\/core\/emoji\/16.0.1\/72x72\/2705.png"},{"@type":"BreadcrumbList","@id":"https:\/\/youzum.net\/a-coding-implementation-to-build-and-train-advanced-architectures-with-residual-connections-self-attention-and-adaptive-optimization-using-jax-flax-and-optax\/#breadcrumb","itemListElement":[{"@type":"ListItem","position":1,"name":"Home","item":"https:\/\/youzum.net\/"},{"@type":"ListItem","position":2,"name":"A Coding Implementation to Build and Train Advanced Architectures with Residual Connections, Self-Attention, and Adaptive Optimization Using JAX, Flax, and Optax"}]},{"@type":"WebSite","@id":"https:\/\/yousum.gpucore.co\/#website","url":"https:\/\/yousum.gpucore.co\/","name":"YouSum","description":"","publisher":{"@id":"https:\/\/yousum.gpucore.co\/#organization"},"potentialAction":[{"@type":"SearchAction","target":{"@type":"EntryPoint","urlTemplate":"https:\/\/yousum.gpucore.co\/?s={search_term_string}"},"query-input":{"@type":"PropertyValueSpecification","valueRequired":true,"valueName":"search_term_string"}}],"inLanguage":"th"},{"@type":"Organization","@id":"https:\/\/yousum.gpucore.co\/#organization","name":"Drone Association Thailand","url":"https:\/\/yousum.gpucore.co\/","logo":{"@type":"ImageObject","inLanguage":"th","@id":"https:\/\/yousum.gpucore.co\/#\/schema\/logo\/image\/","url":"https:\/\/youzum.net\/wp-content\/uploads\/2024\/11\/tranparent-logo.png","contentUrl":"https:\/\/youzum.net\/wp-content\/uploads\/2024\/11\/tranparent-logo.png","width":300,"height":300,"caption":"Drone Association Thailand"},"image":{"@id":"https:\/\/yousum.gpucore.co\/#\/schema\/logo\/image\/"},"sameAs":["https:\/\/www.facebook.com\/DroneAssociationTH\/"]},{"@type":"Person","@id":"https:\/\/yousum.gpucore.co\/#\/schema\/person\/97fa48242daf3908e4d9a5f26f4a059c","name":"admin NU","image":{"@type":"ImageObject","inLanguage":"th","@id":"https:\/\/yousum.gpucore.co\/#\/schema\/person\/image\/","url":"https:\/\/youzum.net\/wp-content\/uploads\/avatars\/2\/1746849356-bpfull.png","contentUrl":"https:\/\/youzum.net\/wp-content\/uploads\/avatars\/2\/1746849356-bpfull.png","caption":"admin NU"},"url":"https:\/\/youzum.net\/th\/members\/adminnu\/"}]}},"rttpg_featured_image_url":null,"rttpg_author":{"display_name":"admin NU","author_link":"https:\/\/youzum.net\/th\/members\/adminnu\/"},"rttpg_comment":0,"rttpg_category":"<a href=\"https:\/\/youzum.net\/th\/category\/ai-club\/\" rel=\"category tag\">AI<\/a> <a href=\"https:\/\/youzum.net\/th\/category\/committee\/\" rel=\"category tag\">Committee<\/a> <a href=\"https:\/\/youzum.net\/th\/category\/news\/\" rel=\"category tag\">News<\/a> <a href=\"https:\/\/youzum.net\/th\/category\/uncategorized\/\" rel=\"category tag\">Uncategorized<\/a>","rttpg_excerpt":"In this tutorial, we explore how to build and train an advanced neural network using JAX, Flax, and Optax in an efficient and modular way. We begin by designing a deep architecture that integrates residual connections and self-attention mechanisms for expressive feature learning. As we progress, we implement sophisticated optimization strategies with learning rate scheduling,&hellip;","_links":{"self":[{"href":"https:\/\/youzum.net\/th\/wp-json\/wp\/v2\/posts\/50718","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/youzum.net\/th\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/youzum.net\/th\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/youzum.net\/th\/wp-json\/wp\/v2\/users\/2"}],"replies":[{"embeddable":true,"href":"https:\/\/youzum.net\/th\/wp-json\/wp\/v2\/comments?post=50718"}],"version-history":[{"count":0,"href":"https:\/\/youzum.net\/th\/wp-json\/wp\/v2\/posts\/50718\/revisions"}],"wp:attachment":[{"href":"https:\/\/youzum.net\/th\/wp-json\/wp\/v2\/media?parent=50718"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/youzum.net\/th\/wp-json\/wp\/v2\/categories?post=50718"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/youzum.net\/th\/wp-json\/wp\/v2\/tags?post=50718"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}