Welcome to Jiyuniverse

[pytorch] squeeze(), unsqueeze() 함수 본문

ML, DL

[pytorch] squeeze(), unsqueeze() 함수

JJiiyun 2025. 2. 7. 17:22
squeeze()

 

크기가 1인 차원을 자동으로 제거하는 함수이다.

x = torch.rand(1, 1, 20, 128)
x = x.squeeze() # [1, 1, 20, 128] -> [20, 128]
x2 = x.squeeze(dim=1) # [1, 1, 20, 128] -> [1, 20, 128]

 

batch_size = 1 일 경우 위의 코드에서 x와 같이 제거될 수 있기 때문에 dimension을 지정해주어야 한다.

 

unsqueeze()

 

주어진 차원(dim)에 길이가 1인 차원을 추가해주는 함수이다. 차원을 하나 늘려서 텐서의 형태를 바꾼다고 생각하자.

 

import torch

x = torch.tensor([10, 20, 30])
print(x.shape)  # torch.Size([3])

# dim=0에 차원 추가
y = x.unsqueeze(dim=0)
print(y.shape)  # torch.Size([1, 3])

# dim=1에 차원 추가
z = x.unsqueeze(dim=1)
print(z.shape)  # torch.Size([3, 1])

 

 

x의 원래 shape는 [3] (1차원 텐서).

x.unsqueeze(dim=0)을 하면 dim=0 방향으로 (길이가 1인) 차원이 추가되어 [1, 3]이 됨

x.unsqueeze(dim=1)을 하면, dim=1 방향에 차원이 추가되어 [3, 1]이 됨

 

 

unsqueeze()는 배치 차원을 추가하거나 브로드캐스팅 시에 주로 사용된다.

 

 

1. 배치 차원 추가

 

배치 차원이 없는 텐서에 배치 차원을 추가한다.

img = torch.rand(3, 256, 256)  # (채널=3, 높이=256, 너비=256)
img.unsqueeze(0)   # shape: [1, 3, 256, 256]

 

 

2. 브로드캐스팅

 

연산 시 차원이 맞지 않을 경우 차원을 맞춰준다.

a = torch.rand(5, 1)
b = torch.rand(5)      # shape: [5]
b = b.unsqueeze(dim=1) # shape: [5, 1]
c = a + b              # shape: [5, 1]