In [1]:
Copied!
import numpy as np
import numpy as np
In [2]:
Copied!
np.random.seed(8947501)
data = np.random.uniform(low=1, high=10_000, size=(10_000, 64, 64))
np.random.seed(8947501)
data = np.random.uniform(low=1, high=10_000, size=(10_000, 64, 64))
In [3]:
Copied!
mean_true = np.mean(data, axis=0)
print(mean_true)
mean_true = np.mean(data, axis=0)
print(mean_true)
[[4979.06187109 5042.03419628 5007.30759842 ... 5027.53444031 5027.39917602 4988.98437983] [5004.40211044 4970.12877011 4974.80460791 ... 4987.42125333 5020.60978713 4974.83743075] [5015.70418641 4964.27107736 4989.80859755 ... 4977.72149064 5010.71360439 5037.07001349] ... [4959.73941289 4994.51878591 4988.9707429 ... 5038.30511538 5065.64332467 4979.81483695] [5013.08498048 4979.70225719 4998.08322856 ... 5004.07146257 5059.87903943 5005.13015452] [5027.95742434 5004.89616586 5021.97812961 ... 5015.61308805 5004.63068546 4987.04909058]]
In [4]:
Copied!
from raygent.task import Task
# -----------------------------------------------------------------------------
# Step 1: Define a Task that computes partial means.
# -----------------------------------------------------------------------------
class MeanTask(Task):
"""
A task that computes the element-wise partial mean of a batch of 2D NumPy arrays.
This task uses the batch processing method to compute the mean for all items in
the input list and returns a tuple containing:
(partial_mean, count)
where partial_mean is the element-wise mean computed over the batch, and count is
the number of observations in the batch.
"""
def process_items(self, items, **kwargs):
# Convert the list of 2D arrays into a single 3D NumPy array.
arr = np.array(items, dtype=np.float64)
# Compute the element-wise mean over the first axis (i.e. across all observations).
partial_mean = np.mean(arr, axis=0)
# The count is the number of observations processed in this batch.
count = arr.shape[0]
return (partial_mean, count)
from raygent.task import Task
# -----------------------------------------------------------------------------
# Step 1: Define a Task that computes partial means.
# -----------------------------------------------------------------------------
class MeanTask(Task):
"""
A task that computes the element-wise partial mean of a batch of 2D NumPy arrays.
This task uses the batch processing method to compute the mean for all items in
the input list and returns a tuple containing:
(partial_mean, count)
where partial_mean is the element-wise mean computed over the batch, and count is
the number of observations in the batch.
"""
def process_items(self, items, **kwargs):
# Convert the list of 2D arrays into a single 3D NumPy array.
arr = np.array(items, dtype=np.float64)
# Compute the element-wise mean over the first axis (i.e. across all observations).
partial_mean = np.mean(arr, axis=0)
# The count is the number of observations processed in this batch.
count = arr.shape[0]
return (partial_mean, count)
In [5]:
Copied!
from raygent.manager import TaskManager
from raygent.results import OnlineMeanResultHandler
result_handler = OnlineMeanResultHandler()
manager = TaskManager(
task_class=MeanTask,
result_handler=result_handler,
use_ray=True,
n_cores=8,
n_cores_worker=1,
)
manager.submit_tasks(data, chunk_size=50, at_once=True)
stats = manager.get_results()
mean_parallel = stats["mean"]
print("Element-wise Mean:")
print(stats["mean"])
from raygent.manager import TaskManager
from raygent.results import OnlineMeanResultHandler
result_handler = OnlineMeanResultHandler()
manager = TaskManager(
task_class=MeanTask,
result_handler=result_handler,
use_ray=True,
n_cores=8,
n_cores_worker=1,
)
manager.submit_tasks(data, chunk_size=50, at_once=True)
stats = manager.get_results()
mean_parallel = stats["mean"]
print("Element-wise Mean:")
print(stats["mean"])
2025-03-08 22:40:46,298 INFO worker.py:1841 -- Started a local Ray instance.
Element-wise Mean: [[4979.06187109 5042.03419628 5007.30759842 ... 5027.53444031 5027.39917602 4988.98437983] [5004.40211044 4970.12877011 4974.80460791 ... 4987.42125333 5020.60978713 4974.83743075] [5015.70418641 4964.27107736 4989.80859755 ... 4977.72149064 5010.71360439 5037.07001349] ... [4959.73941289 4994.51878591 4988.9707429 ... 5038.30511538 5065.64332467 4979.81483695] [5013.08498048 4979.70225719 4998.08322856 ... 5004.07146257 5059.87903943 5005.13015452] [5027.95742434 5004.89616586 5021.97812961 ... 5015.61308805 5004.63068546 4987.04909058]]
In [7]:
Copied!
print("Parallel versus NumPy error")
print(np.mean(mean_parallel - mean_true))
print("Parallel versus NumPy error")
print(np.mean(mean_parallel - mean_true))
Parallel versus NumPy error 5.773159728050814e-13