Skip to content

Add WACNN and SymmetricalTransFormer (STF, CVPR 2022)#354

Draft
Yiozolm wants to merge 3 commits intoInterDigitalInc:masterfrom
Yiozolm:pr-stf-wacnn
Draft

Add WACNN and SymmetricalTransFormer (STF, CVPR 2022)#354
Yiozolm wants to merge 3 commits intoInterDigitalInc:masterfrom
Yiozolm:pr-stf-wacnn

Conversation

@Yiozolm
Copy link
Copy Markdown

@Yiozolm Yiozolm commented May 3, 2026

Adds WACNN and SymmetricalTransFormer (STF) from R. Zou, C. Song, Z. Zhang, "The Devil Is in the Details: Window-based Attention for Image Compression", CVPR 2022 (arXiv:2203.08450).

Adapted from the official implementation at https://github.com/Googolxx/STF (Apache-2.0).

This is the first installment of the per-model PR series proposed in #353. Pretrained weights are intentionally not bundled — calling pretrained=True raises a clear RuntimeError until weights are hosted on S3 (per the discussion in #353).

Summary

  • New zoo entries "stf" and "stf-wacnn" (compressai.models.SymmetricalTransFormer and compressai.models.WACNN).
  • New compressai.layers.attn subpackage with the Swin window-based attention building blocks the two models depend on.
  • New ChannelSliceLatentCodec + SliceEntropyCompressionModel base — designed to be reused by the channel-conditional models in follow-up PRs (CCA, TCM, …).
  • Checkpoint converter in examples/convert_stf_checkpoint.py that loads the published stf_<bpp>_best.pth.tar / cnn_<bpp>_best.pth.tar files from the upstream repo and writes them in compressai layout.
  • timm added to dependencies (DropPath / Mlp / trunc_normal_ are reused throughout the Swin building blocks).

Commits

Three commits, designed to be reviewed independently:

Commit Scope LOC
feat(layers): add Swin window-based attention building blocks compressai/layers/attn/{swin,swin_attention,inference,__init__}.py + tiny re-export in layers/__init__.py +721
feat(latent_codecs): add ChannelSliceLatentCodec + slice-entropy base compressai/latent_codecs/channel_slice.py + compressai/models/_bases/{slice_entropy,__init__}.py + re-export +543
feat(models): add WACNN and SymmetricalTransFormer (STF) from Zou et al. 2022 compressai/models/stf.py + zoo / converter / smoke tests + timm in pyproject.toml +1060
Total 16 files, +2324 lines, no modifications to existing logic

License & attribution

Both compressai/models/stf.py carries a dual-license header noting the upstream source URL and Apache-2.0 license alongside the standard InterDigital BSD 3-Clause Clear License for the modifications. The Swin building blocks in compressai/layers/attn/swin.py / swin_attention.py are slightly reworked (renamed parameters, dynamic-resolution pad_to_window_multiple, etc.) but ultimately derive from the same upstream Apache-2.0 source; happy to add per-file attribution headers there as well if maintainers prefer.

Verified

  • pytest tests/test_models.py tests/test_layers.py tests/test_init.py32 passed (3 new TestStf + 29 existing).
  • WACNN.from_state_dict(model.state_dict()) round-trip → x_hat diff = 0.0 (405 keys).
  • SymmetricalTransFormer.from_state_dict(model.state_dict()) round-trip → x_hat diff = 0.0 (315 keys).
  • convert_upstream_stf_state_dict correctly re-roots module.cc_* / module.gaussian_conditional keys under latent_codec.* so the published Googolxx/STF checkpoints load via from_state_dict.

Test plan

  • Forward + state-dict round-trip for both backbones at small config (already in TestStf).
  • Smoke-test examples/convert_stf_checkpoint.py against an upstream cnn_<bpp>_best.pth.tar checkpoint locally (x_hat diff = 0 between original and converted state dict in eval mode).
  • Maintainers: confirm timm being moved into hard dependencies is acceptable (alternative: keep [stf] extras group).
  • Maintainers: if you want the Swin layer files to carry their own attribution headers (in addition to models/stf.py), I will add them.

Notes for follow-up PRs (per #353)

The next PR will add CCA + TCM together — both reuse ChannelSliceLatentCodec from this PR, and CCA contributes a CausalContextAdjustmentEntropyModel that TCM can opt into. After that, the remaining license-clear models (InvCompress, MLIC++, HPCM, SAAF, DCAE, GLIC, TIC, TinyLIC, ShiftLIC) follow one or two at a time, each PR layering on top of what's already merged.

Yiozolm added 3 commits May 2, 2026 19:13
Add a new compressai.layers.attn subpackage containing the Swin
attention primitives needed by transformer-based learned image
compression models. This is the shared infrastructure that the
STF / WACNN models in this PR (issue InterDigitalInc#353) build on.

New module layout:
  compressai/layers/attn/swin_attention.py  WindowAttention,
                                            window_partition / reverse,
                                            pad_to_window_multiple,
                                            build_window_attention_mask
  compressai/layers/attn/swin.py            WMSA, SwinBlock, SWAtten,
                                            ConvTransBlock,
                                            WinNoShiftAttention,
                                            PatchMerging, PatchSplit
  compressai/layers/attn/inference.py       infer_swatten_ helpers used
                                            by from_state_dict to recover
                                            window / head / dim from a
                                            checkpoint
  compressai/layers/attn/__init__.py        single re-export surface

The root compressai/layers/__init__.py keeps its existing wildcard
imports unchanged, only appending the new attn re-export so callers
can continue to do "from compressai.layers import WindowAttention".

These primitives are adapted from the official STF repository
(https://github.com/Googolxx/STF, Apache-2.0). Full attribution
headers land alongside the model file in the next commit.
Introduce the channel-conditional slice-entropy machinery that the STF
and WACNN models in this PR rely on, so the codec is reusable by other
slice-conditional models (CCA, TCM, ...) added in later PRs.

New surface:

* compressai/latent_codecs/channel_slice.py
  ChannelSliceLatentCodec implements the equal-sized channel slicing
  pattern from Minnen2020 / He2022: every slice is encoded as a single
  RANS string with means / scales transforms cc_mean_transforms and
  cc_scale_transforms, an internal LRP head, and an optional per-slice
  mean / scale support transform. Sibling of the existing
  ChannelGroupsLatentCodec (uneven groups + delegated inner codec).

* compressai/models/_bases/slice_entropy.py
  SliceEntropyCompressionModel collects the recurring "build entropy
  bottleneck for z plus a ChannelSliceLatentCodec for y" logic that
  every slice-based model has to write today, plus its
  from_state_dict-side helpers (infer_num_slices,
  infer_max_support_slices, slice / lrp support channel arithmetic, and
  a make_entropy_transform factory). Subclasses populate g_a / g_s /
  h_a / h_mean_s / h_scale_s, then call self._init_slice_entropy(...).

* compressai/models/_bases/__init__.py
  Re-export surface for the new base + helpers.

* compressai/latent_codecs/__init__.py
  Add ChannelSliceLatentCodec to the public exports.
…al. 2022

Add the WACNN (CNN backbone) and SymmetricalTransFormer (transformer
backbone) models from R. Zou, C. Song, Z. Zhang, "The Devil Is in the
Details: Window-based Attention for Image Compression", CVPR 2022
(https://arxiv.org/abs/2203.08450). Adapted from the official
implementation at https://github.com/Googolxx/STF (Apache-2.0).

What is included:

* compressai/models/stf.py - WACNN, SymmetricalTransFormer, and a
  convert_upstream_stf_state_dict helper that strips the DataParallel
  module. prefix and re-roots cc_mean_transforms /
  cc_scale_transforms / lrp_transforms / gaussian_conditional under
  latent_codec.* so released checkpoints from the upstream repo load
  via WACNN.from_state_dict / SymmetricalTransFormer.from_state_dict.

* compressai/models/__init__.py - export the two new model classes.

* compressai/zoo/image.py and compressai/zoo/__init__.py - register
  "stf" and "stf-wacnn" in image_models with thin pretrained=False
  factory functions; calling pretrained=True raises until weights are
  hosted on S3 by the maintainers (per issue InterDigitalInc#353 discussion).

* examples/convert_stf_checkpoint.py - CLI wrapper around the upstream
  state-dict conversion + an optional smoke test on a synthetic image.

* tests/test_models.py - TestStf class: forward / state_dict round-trip
  for both backbones plus a unit test for the upstream-key conversion
  helper.

* pyproject.toml - add timm to the runtime dependencies (used by the
  Swin building blocks committed in feat(layers): ... earlier in this
  PR for DropPath and Mlp).

Pretrained weights are intentionally not bundled. Smoke testing on a
synthetic 64x64 / 128x128 input passes; state-dict round-trip diff
is 0.0 for both WACNN (405 keys) and STF (315 keys).
@Yiozolm Yiozolm changed the title Pr stf wacnn Add WACNN and SymmetricalTransFormer (STF, CVPR 2022) May 3, 2026
@Yiozolm Yiozolm marked this pull request as draft May 3, 2026 15:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant