class ImageProcessor:
    def __init__(self, flatmasks):
        self.flatmasks = flatmasks
        
    def update_adjacency_value(self, adjacency_matrix, original, neighbor):
            border = False

            if original != 0 and original != neighbor:
                border = True
                if neighbor != 0:
                    adjacency_matrix[int(original - 1), int(neighbor - 1)] += 1
            return border

    def update_adjacency_matrix(self, plane_mask_flattened, width, height, adjacency_matrix, index):
            mod_value_width = index % width
            origin_mask = plane_mask_flattened[index]
            left, right, up, down = False, False, False, False

            if (mod_value_width != 0):
                left = self.update_adjacency_value(adjacency_matrix, origin_mask, plane_mask_flattened[index-1])
            if (mod_value_width != width - 1):
                right = self.update_adjacency_value(adjacency_matrix, origin_mask, plane_mask_flattened[index+1])
            if (index >= width):
                up = self.update_adjacency_value(adjacency_matrix, origin_mask, plane_mask_flattened[index-width])
            if (index <= len(plane_mask_flattened) - 1 - width):
                down = self.update_adjacency_value(adjacency_matrix, origin_mask, plane_mask_flattened[index+width])
            
            if (left or right or up or down):
                adjacency_matrix[int(origin_mask - 1), int(origin_mask-1)] += 1

    def compute_channel_means_sums_compensated(self, image):
            height, width, n_channels = image.shape
            mask_height, mask_width = self.flatmasks.shape
            n_masks = len(np.unique(self.flatmasks)) - 1
            channel_sums = np.zeros((n_masks, n_channels))
            channel_counts = np.zeros((n_masks, n_channels))
            if n_masks == 0:
                return channel_sums, channel_sums, channel_counts

            squashed_image = np.reshape(image, (height*width, n_channels))
            
            #masklocs = np.nonzero(self.flatmasks)
            #plane_mask = np.zeros((mask_height, mask_width), dtype = np.uint32)
            #plane_mask[masklocs[0], masklocs[1]] = masklocs[2] + 1
            #plane_mask = plane_mask.flatten()
            plane_mask = self.flatmasks.flatten()
            
            adjacency_matrix = np.zeros((n_masks, n_masks))
            for i in range(len(plane_mask)):
                self.update_adjacency_matrix(plane_mask, mask_width, mask_height, adjacency_matrix, i)
                
                mask_val = plane_mask[i] - 1
                if mask_val != -1:
                    channel_sums[mask_val.astype(np.int32)] += squashed_image[i]
                    channel_counts[mask_val.astype(np.int32)] += 1
            
            
            # Normalize adjacency matrix
            for i in range(n_masks):
                adjacency_matrix[i] = adjacency_matrix[i] / (max(adjacency_matrix[i, i], 1) * 2)
                adjacency_matrix[i, i] = 1
            
            means = np.true_divide(channel_sums, channel_counts, out=np.zeros_like(channel_sums, dtype='float'), where=channel_counts!=0)
            results = np.linalg.lstsq(adjacency_matrix, means, rcond=None)
            compensated_means = np.maximum(results[0], np.zeros((1,1)))        

            return compensated_means, means, channel_counts[:,0]
import torch
class ImageProcessor:
    def __init__(self, flatmasks):
        self.flatmasks = flatmasks
        
    def update_adjacency_value(self, adjacency_matrix, original, neighbor):
            border = False

            if original != 0 and original != neighbor:
                border = True
                if neighbor != 0:
                    adjacency_matrix[int(original - 1), int(neighbor - 1)] += 1
            return border

    def update_adjacency_matrix(self, plane_mask_flattened, width, height, adjacency_matrix, index):
            mod_value_width = index % width
            origin_mask = plane_mask_flattened[index]
            left, right, up, down = False, False, False, False

            if (mod_value_width != 0):
                left = self.update_adjacency_value(adjacency_matrix, origin_mask, plane_mask_flattened[index-1])
            if (mod_value_width != width - 1):
                right = self.update_adjacency_value(adjacency_matrix, origin_mask, plane_mask_flattened[index+1])
            if (index >= width):
                up = self.update_adjacency_value(adjacency_matrix, origin_mask, plane_mask_flattened[index-width])
            if (index <= len(plane_mask_flattened) - 1 - width):
                down = self.update_adjacency_value(adjacency_matrix, origin_mask, plane_mask_flattened[index+width])
            
            if (left or right or up or down):
                adjacency_matrix[int(origin_mask - 1), int(origin_mask-1)] += 1

    def compute_channel_means_sums_compensated(self, image):
            height, width, n_channels = image.shape
            mask_height, mask_width = self.flatmasks.shape
            n_masks = len(np.unique(self.flatmasks)) - 1
            channel_sums = np.zeros((n_masks, n_channels))
            channel_counts = np.zeros((n_masks, n_channels))
            if n_masks == 0:
                return channel_sums, channel_sums, channel_counts

            squashed_image = np.reshape(image, (height*width, n_channels))
            
            #masklocs = np.nonzero(self.flatmasks)
            #plane_mask = np.zeros((mask_height, mask_width), dtype = np.uint32)
            #plane_mask[masklocs[0], masklocs[1]] = masklocs[2] + 1
            #plane_mask = plane_mask.flatten()
            plane_mask = self.flatmasks.flatten()
            
            adjacency_matrix = np.zeros((n_masks, n_masks))
            for i in range(len(plane_mask)):
                self.update_adjacency_matrix(plane_mask, mask_width, mask_height, adjacency_matrix, i)
                
                mask_val = plane_mask[i] - 1
                if mask_val != -1:
                    channel_sums[mask_val.astype(np.int32)] += squashed_image[i]
                    channel_counts[mask_val.astype(np.int32)] += 1
            
            
            # Normalize adjacency matrix
            for i in range(n_masks):
                adjacency_matrix[i] = adjacency_matrix[i] / (max(adjacency_matrix[i, i], 1) * 2)
                adjacency_matrix[i, i] = 1
            
            means = np.true_divide(channel_sums, channel_counts, out=np.zeros_like(channel_sums, dtype='float'), where=channel_counts!=0)
            # Convert your numpy arrays to PyTorch tensors
            adjacency_matrix_torch = torch.from_numpy(adjacency_matrix)
            means_torch = torch.from_numpy(means)

            # Solve the least squares problem
            results_torch = torch.linalg.lstsq(adjacency_matrix_torch, means_torch).solution

            # Convert the result back to a numpy array if needed
            # Convert the result back to a numpy array if needed
            results = results_torch.numpy()
            compensated_means = np.maximum(results, np.zeros(results.shape))              

            return compensated_means, means, channel_counts[:,0]
# load pickle file
import pickle 

with open("/home/timkempchen/Downloads/seg_output_tonsil2.pickle", 'rb') as f:
    seg_output = pickle.load(f)
seg_output.keys()
dict_keys(['img', 'masks', 'image_dict'])
# get keys of image_dict
channelnames = list(seg_output['image_dict'].keys())

images = seg_output['image_dict']
masks = seg_output['masks']
masks = masks.squeeze()
import numpy as np

# Assuming `images` is your dictionary of 2D numpy arrays
# and `masks` is your 2D numpy array of masks

# Create a list of the 2D numpy arrays in the dictionary
image_list = [images[channel_name] for channel_name in images.keys()]
# Stack the 2D numpy arrays along the third dimension to create a 3D numpy array
image = np.stack(image_list, axis=-1)

# Now you can use `image` as the input for the function
processor = ImageProcessor(masks)
compensated_means, means, channel_counts = processor.compute_channel_means_sums_compensated(image)
# open df
import pandas as pd

df = pd.read_csv("/home/timkempchen/Downloads/tonsil2_mesmer_result.csv")
df
Unnamed: 0 DAPI FoxP3 HLA-DR CD103 CHGA EGFR CD206 GFAP PD-1 ... GATA3 x y eccentricity perimeter convex_area area axis_major_length axis_minor_length label
0 1 105.993197 1.340136 0.557823 16.442177 8.278912 6.183673 3.306122 1.068027 13.020408 ... 11.074830 4.986395 1472.238095 0.603485 44.142136 154.0 147.0 15.439633 12.311169 1
1 2 123.677686 0.619835 0.830579 17.223140 17.194215 5.975207 3.623967 0.942149 15.983471 ... 11.909091 5.359504 1322.851240 0.853893 63.248737 267.0 242.0 24.688741 12.849201 2
2 3 107.203125 1.281250 0.671875 17.925781 9.699219 6.589844 3.566406 1.023438 13.890625 ... 11.816406 5.710938 1506.226562 0.766017 61.798990 268.0 256.0 22.855322 14.691870 3
3 4 49.660959 0.136986 0.006849 39.623288 25.102740 2.797945 0.989726 0.801370 8.965753 ... 6.722603 8.544521 641.938356 0.645764 63.112698 306.0 292.0 22.077563 16.857044 4
4 5 148.702532 1.310127 1.563291 18.183544 33.227848 7.981013 5.082278 0.943038 21.196203 ... 16.829114 9.006329 1303.702532 0.766800 48.420310 172.0 158.0 17.888621 11.482448 5
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
23326 23327 50.512346 0.956790 1.401235 23.746914 18.395062 24.481481 2.932099 0.901235 12.617284 ... 6.654321 2523.888889 1384.086420 0.761539 48.627417 176.0 162.0 17.973486 11.648970 23327
23327 23328 160.015686 1.176471 3.137255 19.737255 34.219608 11.309804 4.721569 0.976471 19.717647 ... 21.035294 2522.047059 1438.568627 0.722578 62.077164 274.0 255.0 21.866422 15.116015 23328
23328 23329 56.734177 0.759494 0.860759 16.835443 17.898734 16.405063 2.063291 0.924051 10.569620 ... 6.506329 2523.924051 1349.759494 0.707599 32.970563 86.0 79.0 12.134167 8.574169 23329
23329 23330 86.326531 0.806122 1.316327 18.612245 19.306122 16.571429 2.561224 1.010204 13.540816 ... 9.806122 2525.234694 1420.387755 0.760933 35.556349 103.0 98.0 13.929985 9.038195 23330
23330 23331 11.522727 0.147727 0.000000 1.170455 1.238636 1.397727 0.136364 0.886364 4.488636 ... 1.659091 2535.886364 231.329545 0.472365 31.556349 89.0 88.0 11.256538 9.921544 23331

23331 rows × 69 columns

compensated_means
array([[73.4003194 ,  0.96314494,  0.        , ..., 17.3796773 ,
         6.1359789 ,  4.10462704],
       [62.18477628,  0.        ,  0.        , ...,  1.98561595,
         3.43656937,  2.04505872],
       [95.65718851,  1.13134308,  0.        , ..., 22.91852613,
        10.15428826,  9.49503184],
       ...,
       [23.74777761,  0.40885537,  1.12656697, ...,  2.91814228,
         2.94357561, 30.74130252],
       [41.32538188,  0.15023363,  0.5936194 , ...,  2.48362393,
         2.98513237, 47.74424314],
       [11.52272727,  0.14772727,  0.        , ...,  2.59090909,
         1.65909091,  1.88636364]])
# Get the keys
keys = list(images.keys())
len(keys)
60
df
Unnamed: 0 DAPI FoxP3 HLA-DR CD103 CHGA EGFR CD206 GFAP PD-1 ... GATA3 x y eccentricity perimeter convex_area area axis_major_length axis_minor_length label
0 1 105.993197 1.340136 0.557823 16.442177 8.278912 6.183673 3.306122 1.068027 13.020408 ... 11.074830 4.986395 1472.238095 0.603485 44.142136 154.0 147.0 15.439633 12.311169 1
1 2 123.677686 0.619835 0.830579 17.223140 17.194215 5.975207 3.623967 0.942149 15.983471 ... 11.909091 5.359504 1322.851240 0.853893 63.248737 267.0 242.0 24.688741 12.849201 2
2 3 107.203125 1.281250 0.671875 17.925781 9.699219 6.589844 3.566406 1.023438 13.890625 ... 11.816406 5.710938 1506.226562 0.766017 61.798990 268.0 256.0 22.855322 14.691870 3
3 4 49.660959 0.136986 0.006849 39.623288 25.102740 2.797945 0.989726 0.801370 8.965753 ... 6.722603 8.544521 641.938356 0.645764 63.112698 306.0 292.0 22.077563 16.857044 4
4 5 148.702532 1.310127 1.563291 18.183544 33.227848 7.981013 5.082278 0.943038 21.196203 ... 16.829114 9.006329 1303.702532 0.766800 48.420310 172.0 158.0 17.888621 11.482448 5
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
23326 23327 50.512346 0.956790 1.401235 23.746914 18.395062 24.481481 2.932099 0.901235 12.617284 ... 6.654321 2523.888889 1384.086420 0.761539 48.627417 176.0 162.0 17.973486 11.648970 23327
23327 23328 160.015686 1.176471 3.137255 19.737255 34.219608 11.309804 4.721569 0.976471 19.717647 ... 21.035294 2522.047059 1438.568627 0.722578 62.077164 274.0 255.0 21.866422 15.116015 23328
23328 23329 56.734177 0.759494 0.860759 16.835443 17.898734 16.405063 2.063291 0.924051 10.569620 ... 6.506329 2523.924051 1349.759494 0.707599 32.970563 86.0 79.0 12.134167 8.574169 23329
23329 23330 86.326531 0.806122 1.316327 18.612245 19.306122 16.571429 2.561224 1.010204 13.540816 ... 9.806122 2525.234694 1420.387755 0.760933 35.556349 103.0 98.0 13.929985 9.038195 23330
23330 23331 11.522727 0.147727 0.000000 1.170455 1.238636 1.397727 0.136364 0.886364 4.488636 ... 1.659091 2535.886364 231.329545 0.472365 31.556349 89.0 88.0 11.256538 9.921544 23331

23331 rows × 69 columns

# Get the keys
keys = list(images.keys())

# Cycle over the keys
for i in range(len(keys)):
    # Add the compensated_means to the DataFrame with column names from keys
    df[keys[i]] = compensated_means[:, i]
df
Unnamed: 0 DAPI FoxP3 HLA-DR CD103 CHGA EGFR CD206 GFAP PD-1 ... x y eccentricity perimeter convex_area area axis_major_length axis_minor_length label segmentation_channel
0 1 73.400319 0.963145 0.000000 10.701006 0.000000 4.174474 1.792112 0.841941 8.013893 ... 4.986395 1472.238095 0.603485 44.142136 154.0 147.0 15.439633 12.311169 1 4.104627
1 2 62.184776 0.000000 0.000000 8.109533 1.652147 2.626056 1.173389 0.613547 7.322614 ... 5.359504 1322.851240 0.853893 63.248737 267.0 242.0 24.688741 12.849201 2 2.045059
2 3 95.657189 1.131343 0.000000 15.147517 5.491188 5.509342 3.137370 0.875055 11.599501 ... 5.710938 1506.226562 0.766017 61.798990 268.0 256.0 22.855322 14.691870 3 9.495032
3 4 49.660959 0.136986 0.006849 39.623288 25.102740 2.797945 0.989726 0.801370 8.965753 ... 8.544521 641.938356 0.645764 63.112698 306.0 292.0 22.077563 16.857044 4 17.931507
4 5 114.610899 1.152759 0.888520 13.454905 26.300720 5.784891 3.961348 0.703600 16.377783 ... 9.006329 1303.702532 0.766800 48.420310 172.0 158.0 17.888621 11.482448 5 25.419915
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
23326 23327 42.196426 0.832806 1.086360 21.067722 15.171849 22.810752 2.518938 0.782610 10.996428 ... 2523.888889 1384.086420 0.761539 48.627417 176.0 162.0 17.973486 11.648970 23327 39.302940
23327 23328 131.577919 0.681751 2.114098 14.009404 27.114338 7.196972 3.672131 0.731182 14.524425 ... 2522.047059 1438.568627 0.722578 62.077164 274.0 255.0 21.866422 15.116015 23328 21.348579
23328 23329 23.747778 0.408855 1.126567 10.593432 12.509655 11.727330 0.997654 0.661402 6.655545 ... 2523.924051 1349.759494 0.707599 32.970563 86.0 79.0 12.134167 8.574169 23329 30.741303
23329 23330 41.325382 0.150234 0.593619 12.793476 10.459755 14.349209 1.055391 0.741730 3.383721 ... 2525.234694 1420.387755 0.760933 35.556349 103.0 98.0 13.929985 9.038195 23330 47.744243
23330 23331 11.522727 0.147727 0.000000 1.170455 1.238636 1.397727 0.136364 0.886364 4.488636 ... 2535.886364 231.329545 0.472365 31.556349 89.0 88.0 11.256538 9.921544 23331 1.886364

23331 rows × 70 columns

# save df to csv
df.to_csv("/home/timkempchen/Downloads/tonsil2_mesmer_result_compensated.csv", index=False)