简体   繁体   中英

How do I clamp the magnitude of a PyTorch tensor?

I know I can use torch.clamp to clamp a tensor's values within some min / max, but how can I do this if I want to clamp by the magnitude (absolute value)? Example:

import torch
t = torch.tensor([-5.0, -250, -1, 0.003, 7, 1238])
min_mag = 1 / 10
max_mag = 100
# desired output:
tensor([  -5.0000, -100.0000,   -1.0000,    0.1000,    7.0000,  100.0000])

Here's one method:

sign = t.sign()
t = t.abs_().clamp_(min_mag, max_mag)
t *= sign

(note: this is using in-place operations)

One way is to multiply the sign of elements by the clamped version of the absolute elements as follows:

output = torch.sign(t) * torch.clamp(torch.abs(t), min_mag, max_mag)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM