Implementación de la equivarianza de rotación: Creación de un CNN equivalente en grupo desde cero

Las redes neuronales convolucionales (CNN) son altamente eficaces, ya que son capaces de detectar características en una imagen sin importar su posición. Sin embargo, no son completamente indiferentes a todos los tipos de movimientos. Mover hacia arriba o hacia abajo, a la izquierda o a la derecha está bien, pero la rotación alrededor de un eje representa un desafío debido a la naturaleza de la convolución: se aplica primero por filas y luego por columnas (o viceversa). Para lograr una detección exitosa de características incluso cuando la imagen está rotada, es necesario extender la convolución a una operación que sea equivariante ante la rotación. Una operación que es equivariante respecto a un tipo específico de acción no solo registrará la característica desplazada en sí misma, sino que también será consciente de la acción específica que dio lugar a su aparición en esa ubicación.

Este es el segunda artículo de una serie que explora las CNN equivalentes en grupo (GCNN). La primera fue una introducción general sobre por qué son importantes y cómo funcionan. En esa ocasión, presentamos el elemento clave, el grupo de simetría, que especifica qué tipos de transformaciones deben tratarse de manera equivariante. Si aún no lo ha hecho, le recomiendo que lea ese artículo antes de continuar aquí, ya que usaré la terminología y los conceptos introducidos previamente.

En esta ocasión, vamos a codificar un GCNN simple desde cero. Tanto el código como la explicación se basan en un cuaderno proporcionado como parte del programa de 2022 del Curso de Aprendizaje Profundo de la Universidad de Ámsterdam. Siempre es importante destacar la disponibilidad de este tipo de recursos educativos de alta calidad.

A continuación, mi objetivo es explicar de manera general el razonamiento detrás de la arquitectura resultante, construida a partir de módulos más pequeños, cada uno con un propósito específico. Por esta razón, no incluiré todo el código aquí; en su lugar, haré uso del paquete gcnn. Sus métodos están bien comentados, por lo que si desea ver más detalles, no dude en revisar el código directamente.

En la actualidad, gcnn implementa un grupo de simetría: (C_4), que se utiliza como ejemplo continuo a lo largo de esta publicación. Sin embargo, es fácilmente extensible y utiliza jerarquías de clases en todas partes.

Paso 1: el grupo de simetría (C_4)

Para codificar un GCNN, lo primero que debemos proporcionar es una implementación del grupo de simetría que queremos utilizar. Aquí está el grupo de simetría (C_4), que consta de cuatro elementos que realizan rotaciones de 90 grados.

Podemos pedirle a gcnn que cree el grupo por nosotros e inspeccionar sus elementos.

# remotes::install_github("skeydan/gcnn")
library(gcnn)
library(torch)

C_4 <- CyclicGroup(order = 4)
elems <- C_4$elements()
elems
torch_tensor
 0.0000
 1.5708
 3.1416
 4.7124
( CPUFloatType{4} )

Los elementos están representados por sus respectivos ángulos de rotación: (0), (frac{pi}{2}), (Pi) y (frac{3 pi}{2}).

Los grupos contienen la identidad y pueden construir la inversa de un elemento:

C_4$identity

g1 <- elems(2)
C_4$inverse(g1)
torch_tensor
 0
( CPUFloatType{1} )

torch_tensor
4.71239
( CPUFloatType{} )

Lo más relevante aquí son los elementos del grupo y sus acciones. En términos de implementación, debemos distinguir entre las acciones entre los propios elementos del grupo y la acción en el espacio vectorial (mathbb{R}^2), donde residen nuestras imágenes de entrada. La primera parte es bastante sencilla: se puede implementar simplemente sumando ángulos. De hecho, esto es lo que hace gcnn cuando le pedimos que aplique la acción izquierda del elemento g1 en el elemento g2:

g2 <- elems(3)

C_4$left_action_on_H(torch_tensor(g1)$unsqueeze(1), torch_tensor(g2)$unsqueeze(1))
torch_tensor
 4.7124
( CPUFloatType{1,1} )

¿Qué significa el uso de unsqueeze()? Dado que (C_4) está destinado a ser utilizado en una red neuronal, la función left_action_on_H() opera con lotes de elementos en lugar de tensores escalares.

Las cosas se complican un poco más cuando se trata de la acción del grupo en (mathbb{R}^2). Aquí es donde necesitamos el concepto de una representación del grupo. Este es un tema complejo que no abordaremos aquí en detalle. En nuestro contexto actual, funciona de la siguiente manera: tenemos una señal de entrada, un tensor con el que deseamos operar de alguna manera (esa operación será convolución, como veremos más adelante). Para representar esa operación de manera equivariante al grupo, primero aplicamos la inversa de la acción grupal a la entrada mediante la representación. Una vez hecho esto, procedemos con la operación como si nada hubiera pasado.

Para ilustrar con un ejemplo concreto, supongamos que la operación es una convolución. Imagine que…

Un corredor, parado al pie de un sendero de montaña, listo para comenzar a correr cuesta arriba, desea que registremos su altura. Tenemos la opción de medir su altura en la base y luego permitirle subir. La validez de nuestra medición será la misma tanto en la cima de la montaña como en la base. Otra opción sería ser corteses y no hacerlo esperar, pidiéndole que baje una vez esté arriba para medir su altura al regresar. En ambos casos, el resultado será el mismo: la altura del cuerpo permanece invariable (o más bien constante) ante la acción de correr hacia arriba o hacia abajo. (Aunque la altura pueda considerarse una medida bastante aburrida, algo más interesante como la frecuencia cardíaca no habría funcionado tan bien en este ejemplo).

Centrémonos en la implementación: las acciones grupales están codificadas en forma de matrices. Hay una matriz correspondiente a cada elemento del grupo. Para (C_4), esta matriz, denominada estándar, tiene una representación como matriz de rotación: [ begin{bmatrix} cos(theta) & -sin(theta) sin(theta) & cos(theta) end{bmatrix} ] Dentro de gcnn, la función que aplica dicha matriz se llama left_action_on_R2(). Al igual que su contraparte, está diseñada para trabajar con lotes, tanto de elementos del grupo como de vectores en (mathbb{R}^2). En términos técnicos, rota la cuadrícula en la que se define la imagen y luego remuestrea la imagen dentro de esa cuadrícula.

Para ilustrar esto de forma más concreta,…El método mencionado tiene el siguiente aspecto. Aquí se presenta una imagen de una cabra. Una cabra sentada cómodamente en un prado. En primer lugar, se invoca C_4$left_action_on_R2() para girar la cuadrícula. "`r # El tamaño de la cuadrícula es (2, 1024, 1024), para una imagen 2D de 1024 x 1024. img_grid_R2 <- torch::torch_stack(torch::torch_meshgrid( list( torch::torch_linspace(-1, 1, dim(img)(2)), torch::torch_linspace(-1, 1, dim(img)(3)) ) )) # Se transforma la cuadrícula de imagen con la representación matricial de algún elemento de grupo. transformed_grid <- C_4$left_action_on_R2(C_4$inverse(g1)$unsqueeze(1), img_grid_R2) "` En segundo lugar, se vuelve a muestrear la imagen en la cuadrícula transformada. La cabra ahora está mirando hacia arriba en la imagen. "`r transformed_img <- torch::nnf_grid_sample( img$unsqueeze(1), transformed_grid, align_corners = TRUE, mode = "bilinear", padding_mode = "zeros" ) transformed_img(1,..)$permute(c(2, 3, 1)) |> as.array() |> as.raster() |> plot() "` La misma cabra, girada 90 grados hacia arriba. ### Paso 2: la convolución de elevación Queremo… Hacer uso de tecnologías existentes y eficaces. Utilizar la funcionalidad de `torch` tanto como sea posible. Específicamente, se desea emplear `nn_conv2d()`. No obstante, lo que se necesita es un núcleo de convolución que sea equivalente no solo a la traslación, sino también a la acción de (C_4). Esto se puede lograr teniendo un núcleo para cada rotación posible.

La implementación de esta idea es lo que hace `LiftingConvolution`. El concepto es el mismo que antes: primero se rota la cuadrícula y luego se vuelve a muestrear la matriz de pesos del núcleo en la cuadrícula transformada.

¿Por qué se le llama esto una convolución de elevación? El núcleo de convolución convencional opera en (mathbb{R}^2); mientras que nuestra versión extendida opera en combinaciones de (mathbb{R}^2) y (C_4). Matemáticamente, ha sido elevado hacia un producto semidirecto (mathbb{R}^2rveces C_4).

lifting_conv <- LiftingConvolution(
    group = CyclicGroup(order = 4),
    kernel_size = 5,
    in_channels = 3,
    out_channels = 8
  )

x <- torch::torch_randn(c(2, 3, 32, 32))
y <- lifting_conv(x)
y$shape
(1)  2  8  4 28 28

Debido a que `LiftingConvolution` utiliza internamente una dimensión adicional para realizar el producto de traslaciones y rotaciones, el resultado no es de cuatro, sino de cinco dimensiones.

Paso 3: convoluciones grupales

Ahora que nos encontramos en el "espacio extendido por grupos", podemos encadenar múltiples capas donde tanto la entrada como la salida son capas de convolución de grupo. Por ejemplo:

group_conv <- GroupConvolution(
  group = CyclicGroup(order = 4),
    kernel_size = 5,
    in_channels = 8,
    out_channels = 16
)

z <- group_conv(y)
z$shape
(1)  2 16  4 24 24

Solo queda empaquetar todo esto. Eso es lo que hace `gcnn::GroupEquivariantCNN()`.

Paso 4: CNN equivalente al grupo

Se puede llamar a `GroupEquivariantCNN()` de la misma manera.

cnn <- GroupEquivariantCNN(
    group = CyclicGroup(order = 4),
    kernel_size = 5,
    in_channels = 1,
    out_channels = 1,
    num_hidden = 2, # número de convoluciones de grupo
    hidden_channels = 16 # número de canales por capa de convolución de grupo
)

img <- torch::torch_randn(c(4, 1, 32, 32))
cnn(img)$shape
(1) 4 1

A simple vista, esta `GroupEquivariantCNN` parece cualquier CNN convencional… excepto por el argumento `group`.

Ahora, al inspeccionar su salida, vemos que la dimensión adicional ha desaparecido. Esto se debe a que después de una secuencia de capas de convolución de grupo a grupo, el módulo proyecta una representación que, para cada elemento del lote, conserva solo los canales. Por lo tanto, promedia no solo las ubicaciones (como solemos hacer) sino también la dimensión del grupo. Una capa lineal final proporcionará la salida del clasificador solicitado (de dimensión `out_channels`).

Y así queda completada la arquitectura. Es hora de una prueba en un entorno real (más o menos).

¡Dígitos girados!

El objetivo es entrenar dos convnets, una CNN "normal" y una equivalente de grupo, en el conjunto de entrenamiento MNIST estándar. Luego, ambas se evalúan en un conjunto de prueba aumentado en el que cada imagen se gira aleatoriamente mediante una rotación continua en un rango de 0 a 360 grados. No se espera que `GroupEquivariantCNN` sea "perfecta" dado el grupo de simetría (C_4). Concretamente, con (C_4), la equivariancia se limita a solo cuatro posiciones. Sin embargo, se espera que funcione considerablemente mejor que la arquitectura estándar equivalente.

En primer lugar, se preparan los datos, en particular, el conjunto de prueba aumentado.


dir <- "/tmp/mnist"

train_ds <- torchvision::mnist_dataset(
  dir,
  download = TRUE,
  transform = torchvision::transform_to_tensor
)

test_ds <- torchvision::mnist_dataset(
  dir,
  train = FALSE,
  transform = function(x) 
    x |>
      torchvision::transform_to_tensor() |>
      torchvision::transform_random_rotation(
        degrees = c(0, 360),
        resample = 2,
        fill = 0
      )
  )

train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE)
test_dl <- dataloader(test_ds, batch_size = 128)
    

¿Cómo se visualiza?


test_images <- coro::collect(
  test_dl, 1
)((1))$x(1:32, 1, , ) |> as.array()

par(mfrow = c(4, 8), mar = rep(0, 4), mai = rep(0, 4))
test_images |>
  purrr::array_tree(1) |>
  purrr::map(as.raster) |>
  purrr::iwalk(~ {
    plot(.x)
  })
    

32 dígitos, rotados aleatoriamente.

En primer lugar, se define y entrena una CNN convencional. Es muy similar a `GroupEquivariantCNN()` en cuanto a arquitectura, pero con el doble de canales ocultos para mantener una capacidad general comparable.


default_cnn <- nn_module(
   "default_cnn",
   initialize = function(kernel_size, in_channels, out_channels, num_hidden, hidden_channels) {
     self$conv1 <- torch::nn_conv2d(in_channels, hidden_channels, kernel_size)self.$convs <- torch::nn_module_list()
     for (i in 1:num_hidden) {
       self$convs$append(torch::nn_conv2d(hidden_channels, hidden_channels, kernel_size))
     }
     self$avg_pool <- torch::nn_adaptive_avg_pool2d(1)
     self$final_linear <- torch::nn_linear(hidden_channels, out_channels)
   },
   forward = function(x) >
       torch::torch_squeeze() 
 )

fitted <- default_cnn |>
    luz::setup(
      loss = torch::nn_cross_entropy_loss(),
      optimizer = torch::optim_adam,
      metrics = list(
        luz::luz_metric_accuracy()
      )
    ) |>
    luz::set_hparams(
      kernel_size = 5,
      in_channels = 1,
      out_channels = 10,
      num_hidden = 4,
      hidden_channels = 32
    ) %>
    luz::set_opt_hparams(lr = 1e-2, weight_decay = 1e-4) |>
    luz::fit(train_dl, epochs = 10, valid_data = test_dl) 
Entrenamiento: Pérdida: 0.0498 - Precisión: 0.9843
Validación: Pérdida: 3.2445 - Precisión: 0.4479

Es previsible que la precisión en el conjunto de prueba no sea tan alta.

A continuación, se entrena la versión equivalente en grupo.

fitted <- GroupEquivariantCNN |>
  luz::setup(
    loss = torch::nn_cross_entropy_loss(),
    optimizer = torch::optim_adam,
    metrics = list(
      luz::luz_metric_accuracy()
    )
  ) |>
  luz::set_hparams(
    group = CyclicGroup(order = 4),
    kernel_size = 5,
    in_channels = 1,
    out_channels = 10,
    num_hidden = 4,
    hidden_channels = 16
  ) |>
```

En este texto reescrito se han conservado las etiquetas y la estructura HTML original, pero el contenido ha sido reformulado para facilitar su comprensión. 
)
Resultados del entrenamiento: Pérdida: 0.1102 - Precisión: 0.9667
Resultados de validación: Pérdida: 0.4969 - Precisión: 0.8549

En cuanto a la red neuronal convolucional equivalente en grupo, se observa que las métricas en los conjuntos de entrenamiento y prueba son bastante similares. ¡Este es un resultado alentador! Para concluir la exploración de hoy, recordemos un principio fundamental destacado en una publicación de un nivel superior.

Un desafío

Al analizar el conjunto de datos de prueba ampliado, en particular las imágenes de los dígitos visualizados, surge una observación interesante. En la fila dos, columna cuatro, se identifica un dígito que aparentemente debería ser un 9, pero es probable que en realidad sea un 6 invertido. (Este tipo de error, donde un 6 se confunde con un 9, es común y humano). Entonces, surge la interrogante: ¿podría esto ser un problema? ¿Quizás la red simplemente necesita aprender estas sutilezas, similar a un humano?

Desde mi perspectiva, todo depende del contexto: cuál es el propósito real y cómo se pretende utilizar la aplicación. En el caso de la clasificación de dígitos, no parece haber una justificación válida para que aparezca un dígito invertido; por lo tanto, la invarianza total a la rotación sería contraproducente. En resumen, llegamos al mismo principio nuclear que los proponentes del aprendizaje automático justo y equitativo nos recuerdan constantemente:

¡Siempre considera cómo se empleará la aplicación en la práctica!

No obstante, en nuestro caso, hay otro aspecto técnico a tener en cuenta. La función gcnn::GroupEquivariantCNN() es un contenedor simple, ya que todas sus capas comparten el mismo grupo de simetría. Teóricamente, esto no sería imprescindible. Con una mayor elaboración en la codificación, podríamos emplear diferentes grupos según la posición de cada capa en la jerarquía de detección de características.

Ahora, permíteme explicarte por qué seleccioné la imagen de la cabra. La cabra se observa a través de una valla de color rojo y blanco, que forma un patrón de cuadrados (o algún otro tipo de bordes, si se prefiere) ligeramente girado debido al ángulo de visión. En esta situación, un tipo de invarianza a la rotación como la codificada por el grupo (C_4) tiene mucho sentido. Sin embargo, preferiríamos que la cabra no esté mirando hacia arriba, como indico con la acción de (C_4). Por lo tanto, en una tarea de clasificación de imágenes de la vida real, sería más adecuado utilizar capas más flexibles en los estratos inferiores de la red y capas más especializadas en los niveles superiores.

sourceCode r

¿Nos apoyarás hoy?

Creemos que todos merecen entender el mundo en el que viven. Este conocimiento ayuda a crear mejores ciudadanos, vecinos, amigos y custodios de nuestro planeta. Producir periodismo explicativo y profundamente investigado requiere recursos. Puedes apoyar esta misión haciendo una donación económica a Gelipsis hoy. ¿Te sumarás a nosotros?

Suscríbete para recibir nuestro boletín:

Recent Articles

Related Stories

DEJA UN COMENTARIO

Por favor ingrese su comentario!
Por favor ingrese su nombre aquí