src.actions.embeddings module
Embeddings Actions Module
This module provides actions for loading, manipulating, and processing neural network embeddings within the SNN2 framework. It contains utilities for:
Loading embeddings from pickle files
Computing centroids from embedding tensors
Loading sample data for embedding analysis
Converting between different tensor formats
The module is designed to work with TensorFlow tensors and PickleHandler objects, providing essential functionality for embedding-based neural network operations and analysis workflows.
Functions
- load_embeddingsfunction
Load embeddings from pickle files and convert to TensorFlow tensors.
- compute_centroidsfunction
Compute centroid vectors from embedding tensors using mean reduction.
- load_samplesfunction
Load sample data from pickle files for embedding analysis.
Notes
All functions in this module use the @action decorator for consistent action tracking and logging within the SNN2 framework. The module handles TensorFlow tensor operations with proper type conversion and shape management.
Examples
Basic embedding loading and centroid computation:
>>> embeddings = load_embeddings(emb_path="path/to/embeddings.pkl", pkl=pickle_handler)
>>> centroids = compute_centroids(embeddings)
>>> samples = load_samples(samples_path="path/to/samples.pkl", pkl=pickle_handler)
See also
SNN2.src.io.pickleHandlerPickleHandler class for file operations
SNN2.src.decorators.decoratorsAction decorator for function tracking
- src.actions.embeddings.compute_centroids(emb_tf: Tensor) Tensor
Compute centroid vectors from embedding tensors using mean reduction.
This function calculates the centroid (mean) of embedding vectors along the first axis (typically the batch/sample dimension) and expands the result to maintain tensor rank consistency.
- Parameters:
emb_tf (tf.Tensor) – Input tensor containing embedding vectors with shape (n_samples, embedding_dim). Each row represents an individual embedding vector.
- Returns:
Centroid tensor with shape (1, embedding_dim) representing the mean of all input embedding vectors. The first dimension is expanded to maintain consistency with batch operations.
- Return type:
tf.Tensor
Notes
The centroid computation uses tf.reduce_mean along axis=0 to average across all samples while preserving the embedding dimension. The result is expanded using tf.expand_dims to add a batch dimension of size 1.
This is commonly used in clustering algorithms, prototype learning, and similarity computations where a representative vector is needed for a group of embeddings.
Examples
>>> embeddings = tf.random.normal([100, 64]) # 100 samples, 64-dim embeddings >>> centroid = compute_centroids(embeddings) >>> print(centroid.shape) TensorShape([1, 64])
>>> # Use centroid for similarity computation >>> similarities = tf.keras.utils.cosine_similarity(embeddings, centroid)
- src.actions.embeddings.load_embeddings(emb_path: str = None, pkl: PickleHandler = None, **kwargs) Tensor
Load embeddings from pickle files and convert to TensorFlow tensors.
This function loads embedding data from pickle files using the provided PickleHandler, converts the data to TensorFlow tensors with float32 dtype, and applies squeeze operation to remove unnecessary dimensions.
- Parameters:
emb_path (str, optional) – Path to the pickle file containing the embeddings data. Must be a valid file path accessible by the PickleHandler.
pkl (PickleHandler, optional) – PickleHandler instance used for loading the pickle file. Must be properly initialized with appropriate configuration.
**kwargs (dict) – Additional keyword arguments passed to the PickleHandler.load() method. Can include loading options, compression settings, etc.
- Returns:
TensorFlow tensor containing the loaded embeddings with float32 dtype. Unnecessary dimensions are removed using tf.squeeze().
- Return type:
tf.Tensor
- Raises:
ValueError – If emb_path is None or if pkl (PickleHandler) is None.
Notes
The function automatically converts loaded data to float32 dtype for compatibility with TensorFlow operations. The squeeze operation removes dimensions of size 1, which is useful for standardizing tensor shapes.
Examples
>>> pkl_handler = PickleHandler(io_handler, "appendix", logger) >>> embeddings = load_embeddings( ... emb_path="embeddings/model_embeddings.pkl", ... pkl=pkl_handler ... ) >>> print(embeddings.shape) TensorShape([1000, 128])
- src.actions.embeddings.load_samples(samples_path: str = None, pkl: PickleHandler = None, **kwargs) List[Tensor]
Load sample data from pickle files for embedding analysis.
This function loads sample data (typically embedding vectors or related tensors) from pickle files using the provided PickleHandler. The loaded data is returned as a list of TensorFlow tensors without additional processing.
- Parameters:
samples_path (str, optional) – Path to the pickle file containing the sample data. Must be a valid file path accessible by the PickleHandler.
pkl (PickleHandler, optional) – PickleHandler instance used for loading the pickle file. Must be properly initialized with appropriate configuration.
**kwargs (dict) – Additional keyword arguments passed to the PickleHandler.load() method. Can include loading options, compression settings, etc.
- Returns:
List containing the loaded sample data as TensorFlow tensors. The structure and content depend on what was originally saved in the pickle file.
- Return type:
List[tf.Tensor]
- Raises:
ValueError – If samples_path is None or if pkl (PickleHandler) is None.
Notes
Unlike load_embeddings, this function doesn’t perform automatic type conversion or shape manipulation. The loaded data is returned as-is, allowing for flexible handling of different sample data formats.
This function is commonly used for loading reference samples, test data, or pre-computed embeddings that need to be compared or analyzed.
Examples
>>> pkl_handler = PickleHandler(io_handler, "appendix", logger) >>> samples = load_samples( ... samples_path="samples/reference_samples.pkl", ... pkl=pkl_handler ... ) >>> print(f"Loaded {len(samples)} sample tensors") Loaded 50 sample tensors
>>> # Access individual samples >>> first_sample = samples[0] >>> print(f"First sample shape: {first_sample.shape}")