Add WACNN and SymmetricalTransFormer (STF, CVPR 2022)#354
Draft
Yiozolm wants to merge 3 commits intoInterDigitalInc:masterfrom
Draft
Add WACNN and SymmetricalTransFormer (STF, CVPR 2022)#354Yiozolm wants to merge 3 commits intoInterDigitalInc:masterfrom
Yiozolm wants to merge 3 commits intoInterDigitalInc:masterfrom
Conversation
Add a new compressai.layers.attn subpackage containing the Swin attention primitives needed by transformer-based learned image compression models. This is the shared infrastructure that the STF / WACNN models in this PR (issue InterDigitalInc#353) build on. New module layout: compressai/layers/attn/swin_attention.py WindowAttention, window_partition / reverse, pad_to_window_multiple, build_window_attention_mask compressai/layers/attn/swin.py WMSA, SwinBlock, SWAtten, ConvTransBlock, WinNoShiftAttention, PatchMerging, PatchSplit compressai/layers/attn/inference.py infer_swatten_ helpers used by from_state_dict to recover window / head / dim from a checkpoint compressai/layers/attn/__init__.py single re-export surface The root compressai/layers/__init__.py keeps its existing wildcard imports unchanged, only appending the new attn re-export so callers can continue to do "from compressai.layers import WindowAttention". These primitives are adapted from the official STF repository (https://github.com/Googolxx/STF, Apache-2.0). Full attribution headers land alongside the model file in the next commit.
Introduce the channel-conditional slice-entropy machinery that the STF and WACNN models in this PR rely on, so the codec is reusable by other slice-conditional models (CCA, TCM, ...) added in later PRs. New surface: * compressai/latent_codecs/channel_slice.py ChannelSliceLatentCodec implements the equal-sized channel slicing pattern from Minnen2020 / He2022: every slice is encoded as a single RANS string with means / scales transforms cc_mean_transforms and cc_scale_transforms, an internal LRP head, and an optional per-slice mean / scale support transform. Sibling of the existing ChannelGroupsLatentCodec (uneven groups + delegated inner codec). * compressai/models/_bases/slice_entropy.py SliceEntropyCompressionModel collects the recurring "build entropy bottleneck for z plus a ChannelSliceLatentCodec for y" logic that every slice-based model has to write today, plus its from_state_dict-side helpers (infer_num_slices, infer_max_support_slices, slice / lrp support channel arithmetic, and a make_entropy_transform factory). Subclasses populate g_a / g_s / h_a / h_mean_s / h_scale_s, then call self._init_slice_entropy(...). * compressai/models/_bases/__init__.py Re-export surface for the new base + helpers. * compressai/latent_codecs/__init__.py Add ChannelSliceLatentCodec to the public exports.
…al. 2022 Add the WACNN (CNN backbone) and SymmetricalTransFormer (transformer backbone) models from R. Zou, C. Song, Z. Zhang, "The Devil Is in the Details: Window-based Attention for Image Compression", CVPR 2022 (https://arxiv.org/abs/2203.08450). Adapted from the official implementation at https://github.com/Googolxx/STF (Apache-2.0). What is included: * compressai/models/stf.py - WACNN, SymmetricalTransFormer, and a convert_upstream_stf_state_dict helper that strips the DataParallel module. prefix and re-roots cc_mean_transforms / cc_scale_transforms / lrp_transforms / gaussian_conditional under latent_codec.* so released checkpoints from the upstream repo load via WACNN.from_state_dict / SymmetricalTransFormer.from_state_dict. * compressai/models/__init__.py - export the two new model classes. * compressai/zoo/image.py and compressai/zoo/__init__.py - register "stf" and "stf-wacnn" in image_models with thin pretrained=False factory functions; calling pretrained=True raises until weights are hosted on S3 by the maintainers (per issue InterDigitalInc#353 discussion). * examples/convert_stf_checkpoint.py - CLI wrapper around the upstream state-dict conversion + an optional smoke test on a synthetic image. * tests/test_models.py - TestStf class: forward / state_dict round-trip for both backbones plus a unit test for the upstream-key conversion helper. * pyproject.toml - add timm to the runtime dependencies (used by the Swin building blocks committed in feat(layers): ... earlier in this PR for DropPath and Mlp). Pretrained weights are intentionally not bundled. Smoke testing on a synthetic 64x64 / 128x128 input passes; state-dict round-trip diff is 0.0 for both WACNN (405 keys) and STF (315 keys).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds WACNN and SymmetricalTransFormer (STF) from R. Zou, C. Song, Z. Zhang, "The Devil Is in the Details: Window-based Attention for Image Compression", CVPR 2022 (arXiv:2203.08450).
Adapted from the official implementation at https://github.com/Googolxx/STF (Apache-2.0).
This is the first installment of the per-model PR series proposed in #353. Pretrained weights are intentionally not bundled — calling
pretrained=Trueraises a clearRuntimeErroruntil weights are hosted on S3 (per the discussion in #353).Summary
"stf"and"stf-wacnn"(compressai.models.SymmetricalTransFormerandcompressai.models.WACNN).compressai.layers.attnsubpackage with the Swin window-based attention building blocks the two models depend on.ChannelSliceLatentCodec+SliceEntropyCompressionModelbase — designed to be reused by the channel-conditional models in follow-up PRs (CCA, TCM, …).examples/convert_stf_checkpoint.pythat loads the publishedstf_<bpp>_best.pth.tar/cnn_<bpp>_best.pth.tarfiles from the upstream repo and writes them in compressai layout.timmadded todependencies(DropPath/Mlp/trunc_normal_are reused throughout the Swin building blocks).Commits
Three commits, designed to be reviewed independently:
feat(layers): add Swin window-based attention building blockscompressai/layers/attn/{swin,swin_attention,inference,__init__}.py+ tiny re-export inlayers/__init__.pyfeat(latent_codecs): add ChannelSliceLatentCodec + slice-entropy basecompressai/latent_codecs/channel_slice.py+compressai/models/_bases/{slice_entropy,__init__}.py+ re-exportfeat(models): add WACNN and SymmetricalTransFormer (STF) from Zou et al. 2022compressai/models/stf.py+ zoo / converter / smoke tests +timminpyproject.tomlLicense & attribution
Both
compressai/models/stf.pycarries a dual-license header noting the upstream source URL and Apache-2.0 license alongside the standard InterDigital BSD 3-Clause Clear License for the modifications. The Swin building blocks incompressai/layers/attn/swin.py/swin_attention.pyare slightly reworked (renamed parameters, dynamic-resolutionpad_to_window_multiple, etc.) but ultimately derive from the same upstream Apache-2.0 source; happy to add per-file attribution headers there as well if maintainers prefer.Verified
pytest tests/test_models.py tests/test_layers.py tests/test_init.py→ 32 passed (3 newTestStf+ 29 existing).WACNN.from_state_dict(model.state_dict())round-trip →x_hatdiff = 0.0 (405 keys).SymmetricalTransFormer.from_state_dict(model.state_dict())round-trip →x_hatdiff = 0.0 (315 keys).convert_upstream_stf_state_dictcorrectly re-rootsmodule.cc_*/module.gaussian_conditionalkeys underlatent_codec.*so the publishedGoogolxx/STFcheckpoints load viafrom_state_dict.Test plan
TestStf).examples/convert_stf_checkpoint.pyagainst an upstreamcnn_<bpp>_best.pth.tarcheckpoint locally (x_hatdiff = 0 between original and converted state dict in eval mode).dependenciesis acceptable (alternative: keep[stf]extras group).models/stf.py), I will add them.Notes for follow-up PRs (per #353)
The next PR will add CCA + TCM together — both reuse
ChannelSliceLatentCodecfrom this PR, and CCA contributes aCausalContextAdjustmentEntropyModelthat TCM can opt into. After that, the remaining license-clear models (InvCompress,MLIC++,HPCM,SAAF,DCAE,GLIC,TIC,TinyLIC,ShiftLIC) follow one or two at a time, each PR layering on top of what's already merged.