Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,18 @@ Usage: caption-flow export [OPTIONS]
Export caption data to various formats.

Options:
--format [jsonl|json|csv|txt|huggingface_hub|all] Export format (default: jsonl)
--format [jsonl|json|csv|txt|parquet|webshart|lance|huggingface_hub|all] Export format (default: jsonl)
```

* **jsonl**: create JSON line file in the specified `--output` path
* **csv**: exports CSV-compatible data columns to the `--output` path containing incomplete metadata
* **json**: creates a `.json` file for each sample inside the `--output` subdirectory containing **complete** metadata; useful for webdatasets
* **txt**: creates `.txt` file for each sample inside the `--output` subdirectory containing ONLY captions
* **webshart**: updates an **existing per-shard metadata `.json` file** by writing captions under the plural `captions` key. for this format, pass `--output` as the path to the existing shard metadata JSON file when exporting one shard. if you export multiple shards, pass `--output` as a directory containing one existing `{shard_name}.json` file per shard.
* **huggingface_hub**: creates a dataset on Hugging Face Hub, possibly `--private` and `--nsfw` where necessary
* **all**: creates all export formats in a specified `--output` directory
* **all**: creates the directory/file-generating export formats in a specified `--output` directory. prefer a directory here; `webshart` is a special case that expects existing per-shard metadata `.json` files rather than creating new metadata files.

> note: `--output` paths ending in `.json` are treated specially for `webshart`. use a directory for normal multi-format exports and an existing shard metadata JSON file only when intentionally updating a `webshart` shard.

---

Expand Down
12 changes: 11 additions & 1 deletion src/caption_flow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,7 +1498,17 @@ async def _run_export_process(
@click.option(
"--format",
type=click.Choice(
["jsonl", "json", "csv", "txt", "parquet", "lance", "huggingface_hub", "all"],
[
"jsonl",
"json",
"csv",
"txt",
"parquet",
"webshart",
"lance",
"huggingface_hub",
"all",
],
case_sensitive=False,
),
default="jsonl",
Expand Down
41 changes: 33 additions & 8 deletions src/caption_flow/processors/webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _create_units_background(self) -> None:
continue

shard_name = shard_info["name"]
shard_files = shard_info["num_files"]
shard_files = shard_info.get("num_samples", shard_info["num_files"])

# Check if we need to move to next shard
if current_file_idx >= shard_files:
Expand Down Expand Up @@ -681,17 +681,22 @@ def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict
# Use webshart to process unprocessed ranges
for start_idx, end_idx in unprocessed_ranges:
try:
# Jump to shard and starting position
if shard_idx is not None:
self.loader.shard(shard_idx=shard_idx, cursor_idx=start_idx)
else:
# Try to find shard by name
self.loader.shard(filename=shard_name, cursor_idx=start_idx)
use_sample_loader = shard_idx is not None and hasattr(self.loader, "load_sample")
if not use_sample_loader:
# Fallback for older webshart versions. Seek once per contiguous range,
# then advance with next_with_cache_wait for each item.
if shard_idx is not None:
self.loader.shard(shard_idx=shard_idx, cursor_idx=start_idx)
else:
self.loader.shard(filename=shard_name, cursor_idx=start_idx)

# Iterate through the range
for idx in range(start_idx, end_idx + 1):
try:
entry = webshart.next_with_cache_wait(self.loader)
if use_sample_loader:
entry = self.loader.load_sample(shard_idx, idx)
else:
entry = webshart.next_with_cache_wait(self.loader)

# Decode image
image = None
Expand Down Expand Up @@ -723,18 +728,38 @@ def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict
shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(idx)
)
job_id_str = job_id.get_sample_str()
entry_metadata = getattr(entry, "metadata", {}) or {}
if not isinstance(entry_metadata, dict):
entry_metadata = {}
filtered_entry_metadata = {
k: v
for k, v in entry_metadata.items()
if not k.startswith("_")
and k
not in {
"path",
"offset",
"size",
"width",
"height",
"aspect",
"json_path",
}
}

yield {
"image": image,
"image_data": entry.data,
"item_key": Path(entry.path).stem,
"item_index": idx,
"metadata": {
**filtered_entry_metadata,
"_item_index": idx,
"_chunk_relative_index": idx - unit.data["start_index"],
"_job_id": job_id_str,
"_filename": entry.path,
"_file_size": entry.size,
"_json_path": entry_metadata.get("json_path"),
"_processed_indices": processed_indices,
},
"job_id": job_id_str,
Expand Down
108 changes: 107 additions & 1 deletion src/caption_flow/storage/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def export_shard(
Args:
----
shard_name: Name of the shard to export
format: Export format ('jsonl', 'json', 'csv', 'parquet', 'txt')
format: Export format ('jsonl', 'json', 'csv', 'parquet', 'txt', 'webshart')
output_path: Output file or directory path
columns: Specific columns to export
limit: Maximum number of rows to export
Expand Down Expand Up @@ -80,6 +80,20 @@ async def export_shard(
)
else:
output_file = output_path / f"{shard_name}.{format}"
elif format == "webshart":
# webshart metadata must be shard-specific to avoid multiple shards
# reusing the same JSON file when called repeatedly.
if output_path.suffix.lower() == ".json" and not output_path.is_dir():
if output_path.stem == shard_name:
output_file = output_path
else:
raise ExportError(
"Invalid webshart output path "
f"'{output_path}': explicit JSON output files must be named "
f"'{shard_name}.json' for shard '{shard_name}'."
)
else:
output_file = output_path / f"{shard_name}.json"
else:
# Directory-based formats
output_file = output_path / shard_name
Expand All @@ -99,6 +113,12 @@ async def export_shard(
kwargs.get("filename_column", "filename"),
kwargs.get("export_column", "captions"),
)
elif format == "webshart":
return exporter.to_webshart_metadata(
output_file,
kwargs.get("filename_column", "filename"),
kwargs.get("export_column", "captions"),
)
else:
raise ValueError(f"Unsupported format: {format}")

Expand Down Expand Up @@ -629,3 +649,89 @@ def to_txt(

logger.info(f"Created {files_created} text files in: {output_dir}")
return files_created

def _normalize_webshart_captions(self, value: Any, export_column: str) -> List[str]:
"""Normalize a caption export value to webshart's plural captions list."""
if isinstance(value, str):
return [value] if value else []
if isinstance(value, (list, tuple)):
captions = []
for item in value:
if item is None:
continue
text = item if isinstance(item, str) else str(self._serialize_value(item))
if text:
captions.append(text)
return captions
raise ExportError(
f"Column '{export_column}' must contain a string or list of strings "
"for webshart metadata export"
)

def to_webshart_metadata(
self,
metadata_path: Union[str, Path],
filename_column: str = "filename",
export_column: str = "captions",
) -> int:
"""Store captions in an existing webshart metadata JSON file."""
if export_column not in self.contents.columns:
if export_column not in self.contents.output_fields:
raise ExportError(f"Column '{export_column}' not found in data")

captions_by_sample = {}
skipped_no_filename = 0
skipped_no_content = 0

for row in self.contents.rows:
filename = self._get_filename_from_row(row, filename_column) or row.get("item_key")
if not filename:
skipped_no_filename += 1
continue

content = row.get(export_column)
if content is None:
skipped_no_content += 1
continue

captions = self._normalize_webshart_captions(content, export_column)
if not captions:
skipped_no_content += 1
continue

captions_by_sample[str(filename)] = captions

if skipped_no_filename > 0:
logger.warning(f"Skipped {skipped_no_filename} rows with no extractable filename")
if skipped_no_content > 0:
logger.warning(f"Skipped {skipped_no_content} rows with no {export_column} content")

if not captions_by_sample:
return 0

metadata_path = Path(metadata_path)
if not metadata_path.exists():
raise ExportError(
f"Webshart metadata file does not exist: {metadata_path}. "
"Expected an existing metadata JSON file to update."
)
if not metadata_path.is_file():
raise ExportError(
f"Webshart metadata path is not a file: {metadata_path}. "
"Expected an existing metadata JSON file to update."
)

try:
import webshart
except ImportError as exc:
raise ExportError("webshart is required for webshart metadata export") from exc

if not hasattr(webshart, "write_captions_to_metadata"):
raise ExportError(
"Installed webshart does not support write_captions_to_metadata; "
"upgrade webshart to export captions into metadata listings."
)

updated = webshart.write_captions_to_metadata(metadata_path, captions_by_sample)
logger.info(f"Updated {updated} captions in webshart metadata: {metadata_path}")
return updated
Loading
Loading